本文提出RWKV-X,一种线性复杂度的混合语言模型,通过结合RWKV和稀疏注意力机制,提升长上下文建模能力,同时保持高效性和短上下文性能。
Large Language Model, Long Context, Efficiency, Pre-training, Multimodality
Haowen Hou, Zhiyi Huang, Kaifeng Tan, Rongchang Lu, Fei Richard Yu
Guangdong Laboratory of Artificial Intelligence and Digital Economy (SZ), College of Information Science and Engineering, Hohai University, College of Computer Science and Software Engineering, Shenzhen University, School of Ecological and Environmental Engineering, Qinghai University
Generated by grok-3-mini-latest
Background Problem
传统Transformer模型依赖自注意力机制,具有二次方复杂度,这在处理长序列输入时会带来显著限制。尽管线性替代模型如RWKV等在效率上有所提升,但它们在长上下文理解方面仍存在挑战,例如在长序列基准测试中性能急剧下降。本工作旨在解决这一问题,提出一个线性复杂度的混合模型,以高效捕获短距离和长距离依赖,同时避免现有混合模型的二次方复杂度瓶颈。
Method
- 核心思想: RWKV-X是一种混合架构,将RWKV-7用于短距离依赖建模,并引入稀疏注意力机制来捕获长距离上下文,实现训练的线性复杂度和推理的常量复杂度。
- 具体实现: 基于RWKV-7,RWKV-X周期性地插入稀疏注意力块。具体地,采用Top-k Chunk Sparse Attention机制:将输入序列分块,每个查询只关注top-k个最相关的块,计算公式为:,其中是块i的相关性分数;然后选择top-k块索引,并计算注意力:。此外,引入KV缓存管理(受SnapKV启发),通过计算重要性分数来压缩缓存,确保推理时内存恒定。训练过程采用块扩展方法和长上下文持续预训练,使用LongCE损失来动态加权关键token。
- 关键步骤: 不从零开始训练,而是从RWKV-7扩展;两阶段训练:第一阶段冻结RWKV块,只训练稀疏注意力块;第二阶段解冻所有参数,进行长上下文预训练。
Experiment
- 实验设置: RWKV-X通过两阶段训练:对齐预训练使用MiniPile数据集,上下文长度为1024;长上下文预训练使用ProLong-64K数据集,上下文长度为64K。基准测试包括长上下文任务(如S-NIAH基准的Pass-key Retrieval、Number in Haystack、UUID in Haystack)和短上下文任务(如LAMBADA、HellaSwag等)。效率分析比较了预填充和解码延迟。
- 结果与分析: 在长上下文基准上,RWKV-X在64K序列上实现近乎完美准确率(表2),显著优于基线如RWKV-7;短上下文性能保持竞争力(表3),如RWKV-X 3.6B在平均得分上接近RWKV-7 2.9B。效率方面,RWKV-X显示线性训练复杂度和常量解码复杂度(图3、4),在128K长度下比Flash-Attention v3快1.37倍。消融实验验证了LongCE损失的作用(表4),证明其在长序列任务中提升性能;注意力层比例优化显示25%稀疏注意力层最小化损失(图5)。结果符合预期,证明了方法的有效性和鲁棒性。
Further Thoughts
RWKV-X的混合架构启发我们探索更多领域,如多模态模型中结合稀疏机制处理长序列数据,可能提高视频或音频处理的效率;KV缓存管理策略可扩展到实时应用中,优化边缘设备上的内存使用;此外,与Mamba或Jamba等模型比较,RWKV-X的线性复杂度可能在数据稀缺场景下更具优势,未来可研究其在联邦学习中的应用,以保护隐私的同时提升模型泛化。