本文提出分块训练(CWT)和跳跃思维训练(STT),通过将推理过程分块并跳过非核心块,显著提升小型语言模型在链式思维蒸馏中的推理准确性和速度。
Large Language Model, Reasoning, Fine-tuning, Efficiency, Supervised Learning
Xiao Chen, Sihang Zhou, Ke Liang, Xiaoyu Sun, Xinwang Liu
National University of Defense Technology
Generated by grok-3
Background Problem
大型语言模型(LLM)通过链式思维(CoT)蒸馏可以指导小型语言模型(SLM)提升推理能力,但传统方法要求SLM在一次训练迭代中学习完整的长推理过程(rationale),导致两个关键问题:1)表层理解,由于长推理过程使得token级批次大小过大,核心推理token的梯度被过度平滑,SLM难以掌握推理逻辑,趋向于学习表达模式而非核心逻辑;2)响应速度慢,SLM在测试时需生成完整推理过程才能得出答案,耗时较长。本文旨在通过改进训练策略,解决SLM在CoT蒸馏中的表层理解和响应速度问题。
Method
本文提出了两种主要方法:分块训练(CWT)和跳跃思维训练(STT),以改进SLM在CoT蒸馏中的推理能力与速度。
-
分块训练(CWT):
- 核心思想:通过将LLM生成的完整推理过程(rationale)分割为语义连贯的小块(chunks),在每次训练迭代中仅让SLM学习一个块,从而减小token级批次大小,缓解梯度过度平滑问题,增强SLM对核心推理逻辑的理解。
- 实现步骤:
- 使用分块数据生成器,将推理过程分为固定数量(M)的块。提供两种分块方式:平均分块(AC,将推理步骤均分)和基于搜索的分块(SBC,利用SLM损失作为启发式信息,通过贪婪策略调整分块以提高语义连贯性,详见算法1)。
- 在训练中,将每个样本转换为M+1个训练数据,前M个数据对应逐步推理阶段(添加前缀[m]指示当前阶段),最后一个数据包含完整推理和答案(前缀[answer])。
- 关键点:CWT通过隔离非推理块(如过渡或总结性内容),增加核心推理token在对应迭代中的比例,理论上帮助SLM聚焦于推理逻辑。
-
跳跃思维训练(STT):
- 核心思想:基于CWT,训练SLM自动跳过非核心推理块(通过潜在空间内化),仅显式输出包含核心推理信息的块,从而在保持准确性的同时加速推理。
- 实现步骤:
- 使用跳跃数据生成器,依次移除每个块,并利用CWT训练的SLM预测答案。若移除某块后答案仍正确,则该块被认为非必要,可内化;否则,该块需显式输出。
- 基于上述判断构建训练数据,重新初始化SLM(避免过拟合),结合CWT继续训练,确保SLM仍接触完整推理过程。
- 测试时,使用前缀[skip]提示SLM自适应跳过非必要块。
- 关键点:STT平衡了推理速度和准确性,避免完全内化推理过程导致的信息遗忘问题。
批判性思考:CWT依赖于分块的合理性,但SBC的贪婪搜索可能陷入局部最优,论文未充分探讨全局最优分块策略(如模拟退火,僅在限制部分提及)。此外,STT对答案正确性的依赖可能因任务复杂度或数据集特性而失效,缺乏对判断标准的鲁棒性分析。方法设计中未考虑不同SLM规模对分块和跳跃策略的适应性,可能限制其普适性。
Experiment
本文在多个推理任务和SLM上验证了CWT和STT的效果,实验设计涵盖以下方面:
-
数据集与设置:使用7个推理基准数据集,覆盖算术、符号、常识和其他逻辑推理任务。LLM采用text-davinci-002(175B),SLM包括GPT-2(base到large)和T5(small到large)。实验分为消融研究(验证各模块效果)、与其他方法的对比、CWT对表层理解问题的缓解效果验证、以及STT的速度-准确性权衡分析。
-
消融实验(Q1):表1显示,相比基线(完整推理训练),AC在大多数任务上提升SLM准确性,但部分任务下降(可能因语义连贯性受损);SBC进一步提升所有任务性能,证明其分块更优。STT在SBC基础上进一步提高准确性,尤其在复杂任务(如MA)上提升明显(GPT2-base从17.77%到22.77%)。然而,完全跳跃所有块(SkipALL)的变体在某些任务(如LLC)上性能显著下降,表明完全内化推理不适合并行推理任务。
-
与其他方法对比(Q2):表2表明,本文方法在多个任务上优于其他CoT蒸馏方法(如Scott、Step-by-Step、ICoT-SI),在某些任务(如TSO)接近LLM性能。图3显示推理速度虽不及多任务学习或完全内化方法,但实现了性能与速度的较好平衡。
-
CWT效果验证(Q3):图5证明token级批次大小减小(通过CWT)后SLM性能提升,验证了梯度平滑缓解的假设。表3显示CWT后核心推理token的置信分数提升,表明SLM对推理逻辑理解更深。推理速度也有所提升(图3),归因于更简洁的推理表达。
-
STT效果验证(Q4):表4显示STT相比SBC的推理加速比在不同数据集上为1.08到1.89,简单任务(如SQA、TSO)加速更明显,复杂任务(如LLC)因需保留更多关键块而加速有限。案例研究(附录G.3)表明STT通过跳跃中间步骤减少了模型幻觉。
批判性思考:实验设置较为全面,覆盖多种任务和模型,但对比方法(如ICoT-SI)在部分SLM上未实现,可能导致对比不公平。分块数量M的选择对性能影响较大(图4),但实验仅展示趋势,缺乏系统性分析和理论依据。STT在复杂任务上加速效果有限,是否适用于更广泛场景存疑。此外,实验未充分探讨训练时间和内存消耗的实际影响(仅在附录F提及),对实际应用价值评估不足。
Further Thoughts
本文提出的CWT和STT方法在CoT蒸馏中展现了一定的创新性,尤其是在解决梯度平滑和推理速度问题上的尝试值得关注。然而,其方法设计和实验验证仍存在局限性,例如SBC的局部最优问题和STT对任务复杂度的适应性不足。进一步思考,是否可以引入更先进的搜索算法(如遗传算法或强化学习)来优化分块策略,以避免贪婪搜索的局限?此外,STT跳跃机制是否可以结合注意力机制,动态识别核心推理块,而非依赖静态的答案正确性判断?
从跨领域角度看,CWT的分块思想可能借鉴了自然语言处理中分段式处理(如长文本分割)的理念,但其在推理任务中的应用是否会引入上下文割裂问题,值得与长上下文处理技术(如Transformer的长序列优化)结合探讨。STT的跳跃机制与人类认知中的‘思维捷径’有相似之处,是否可以进一步结合认知科学理论,设计更符合人类推理模式的SLM训练策略?这些方向可能为未来的CoT蒸馏研究提供新的视角,同时也需要在更广泛的任务和模型规模上验证其普适性。