NEE's Blog

使用 PPO 进行语言模型的树搜索蒸馏

March 15, 2026

本文翻译自 Tree Search Distillation for Language Models using PPO,原载于 Hacker News。


像 AlphaZero 这样的博弈类神经网络,通过将原始策略与测试时搜索框架(test-time search harness)相结合,再将增强后的更强策略蒸馏回网络,从而在棋类游戏中取得超越人类的表现。那么,为什么类似的技术在当今的语言建模中没有得到广泛应用呢?DeepSeek-R1 的作者提到,他们在 MCTS 方面收效甚微;Finbarr Timbers 有一篇精彩的文章分析了他们可能遇到的问题——即选择了 UCT 而非 pUCT。

本文旨在探讨两个问题:

  1. 搜索蒸馏是否真的能提升语言模型的推理能力?
  2. 与标准语言 RL 方法(如 GRPO)相比,它的表现如何?

为了探索这些问题,我将 MCTS 应用于 Qwen-2.5-1.5B-Instruct 模型的推理步骤,搜索更强的推理轨迹,并通过在线 PPO 循环将这些轨迹蒸馏回模型。在 Countdown(一个组合算术游戏)任务上,蒸馏后的模型(不使用搜索框架进行评估)达到了 11.3% 的渐近 mean@16 评估分数,而 CISPO 为 8.4%,Best-of-N 为 7.7%。相对于预训练的 instruct 模型(3.1%),这是一个 8.2 个百分点的提升。

这些较低的绝对分数反映了这些只是在 1.5B 模型上进行的小规模实验。我希望将这篇文章作为系列的开篇,并期待在后续文章中使用更大的模型和计算预算来提高这些分数。

Countdown 游戏环境

我最初尝试使用 GSM8K 作为测试环境,但发现 GRPO 和 MCTS 之间的差异微乎其微,难以做出强有力的结论。因此,我决定使用 Countdown 游戏作为我们的环境。其规则很简单:给定 N 个正整数,使用标准运算(+、-、/、*)计算特定的目标值。

为什么选择 Countdown? 假设是,组合问题更能受益于树搜索所提供的并行自适应推理能力,而像 GSM8K 这样的任务,顺序推理也能取得良好效果。我们在包含 20,000 个样本的数据集上训练,在 820 个样本的测试集上评估。每个样本由 1 到 13 之间的四个输入整数组成。

我发现,在训练过程中使用稀疏奖励(正确性为 0/1)会导致训练不稳定。于是我改用密集奖励函数:

\[1.0 - 2 \cdot \min\left(\frac{|t - p|}{t}, 1.0\right)\]

其中 $t$ 是真实目标值,$p$ 是预测目标值。如果格式正确,则使用上述公式;否则为 -1.0。

不过,评估仍然使用稀疏奖励函数,因为我们希望能够直观理解分数(例如通过率 %)。

蒙特卡洛树搜索(MCTS)

MCTS 算法已经被其他人深入讨论过,所以我将跳过详细描述。本文将重点介绍经典 MCTS 与我尝试的方法之间的差异。简而言之,MCTS 迭代地构建搜索树,在价值函数的指导下智能地探索动作空间。

棋类游戏有一个相对有意义的动作空间,即国际象棋中的每一步棋都对玩家是否能获胜有实质性影响。相比之下,在语言建模中,推理轨迹中的许多 token 只是填充词或语法糖,从 top-k logits 分支(或基于熵阈值的条件)并不总能带来搜索多样性。想象一个状态,下一个可能的 token 是 “but”、”however”、”yet” 等;我们最终会花费计算资源构建过大的搜索树,但在每个 token 的基础上收益微乎其微。

我更倾向于 Tree-of-Thoughts(Yao et al., 2023)引入的方法,在可能的下一个推理步骤上进行搜索。在这个公式中,每个节点状态是一个连续 token 序列:

  • 根节点对应输入 prompt
  • 中间节点对应推理步骤:<step>...</step>
  • 终端节点对应答案:<answer>...</answer>

为了探索更多的扩展”旋钮”,我的实现使用了并行 MCTS,其中 N 个 agent 共享同一个样本级搜索树,并使用虚拟损失(virtual losses)来鼓励搜索多样性。

从每个叶节点开始,我们生成 K 个补全,直到遇到停止标签 </answer>。这 K 个序列构成了该特定节点的动作空间。

由于 pUCT 需要动作级别的先验概率,我们计算序列级别的累计对数概率,并应用 softmax 函数来获得相对先验。这样做很有效,因为原始的累计序列概率会变得非常微小且数值不稳定。

