为什么TRANSFORMER学不会乘法?逆向工程揭示了长程依赖的陷阱

作者: Xiaoyan Bai∗1, Itamar Pres∗2, Yuntian Deng3, Chenhao Tan1, Stuart Shieber4, Fernanda Viégas4,5†, Martin Wattenberg4,5†, Andrew Lee4

机构: 1芝加哥大学 2麻省理工学院 3滑铁卢大学 4哈佛大学 5Google DeepMind


摘要

语言模型的能力日益增强,但仍然在看似简单的多位数乘法任务上失败。在这项工作中,我们通过逆向工程一个通过隐式思维链成功学会乘法的模型,来研究其原因,并报告了三个发现:(1)长程依赖结构的证据:Logit归因和线性探针表明,该模型编码了多位数乘法所必需的长程依赖关系。(2)机制:模型利用注意力机制构建一个有向无环图来“缓存”和“检索”成对的部分积,从而编码长程依赖关系。(3)几何学:模型通过在注意力头中形成数字对之间的闵可夫斯基和来实现部分积,并且数字使用傅里叶基进行表示,这两种表示都是直观且高效的,而标准的微调模型缺乏这些表示。基于这些见解,我们重新审视了标准微调的学习动态,发现模型收敛到了一个缺乏所需长程依赖的局部最优解。我们通过引入一个辅助损失来进一步验证这一理解,该损失通过线性回归探针预测“累加和”,这提供了一种归纳偏置,使模型能够成功学习多位数乘法。总之,通过逆向工程隐式思维链模型的机制,我们揭示了Transformer学习长程依赖的一个陷阱,并提供了一个例子,说明正确的归纳偏置如何解决这个问题。

摘要解读(高三版):

想象一下,你正在教一个非常聪明的AI(语言模型)做数学题。这个AI能写诗、能聊天,但你让它算个三位数乘以三位数,它就算不对了。这篇论文就在研究这件怪事:为什么AI连小学的乘法都学不会?

为了找到答案,研究者们找来了一个“学霸AI”,这个AI用一种叫做“隐式思维链(ICoT)”的特殊方法,竟然学会了乘法。他们把这个“学霸AI”给“解剖”了(也就是“逆向工程”),想看看它的大脑里到底发生了什么。结果有三大发现:

  1. “学霸AI”会放长线,钓大鱼: 它懂得做乘法需要“全局观”。比如算 $123 \times 456$,个位数 $3 \times 6$ 的结果会影响到十位数的计算(因为有进位)。这种前后关联,就叫“长程依赖”。“学霸AI”成功地建立了这种关联。
  2. “学霸AI”有个聪明的草稿系统: 它使用了一种叫“注意力机制”的工具,在脑子里画了一张“流程图”(有向无环图)。这张图能帮它把每两个数字相乘的“部分积”(比如 $3 \times 6=18, 2 \times 6=12$ 等)先算出来存着(“缓存”),等需要的时候再取出来用(“检索”)。
  3. “学霸AI”对数字有独特的几何理解: 它看待数字的方式很特别。它把计算两个数的部分积,想象成在几何空间里把两个图形叠加起来(“闵可夫斯基和”)。而且,它还用一种类似声波分解的方式(“傅里叶基”)来表示数字0到9。这两种方法都非常高效,而那个学不会乘法的“学渣AI”完全没掌握。

搞懂了“学霸”的秘诀后,研究者们回头去看那个“学渣AI”,发现它之所以学不会,是因为它陷入了“局部最优”的陷阱——就像一个学生只顾着眼前的一步,没有长远规划,导致学得一塌糊涂。最后,研究者们给“学渣AI”加了一个“辅助轮”(辅助损失),强制它在计算过程中随时汇报当前的“累加和”,这等于给了它一个正确的引导(“归纳偏置”),结果“学渣AI”也奇迹般地学会了乘法!

核心结论:这篇论文告诉我们,不是Transformer笨,而是常规的训练方法很容易让它在需要“长远规划”的任务上“跑偏”。只要给它一点正确的引导,它就能克服这个困难。

1. 引言

