本研究提出 SpargeAttn,一种通用稀疏注意力机制,通过两阶段在线过滤器和量化技术加速各种模型的推理,同时保持端到端性能无损。
Sparse Attention, Efficiency, Transformer, Prediction, Multimodal Systems, Generative AI
Jintao Zhang, Chendong Xiang, Haofeng Huang, Jia Wei, Haocheng Xi, Jun Zhu, Jianfei Chen
清华大学, 加州大学伯克利分校
Generated by grok-3-mini-latest
Background Problem
大型模型中的注意力机制由于二次方时间复杂度,在处理长序列时效率低下。尽管注意力图常表现出稀疏性,许多值接近零,可以跳过相应计算,但现有的稀疏注意力方法多针对特定任务优化,通用性不足,且难以同时实现准确性和高效性。例如,模式-based方法依赖经验观察的固定模式,动态稀疏方法虽更通用但可能丢失重要信息,训练-based方法需要重新训练模型,成本高。本工作的出发点是设计一种无需训练的稀疏注意力机制,能够加速各种模型的推理过程而不损失性能指标,解决注意力计算在不同任务中的通用加速问题。
Method
- 核心思想: SpargeAttn 利用注意力图的固有稀疏性,通过两阶段在线过滤器实现高效计算,而不依赖特定任务模式。
- 工作原理: 第一阶段,快速预测稀疏块:基于块内令牌相似度选择性地压缩查询(Q)和键(K)矩阵,将高相似度块压缩为单个令牌,计算压缩注意力图 ,然后使用累积密度函数(CDF)选择重要位置设置全局掩码 ,跳过不必要的矩阵乘法。第二阶段,稀疏在线 Softmax:利用在线 Softmax 计算过程,基于行最大值与全局最大值的差值跳过部分 计算,如果 则跳过。方法还整合量化技术和 Hilbert 曲线置换以提升相似度和稀疏性。
- 主要步骤: 1. 计算块内余弦相似度,压缩高相似度块;2. 计算压缩注意力图并设置掩码;3. 在 FlashAttention 框架内应用掩码跳过计算;4. 第二阶段进一步优化 Softmax 计算;5. 与 SageAttention 结合进行量化加速。
Experiment
- 实验设置: 实验覆盖多种任务和模型,包括文本生成(Llama3.1,使用 WikiText、Longbench、InfiniteBench 和 Needle-in-A-Haystack 数据集,序列长度达 128K)、视频生成(CogvideoX 和 Mochi,使用 open-sora 提示集,序列长度 17K 和 22K)、图像生成(Flux 和 Stable-Diffusion3.5,使用 COCO 数据集,序列长度 4.5K)。评估指标包括速度(TOPS)、稀疏性、端到端性能(如困惑度、CLIPSIM、VQA-a/t、FScore、FID、Clipscore、ImageReward)。基线包括 MInference 和 FlexPrefill,以不同稀疏度比较。实验设计全面,考虑不同序列长度和任务类型,确保公平性(如使用相同输入测量速度)。
- 结果分析: SpargeAttn 在保持性能指标不变的情况下显著加速推理,例如 Mochi 上实现 1.83 倍加速,Llama3.1 上 TOPS 提升至 708.1,而基线方法在相同稀疏度下性能下降(如 MInference 导致指标恶化)。预测开销低(序列长度增加时开销比例下降),消融实验证实关键组件(如自相似判断和 Hilbert 曲线置换)有效。结果符合预期,证明方法鲁棒性和泛化能力,实验设置合理,覆盖了实际应用场景。
Further Thoughts
SpargeAttn 的动态稀疏预测和在线过滤器设计可能启发其他领域,如卷积神经网络的稀疏化或推荐系统的计算优化,强调了模块化方法的优势;未来可探索与其他高效范式结合,例如在边缘设备实现实时 AI,或与知识蒸馏技术整合以提升模型泛化,同时注意潜在的鲁棒性挑战,如在非结构化数据上的表现。