MCTS 通常还使用一个价值头 $V(s_t)$,它在训练过程中改进,并帮助引导搜索过程找到更好的轨迹。这被实现为一个 MLP,后跟一个 tanh 函数,应用于 transformer 的最终隐藏状态。

这种方法与 TS-LLM(Feng et al., 2023)有相似之处,后者也将 AlphaZero 风格的树搜索与句子级动作上的学习价值函数相结合。

主要区别在于:

  1. 使用在线 RL(CISPO/PPO)而非 SFT 进行蒸馏
  2. 使用带虚拟损失的并行 MCTS 作为额外的扩展维度

轨迹选择

在棋类游戏的 MCTS 中,训练信号通常来自最小化根节点搜索策略与模型原始策略之间的 KL 散度。然而,由于我们的动作空间粒度与原始模型动作空间(推理步骤 vs. token)存在不匹配,我们需要采用其他方法。我使用的方法是:当所有 worker 完成某个样本的 M 次迭代后,它们执行贪婪选择过程:

  • 从根节点开始,按最大访问次数选择一条轨迹
  • 将此轨迹提交到共享缓冲区,用于 PPO 训练

训练过程

被指定为”trainer”的 worker 异步地从共享缓冲区中拉取样本。它们使用 AdamW 优化器,对每批 B 个样本执行一次 PPO 内部步骤,使用 CISPO 作为损失类型。

训练目标是最小化总损失 $L_{total}$:

\[L_{total} = c_{ppo} L_{ppo} + c_{value} L_{value} + c_{KL}\, \mathbb{D}_{KL}(\pi_\theta \mid\mid \pi_{ref})\]

其中 CISPO 损失为:

\[L_{cispo} = -\mathbb{E}\left[sg(\min(\frac{\pi_\theta(a_t \mid s_t)}{\pi_{old}(a_t \mid s_t)}), \epsilon) \cdot A_t \cdot \log \pi_\theta(a_t \mid s_t) \right]\]

这里 $A_t = r_{terminal} - sg!\left(V_{old}(s_t)\right)$ 是 token 级别的优势值(我们将相同的终端奖励分配给每个 token)。我没有使用 GAE,因为推理轨迹可以延伸到数千个 token,而在终端奖励的情况下,早期 token 会以指数方式被折扣到可以忽略不计的值。

价值损失为:

\[L_{value} = \mathbb{E} \left[(V(s_t) - r)^2\right]\]

KL 散度为(来自 DeepSeek-R1 论文):

\[\mathbb{D}_{KL}(\pi_\theta \mid\mid \pi_{ref}) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{ref}(a_t \mid s_t)} - \log \frac{\pi_\theta(a_t \mid s_t)}{\pi_{ref}(a_t \mid s_t)} - 1\]

我们运行训练过程,直到评估分数趋于稳定。

基础设施

所有实验都在 Andromeda 的 8xH100 节点上进行。对于 MCTS,六个 GPU 被指定为生成器,两个为训练器。一个 Rust worker 从数据集中采样问题,并通过 gRPC 向生成器池提交推理请求。它们将选中的轨迹写入 Redis 流;训练器从这里迭代地拉取样本。权重每 8 个梯度步使用 Redis pub/sub 在生成器和训练器之间同步。

基线方法

我运行了一个 CISPO 基线,全局批次大小为 128 个样本,组大小为 16,有效批次大小为 2048。Logits 以 float32 计算,按照 ScaleRL 的做法。同样,训练运行直到评估分数趋于稳定。所有八个 GPU 都用于训练 CISPO,没有训练器/生成器分离。

为了隔离树结构的价值,我还运行了一些实验,其中提交到训练缓冲区的轨迹是通过 “best-of-N”(N=64)而非树搜索选择的。

实验结果

我们使用 mean@16 来评估模型。这意味着对每个评估 prompt 运行 16 次生成,用稀疏的 0/1 奖励对它们进行评分,然后取平均值。

在评估时,不带搜索框架的 MCTS 蒸馏策略达到了 11.3% 的渐近 mean@16 分数,而 CISPO 模型稳定在 8.4%,Best-of-N 表现最差,稳定在 7.7%。

令人惊讶的是,我还发现尽管训练奖励明显更高,但 “best-of-N” 蒸馏在评估套件上的表现不如 CISPO 和 MCTS。虽然原因尚不完全清楚,但我们可以推测:如果我们的模型在思考轨迹中至少犯一个推理错误的概率是 98%,那么选择至少一条正确轨迹的概率仍然是 $1 - 0.98^{64} \approx 72.6\%$。但如果没有每次都产生稳健推理的激励,模型就不太可能学会发展出能提高其单次得分的策略。

