本文提出PAMM方法,通过随机选择代表性token近似输入张量,大幅减少注意力机制中Q、K、V投影的内存占用(高达512倍),同时在预训练和微调中基本维持模型性能。
Large Language Model, Efficiency, Pre-training, Fine-tuning, Transformer
Malik Khalf, Yara Shamshoum, Nitzan Hodos, Yuval Sieradzki, Assaf Schuster
Technion, Israel Institute of Technology
Generated by grok-3
Background Problem
大型语言模型(LLMs)在训练过程中面临显著的内存挑战,尤其是在注意力机制中,线性投影层生成Q、K、V张量所需的输入激活值在反向传播时需保存,占用了高达20%的GPU内存峰值。随着模型规模、批次大小和序列长度的增加,这一问题愈发严重。现有研究多集中于优化注意力计算本身(如scaled dot product),而忽略了投影层的内存开销。本文提出了一种新方法,旨在通过压缩输入激活张量来大幅减少Q、K、V投影的内存占用,同时尽量维持模型性能。
Method
PAMM(Point-Approximate Matrix Multiplication)是一种张量压缩技术,核心思想是利用输入张量在序列维度上的冗余性,通过一小部分代表性点(生成点)近似表示整个张量。具体步骤如下:
- 压缩阶段:将输入张量X的行视为高维空间中的点,随机选择一小部分行作为生成点C(数量为k,远小于总行数b),然后对每个行A^i,找到在生成点C^j张成的直线上最接近A^i的点作为其代表点A˜^i,并通过邻域条件(误差容忍度ε)决定是否保留该近似。
- 近似矩阵乘法:在反向传播时,使用压缩后的张量A˜替代原始张量A,计算近似梯度,通过归一化因子β校正偏差,确保期望上近似结果接近真实值。
- 关键参数:压缩比r(决定生成点数量k=r·b)和容忍度ε(控制近似精度)。
批判性思考:虽然PAMM在理论上利用了序列维度的冗余性,但随机选择生成点可能无法充分捕捉输入张量的结构特性,尤其在token分布不均匀时可能导致较大误差。此外,实验中设置ε=∞(即不限制误差)可能导致部分token的近似严重偏离真实值,潜在影响训练稳定性。
Experiment
实验主要在预训练和微调两个场景下评估PAMM的效果:
- 预训练:在LLaMA系列模型(60M到1B参数)上,使用C4数据集进行预训练,压缩比r低至1/512。结果显示,Q、K、V投影的内存占用减少高达512倍,而模型困惑度(perplexity)仅上升3%,某些中间压缩比下甚至优于全内存训练,表明冗余token可能对训练有负面影响。
- 微调:在RoBERTa-base模型上,使用GLUE基准进行微调,压缩比r为1/128和1/256时,内存占用减少超过97%,性能与全微调几乎一致。
- 消融研究:与CompAct和Uniform-CRS等方法对比,PAMM在内存-质量权衡上表现更优;ε=∞时性能最佳,表明注意力输入已具有一定聚类特性。
批判性思考:实验设置较为有限,仅在中小规模模型上测试,未涉及万亿参数级别模型,泛化性存疑。此外,数据集和任务较为单一,未充分验证PAMM在多样化场景下的稳定性。困惑度小幅改善的现象缺乏深入分析,可能只是噪声或特定数据集特性,而非PAMM的普遍优势。
Further Thoughts
PAMM提供了一个有趣的视角,即通过序列维度的冗余性压缩注意力输入激活值,这与近年来对KV缓存压缩的研究(如DeepSeek-V2的MLA)有一定关联,但其聚焦于训练阶段而非推理阶段,应用场景不同。未来可以探索PAMM是否能与KV缓存压缩技术结合,进一步优化推理内存效率。此外,随机选择生成点的策略可能并非最优,是否可以通过聚类算法(如k-means)或基于注意力分数选择生成点,值得进一步研究。同时,PAMM在更大规模模型上的表现仍需验证,尤其是在长序列任务中,token分布可能更加复杂,随机近似可能导致不可忽视的误差。最后,作者提到的CUDA优化问题提示我们,算法的理论优势需要在工程实现上得到支持,否则实际应用价值有限。