LLM 真的能学会算术吗?
受 DeepSeek R1 采用强化学习(RL)训练的成功启发,我决定重新探索 RL 技术。四年前,AlphaGo 战胜人类棋手时,我曾对 RL 产生极大兴趣。然而,尽管 OpenAI 早在几年前就在 ChatGPT 训练中应用了 RL(尤其是 RLHF),我始终未能完全理解 RL 在 LLM 训练中的具体应用方式。
几年前,我曾在树莓派(Raspberry Pi)驱动的 Donkey 机器人小车上尝试过 RL 训练,但效果不佳。尽管我对 RL 的潜力有所领悟,但也深知其局限性。
在阅读了 DeepSeek 相关论文后,我对他们发明的 GRPO(Group Relative Policy Optimization) 深感兴趣。GRPO 通过消除价值网络,大大简化了 RL 方法。我随后阅读了 huggingface/trl 和 tinyzero 等项目的 GRPO 代码。虽然核心代码相对简单,但大量的支撑代码(scaffold code)使实现过程变得繁琐。正如科学家所言,理解世界最好的方式就是亲手构建它。
因此,我决定自己复现 GRPO,并将其构建在我两年前开发的 nanoGPT-rs 项目之上。nanoGPT-rs 是 nanoGPT 的 Rust 复刻版。通过从零实现 LLM,我能够彻底理解基于 Transformer 的架构。此外,我非常喜欢 dfdx 这个 Rust 机器学习框架,它类似于 PyTorch,但能够在编译时静态检查张量形状,确保代码正确。这种强制性约束帮助我更好地理解神经网络组件之间的交互关系。这也是 Rust 与 Python 在开发体验上的一大不同之处。遗憾的是,dfdx 目前的开发已经停止了。
实验目标
我的目标是训练 LLM 学习简单的数字加法。这看似是一个简单的任务,但实际上并不容易。即便是最先进的模型(如 GPT-4),有时仍然会在基本算术运算上出错。而且,我们仍不清楚 LLM 究竟是在学习算术规则,还是仅仅在记忆训练数据中的加法结果。
为了简化任务,我的模型仅需预测加法的运算结果,而不涉及其他文本生成。我采用 Python 脚本生成训练数据,包含:
- 两个 1 位数相加(如 2+1=3)
- 两个 2 位数相加(如 11+12=23)
- 三个 1 位数相加,并加入中间步骤(如 1+1+1=2+1=3)
模型的输入格式如下:
2+1=
11+12=
1+1+1=
实验结果
- 经过预训练(pre-training),模型可以在少量训练周期内迅速达到高准确率,但随后精度停滞。
- 继续采用 RL 训练后,准确率有所提升。
- 即使准确率接近 100%,模型仍可能在未见过的样本上出错。
- 这表明 LLM 并没有真正学会加法规则,而是通过某种近似方法进行推理。但这种方法与人类的数学推理方式完全不同。
我的思考与收获
1. LLM 不会“自动”学会算术规则
LLM 不会自然地学习算术法则,而是通过某种方式近似计算结果。这种方法不一定符合人类的算术规则,因此 LLM 的运算结果并非 100% 正确,且在某些情况下可能不可预测。
2. 过拟合是一个严重问题
对于简单的加法任务,LLM 很容易过拟合训练数据。因此,使用小型模型可以减少过拟合,并且 RL 训练能够促使模型学习底层规律,而不是死记硬背结果。
3. RL 训练依赖大量试错
RL 训练需要随机试错来找到正确答案,因此它通常需要比预训练更多的计算时间。然而,在足够长的训练时间后,RL 能够进一步提高模型的准确性。
4. 最优策略:先预训练,再强化学习
先预训练,再进行 RL 训练是更有效的方法。
- 预训练可以让模型掌握基本规则,例如确保输出为数字,从而减少 RL 训练的探索空间,加快收敛速度。
- 但如果预训练时间过长,模型可能会过拟合训练数据,导致 RL 训练时难以探索新的规则。如果模型在预训练阶段学错了算术规则,它将很难遗忘错误的规则,并学会正确的规则。这是由于神经网络的局部梯度下降特性,使其难以跳出局部最优解。
5. 数据多样性对于泛化至关重要
训练数据的多样性至关重要,它能够防止过拟合。模型越大,对数据的多样性要求也就越高。如果一个模型过早对特定模式产生自信,它会停止探索,从而无法泛化。
好的学习曲线应该在中期有一个突增。如果模型的准确率是线性增长的,那么它很可能只是在过拟合,而无法真正泛化。这也是为什么只有非常大规模的模型,在海量数据训练下才能实现泛化。
这一点也适用于人类:如果一个人接受的知识范围过于狭窄,他们往往会过早地对自己的观点产生极端自信,而难以接受新观点。
6. LLM 能否学会“怀疑”自己?
真正的科学突破往往需要质疑已有知识体系。然而,当前 RL 训练(如 PPO)强调渐进式参数更新,限制了模型对已有知识的彻底推翻。
在科学史上,哥白尼、伽利略、牛顿等人都是质疑旧理论、打破既有框架,才实现了革命性的发现。但 LLM 只能不断累积复杂的近似规则,而无法主动推翻错误的推理方式。这就像历史上天文学家在地心说的框架下,不断添加复杂的数学模型来解释行星轨道,而无法轻易接受日心说的简单理论。
7. LLM 仅靠数学数据能学会数学吗?
数学看似是一个封闭体系,但人类的数学理解依赖于现实经验。比如 9+2=11,理解这一概念需要:
- 知道“9”和“2”是数字
- 理解“+”表示加法
- 理解 11 是 10 之后的下一个数
- 认识到 9+2 = 2+9(交换律)
孩子学习算术时会用手指数数,而 LLM 没有这样的现实经验。它必须依靠完全不同的方式来学习算术,这可能永远无法与人类的方式对齐。
结论:LLM 仍然是一个不确定的黑箱
我的实验表明,无论模型大小,当前 LLM 架构都无法真正学会算术规则。它们只能通过未知的近似方法进行计算,而这种方法的准确性无法保证。
如果我们需要绝对正确的算术计算,最好的解决方案并不是 LLM,而是一个按照人类定义的数学规则编写的计算器。