大型语言模型在推理、规划和工具使用方面展现出惊人的能力。然而,它们在一些出奇简单的算法任务上却会失败(Nye et al., 2021; Lee et al., 2023)。为什么Transformer在某些任务上表现出色,却学不会另一些任务?多位数乘法就是这样一个例子。尽管拥有数十亿参数,像Llama-3.2 90B或GPT4这样的模型在4x4位数乘法上仍然会失败(Gambardella et al., 2024),即使在对该任务进行显式微调时也是如此(Yang et al., 2023)。为什么Transformer学不会乘法?

引言解读 (第一段):

这里开门见山地指出了一个矛盾现象:现在的AI模型(比如大家熟知的GPT-4)非常强大,能帮你写代码、做计划,能力超群。但奇怪的是,你让它做一个小学生都会的四位数乘法,它却可能算错。这就好比一个大学教授,能解微积分方程,却总是搞错 $15 \times 15$ 等于多少。这个问题很反常,即使你专门拿一大堆乘法题给它“补课”(微调),效果也不好。作者们因此提出了本文的核心疑问:这到底是为什么?

我们通过对比一个标准的微调模型(SFT),它在乘法上失败了,和一个用隐式思维链(ICoT)训练的模型,它成功了,来研究这些问题(Deng et al., 2024; 2023)。ICoT的工作原理是在训练期间提供显式的思维链通证,但逐渐将它们移除,从而迫使模型在其潜在状态中内化中间步骤。

引言解读 (第二段):

为了找到答案,研究者设计了一个对比实验,就像在学校里比较“学霸”和“学渣”的学习方法一样。

通过比较这两个模型的内部差异,作者们希望能揭示学会乘法的关键所在。

我们部分逆向工程了ICoT模型,并揭示了几个见解。首先,与SFT模型不同,ICoT模型学习到了多位数乘法所需的正确的长程结构。我们使用logit归因和线性回归探针为此提供了证据。在机制上,ICoT模型通过将其注意力组织成一个稀疏的、类似二叉树的图来编码长程依赖,该图(i)选择正确的数字对来计算部分积,以及(ii)将这些中间计算“缓存”到较早的通证中以便后续检索。最后,在几何学上,注意力头将部分积实现为数字嵌入的闵可夫斯基和,并用傅里叶基表示数字,形成一个五角棱镜结构——这两种表示都是直观且高效的,而SFT模型缺乏这些表示。

引言解读 (第三段):

这段是对“解剖学霸”过程的总结,也就是摘要里提到的三大发现的浓缩版。首先,“学霸”的大脑里形成了正确的“全局观”(长程结构)。其次,从运行机制上看,“学霸”脑中有一套高效的“草稿系统”(注意力图),能精确地计算和存储每一步的小结果(部分积),就像你在草稿纸上列竖式一样。最后,从更抽象的数学层面看,“学霸”对数字的表示方法本身就很高级,它把数字和计算过程看作是几何图形的变换,这种方式非常适合解决乘法问题。而“学渣”在这些方面都是一片空白。

基于这些见解,我们重新审视了标准微调的动态过程:在梯度下降和自回归损失下,模型从未学会这些长程依赖,因此损失在中间数字上停滞不前。为了证实我们的理解,我们引入了一个简单的修复方法,即引入一个辅助损失,通过一个轻量级的线性回归探针来监督模型预测“运行中的部分和”。这提供了一个归纳偏置来学习适当的长程依赖,使其能够达到完美的准确率,而无需任何来自思维链的监督。

引言解读 (第四段):

搞清楚了“学霸”的秘诀后,作者们再去看“学渣”的学习过程,发现它在学习时总是卡在计算结果的中间几位数上,因为那里最需要“全局观”。于是,他们想出了一个“补救措施”:在训练“学渣”时,不仅要求它给出最终答案,还要求它每算一步,都汇报一下当前的“累加和”(就像老师检查你的竖式计算,让你念出每一步的结果)。这个额外的小要求,就像一个强制性的“学习引导”(归纳偏置),迫使“学渣”不得不去考虑前后步骤的关联。神奇的是,这个简单的方法非常有效,“学渣”最终也考了100分,而且全程没再给他看过详细的解题步骤!

