在长上下文建模领域,如何突破全注意力机制的二次复杂度瓶颈一直是核心挑战。EverMind 团队开源的 Memory Sparse Attention(MSA)实现给出了一条端到端可训练的稀疏注意力路径,支持最高达 1 亿 token 的超长上下文推理。本文将从代码架构层面剖析 MSA 的可学习路由层设计、稀疏策略的选择逻辑以及工程实现的关键细节,为希望在项目中复现或借鉴该技术的开发者提供可落地的参数参考。

可学习记忆路由层的核心架构

MSA 的核心创新在于引入了一个可学习的记忆路由层,该层负责在海量文档中动态筛选与当前查询最相关的记忆片段。从架构层面来看,记忆路由的实现包含三个关键组件:文档潜在状态压缩、路由器投影器以及 Top‑k 选择机制。

文档潜在状态的压缩采用 chunk‑mean pooling 策略。具体而言,模型首先将每个文档划分为固定大小的 chunk,随后对每个 chunk 内的 token 进行平均池化,生成压缩后的键向量(K̄)、值向量(V̄)以及路由键向量(K̄ᵣ)。这种设计使得即便原始文档包含数万 token,经过压缩后也仅需保留少量的向量表示,极大降低了后续路由计算的存储开销。

路由器投影器负责将查询向量映射到与路由键相同的语义空间,以便通过相似度计算进行匹配。其实现采用 cosine similarity 作为相似度度量,计算过程中对注意力头维度进行 mean‑pooled 处理后,再在 token 维度取 max 操作,最终得到每个文档的关联分数。这一设计确保了模型能够捕捉到文档中最具区分性的特征,同时保持对局部细节的敏感性。

Top‑k 选择机制则根据路由分数从全局记忆库中挑选出得分最高的 k 个文档,将这些文档的压缩 K̄/V̄ 与查询自身的本地 K/V 进行拼接,作为后续稀疏注意力的输入。值得注意的是,路由层仅在模型的上层应用,下层仍保持独立文档处理,以确保层级间的层次化对齐。

稀疏策略的选择逻辑与实现考量

MSA 的稀疏策略并非简单的固定稀疏模式,而是基于内容可学习的动态选择。这一设计背后蕴含着对效率与效果的双重考量。

从效率角度审视,传统的全注意力机制在处理 1 亿 token 上下文时,计算量将膨胀至不可承受的规模。MSA 通过将注意力计算限定在 Top‑k 个相关文档上,将复杂度从 O (L²) 降低至接近 O (L),其中 L 为序列长度。这一改进使得在 2 张 A800 GPU 上运行 1 亿 token 推理成为现实。

从效果角度考量,可学习的路由机制相比固定稀疏模式具有更强的适应性。固定稀疏模式往往需要对稀疏比例、稀疏位置进行繁琐的手工调参,且在不同任务上表现波动较大。而 MSA 的路由层通过端到端训练,能够自动学习到与下游任务相关的记忆选择策略。实验结果显示,在 16K 到 1 亿 token 的跨幅内,模型准确率降幅始终控制在 9% 以内,展现了优异的扩展稳定性。

路由策略的另一关键设计点在于其与生成过程的深度耦合。MSA 采用自适应机制,路由选择并非一次性完成,而是与自回归生成过程交替进行。在每个生成步骤中,模型可以根据已生成的部分动态调整检索策略,这一机制被称为 Memory Interleave。该设计显著提升了模型在多跳推理任务上的表现,在 2WikiMultiHopQA 和 HotpotQA 等数据集上取得了显著超越基线 RAG 方法的成绩。

工程实现的关键细节

将理论设计转化为生产级代码需要处理诸多工程挑战,MSA 的实现在这方面提供了值得借鉴的范例。

首先是分层存储架构的设计。为了在有限 GPU 显存中支撑超大规模记忆库,MSA 采用了 GPU‑CPU 分层存储策略:路由键(K̄ᵣ)常驻 GPU 显存以加速评分计算,而实际的键值内容(K̄/V̄)则存放在主机 DRAM 中,仅在路由确定目标文档后通过异步传输机制加载到 GPU。这一设计在保持路由低延迟的同时,实现了存储容量的线性扩展。

其次是分布式评分机制的实现。Memory Parallel 模块将路由键分片到多张 GPU 上,每张 GPU 负责对本地存储的键进行本地评分,随后通过全局归约操作汇总得到全局排序。这一并行化策略有效利用了多卡算力,查询向量在广播后,各节点独立计算并返回本地最优结果,最终在根节点聚合为全局 Top‑k。

第三是位置编码的特殊处理。MSA 引入了 document‑wise RoPE 和 global RoPE 两套位置编码机制。前者为每个文档独立重置位置编号,避免了文档间位置信息的泄漏;后者则通过将查询的起始索引偏移 k 个检索块的位置来维护全局因果序。这套组合策略使得模型能够在短序列训练后自然外推到极长上下文场景,实验验证表明模型可从 64K 训练长度外推至 1 亿 token 而不损失显著性能。

在训练层面,MSA 采用了辅助路由损失的联合优化策略。主任务损失与路由辅助损失以加权方式共同监督,确保路由模块学习到有意义的文档选择能力。训练数据规模达到 158.95B token,并采用从 8K 到 64K 的课程学习策略,逐步提升上下文长度。

实践建议与参数参考

对于希望在自己的项目中引入 MSA 设计的开发者,以下参数可作为初始配置的参考。路由 Top‑k 的取值通常在 4 到 16 之间,具体数值可根据任务复杂度调整,多跳推理任务建议取值偏大。Chunk 大小的选择需权衡压缩率与信息保留度,默认配置通常采用 128 或 256 token 为单位。路由器投影的隐藏层维度建议设置为原始注意力头维度的 2 到 4 倍,以提供足够的表达能力。

Memory Interleave 的迭代次数可根据任务难度设置,通常 2 到 3 轮迭代足以覆盖多数多跳场景。每轮迭代后需重新执行路由选择,以利用已更新的上下文信息。

综合而言,MSA 的开源实现为长上下文建模提供了一个兼具理论创新与工程可行性的完整方案。其可学习路由层的架构设计、稀疏策略的动态选择逻辑以及面向大规模部署的存储与并行策略,均为后续相关研究提供了重要的技术参考。

资料来源:EverMind‑AI/MSA GitHub 仓库(https://github.com/EverMind-AI/MSA)