本文提出CoLA及其内存优化变体CoLA-M,通过用低秩自动编码器替换LLMs的全尺寸MLP和投影层,实现2倍模型大小和计算成本的减少,同时保持全秩性能,并在训练和推理中显著提升吞吐量。
Large Language Model, Pre-training, Efficiency, Transformer, Parameter-Efficient Fine-Tuning
Ziyue Liu, Ruijie Zhang, Zhengyang Wang, Zi Yang, Paul Hovland, Bogdan Nicolae, Franck Cappello, Zheng Zhang
University of California at Santa Barbara, University at Albany, SUNY, Argonne National Laboratory
Generated by grok-3
Background Problem
大型语言模型(LLMs)在预训练阶段由于全尺寸MLP和注意力投影层的巨大参数量和计算需求,对计算资源提出了极高的要求。随着模型规模的持续增长(如GPT-3的175B参数,LLaMA-3的405B参数),训练成本变得不可持续。论文观察到预训练LLMs的激活值呈现低秩特性,提出通过减少激活冗余来提高计算和内存效率,解决预训练过程中的资源瓶颈问题,同时力求维持模型性能。
Method
CoLA(Compute-Efficient Pre-Training of LLMs via Low-Rank Activation)提出了一种全新的架构设计,通过以下方式实现高效预训练:
- 核心思想:基于预训练LLMs激活值的低秩特性,用瓶颈结构的自动编码器替换传统全尺寸MLP和注意力投影层,强制执行低秩激活以减少计算和参数冗余。
- 具体实现:将原始线性层 替换为自动编码器形式 ,其中 和 为低秩矩阵,秩 ,并在中间引入非线性激活 。这种结构应用于Transformer架构中的所有MLP和投影层。
- 内存优化变体CoLA-M:通过梯度检查点技术,仅保存低秩激活值,在反向传播时重新计算部分操作(如上投影和自注意力),进一步减少内存占用。
- 关键优势与质疑:CoLA在理论上减少了计算量(FLOPs约为全秩训练的一半)和参数量,但其低秩假设是否适用于训练初期或更大规模模型尚未明确。此外,CoLA-M的重新计算策略可能在高负载或不同硬件环境下导致吞吐量下降,论文未充分探讨这些场景。
Experiment
实验在LLaMA模型(参数规模从60M到7B)和BERT-Large上进行,使用C4数据集和Wikipedia数据进行预训练,遵循计算最优(compute-optimal)设置,并与全秩训练、ReLoRA、GaLore和SLTrain等基线方法对比:
- 数据集与设置:LLaMA模型在C4数据集上训练,遵循计算最优的token-to-parameter比例(约20:1),BERT-Large在Wikipedia上训练85B tokens。实验设置参考了现有工作(如Zhao et al., 2024),以确保可比性。
- 结果:CoLA在所有规模下实现了约2倍的模型大小和计算成本(FLOPs)减少,同时验证困惑度(perplexity)与全秩训练相当(如LLaMA-1B上CoLA为15.52 vs 全秩的15.56)。CoLA-M进一步将内存占用减少至全秩的三分之二,同时训练吞吐量提升1.86倍,推理吞吐量提升1.64倍。过训练(over-training)实验(如LLaMA-350M训练51B tokens)也显示CoLA优于全秩基线。
- 评估与质疑:实验设置较为全面,涵盖了不同模型规模和训练场景,但主要集中在学术预算下的计算最优设置,未涉及工业级超大规模模型或token量(如LLaMA-3的9T tokens)。此外,低秩激活的有效性在训练初期是否成立未被验证,可能影响从头训练的结果。CoLA-M的内存优化在不同硬件或更大批量大小下的表现也缺乏测试,可能存在隐藏的性能瓶颈。
Further Thoughts
CoLA的低秩激活思想提供了一个有趣的视角,特别是在资源受限环境下的预训练中可能有广泛应用。然而,我认为其核心假设——激活值的低秩特性——需要进一步验证,尤其是在训练初期和超大规模模型上的适用性。未来的研究可以探索CoLA与现有高效训练方法(如LoRA或GaLore)的结合,例如在CoLA的低秩架构上应用梯度压缩技术,以进一步减少优化器内存开销。此外,CoLA是否能适应混合专家(MoE)架构也是一个值得探索的方向,论文中也提到了这一点。如果CoLA能在MoE模型上实现类似效率提升,可能对工业级大规模模型训练产生深远影响。另一个思考点是,CoLA的瓶颈结构是否会限制模型的表达能力,尤其是在处理复杂任务或长上下文时,这需要在下游任务上的更多测试来验证其泛化能力。