总之,通过部分逆向工程一个成功实现多位数乘法的网络,我们揭示了它如何实现长程依赖,这是不成功的模型所缺乏的一种机制。我们的工作突显了Transformer在使用梯度下降和自回归损失学习长程依赖方面面临的挑战。虽然我们展示了一种特定于任务的归纳偏置来解决这个问题,但我们预计会有通用的改进来解决这一限制。

引言解读 (第五段):

最后是一个总结。这项研究的核心贡献是,通过“解剖”一个成功的模型,我们找到了Transformer学习乘法这类需要长远规划任务的“命门”所在。目前的标准学习方法(梯度下降和自回归损失)存在天然的缺陷,容易让模型“鼠目寸光”。虽然本文提出的“汇报累加和”的方法是针对乘法这个特定任务的,但它揭示了一个更深层的问题。作者们希望,未来能有更通用的方法,让AI模型天生就具备这种“长远眼光”,而不需要针对每个任务都设计特殊的“辅助轮”。

2. 实验设置,训练ICoT,符号说明

任务,模型。 我们感兴趣的是理解使用标准微调和ICoT训练的模型之间的差异。从实验中,我们发现标准微调失败但ICoT成功的最简单的多位数乘法是4×4位数乘法。同样,ICoT能成功的最小的架构是一个有4个注意力头的2层模型。因此,我们仔细研究了一个2层4头的ICoT模型和一个在4×4乘法上训练的标准微调模型。

实验设置解读 (第一段):

为了让问题尽可能简单清晰,研究者们选择了“最简单”的失败案例来进行研究。他们发现,对于4位数乘4位数这个问题,普通的训练方法就不行了,而ICoT方法还能搞定。同时,他们也用了最精简的模型配置(2层,4个注意力头)。这就像在物理实验中控制变量,通过研究这个“临界点”,可以最有效地发现问题的本质,避免被其他复杂的因素干扰。

训练过程。 我们的ICoT设置与Deng等人(2024)的相同。这里我们提供一个ICoT的非正式概述,细节在附录A.1。即,假设有两个操作数 $a = (a_3, a_2, a_1, a_0)$,$b = (b_3, b_2, b_1, b_0)$ 和它们的积 $c = (c_7 \dots c_0)$。操作数是最低有效位优先写入的,与其他算法设置类似(Deng et al., 2024; 2023; Lee et al., 2023)。对于ICoT,训练数据包括中间的思维链(CoT)通证 $q_i$,这些通证明确记录了逐步的计算过程。举一个简单的例子,考虑12×34。两个等号之间的通证遵循我们4×4位数乘法任务中使用的相同CoT格式:

12 * 34 = 4812*4 + 36012*30 (408)running sum = 408

在每个训练周期,从链的左侧移除固定数量的CoT通证。具体来说,每个周期的训练样本可能具有以下形式:

(周期 1) a0a1a2a3 * b0b1b2b3 %%% q0 ... qi ... qj ... qk ... qτ #### c0 ... c7
(周期 2) a0a1a2a3 * b0b1b2b3 %%% qi ... qj ... qk ... qτ #### c0 ... c7
(周期 3) a0a1a2a3 * b0b1b2b3 %%% qj ... qk ... qτ #### c0 ... c7
...
(周期 N) a0a1a2a3 * b0b1b2b3 %%% #### c0 ... c7

其中 $q_i$ 是CoT通证,而 % 和 # 是特殊的分隔符。注意,在每个周期之后,模型会通过截断一些通证看到一个更短的链,到最后,只剩下操作数和最终答案。作为对比,标准微调只在操作数上训练:a0a1a2a3 * b0b1b2b3 %%% #### c0 ... c7。

有趣的是,ICoT模型能够在4×4位数乘法上达到100%的准确率,而标准微调仅达到不到1%的准确率。注意,扩大模型规模没有帮助——扩展到12层8头的模型也达到了同样的<1%的准确率,并且Yang等人(2023)表明,微调一个2B参数的模型仍然在95%的准确率上停滞不前。

