本文提出自数据蒸馏微调方法,通过利用未剪枝模型生成蒸馏数据集恢复剪枝后大型语言模型的质量,在HuggingFace OpenLLM Leaderboard v1上显著优于标准监督微调,并通过模型合并和推测解码进一步提升性能和效率。
Large Language Model, Fine-tuning, Efficiency, Reasoning, Pre-training
Vithursan Thangarasa, Ganesh Venkatesh, Mike Lasby, Nish Sinnadurai, Sean Lie
Cerebras Systems, University of Calgary
Generated by grok-3
Background Problem
大型语言模型(LLMs)在自然语言处理中取得了显著进展,但其高计算和内存需求使得部署成本高昂。随着模型规模的增加,压缩技术(如剪枝)成为平衡模型质量和计算效率的关键。然而,结构化剪枝(Structured Pruning)常导致质量下降,尤其是在需要多步推理的任务上。监督微调(SFT)虽能部分恢复质量,但可能引发灾难性遗忘(Catastrophic Forgetting),即模型丢失先前学习到的知识,主要是由于微调数据分布与原始分布的偏移。本文旨在解决剪枝和微调带来的质量下降问题,提出了一种自数据蒸馏方法以保持模型性能。
Method
本文提出了一种结合结构化层剪枝和自数据蒸馏微调的方法,旨在提高大型语言模型的效率并恢复剪枝后的质量。
- 结构化层剪枝:通过计算层间激活输出的角余弦距离(Angular Cosine Distance)来识别冗余层,具体使用公式 ,选择距离最小的层块进行移除,并直接连接前后层以重构模型。然而,论文对角余弦距离作为唯一指标的选择缺乏充分对比分析,可能忽略了其他指标(如Block Influence Score)在不同场景下的潜在优势。
- 自数据蒸馏微调:利用未剪枝的原始模型生成蒸馏数据集,确保微调数据与原始模型分布对齐,从而减少灾难性遗忘。具体步骤包括:首先用原始模型生成蒸馏输出 ,并通过条件选择确保输出质量;然后对剪枝模型进行微调,最小化损失 。这种方法在理论上有效,但在实际操作中生成大规模蒸馏数据的计算成本可能较高,且论文未充分讨论蒸馏数据质量对结果的影响。
- 扩展应用:包括模型合并(通过SLERP方法合并在不同数据集上微调的模型)和推测解码(利用剪枝模型作为草稿模型提升推理效率)。这些扩展虽有创新,但其效果可能依赖于特定模型和任务,普适性存疑。
Experiment
实验主要基于Llama3.1-8B Instruct和Mistral-7B-v0.3 Instruct模型,剪枝块大小从2到10层不等,评估了自数据蒸馏与标准监督微调(SFT)及无微调(No FT)在HuggingFace OpenLLM Leaderboard v1上的表现。
- 数据集:使用GSM8k、Dolly、Alpaca和OpenMathInstruct(50k样本)等数据集,覆盖开放域对话、推理和指令跟随任务。数据集选择合理,但偏重于推理任务,可能未全面反映模型在其他任务上的表现。
- 结果:自数据蒸馏在所有剪枝规模下均优于SFT,尤其在剪枝6层时,Llama3.1-8B模型恢复了91.2%的原始精度(SFT为81.7%),FLOPs减少16.3%。在推理任务(如GSM8k)上,自数据蒸馏表现尤为突出。然而,实验结果显示,随着剪枝规模增大,恢复率仍显著下降,表明方法在极端压缩下的局限性。
- 扩展实验:模型合并(Model Merging)进一步提升恢复率(如Llama3.1-8B在剪枝6层时达93.3%),推测解码(Speculative Decoding)在Spec-Bench上提高了令牌接受率和推理效率。但这些实验的设置较为特定,未充分探讨不同模型架构或任务类型的影响。
- 评估:实验设计较为全面,但对数据集规模的依赖性较强,且未提供足够的多模型、多任务验证,可能存在结果偏倚。此外,角余弦距离作为剪枝依据的效果与其他指标的对比不足,限制了方法的说服力。
Further Thoughts
尽管本文提出的自数据蒸馏方法在恢复剪枝模型质量方面表现出色,但其对大规模蒸馏数据集的依赖可能限制了其在资源受限环境下的应用。未来研究可以探索如何在小规模数据集上实现类似效果,例如通过结合知识蒸馏(Knowledge Distillation)或其他参数高效微调方法(如LoRA的改进版本)来降低计算成本。此外,论文中推测解码的应用让我联想到近期关于模型推理加速的研究,如Eagle(Li et al., 2024),其强调特征不确定性在推测采样中的作用。或许可以将自数据蒸馏与此类方法结合,进一步优化草稿模型与目标模型的对齐,从而在更大规模模型(如Llama3.1-70B)上实现更高效的推理。最后,剪枝指标的选择仍是一个开放问题,是否可以引入多指标融合策略(如结合角余弦距离和Block Influence Score),以提高剪枝决策的鲁棒性,值得进一步探索。