大模型Memory如何改进以适应需求?
摘要:上一章畅想里面我们重点提及了大模型的记忆模块,包括模型能否持续更新记忆模块,模型能否把持续对记忆模块进行压缩更新在有限的参数中存储更高密度的知识信息,从而解决有限context和无限知识之间的矛盾。这一章我们分别介绍两种方案,一种是基于模型
上一章畅想里面我们重点提及了大模型的记忆模块,包括模型能否持续更新记忆模块,模型能否把持续对记忆模块进行压缩更新在有限的参数中存储更高密度的知识信息,从而解决有限context和无限知识之间的矛盾。这一章我们分别介绍两种方案,一种是基于模型结构的Google提出的Titan模型结构,另一种是基于外挂知识库表征对齐的Kbalm
Titan
Titans: Learning to Memorize at Test Time
Titan的出发点有两个
传统Transformer的局限性:虽然Transformer在序列建模中表现优异,但其注意力机制是平方级的复杂度,因此限制了输入的上文的长度。
改良线性Transformer的不足:虽然线性Transformer通过核函数近似注意力,降低了复杂度,但对历史记忆的处理是把记忆压缩成固定的向量矩阵,本质还是上一章我们闲聊提到的线性表达,没有更深层次的压缩和抽象。
所以Titan主要探索了基于多层MLP的记忆存储模块和当前模型的融合。核心目标就似乎构建动态、可学习的记忆系统,结合短期(注意力)和长期记忆(神经记忆),模拟人脑的多层次记忆机制。
这里我们重点看下记忆模块的实现,分成3个部分,记忆存储结构、记忆更新方式和记忆检索。
核心
首先是记忆存储结构,论文没有在这个部分做深入的探索,只考虑了简单的多层MLP,更多对于存储网络结构的探索留给后人完成。这里多层MLP本身结构足够简单,同时对比矩阵又提供了多层压缩和抽象的能力。类似之前人们发现Transformer的FFN层中知识是按键值对方式进行存储,这里Titan也使用键值对来存储记忆,Key,Val分别使用多层MLP映射得到,而映射的MLP通过以下记忆更新进行持续学习。
其次是记忆更新方式,这里论文采用在线学习来对记忆模块的参数进行持续更新,损失函数就是存储记忆的Key,Value之间的空间距离,而梯度更新时,Titan同时考虑到了动态遗忘机制和动量机制,其中
\(\theta_t\)后面的部分是序列当前输入的梯度大小,Titan认为这里的梯度大小反映了信息的重要程度,梯度越大对记忆的更新幅度也就越大
\(\alpha_t\)是类似RNN的遗忘门,支持在输入超长上文时动态选择对历史信息进行遗忘
\(\eta_t\)类似momentun,保证记忆更新的方向对持续性,同时也可以理解为历史记忆的时间衰减参数
再就是记忆获取方式,既然存储是KV所以信息获取也是一样,也就是直接冻结参数,把输入喂进MLP得到的就是该信息相关的记忆存储了\(M(q_t)\)
如何使用记忆模块
Titans给出了三种不同的记忆模块的使用方案
MAC(Memory as Context)
长期记忆的输出作为上下文与当前输入拼接 ,显式引入历史记忆作为全局上下文。 也就是把超长上文分段,使用最后一个chunk作为query,其余为历史记忆,使用query去获取历史记忆,并和当前chunk拼接作为上文。
公式:\(\hat{x}_t = \text{Concat}(h_t, p, x_t)\),其中\(h_t = \mathcal{M}^*(x_t)\)
适合需要全局检索例如如多文档问答,以及依赖全局高度一致性的任务例如长文本写作任务像小说写作。
MAG(Memory as Gate)
输入通过滑动窗口处理短期信息,同时从记忆模块提取长期偏好,通过门控机制动态融合注意力输出(短期依赖)与记忆输出(长期偏好)
公式:\(o_t = \sigma(W_g) \odot \text{Attn}(x_t) + (1-\sigma(W_g)) \odot \mathcal{M}(x_t)\)。
MAL(Memory as Layer)
将记忆模块作为独立层,与注意力模块串联(如Memory → Attention或Attention → Memory),有点类似Mamba等hybrid模型。
公式:\(y_t = \text{Attn}(\mathcal{M}(x_t))\) 或 \(y_t = \mathcal{M}(\text{Attn}(x_t))\)。
Kbalm
KBLAM: KNOWLEDGE BASE AUGMENTED LANGUAGE MODEL
kbalm则是从知识压缩入手,上一章我们提到当前RAG的一个问题在于,检索回来的内容因为是平铺拼接的因此上文很长,但信息量却不高,虽然当前的大模型已经能支持越来越长的上文输入,但是越长的上文确实带来更差的推理效果,更慢的首Token延时,更高的内存占用等等问题。
