摘要
现代语言模型(LMs)的能力日益强大,但在处理看似简单的多位数乘法任务时却屡屡碰壁。这项工作旨在揭示这一现象背后的深层原因。我们通过逆向工程一个借助“隐式思维链”(Implicit Chain-of-Thought, ICoT)成功学会乘法的模型,得出了三项关键发现。首先,我们找到了模型掌握长程依赖结构的关键证据:通过分析模型对输入的敏感度(logit attributions)和使用线性探针技术,我们证实了ICoT模型能够成功编码多位数乘法所必需的长距离数字间依赖关系,而标准微调(SFT)模型则不然。其次,我们揭示了其内部机制:ICoT模型利用其注意力头构建了一个有向无环图,这个结构如同一个高效的计算网络,能够“缓存”中间计算的两位数乘积(即“部分积”),并在后续步骤中精确地“检索”它们以合成最终答案。最后,我们探究了其特征的几何表示:模型将数字嵌入到一个傅里叶基构成的空间中,形成了一个独特的“五角棱镜”结构,并且其注意力头通过计算数字嵌入的“闵可夫斯基和”来实现部分积的运算。这些都是标准微调模型所缺乏的、直观且高效的表征方式。基于这些洞见,我们重新审视了标准微tuning的学习动态,发现它会陷入一个局部最优解,无法建立必要的长程依赖。为验证这一理解,我们设计了一种辅助损失函数,通过监督模型预测一个“累加和”来引入归纳偏置,成功地引导标准模型学会了多位数乘法。总而言之,通过解构一个成功的模型,我们不仅阐明了Transformer在学习长程依赖任务时的一个关键陷阱,也提供了一个实例,证明了正确的归纳偏置能够如何克服这一难题。
大家好,我是这篇研究的作者之一。今天,我想带大家进行一次深入的旅行,探索一个看似矛盾却又极其迷人的问题:为什么像GPT-4这样强大的AI,能写诗、能编程,却在小学乘法上“翻车”?这不仅仅是一个关于计算的问题,它更像一扇窗,让我们得以窥见这些复杂模型的“心智”世界,理解它们学习、思考乃至“犯错”的方式。
我们的故事始于一个简单的观察。当你给一个标准的、经过海量文本训练的Transformer模型(我们称之为SFT模型)一堆乘法题,比如 \(1234 \times 5678\),并告诉它正确答案,希望它能学会这个算法时,结果往往令人失望。它的准确率低得可怜,甚至连拥有数百亿参数的庞然大物也难以幸免。然而,我们发现,一种叫做“隐式思维链”(ICoT)的特殊训练方法,却能奇迹般地让模型100%掌握这项技能。这就像两个学生,一个死记硬背却不得要领,另一个则掌握了举一反三的学习方法。我们的核心任务,就是解剖这个“优等生”的大脑,看看它到底做对了什么。
第一幕:乘法的“灵魂”——长程依赖
要理解模型的失败,我们首先要理解乘法本身的挑战。它不是一个简单的“看一眼就出答案”的任务。计算一个多位数乘法的最终结果,尤其是中间的几位数,需要你综合考虑几乎所有输入数字之间的相互作用。让我们以一个简单的 \(a_1a_0 \times b_1b_0\) 为例:
\[ c_1 = ( (a_1 b_0 + a_0 b_1) + \lfloor \frac{a_0 b_0}{10} \rfloor ) \pmod{10} \]看到吗?为了得到结果的第二位数 \(c_1\),你不仅需要计算 \(a_1 \times b_0\) 和 \(a_0 \times b_1\),还需要考虑来自前一位 \(a_0 \times b_0\) 的进位。对于一个4x4的乘法,要计算出结果的中间数字 \(c_3\) 或 \(c_4\),这个依赖网络会变得异常复杂,像一张错综复杂的蜘蛛网。我们把这种必须回顾和整合遥远输入信息才能做出正确决策的特性,称为“长程依赖”。这正是Transformer这类模型的“阿喀琉斯之踵”。
静态图1:4x4乘法算法的内在结构
这张图直观地展示了计算每一位输出(\(c_0\) 到 \(c_7\))所需的部分积(如 \(a_ib_j\))和进位(\(r_k\))。注意看,计算靠后的数字(如 \(c_4\))需要综合前面所有相关计算的结果,形成复杂的依赖链条。
那我们的“优等生”ICoT模型是如何驾驭这张网的呢?我们用了两种方法来“窃听”它的心声。第一种是Logit归因,通俗讲,就是我们微调输入的一个数字,比如把1234的‘3’改成‘2’,然后观察模型对最终答案每一位的预测信心(logit)有多大变化。结果惊人:ICoT模型表现得像一个专业的会计,当输入数字 \(a_i\) 或 \(b_j\) 改变时,所有依赖于它们的输出位 \(c_k\)(其中 \(k \ge i+j\))的信心都会相应调整。而SFT模型则像个近视眼,一个数字的改变,通常只会影响到它附近一两位的结果,远处的依赖关系它完全“看不见”。
动画1:长程依赖 vs. 短视依赖
这个动画模拟了信息流在两种模型中的传播。左侧的ICoT模型中,来自输入(底部)的信息粒子能通过复杂的路径网络到达所有相关的输出(顶部)。右侧的SFT模型中,信息粒子很快就“迷路”或消散了,无法建立远距离连接。点击“开始/暂停”观察动态,点击“重置”重新开始。
状态: 待开始
第二种方法是线性探针。我们假设,如果模型真的理解了乘法,那么在它计算出 \(c_k\) 之前,其内部的某个神经激活状态(我们称之为隐藏态)里,一定已经算好了一个中间值,我们把它记为 \(\hat{c}_k = s_k + r_{k-1}\),它包含了计算 \(c_k\) 所需的全部信息。于是,我们训练了一个简单的线性“读心器”(探针),试图从模型的隐藏态中直接读出这个 \(\hat{c}_k\) 值。结果再次证实了我们的猜想:从ICoT模型的隐藏态中,我们可以非常精确地解码出 \(\hat{c}_k\),而从SFT模型中读出的则是一堆噪声。这雄辩地证明了,ICoT模型确实在内部构建了正确的、长程的计算结构。
第二幕:注意力的“杂技”——构建动态计算图
那么,ICoT模型究竟是怎样用它的“神经元”——也就是Transformer的注意力机制——来搭建这个复杂的计算网络的呢?答案既优雅又巧妙:它学会了构建一个动态的、树状的注意力图。
想象一下,模型在计算的每一步,都需要从输入数字中挑选出正确的一对来进行相乘(例如,计算 \(s_2\) 时需要 \(a_2b_0, a_1b_1, a_0b_2\))。ICoT模型的做法是:
- “缓存”部分积:在第一层网络中,不同的注意力头会各自负责盯住一对输入数字(比如一个头盯住 \(a_i\) 和 \(b_j\))。然后,它会把这两个数字相乘的“想法”(即计算结果)“写入”到序列中某个较早的位置(比如某个分隔符 ‘%’ 的位置)的隐藏态里。这就像在草稿纸上记下了一个个两位数的乘积,以备后用。
- “检索”与求和:在第二层网络中,当需要计算最终结果的某一位 \(c_k\) 时,注意力头们会非常精准地回头去“看”那些之前缓存了相关部分积(所有满足 \(i+j=k\) 的 \(a_ib_j\))的“草稿纸”位置。它把这些信息收集起来,再加上前一步的进位信息,最终计算出 \(c_k\)。
动画2:注意力树的构建
这是一个动态的图可视化。节点代表输入/输出数字和中间的“缓存”位置。点击下方的输出数字(如 c₂),动画会高亮显示为了计算它,模型的注意力(连线)是如何在第一层(紫色)选择输入对进行“缓存”,又如何在第二层(青色)从缓存中“检索”信息。这是一个高度简化的、概念性的演示。
当前计算: c₂
整个过程就像一个训练有素的杂技团,成员们(注意力头)通过精确的抛接(注意力的读写),将信息在正确的时间传递到正确的位置。而SFT模型则像一群毫无章法的演员,注意力分散,无法形成有效的协作,自然也就无法完成复杂的计算任务。
第三幕:特征的“几何学”——从闵可夫斯基和到五角棱镜
如果说注意力机制是模型的“算法”,那么它用来表示数字的“语言”——也就是特征的几何结构——则同样充满了智慧。我们发现,ICoT模型在两个层面展现了惊人的数学之美。
部分积的几何表示:闵可夫斯基和
当一个注意力头只关注两个数字 \(a_i\) 和 \(b_j\) 的嵌入向量时,它的输出向量在几何上会形成一个有趣的结构,称为闵可夫斯基和。简单来说,就是将代表 \(a_i\) 的所有可能向量和代表 \(b_j\) 的所有可能向量进行两两相加。如果我们将所有数字的嵌入向量用PCA降维到三维空间,会发现它们各自形成一簇点云。当模型计算 \(a_i \times b_j\) 时,它实际上是在这个高维空间中,将 \(a_i\) 的点云簇和 \(b_j\) 的点云簇进行“几何叠加”,形成一个新的、更大的点云簇,这个新簇的位置和形状就编码了乘积的信息。更神奇的是,这种表示是嵌套的:在一个代表 \(b_1\) 值的“大簇”内部,又会整齐地排列着代表 \(a_0, a_1, a_2, ...\) 的“小簇”,结构井然有序。
动画3:探索闵可夫斯基和
拖动滑块选择两个数字 \(a_i\) 和 \(b_j\)。左侧和中间的视图分别展示了代表这两个数字的点云簇(为简化,每个数字用一组随机点表示)。右侧视图展示了将这两个点云簇进行闵可夫斯基和运算后得到的结果。观察结果簇的形状和位置如何随着输入数字的变化而变化。
数字的宇宙:傅里叶基与五角棱镜
最令人震撼的发现,来自于我们对模型最终隐藏层输出的几何分析。当我们把模型在输出每一位结果前一刻的“思考”(即隐藏态向量)进行PCA降维可视化时,一幅奇异的景象出现了:代表0到9这十个数字的点,竟然在三维空间中完美地排列成一个五角棱镜!
这并非巧合。通过进一步的数学分析,我们发现模型自发地学会了使用傅里叶基来表示数字。这就像音乐家用不同频率的正弦波和余弦波来合成复杂的音色一样,模型用一组特定的周期性函数来“描绘”数字。具体来说:
- 第一主成分(PC1),也就是最重要的那个轴,编码了数字的奇偶性。所有的偶数(0, 2, 4, 6, 8)被投射到棱镜的一端,所有的奇数(1, 3, 5, 7, 9)则在另一端。
- 第二和第三主成分(PC2, PC3)则由一组频率为2的傅里叶基(即 \(\cos(2\pi n/5)\) 和 \(\sin(2\pi n/5)\))构成。这使得偶数和奇数各自在一个完美的正五边形上排列。
于是,两个平行的、分别由偶数和奇数构成的五边形,沿着代表奇偶性的轴堆叠起来,就形成了这个美丽的五角棱镜。这种表示方式极其高效,因为它利用了数字运算中的周期性和对称性,是模型为了解决乘法问题而自发演化出的最优“编码语言”。而SFT模型的数字表示则是一片混乱,毫无结构可言。
动画4:旋转的数字棱镜
这是一个可交互的3D场景。你可以用鼠标拖动来旋转这个由数字0-9构成的五角棱镜。观察奇数和偶数如何分别构成两个平行的五边形。这种结构揭示了ICoT模型对数字内在属性的深刻理解。
第四幕:打破僵局——用归纳偏置“点化”愚者
既然我们已经洞悉了“优等生”的秘密,那么能否用这些知识去帮助那个“差生”SFT模型呢?答案是肯定的。我们发现SFT模型学习失败的根本原因在于,它的学习过程陷入了局部最优。在训练初期,它很快就学会了计算最简单、依赖关系最近的几位(如 \(c_0, c_1\) 和 \(c_7\))。因为这些“简单题”能快速降低损失函数,模型就满足于此,不再有动力去探索和建立计算中间数字所需要的复杂长程依赖结构。梯度(学习信号)在中间数字上始终无法有效地传播,导致损失居高不下,学习停滞不前。
静态图2:ICoT训练方法的精髓
ICoT的训练并非一步到位。它开始时会给模型展示完整的计算步骤(思维链),然后像一位耐心的老师,在每个阶段(Epoch)逐渐擦掉一些步骤,迫使模型将这些计算过程“内化”到自己的参数中,最终在没有任何外部辅助的情况下独立完成任务。
为了打破这个僵局,我们设计了一个简单的“外挂”:一个辅助损失函数。我们不再仅仅要求模型预测最终的正确答案,而是增加了一个额外的任务:在计算每一步 \(c_k\) 时,都必须通过一个线性探针准确地预测出我们之前提到的那个关键中间值 \(\hat{c}_k\)。这相当于我们给了模型一个明确的“路标”,告诉它:“嘿,别只盯着最终结果,你必须先把这些中间步骤想清楚!”
这个小小的改动,我们称之为引入了正确的归纳偏置,效果立竿见影。被“点化”了的SFT模型,学习动态发生了根本性的改变。它不再短视地卡在局部最优,而是被迫去建立那些对于计算 \(\hat{c}_k\) 至关重要的长程依赖。最终,它成功地学会了4x4乘法,准确率达到了99%!
动画5:学习的僵局与突破
此动画展示了两种训练方式下,模型对各个输出位 \(c_k\) 的损失(错误程度)随时间的变化。左侧是标准SFT,你会看到中间几位(黄色、绿色)的损失一直降不下去。右侧是增加了辅助损失的模型,所有位的损失最终都平稳地收敛到零。这是一个加速和概念化的过程。
SFT 中间位损失: 高 | 辅助损失模型 中间位损失: 高
终章:超越乘法,洞见未来
所以,为什么Transformer学不会乘法?我们的旅程给出了答案:标准的训练方法(自回归损失+梯度下降)并不足以引导模型自发地发现和构建解决此类任务所需的复杂、长程的算法结构。模型容易被“捷径”所迷惑,陷入无法捕捉全局依赖的局部最优解。
我们的研究,通过解剖一个成功的ICoT模型,揭示了模型内部实现复杂算法所需的三大支柱:正确的长程依赖结构、高效的动态计算图(注意力树)、以及优雅的几何特征表示(傅里叶基与五角棱镜)。更重要的是,我们证明了通过引入恰当的归纳偏置(如我们的辅助损失),可以有效地引导模型克服这些学习障碍。
这不仅仅是关于乘法的故事。它揭示了当前AI模型在进行严谨、多步推理时面临的普遍挑战。未来,要让AI真正成为可靠的思考者和问题解决者,我们或许需要更多类似ICoT或者辅助损失这样的方法,去“教”会它们如何思考,而不仅仅是模仿。这趟深入模型“心智”的旅程,让我们对人工智能的未来充满了更多的期待与思考。
静态图3:依赖结构的“X光片”
这张图简化展示了Logit归因的结果。热力图的每个格子代表改变某个输入数字(纵轴)对某个输出数字(横轴)预测信心的影响强度。ICoT模型(右)呈现出清晰的、沿对角线扩展的结构,证明了其建立了正确的长程依赖。SFT模型(左)的影响则非常局部化,结构杂乱。
技术细节附录
关于ICoT训练
“隐式思维链”(ICoT)训练是一种知识蒸馏技术。其核心思想是,先让一个强大的教师模型(或通过规则生成)为每个问题提供详细的、分步骤的解答,即“显式思维链”(CoT)。然后,我们用这些带有详细过程的数据来训练一个学生模型。关键在于,训练过程是分阶段的:在每个阶段,我们都会从思维链的开头移除一部分步骤,迫使学生模型必须在没有这些外部提示的情况下,自己“脑补”出这些计算过程,并将它们“内化”到模型的参数中。经过多个阶段的训练,最终思维链被完全移除,学生模型学会了仅根据原始问题就独立地、在内部完成所有必要的推理步骤,从而得到正确答案。
傅里叶基作为数字表征
模型选择傅里叶基来表示数字0-9是一种非常高效的策略。一个数字 \(n\) 的向量表示 \(v_n\) 可以被看作是多个基础周期函数的线性组合: \[ v_n = \sum_{k} C_k \cdot \Phi_k(n) \] 其中 \(\Phi_k(n)\) 是傅里叶基函数,例如 \(\cos(2\pi kn/10)\) 或 \(\sin(2\pi kn/10)\),而 \(C_k\) 是对应的系数向量。我们发现,模型主要使用了 \(k=0, 1, 2, 5\) 这几个频率。
- \(k=0\) 对应直流分量(一个常数),提供了基础的偏移。
- \(k=5\) 对应 \(\cos(\pi n) = (-1)^n\),完美地编码了奇偶性。
- \(k=2\) 对应 \(\cos(2\pi n/5)\) 和 \(\sin(2\pi n/5)\),这个基函数对模5运算很敏感,这在十进制算术中非常有用,因为它能捕捉到“逢五进一”的规律。
这种分解方式让模型能够轻易地通过线性操作来提取数字的算术属性(如奇偶性、模5的余数等),为后续的乘法和加法运算提供了极大的便利。这证明了模型在无监督的情况下,能够发现并利用深刻的数学结构来优化其内部表征。
辅助损失函数
我们引入的辅助损失函数 \(\mathcal{L}_{aux}\) 与原始的语言模型损失 \(\mathcal{L}_{LM}\) 相结合,形成最终的总损失 \(\mathcal{L}\): \[ \mathcal{L} = \mathcal{L}_{LM} + \lambda \mathcal{L}_{aux} \] 其中 \(\lambda\) 是一个超参数,用于平衡两个任务的重要性。辅助损失的具体形式是一个均方误差(MSE)损失: \[ \mathcal{L}_{aux} = \frac{1}{N} \sum_{i=0}^{N-1} (w^T h_i - \hat{c}_i)^2 \] 这里,\(h_i\) 是模型在预测第 \(i\) 位输出时的隐藏状态,\(w\) 是一个可训练的线性探针权重,而 \(\hat{c}_i\) 是我们期望模型计算出的真实中间值(部分积与前一进位的和)。通过最小化这个损失,我们强制模型在每个时间步的隐藏状态中都必须线性可分地表示出正确的中间计算结果,从而引导它学习到正确的算法路径,而不是陷入只关注表面预测的局部最优解。