TTT 框架
约 1298 字大约 4 分钟
2025-01-20
美东时间周一(7月8日),一种全新的大语言模型(LLM)架构有望代替至今在AI领域如日中天的Transformer,性能也比Mamba更好。
在预印本网站arXiv上发布的一篇论文中,斯坦福大学、加州大学伯克利分校、加州大学圣地亚哥分校和Meta的学者提出了一种全新架构,希望能用机器学习模型取代RNN的隐藏状态。这个架构通过对输入token进行梯度下降来压缩上下文,被称为“测试时间训练层(Test-Time-Training layers,简称TTT层)”。“共同一作”加州大学伯克利分校的Karen Dalal表示,我相信这将从根本上改变语言模型。
但对于该论文,也有人提出质疑,认为只有30亿~70亿参数的可用演示模型才足以了解其实用性。
TTT模型原理 TTT,全称Test-Time Training(测试时训练)层,是一种全新的大语言模型(LLM)架构,其核心原理在于通过机器学习模型替代传统RNN中的隐藏状态,并利用输入token的实际梯度下降来压缩上下文信息。这一创新方法不仅简化了模型结构,更在性能上实现了显著提升。TTT层直接取代了Transformer中的自注意力机制,解锁了线性复杂度架构的潜力,使得在上下文中训练包含数百万甚至数十亿个token的大规模语言模型成为可能。
二、TTT模型核心思想 (1) 线性复杂度架构 TTT模型的关键思想在于使隐藏状态本身成为机器学习模型,更新规则成为自监督学习的一个步骤。
(2) 测试时训练机制 由于隐藏状态甚至在测试序列上也通过训练来更新,因此该层被称为测试时间训练(TTT)层。这种机制允许模型在测试时根据输入数据动态调整其内部状态,从而提高对长上下文信息的利用效率和准确性。
三、TTT模型架构 TTT模型架构主要包括以下几个部分:
TTT层:TTT层是模型的核心,它取代了传统的自注意力层。TTT层通过机器学习模型来压缩和表示上下文信息,同时利用梯度下降来更新隐藏状态。根据不同的实现方式,TTT层可以分为TTT-Linear(线性模型)和TTT-MLP(多层感知机)两种变体。
编码器:类似于Transformer架构,TTT模型也包含编码器部分。编码器负责将输入序列转换为上下文感知的表示,以便后续处理。
解码器(可选):对于需要生成输出序列的任务(如机器翻译),TTT模型还可以包含解码器部分。解码器通常也是由多个TTT层堆叠而成,用于生成目标序列。
位置编码:由于TTT模型中没有使用递归或卷积操作来捕捉位置信息,因此需要一种机制来将位置信息嵌入到输入序列中。位置编码是一种常用的方法,它使用正弦和余弦函数来生成位置编码,并将其与输入序列相结合。
训练与测试:在训练阶段,TTT模型通过标准的有监督学习方法进行训练。在测试阶段,TTT模型则利用测试时训练(TTT)机制来动态更新隐藏状态,从而实现对长上下文信息的有效利用。
四、TTT模型优势 线性复杂度:TTT模型具有线性复杂度,这意味着其计算成本随上下文长度的增加而线性增长,而不是像Transformer那样呈二次方增长。这使得TTT模型在处理长序列任务时更加高效。
高表达能力:TTT模型通过机器学习模型来压缩和表示上下文信息,因此具有更高的表达能力。这使得TTT模型能够更准确地捕捉长距离依赖关系,并在各种任务中表现出色。
动态适应性:TTT模型在测试时能够根据输入数据动态调整其内部状态,从而实现对不同上下文信息的有效适应。这种动态适应性使得TTT模型在处理复杂任务时更加灵活和准确。 综上所述,TTT模型是一种具有创新性和实用性的大语言模型架构,它通过测试时训练机制和线性复杂度架构的结合,为AI语言模型的发展开辟了新的道路。
论文链接:https://arxiv.org/abs/2407.04620 在论文上线后,作者公开了代码与 jax 以供人们训练和测试:https://github.com/test-time-training/ttt-lm-jax
还有 PyTorch 推理代码:https://github.com/test-time-training/ttt-lm-pytorch