个人感悟:这让我想起中学时解题的场景。为了在数学考试中避免”低级错误”,我会使用各种技巧来跟踪中间步骤。如果我有多次考试机会,我可能永远不会养成这些好习惯!这正是 Best-of-N 的困境——它提供了”多次尝试”的安全网,却削弱了每次都做好的动力。

代码

所有代码都是开源的,可以在这里找到。

未来方向

这意味什么?让我兴奋的是我们可以调整的额外推理旋钮,比如每棵树的并行 worker 数量,或 MCTS 迭代次数。我还没有正确调整这些参数,但初步实验表明增加这两个值会带来显著的性能提升。所以我想进一步探索这个方向!有大量的工作要做,扩展这种方法并绘制经验趋势来评估其在更大模型和计算预算下的潜力。如果你想合作,请联系我!

当然也有一些注意事项:这可能是”小模型现象”,该方法对于更大模型可能不如 GRPO 扩展得好。是否可以通过调整 GRPO(CISPO)基线来匹配 MCTS?也许可以,但 ScaleRL 发现 GRPO 的大多数超参数影响的是计算效率,而非最终奖励上限。

有人可能会指出,MCTS 在每个样本的基础上比 GRPO 使用更多的推理计算:当然它表现更好!然而,这里的目标不是进行完全对等的计算比较;是的,MCTS 确实使用了更多的推理时计算,但它也给了我们额外的杠杆来应用/扩展这些计算并提高奖励上限。而在 GRPO 上投入 100 倍的计算是否能将平稳期变成曲棍球棒式增长,对我来说并不明显。

致谢

我要感谢 Andromeda 团队和 Molly Mielke McCarthy 为这个项目赞助计算资源,以及 Tom McCarthy 和 Joe Melkonian 阅读本文的早期草稿并提供宝贵反馈。我还要感谢 Finbarr Timbers 的博客文章,它是这项工作的催化剂。

超参数表

参数 描述
基础模型 Qwen-2.5-1.5B-Instruct 实验使用的基础模型
训练数据集大小 20,000 样本 训练用的 Countdown 问题数量
评估集大小 820 样本 评估集中的问题数量
N(每题整数数) 4 每个 Countdown 问题中的整数数量
输入范围 [1, 13] 输入整数的范围
MCTS Workers 16 共享一个搜索树的并行 agent 数量
每节点补全数(K) 4 在每个叶节点生成的候选序列数
MCTS 迭代次数(M) 100 每个样本的迭代次数
虚拟损失 1 添加到访问计数的值,以避免并行分支冲突
MCTS/BoN 全局批次大小 32 MCTS/Best-of-N 训练的总批次大小
CISPO 全局批次大小 128 基线训练的总批次大小
CISPO 组大小 16 每组的轨迹数量
CISPO epsilon_high 5.0 CISPO 的裁剪参数
Best-of-N 采样大小 64 N,即 Best-of-N 基线每个样本的生成数
权重同步频率 每 8 个梯度步 生成器和训练器之间权重同步频率
c_puct 0.5 pUCT 系数
c_KL 0.05 KL 散度权重
c_ppo 1.0 PPO 目标权重
c_value 1.0 价值目标权重
N_t 2 训练器进程数
N_g 6 生成器进程数

总结

这篇文章探讨了将 AlphaZero 风格的树搜索与 PPO 在线强化学习相结合,用于语言模型推理任务的知识蒸馏。主要亮点:

  1. 方法创新:将 MCTS 应用于推理步骤级别(而非 token 级别),通过并行 agent 和虚拟损失提升搜索多样性
  2. 实证效果:在 Countdown 组合数学任务上,MCTS 蒸馏(11.3%)显著优于 CISPO(8.4%)和 Best-of-N(7.7%)
  3. 重要洞察:Best-of-N 虽然训练奖励高,但评估效果差——”多次尝试”的机会反而削弱了模型学习稳健推理的动力
  4. 扩展潜力:树搜索提供了额外的扩展维度(worker 数、迭代次数),为进一步提升留有空间

对于关注 LLM 推理能力提升的开发者来说,这是一个值得关注的实验方向。虽然目前只是小规模实验,但它为”推理时计算”(inference-time compute)的有效利用提供了新的思路。

comments powered by Disqus