本文提出Mixture of Sparse Attention (MoSA)方法,通过专家选择路由实现基于内容的稀疏注意力,显著提高了Transformer模型在相同计算预算下的语言建模性能,并优化了资源使用。
Sparse Attention, Transformer, Efficiency, Large Language Model, Pre-training
Piotr Piękos, Róbert Csordás, Jürgen Schmidhuber
KAUST, Stanford University
Generated by grok-3-mini-latest
Background Problem
现代Transformer架构中自注意力机制的计算成本过高,导致序列长度增加时计算和内存复杂度呈二次方增长。尽管有许多研究尝试了亚二次方注意力方法(如状态空间模型和线性注意力),但它们在实际性能上仍逊色于全自注意力。本文假设动态的、基于内容的稀疏性可以带来更有效的注意力机制,从而解决高效处理长序列时性能下降的关键问题,同时减少KV缓存内存占用。
Method
- 核心思想: 受Mixture of Experts(MoE)启发,提出Mixture of Sparse Attention(MoSA),通过专家选择路由动态选择每个注意力头的令牌子集,实现基于内容的稀疏注意力,从而在不牺牲性能的前提下降低计算复杂度。
- 如何实现: 每个注意力头有一个路由器,使用权重矩阵计算令牌分数(其中是sigmoid函数),然后通过TopK操作选择个令牌的索引,提取子集,并仅在这些令牌上计算查询、键、值投影(,,)。注意力计算使用公式,其中掩码确保因果性()。输出通过路由分数缩放并恢复到原序列位置。
- 主要步骤: 路由分数计算、TopK令牌选择、稀疏注意力计算、输出缩放和位置恢复。该方法不修改原始模型,仅在推理时调整采样,允许与优化内核(如Flash Attention)结合。
Experiment
- 实验设置: 使用C4数据集,序列长度1024,训练多个模型规模(参数从28M到516M),包括IsoFLOP和困惑度匹配实验。比较MoSA与密集注意力、固定稀疏注意力、路由注意力的性能。实验设计合理全面,考虑了不同稀疏率、混合模型(4个密集头+MoSA头)、下游任务评估(如LAMBADA、WinoGrande等),并分析了长序列(up to 8192 tokens)下的表现。
- 数据集和原因: 选择C4数据集是因为它是标准语言建模基准,能评估泛化性能;下游任务用于验证零样本性能。实验组织旨在FLOP匹配下测试效率,并在困惑度匹配下测量实际资源使用(如墙钟时间、内存、KV缓存大小)。
- 结果: MoSA在IsoFLOP设置下将困惑度降低高达27%,优于密集基线和其它稀疏方法;在困惑度匹配设置下,MoSA模型更快(墙钟时间减少7.3%~12.9%)、内存更少(减少1.6%~10.0%)、KV缓存显著减小(51.1%~69.5%)。长序列实验中,MoSA保持优势。结果符合预期,证明了MoSA的有效性和高效性,尤其在计算预算有限的场景下。
Further Thoughts
MoSA的专家选择路由机制可能适用于其他模态,如视觉Transformer中动态选择关键特征点以提高效率;与状态空间模型结合可能进一步优化长序列建模;未来可以探索在推理阶段的自适应路由减少KV缓存,或与其他稀疏方法(如Hash Attention)结合以实现协同效应;此外,解决MoSA在短序列下游任务中的性能下降问题,可能通过指令微调或序列长度自适应策略来缓解,从而扩展其在多领域应用中的潜力。