有关训练的更多细节(数据格式、样本大小、超参数),请参见附录A。

训练过程解读:

这段详细解释了“学霸”(ICoT)的“渐进式学习法”。首先要注意,为了方便计算,数字都是反着写的,比如123会写成`321`,这和我们列竖式时个位对齐的习惯是一致的。

ICoT的训练过程就像一个逐渐放手的老师:

而“学渣”(SFT)的训练过程就简单粗暴多了,从头到尾都只给他看题目和答案。结果显而易见:“学霸”ICoT最终完全掌握了乘法,达到了100%的准确率;而“学渣”SFT几乎没学会,准确率不到1%。更重要的是,给“学渣”换个更大的脑袋(增加模型参数)也没用,说明问题不出在“脑容量”上,而是出在“学习方法”上。

符号说明。 $h^ℓ_t$ 表示在层 $ℓ$ 时间步 $t$ 的隐藏状态。对于解的通证 $c_k, k = [0, \dots, 7]$ 的时间步记为 $t_{c_k}$。$ATT^ℓ_h(\cdot), MLP^ℓ(\cdot)$ 表示在层 $ℓ$、头索引 $h$ 的注意力头或MLP块的输出。$E, U \in R^{V \times d}$ 表示(非)嵌入权重。

符号说明解读:

这里是论文中会用到的一些数学符号的“词汇表”,帮助你理解后面的公式。你可以把它看作是物理课本开头对 $v$ (速度), $a$ (加速度), $t$ (时间) 等符号的定义。这里面最重要的可能是 $h^ℓ_t$,它代表了模型在计算过程中的“瞬时记忆”或“思考状态”。

3. 比较ICOT与SFT的机制

3.1 多位数乘法中的长程依赖

这里我们讨论如何解决多位数乘法,以及解决乘法所需的必要长程依赖。

一种计算每个数字 $c_k$ 的方法如下:

$$s_k \triangleq \sum_{i+j=k} a_i b_j \quad \text{, (部分积之和)}$$ $$c_k = (s_k + r_{k-1}) \pmod{10}, \quad r_k = \lfloor \frac{s_k + r_{k-1}}{10} \rfloor \quad \text{, (处理进位)}, \quad r_{-1} = 0 \quad (1)$$

1. 直觉目的:这个公式就是我们小学学的列竖式乘法的数学表达。它描述了如何计算出结果的每一位数,以及如何处理进位。

2. 符号释义

3. 逻辑骨架:这个计算分为两步。第一步,计算 $s_k$,也就是所有能凑成第 $k$ 位的部分积的总和。第二步,把这个总和 $s_k$ 加上来自前一位的进位 $r_{k-1}$,然后用这个新总和来确定当前位 $c_k$ (取余数) 和要传给下一位的进位 $r_k$ (取商)。

4. 关系网络

a₀, b₀ a₁, b₀ a₀, b₁ ... s₀ = a₀b₀ s₁ = a₁b₀+a₀b₁ r₋₁ = 0 r₀ + ĉ₀ c₀ c₁ 输入数字对 部分积之和 中间值 & 进位 最终结果

图1: 乘法中的长程依赖。要计算出当前位的结果(如 $c_1$),必须依赖于之前所有的部分积(如 $s_0$, $s_1$)和进位(如 $r_0$)。这个信息流形成了一个复杂的依赖网络。

注意,$c_k$ 和 $r_k$ 都可以用一个中间项 $\hat{c}_k$ 来表示,它封装了来自部分积和进位的相关信息:

$$\hat{c}_k \triangleq s_k + r_{k-1}, \quad c_k = \hat{c}_k \pmod{10}, \quad r_k = \lfloor \frac{\hat{c}_k}{10} \rfloor \quad (2)$$

1. 直觉目的:这个公式是对上面公式(1)的一个简化和提炼。它告诉我们,其实没必要分两步走,我们可以先把“部分积之和”和“上一位的进位”加起来,得到一个“中间总和” $\hat{c}_k$。然后,从这个“中间总和”里,我们可以非常轻松地同时得到“当前位该写几”($c_k$)和“要进到下一位的是几”($r_k$)。

