本文翻译自 Building a Minimal Transformer for 10-digit Addition,原载于 Hacker News。
挑战的由来
我在 Twitter/X 上看到一篇帖子,讨论如何训练一个参数少于 1000 的模型来做 10 位数加法。帖子里提到 Claude 和 Codex 能够独立完成这个任务,认为这是一件非常令人印象深刻的事情。
坦白说,我并没有觉得那么震撼。一方面是因为我见过这些模型能做到的更惊人的事情;另一方面,我觉得自己可以做得更好。粗略估算一下,我认为大概能用 100 个参数甚至更少来完成这个任务。于是我开始思考最佳的实现方式……
核心思路
了解 Transformer 架构的人都知道,注意力机制(Attention)自然是用来做累加操作的不二之选。你有 queries、keys 和 values,理想情况下每个都是一维的,这样就不会增加额外的参数。如果你只有一个 value,最自然的选择就是让它成为累加和。
要实现这个目标,你需要:
- 数字本身
- 10 的幂次
- 而且这些幂次需要以可累加的格式呈现
你需要把数字乘以 10 的幂次。实际上不需要精确的 10 的幂次——只需要保持正确的相对比例,然后再选择一个任意绝对比例即可。
完整代码可以在这里找到,还有一个排行榜记录了各路选手的提交。
设计原则:什么才算”合理的 Transformer”?
我想要的是一个合理的 Transformer——如果有人看到 ONNX 文件,会说”对,这就是个 Transformer”。可能会有一些奇怪的地方,毕竟我们在追求极低的参数量,但不能瞎搞。据我所知,没人会把 RoPE 应用到 values 上。也许某些超冷门的论文里做过这样的消融实验,但这真的不是常规操作。所以我放弃了这条路。
更广泛地说,以下是我遵循的标准:
什么算参数? 这显然是个见仁见智的问题。全零矩阵不算参数——如果一个张量里全是固定的零,那就不应该计数。零不是参数,就像无穷大不是一个数一样。单位矩阵也很自然,我觉得那也不应该算作参数。
什么在架构范围内? 某些稍微越界的做法可能是合理的,比如每层使用不同的维度。但我们不应该做一些临时性的操作,比如层与层之间使用不同的激活函数,或者不同形式的前馈网络/注意力机制。我们要保持在让看 ONNX 文件的人不会完全毛骨悚然的范围内。
我使用了 ReGLU、ALiBi(只有一个 head 有非零斜率,其他 head 相当于零斜率)、双精度浮点数、每层不同的模型维度和 head 数量、无/隐式单位矩阵的 lm_head。
当然,我们的目标是尽可能少的参数。如果用稠密矩阵计算所有非零参数,是 95 个;忽略单位矩阵的话是 36 个。如果复用输入嵌入的一个维度来表示候选数字(包括缩放和偏移),可以降到 28 个。如果进一步追求”标准架构”,用 RoPE 代替 k 值来选择,可以降到约 22 个;如果不计算嵌入参数,甚至可以降到约 12 个。
你大概还可以通过各种技术细节和 hack 手段进一步压缩。我的直觉是,对于这种结构化序列,结合使用 alibi、rope、softmax1,你可以通过让某些东西变成固定模式而不计入参数,从而再减少几个注意力参数。
数据格式
Codex 的解决方案反转了数字顺序,这让进位逻辑更容易,但不够优雅。所以我更喜欢在前面前面补零。格式如下:两个 10 位数字输入,一个 11 位数字输出。示例:
\[7650676663+0149460439=07800137102\]RoPE 的弯路(以及为什么放弃)
为了获取 10 的幂次,我一开始考虑使用旋转位置编码(RoPE)。你可以做旋转,如果考虑非常小的旋转角度,它们之间基本是线性间隔的。有了线性间隔的数字,你可以重新缩放或添加偏置,如果你使用 ReGLU,sigmoid 的某些区域——主要是非常负的区域——近似于指数函数,所以你可以通过那条路到达目标。
但这种方法需要把 RoPE 应用到 values 上,这完全不自然。这带我到了一个关键的设计原则:虽然你也可以用 RoPE 来有效地掩盖位置以便做注意力计算,但我想关注的是参数而不是超参数。
用 ALiBi 实现指数衰减
放弃了对 values 应用 RoPE 之后,我改用 ALiBi。ALiBi 非常适合实现指数衰减,而这正是我们想要的。我们需要递减的 10 的幂次:第一位是 1e9,一直到 1e0。然后在看到加号后重置,在看到等号后再次重置,这时我们开始生成输出。我们需要对前两个数字的值进行适当归一化后求和。所以我们的注意力输出直接就是序列长度上的累加和。
嵌入策略:保持模型维度小
如果我们想方便地获取数字的实际值,可以用 one-hot 嵌入——但这已经是 100 个参数了,超过我们的目标。即使单位矩阵”不应该真的算数”,有一个大的模型维度感觉也挺浪费的。所以相反,我们让第一个维度就是数字的值,对于其他所有东西(运算符、特殊 token),那个维度为零。如果你想让注意力值直接抓取那个值,这是一个很方便的表示。记住:我们的 value 投影里会有一个 1,基本上就是抓取嵌入的那个维度。
使用额外维度处理控制 Token
我们可以自由使用其他维度,所以用它们来表示 BOS(序列开始)、加号和等号——遇到这些 token 时对应维度为 1。理论上我们可以在同一个维度上用 -1 作为标志值,但用独立维度更简单。我们还有一个维度只在遇到 = 时为 1,原因很快就会清楚。
均值 vs. 求和的问题
如果你想做注意力操作,可以做均值,但不能真正做求和。因为 softmax:它被归一化到 0 到 1 的范围内。你可以在这个范围内做不同的加权,但超出这个范围就更困难了。我们要的是和,但得到的是均值。这是个问题。
不过,有一个非常简单的解决方案。如果 queries 和 keys 在所有位置都相等,而且你在 BOS 位置有一个 1,那么每个位置获得相等的权重。你不能直接得到 N(序列中的位置),但你可以很容易得到 $\frac{1}{N}$——或者其他任何东西除以 N。
所以如果我们有均值,我们不能真的乘以 N(我们没有好的乘法机制),但如果我们在操作其他东西,我们可以把它除以 N,这能帮我们前进。
缩放因子和补零格式
我们要做的是有一个缩放因子,就是递减的 10 的幂次,并与 BOS、加号和等号 token 关联的特殊值。方便的是,如果我们使用补零格式,这些都相隔 11 个位置。所以我们甚至不需要用 1/N——我们知道确切想要的缩放。
有了 ALiBi 提供的指数衰减,我们可以用那些值来做重置,并适当地重新归一化 10 的幂次。
计算输出:累加和和数字选择
对于等号后面的部分,我们实际上想要一个累加和,表示我们有多少以及还剩多少。实际上我们会把等号后面的数字位算作负数,因为我们想让任务保持一致:”写出剩余的数字。”
我们要这样做:我们有候选值(数字——我们很容易获取它们),然后用归一化因子(缩放)调整它们。然后我们看我们有的值与数字加 0.5 之间的差值。
为什么加 0.5?如果我们想最小化差值,我们实际上想对每个数字做向下取整操作(最后一个除外,稍后讨论)。我们想找到不会超过该数字的最大数字位。给数字加 0.5 让我们把任务重新表述为找到与这些值的最小差值,这简化了问题。
用 ReLU 计算差值
如果我们能用适当的缩放计算这些差值,有多种方法可以继续。对于 ReLU 网络,一种方法是使用固定的上投影(就设为 1 或任何非零常数),并让门投影对应于误差和误差的负值(权重为 1 和 -1),然后将这些值相加得到绝对误差。
你也可以通过使用 ReLU 的线性区域来做类似平方的操作。有了平方方面,你可以做一种半空间的 N 平方,以及各种其他事情。
从差值到输出 Logits
这就是整体策略。一旦我们有了这些差值,我们要选择差值最小的项。最简单的方法是:不是取差值的绝对值,而是取差值的负绝对值,并直接把它们当作 logits。
你不需要单独的输出投影,因为你已经设置好了,使得最大的 logit 对应你想要的数字——除了最后一个位置需要向上取整。为了最小化残差的影响,我把这些值放大了很多。你的损失函数可能关心大幅度 logits,但我只关心正确的排序。
剩余部分:归一化和累加和
从这里我们只需要归一化因子和数字的累加和。如前所述,累加和不能直接做——我们得到的是累加均值。要用它作为和,我们获取 $\frac{1}{N}$ 并用它作为我们要比较的东西的缩放因子(在这种情况下,是用来计算误差的候选数字)。
为什么用 Softmax1:解决 1/N 问题
用 ALiBi 会更容易,因为你有指数衰减。但对于完全标准的注意力,它表现不太好——你在添加额外的东西并对其做指数衰减,所以你得不到纯粹的指数衰减,因为你要么在添加以不同方式影响不同序列位置的东西,要么不添加,输出就是恒定的注意力权重到该位置。
这实际上是 Softmax1 非常干净地解决的问题。关键观察是标准 Softmax 本身没有正确归一化,而 Softmax1 是:
\[\frac{e^{x_i}}{1 + \sum_j{e^{x_j}}}\]我们的案例有一个很好的例证:我们有一个与第一个元素关联的值和 key,我们希望它的权重是指数衰减的。但它需要有东西作为参照,功能上必须是其他元素的注意力分数,这些分数也受衰减影响。所以最终看起来完全不同——一个更不合理的基础来获得指数衰减。
关于 $\frac{1}{N}$,有一些困难,你得不到干净的 $\frac{1}{N}$,因为 Softmax1 中的那个 1。我们实际上是在除以 N + 1,但因为我们同时对均值和缩放因子都这样做,所以实际上没有影响。
所以我们要做的就是使用 Softmax1。通过让一个 token 的权重为 1,所有其他为 0,并让所有 keys 和 values 等于零,我们可以得到 $\frac{1}{N}$。
然而,对于平均用于重新缩放幂次的特殊 token,我们想要的行为是 1->衰减, 1->衰减, -1->衰减,均匀地。但是,如果我们从大约 0 的分数开始,我们的衰减将开始时近似线性,导致 10 的幂次的值不好,而且由于早期的贡献会非常大,我们承担不起哪怕很小的误差,所以我们使用 -30 的大负偏置来保持衰减非常接近完美的指数。
使用双精度
因为我们想能够直接在一个激活中表示和,我们需要尾数中至少有 11 位十进制数字的精度,32 位浮点数只能处理大约 7 位,所以我们使用双精度。
反思
这是一个有趣的”极客陷阱”。我还没准备好挂冠让 AI 们独享这类乐趣。我自己写了一个完整的实现,几乎没有 AI 参与,基本上只是 Copilot 的样板代码。在某个时候我让 Claude Code 帮我调试,令我惊讶的是,我不记得它实际上解决了任何 bug,它似乎更关心”纠正”我有意做的那些奇怪事情。
当然 Transformer 是很有表达力的。我不认为这是 Transformer 真正能够通过学习获得的东西,部分原因是 logits 在训练中几乎肯定不稳定,更不用说 ALiBi head 和实际上没人在双精度下训练的事实。可能有类似的东西是可以学习的,这个事实本身就很有趣,但差距也同样有趣——只有某些算法可以通过梯度下降学习。我可以手写这个,但写不出写诗的那个。有没有另一种方法从一个到另一个(学习的 vs. 编码的)。
要点总结
- 参数计数是一门艺术:零不算参数,单位矩阵也可以不算,关键在于如何定义”合理”
- ALiBi 比 RoPE 更适合指数衰减:在需要精确控制位置权重时,ALiBi 的数学特性更有优势
- Softmax1 解决了经典 Softmax 的归一化问题:提供了更干净的 1/N 行为
- 手写算法 vs 学习算法之间存在鸿沟:有些东西可以编码但无法通过梯度下降学习
- 双精度浮点数在某些场景下是必要的:当需要 11 位十进制精度时,32 位浮点数不够用