本文提出Learning With Forgetting (LWF)框架,通过自生成知识、Fisher信息矩阵加权的遗忘置信度计算和周期性遗忘策略,在生成式语言模型的微调中实现优雅遗忘,实验表明其在大多数领域特定问答任务上显著提升性能。
Generative AI, Fine-tuning, Transfer Learning, Large Language Model, Continual Learning
Chunyang Jiang, Chi-min Chan, Yiyang Cai, Yulong Liu, Wei Xue, Yike Guo
HKUST
Generated by grok-3
Background Problem
在自然语言处理(NLP)领域,预训练-微调范式已成为主流,通过在大规模语料库上预训练模型,随后在特定任务数据集上微调以适应下游应用。然而,这一范式面临一个长期存在的问题——负面迁移(negative transfer),即预训练知识中某些部分可能对目标微调任务产生不利影响。论文指出,传统的微调方法对预训练知识一视同仁,无法有效区分有益和有害知识,导致性能受限。为解决这一问题,作者引入了’graceful forgetting’(优雅遗忘)的概念,旨在通过有选择地遗忘无关或有害知识来增强模型对新任务的学习可塑性,特别是在生成式语言模型中,这一领域的研究尚属空白。
Method
论文提出了一种名为Learning With Forgetting (LWF)的新框架,用于在生成式语言模型中实现优雅遗忘。其核心思想是通过以下三个步骤有选择地遗忘无关知识:
- 自生成知识表征:由于预训练数据通常不可访问,LWF利用生成式模型的特性,通过输入遗忘任务(forgetting task)的提示(prompt)生成文本,形成遗忘数据集(D_F),以此表征需要遗忘的知识。
- 计算遗忘置信度:为避免遗忘有益知识,LWF为每个数据点计算遗忘置信度(forgetting confidence),通过Fisher信息矩阵(FIM)对参数更新的重要性进行加权,评估自生成数据与学习任务(learning task)之间的冲突程度。具体公式为 ,其中FIM捕捉参数重要性,冲突程度通过参数更新与目标参数的差异来近似。
- 周期性遗忘:基于遗忘置信度选择高置信度数据点进行遗忘,采用梯度上升(gradient ascent)作为遗忘算法,并通过周期性遗忘策略(periodically unlearning)在微调过程中以固定间隔执行遗忘操作,以缓解遗忘的不稳定性。
批判性思考:虽然LWF在方法设计上具有创新性,但遗忘置信度的计算依赖于参数更新的近似,可能无法准确反映知识冲突的本质,尤其是在语义复杂的自然语言任务中。此外,周期性遗忘的间隔和遗忘率参数的选择缺乏理论依据,可能导致遗忘效果的不稳定。
Experiment
论文在多个领域特定问答任务上评估了LWF框架的有效性,数据集包括gsm8k(数学推理)、qasc(基础科学)、sst5(情感分类)、dental(牙科知识)和psychol(心理学知识)。实验基于Llama3.2-1B模型,并扩展到Llama3-8B以验证可扩展性。实验设置包括:自生成数据采用3-shot提示和贪婪解码策略;遗忘置信度计算中单步更新系数设为1e-2;周期性遗忘间隔设为7,遗忘率(β)为0.1或0.05。结果显示:
- 性能提升:在大多数学习-遗忘任务组合中,LWF相较于vanilla微调提升了性能,尤其是在’mixed’遗忘设置(遗忘除学习任务外的所有数据集)下,性能提升稳定(如gsm8k提升6.95%,psychol提升7.93%)。
- 负面结果:部分组合(如学习dental,遗忘qasc或psychol)出现性能下降,作者归因于自生成样本的低遗忘置信度,但未深入分析原因。
- 可扩展性:在更大模型Llama3-8B上,LWF仍能提升性能,但提升幅度因基线性能较高而有所下降。
- 遗忘置信度分析:高置信度遗忘比低置信度遗忘带来更稳定和更高的性能增益,验证了遗忘置信度的有效性。
- 周期性遗忘验证:与提前遗忘(ahead unlearning)和随机遗忘(randomly unlearning)相比,周期性遗忘显著更优,避免了过早遗忘关键知识导致的性能剧降。
批判性思考:实验设计覆盖了多个领域,设置较为全面,但对负面结果的解释不足,未能揭示任务交互的深层机制。此外,遗忘任务性能下降的分析较为表面,仅通过准确率、语义相似度和词汇多样性指标,未探讨模型内部表征变化或遗忘的具体知识内容。参数(如遗忘间隔和遗忘率)的选择缺乏系统性调优,可能影响结果的普适性。
Further Thoughts
LWF框架为生成式语言模型中的优雅遗忘研究开辟了新方向,但其遗忘置信度的计算方式依赖于参数更新的近似,可能无法完全捕捉语义层面的知识冲突。未来研究可以探索结合模型内部表征(如注意力机制或隐藏层激活)来更精确地识别需要遗忘的知识。此外,LWF的周期性遗忘策略虽然有效,但遗忘间隔和遗忘率的优化仍需进一步探索,或许可以引入自适应机制,根据任务难度或模型性能动态调整参数。另一个值得思考的方向是,优雅遗忘是否可以与其他微调技术(如参数高效微调方法PEFT)结合,以在减少计算开销的同时进一步提升性能?例如,是否可以在LoRA(Low-Rank Adaptation)的基础上应用遗忘策略,仅调整部分参数以实现更高效的遗忘?此外,LWF对遗忘任务性能的影响分析提示我们,某些复杂任务(如dental和psychol)对遗忘表现出更强的抗性,这可能与知识的结构化程度有关,未来可以深入研究不同类型知识在模型中的表征方式及其对遗忘的影响,这或许能为构建更具可解释性的AI模型提供启示。