2. 符号释义

3. 逻辑骨架:这个公式的逻辑更清晰了。第一步:计算中间总和 $\hat{c}_k$。第二步:对 $\hat{c}_k$ 进行取余和取商操作,一步到位,直接得到 $c_k$ 和 $r_k$。这就像你在心算时,会先算出“这一列加起来总共是18”,然后再想“写8,进1”。这个“18”就是 $\hat{c}_k$。

4. 关系网络

重要的是,注意多位数乘法所需的长程依赖。具体来说,我们强调两个观察结果:(i)要确定 $c_k$,必须使用所有的部分积 $\{a_i b_j | i+j \le k\}$,因为所有这些项都对 $c_k$ 有贡献。(ii)知道中间项 $\hat{c}_k$ 就足以计算 $c_k$ 并为后续数字传播必要的信息。因此,我们使用 $\hat{c}_k$ 作为每个时间步 $t_{c_k}$ 的探测签名(第3.2节),以检查模型是否正在利用所有必要的长程信息来预测正确的通证 $c_k$。

长程依赖解读:

这段话是对上面两个公式的总结,并引出了核心观点。这里的“长程依赖”是理解本文的关键。

它到底是什么意思呢?想象一下计算 $123 \times 456$。当你要计算结果的百位数时,你需要考虑哪些东西?

你看,为了确定“百位数”这一个数字,你需要追溯到最开始“个位数”的计算。信息需要从头传到尾,这种跨越很长距离的依赖关系,就是“长程依赖”。乘法是这种依赖关系的典型例子。

作者指出,中间值 $\hat{c}_k$ 是一个完美的“检验指标”。如果模型在预测第 $k$ 位时,它脑子里已经算出了正确的 $\hat{c}_k$,那就证明它确实考虑了所有应该考虑的历史信息。反之,如果它连 $\hat{c}_k$ 都算错了,那它肯定没有建立起正确的长程依赖,只是在瞎猜。

在接下来的章节中,我们将展示ICoT模型如何满足这种长程依赖,而标准微调模型则没有。

后续内容预告:

这是一个承上启下的段落。作者已经把舞台搭建好了:定义了问题(长程依赖),找到了检验标准($\hat{c}_k$)。接下来,就要开始真正的“对比实验”,看看“学霸”和“学渣”在处理这个核心问题上到底有什么不同。

3.2 ICOT中长程依赖的证据

我们首先展示两条证据,证明ICoT模型满足多位数乘法中的长程依赖,而标准微调模型则不满足。

Logit归因。 从图1中注意到,数字 $a_i, b_i$ 只能影响 $k \ge i$ 的 $c_k$ 项。另请注意,在时间步 $t_{c_k}$,成对的乘积 $\{a_i b_j | i+j=k\}$ 对最终预测 $c_k$ 的影响最大。“更早”的成对乘积 $\{a_i b_j | i+j < k\}$ 仍然可以影响 $c_k$,但随着 $i+j$ 变小,影响会减弱。

长程依赖证据解读:

这里开始上第一条证据了,叫做“Logit归因”。这个词听起来很专业,但它的思想很简单,就是分析“功劳分配”。当模型要输出结果的第 $k$ 位数字 $c_k$ 时,我们想知道,输入的数字 $a_i, b_j$ 中,哪些对这个决策的“贡献”最大。

根据我们对乘法的理解,应该有两条规律:

  1. 相关性规律: 只有那些下标 $i, j$ 满足 $i+j \le k$ 的输入数字,才可能对 $c_k$ 有影响。比如算百位数 $c_2$,你用不到千位数 $a_3$。
  2. 影响递减规律: 对 $c_k$ 影响最大的一定是那些 $i+j=k$ 的数字对(直接贡献者)。那些 $i+j$ 更小的数字对,它们的影响只能通过“进位”来间接实现,所以影响会小一些。

接下来,作者们就要去检查,“学霸”ICoT模型是否遵循了这两条规律,而“学渣”SFT是否没有。