本文提出 Think, Prune, Train 框架,通过迭代监督微调和基于正确性的数据修剪,实现模型在不增加规模的情况下提升推理能力,避免模型坍缩。
Large Language Model, Reasoning, Fine-tuning, Synthetic Data, Data Augmentation, Efficiency
Caia Costello, Simon Guo, Anna Goldie, Azalia Mirhoseini
Stanford University, Ceramic AI
Generated by grok-3-mini-latest
Background Problem
大型语言模型(LLMs)在编程和数学推理任务中表现出色,但受限于高质量训练数据的可用性。合成数据可以用于提升微调效果,但涉及多个因素,包括模型大小、合成数据量、修剪策略和微调轮数。研究起始点是探讨模型是否能通过自身生成的推理轨迹实现自我提升,而不依赖外部监督。关键问题解决包括避免模型坍缩(如知识遗忘和幻觉生成),并通过正确性-based 修剪策略稳定训练过程,从而在不增加模型规模的情况下提升推理能力。
Method
- 核心思想: 引入 Think, Prune, Train (TPT) 框架,通过迭代过程让模型在自身生成的正确推理轨迹上进行监督微调 (SFT),而不依赖外部模型或强化学习 (RL)。
- 实现方式: 该方法包括三个主要步骤:1. 提示模型以结构化方式生成推理轨迹(例如使用 Chain-of-Thought 提示);2. 使用 ground-truth 正确性过滤修剪不正确输出,确保数据质量;3. 在当前模型上进行 SFT,使用验证后的解决方案。关键是通过算法 1 描述的迭代过程:从基模型开始,生成多个解决方案,应用修剪策略,随机采样固定大小的数据集进行微调,每个轮次独立进行,以避免数据积累带来的混淆效应。
- 主要步骤: 输入基模型、修剪策略和数据量;迭代 N 轮:使用当前模型生成解决方案,修剪后采样数据,进行 SFT,输出改进模型。相关理论支持包括将 SFT 视为政策梯度的特殊情况,例如在相关工作中提到的公式:,并简化到基于正确性的 SFT。
Experiment
- 实验设置: 使用 GSM8K 和 CodeContests 数据集,实验对象包括 Gemma 和 LLaMA 模型家族(如 Gemma2-2B-it、Gemma2-9B-it、LLaMA-3.1-70B-Instruct)。微调参数:学习率分别为 Gemma 的 1e-6 和 LLaMA 的 1e-5,使用 AdamW 优化器,一次迭代一个周期,10% 热身步骤。数据生成使用温度 0.8 的 CoT 提示,评估使用温度 0.7,指标包括 Pass@1、Pass@20 和 Pass@50。实验设计全面,包含基线比较(如无修剪数据)、消融研究(如数据量、修剪策略的影响)和迭代轮次分析。
- 结果: 在 GSM8K 上,Gemma2-2B 的 Pass@1 从 41.9% 提高到 57.6%,Gemma2-9B 达到 82%,LLaMA-3.1-70B 从 78% 提升到 91%,超过 GPT-4o 的 82%。在 CodeContests 上,Gemma2-2B 的 Pass@1 从 0.90% 提高到 1.14%,Gemma2-9B 从 5.10% 到 7.90%。结果符合预期,证明修剪策略有效防止模型坍缩,同时 Pass@20 和 Pass@50 的稳定表明模型维持了多样性。实验合理性高,通过控制变量(如固定数据量)隔离了自我提升效果,并与相关工作(如 STaR)比较,验证了 TPT 的优越性。
- 为什么这样设计: 实验旨在隔离因素影响,如通过固定数据集大小避免数据积累混淆,焦点在推理能力的真实提升而非数据规模效应。
Further Thoughts
这个框架强调了数据修剪在自我提升中的关键作用,或许可以扩展到其他领域如自然语言理解或决策任务中,与 STaR 或 ReST 等方法相比,其简化了过程仅依赖 SFT,而非复杂的 RL,这可能降低部署成本;然而,递归训练可能引入偏见或减少输出多样性,未来研究可探索结合外部数据或多样性技术以增强鲁棒性,同时考虑到实际应用中,类似方法能促进高效 AI 开发,但需警惕潜在的伦理问题如模型同质化。