通过动态内存稀疏化(DMS)技术,赋予大语言模型更深邃的思考力
大家好,我是这篇研究的发起人之一。在我们的日常工作中,我们常常与大语言模型(LLM)打交道。你有没有过这样的体验:电脑上同时打开了几十个浏览器标签页,系统就变得奇慢无比?其实,LLM在进行复杂推理时,也面临着类似的“内存危机”。这个“内存”就是我们常说的KV缓存(Key-Value Cache)。
想象一下,你在解一道复杂的数学题。你需要记住题目中的所有条件、你尝试过的解法、以及每一步的中间结果。这些信息构成了你的“短期记忆”。对于Transformer架构的LLM来说,KV缓存就是它的短期记忆[1]。模型每生成一个新词(token),就需要回顾之前所有的记忆,这个过程就像你在草稿纸上反复查看之前的计算步骤。当推理链条变得很长,或者需要同时探索多种可能性时,这份“草稿纸”就会变得越来越长、越来越拥挤。
问题来了:存储和读取这份越来越大的“草稿纸”非常耗时,并且会迅速耗尽GPU宝贵的显存[3]。这就成了一个巨大的瓶颈,限制了模型进行更长、更深入思考的能力。我们称之为推理时缩放(Inference-time Scaling)的瓶颈。简单地增加计算预算,就像给你一张无限大的草稿纸,但如果你的读写速度跟不上,效率依然低下。
于是,我们开始思考一个有趣的问题:我们能否教会模型“智慧地遗忘”?不是粗暴地丢掉信息,而是像人一样,有选择地、动态地清理不再那么重要的记忆,从而为新的、更关键的思考腾出空间。这个想法,最终引导我们走向了“推理时超缩放(Inference-Time Hyper-Scaling)”以及我们为此设计的核心技术——动态内存稀疏化(Dynamic Memory Sparsification, DMS)[6]。这不仅仅是一次技术优化,更是一场赋予机器更高效、更深刻推理能力的革命。
首先,让我们直观地感受一下这个瓶颈。当LLM生成文本时,每一步都会产生一个新的Key和Value向量,并追加到KV缓存中。这个缓存的大小与生成序列的长度成正比。当模型需要生成长篇大论或者进行“思想链”(Chain of Thought)推理时,KV缓存会线性增长,导致两个主要问题:
这就像一个城市的交通系统。车辆(tokens)不断进入市中心(GPU),每辆车都需要一个停车位(KV缓存条目)。当停车位越来越多,新来的车要找到一个位置并与所有车通信(Attention计算)就变得异常缓慢,最终导致整个系统瘫痪。下面的动画生动地模拟了这一过程。
观察随着生成的Token(紫色小球)增多,KV缓存(右侧方块)被填满,新Token的生成速度(左上角指示器)如何显著下降。这直观展示了KV缓存如何成为性能瓶颈。
这就像你在准备一场大型考试前的复习。一开始,你只有几页笔记,回顾起来很快。但随着复习深入,你的笔记堆成了一座小山。每次想找一个知识点,你都得翻遍所有的笔记,速度自然就慢下来了。LLM的KV缓存就是这座“笔记山”。
面对这个“交通堵塞”,最直接的想法是“清理”掉一些不再需要的车辆。这就是稀疏化(Sparsification)的基本思想。但问题是,如何判断哪些信息“不再需要”?传统的无训练方法(training-free)通常依赖一些启发式规则,比如丢弃注意力得分最低的token。但这往往会“误伤”一些长期依赖的关键信息,导致模型性能下降[1]。
我们的DMS方法则更加聪明。我们不使用固定的规则,而是让模型在训练中自己学会一个动态的、自适应的“遗忘”策略。最关键的创新在于,我们引入了“延迟驱逐(Delayed Eviction)”机制。
这是什么意思呢?当模型在第 $t$ 步决定要“忘记”某个token时,它并不会立刻将它从缓存中删除。相反,这个token会被标记为“待删除”,并继续在缓存里停留一个固定的时间窗口(比如 $w$ 步)。直到第 $t+w$ 步,这个token才会被真正移除[1]。
这个“缓冲期”至关重要!它给了模型充足的时间来“吸收”这个即将被遗忘的token中的重要信息,并将其整合到后续的token表示中。这就像你决定要扔掉一本旧书,但在扔掉之前,你又快速翻了一遍,把里面的重点摘抄到了新的笔记本上。这样一来,既清理了空间,又没有丢失关键知识。
Token进入一个“滑动窗口”(半透明区域)。模型为每个Token做出决策(红色标记为“待驱逐”)。被标记的Token在窗口内依然可见,离开窗口后才被真正从缓存中移除。观察缓存大小如何被有效控制。
想象一下你的办公桌。你不会一用完文件就立刻扔进碎纸机。相反,你可能会把它放在“待处理”文件夹里。在这个星期内,你随时可以查阅它。一星期后,如果确实没用了,你再把它处理掉。DMS的“滑动窗口”就是这个“待处理”文件夹,它提供了一个安全、无损的整理过程。
“延迟驱逐”和“立即驱逐”的差别有多大?我们的实验给出了惊人的答案。在消融研究中我们发现,如果采用立即驱逐策略(即一旦决定就马上删除),模型的性能会随着压缩率的提高而急剧下降。而我们的延迟驱逐策略,则能在很高的压缩率(比如8倍)下,依然保持优异的性能[1]。
这背后的原因是,Transformer模型尤其关注最近的上下文。立即驱逐一个刚刚生成的token,对模型来说是毁灭性的打击,因为它还没来得及消化其中的信息。而延迟驱逐完美地规避了这个问题。下面的对比动画清晰地展示了这一点。
上方是“立即驱逐”,下方是“延迟驱逐”。观察上方的模型性能(右侧的仪表盘)在驱逐发生时如何剧烈波动和下降,而下方的DMS模型则能保持平稳和高分。这证明了延迟驱逐在保护模型推理能力上的关键作用。
现在,让我们戴上工程师的眼镜,深入DMS背后的数学原理。理解这些公式,能帮助我们更深刻地体会DMS设计的精妙之处[2]。
首先,回顾一下标准的自注意力机制。对于一个输入序列的隐状态 $h_{1:T}$,我们通过线性变换得到查询(Query)、键(Key)和值(Value):
$$ q_{1:T} = W_q h_{1:T}, \quad k_{1:T} = W_k h_{1:T}, \quad v_{1:T} = W_v h_{1:T} $$第 $i$ 个token的输出 $o_i$ 是所有值 $v_j$ 的加权和,权重 $a_{ij}$ 由 $q_i$ 和 $k_j$ 的相似度决定:
$$ a_{ij} = \frac{\exp(q_i^\top k_j / \sqrt{d_k})}{\sum_{t=1}^{i} \exp(q_i^\top k_t / \sqrt{d_k})}, \quad o_i = \sum_{j=1}^{i} a_{ij} v_j $$💡 公式解读 这里的关键是,$o_i$的计算需要访问从1到$i$的所有$k_j$和$v_j$。这就是KV缓存不断增长的根源。
DMS的核心是在每个时间步 $t$ 为新生成的键值对 $(k_t, v_t)$ 预测一个驱逐决策 $\alpha_t$。为了让这个离散的决策(保留或驱逐)在训练中可微分,我们使用了Gumbel-Sigmoid分布进行随机重参数化:
$$ \alpha_t \sim \text{Gumbel-sigmoid}(h_t w^\top + b, \tau) \in [0, 1] $$💡 公式解读
在训练中,我们如何实现“延迟驱逐”呢?我们通过构造一个特殊的注意力掩码 $M_\alpha$ 来模拟这一行为。这个掩码被加到原始的注意力分数 $QK^\top$ 上。如果一个token $j$ 在未来某个时刻被标记为驱逐,那么在它被正式“移除”之前,所有后续的token $i$ ($i > j$) 仍然可以访问它。一旦过了滑动窗口期,掩码就会将对应位置的值设为 $-\infty$,从而在注意力计算中完全屏蔽掉这个token。
这种设计非常巧妙,它允许我们在不改变模型核心架构的情况下,仅通过修改注意力掩码就教会模型复杂的延迟驱逐行为。如论文图2(b)所示,掩码的构建方式使得信息可以在窗口期内流动[1]。
我们的训练目标函数由两部分组成:
$$ \mathcal{L} = \mathcal{L}_D + \mathcal{L}_{\text{aux}} $$💡 公式解读
通过这套组合拳,DMS能够在保持高性能的同时,精准地达到我们期望的压缩目标。更令人兴奋的是,整个“再训练”(retrofitting)过程非常高效,仅需约1000个训练步骤就能达到8倍压缩,数据效率远超之前的DMC等方法[1][4]。
理论再完美,也需要实验来验证。我们在多个高难度的推理基准测试(如AIME 24数学竞赛、GPQA科学问答和LiveCodeBench编程)上,对DMS进行了严苛的考验[1]。结果令人振奋!
在性能评估中,我们最关心的是“帕累托前沿”。你可以把它理解为“性价比曲线”。在给定的计算成本(无论是时间还是内存)下,谁能达到更高的准确率,谁就更优越。我们的实验结果(如下图动画所示)清晰地表明,无论是在“准确率 vs. 运行时”还是“准确率 vs. 峰值内存”的比较中,DMS的帕累托前沿都全面超越了原始的Vanilla模型以及其他先进的稀疏化方法(如Quest和TOVA)[1]。
这意味着,对于任何给定的计算预算,使用DMS的模型都能比原始模型“思考”得更长或更广,从而取得更高的分数。例如,在Qwen-R1 32B模型上,DMS在AIME 24上平均提升了9.1分,在GPQA上提升了7.6分[1]。这证明了推理时超缩放的巨大潜力。
动画展示了在“准确率”与“计算成本”的二维平面上,不同方法(紫色-Vanilla, 黄色-Quest, 绿色-DMS)的性能点。DMS形成的帕累托前沿(绿色实线)明显优于其他方法。将鼠标悬停在DMS的点上可以查看具体的超参数配置。
另一个让我们非常自豪的成果是DMS的数据效率。之前的一些需要训练的压缩方法(如DMC)虽然有效,但需要大量的训练数据和时间,成本高昂。而DMS得益于其更简单的驱逐机制和延迟策略,训练过程极其高效。
我们的实验表明,DMS达到同等甚至更好性能所需的训练数据,比DMC少了一个数量级。具体来说,我们用比DMC少8倍的训练token,就获得了比它更强的性能[1]。这大大降低了将现有LLM改造为高效推理模型的门槛,使其变得非常实用。
动画模拟了训练过程。横轴是训练数据量,纵轴是模型性能。观察DMS(绿色线)如何用更少的数据量,就迅速达到了比DMC(蓝色线)更高的性能水平。
回顾我们的研究历程,从最初对KV缓存瓶颈的困惑,到提出“延迟驱逐”这个核心创意的欣喜,再到看到实验结果中那条漂亮的绿色帕累托曲线时的激动,每一步都充满了挑战与回报。
我们相信,推理时超缩放为提升大语言模型能力提供了一个全新的维度。它告诉我们,除了不断增大模型参数,我们还可以通过优化推理过程,在同样的硬件上榨取出更多的“智能”。而动态内存稀疏化(DMS)正是实现这一目标的一把经济、高效且强大的钥匙。
它不仅仅是一种压缩技术,更是一种让模型学会“专注”与“遗忘”的智慧。通过赋予模型动态管理自己“短期记忆”的能力,我们让它能够在有限的资源下,进行更复杂、更深入的思考。我们由衷地希望,这项工作能够启发更多同行,共同推动LLM向着更高效、更强大的未来迈进,最终将这些前沿技术,转化为每个人都能受益的工具。
感谢您的时间和关注,我们的探索之旅,未完待续……