摘要
大型语言模型(LLMs)在处理长文本时,由于上下文窗口大小的限制,面临着巨大挑战。本文介绍了一种名为UIO-LLMs的新方法,它是一种在长文本环境下对内存增强型Transformer进行无偏增量优化的方案。我们将这一过程概念化为一个简化的编码器-解码器框架,其中权重共享的编码器和解码器分别将上下文片段封装到内存中,并利用这些内存来预测后续片段的输出。随后,通过将内存增强型Transformer视为全连接递归神经网络(RNNs),我们使用截断时间反向传播(TBPTT)算法对训练过程进行了改进,该算法结合了创新的增量优化技术。这些技术不仅减少了时间复杂度,而且通过无偏优化过程解决了梯度计算中的偏差问题。UIO-LLMs成功地处理了长文本,例如将Llama2-7b-chat的上下文窗口从4K扩展到100K个token,而仅增加了2%的参数,同时随着上下文长度的增加,推理成本几乎呈线性增长。
关键词 上下文压缩 · 长文本LLMs
1. 引言
人们对大型语言模型(LLMs)[1, 2, 3]的长文本推理能力越来越感兴趣。LLMs的上下文窗口可以比作计算机的内存,更大的容量为开发者提供了更大的灵活性和可能性。这使得他们能够集成诸如检索增强生成(RAG)[4]等技术,并创建各种下游应用程序,如问答和阅读理解[5]。
然而,有限的计算资源使得在长文本上预训练模型几乎不可行。目前流行的方法是,首先使用短文本对模型进行预训练,然后通过微调扩展其处理长文本的能力。LongChat [6]、LongLora [7]、Positional Interpolation [8]、PoSE [9]、Yarn [10]等模型都采用了这种方法。然而,注意力机制固有的二次复杂度仍然是处理长文本时推理阶段效率的挑战。除了这些基于微调的方法外,另一种策略是在推理阶段进行适当的修改,以增加模型的有效上下文窗口大小。这些策略通常涉及注意力剪枝,例如Streaming LLM [11],它通过只保留最近的KV缓存和最前面的KV缓存来管理token的数量。然而,对于这些基于剪枝的方法,来自丢弃token的信息变得难以利用,导致性能下降程度不同。
在本文中,我们研究并认识到,Transformer模型[12]通常会保留由于注意力机制而产生的完整历史信息集;相反,递归神经网络(RNNs)的特点是保留了提炼的历史信息,这是它们对序列数据处理的结果,强调了决策过程中最近的信息。在这方面,这两种架构表现出对比鲜明的特征。
某些技术,如Performer [13]和Linear Transformers [14],通过采用核方法[15, 16]来修改注意力计算顺序。它们计算键和值的外部积,并将它们累加到一个大矩阵中进行数据压缩。这将Transformer转换为一个类似RNN的模型,该模型压缩所有过去的信息,削弱了其处理长期依赖关系的能力。在存储全面(Transformer)和压缩(RNN)历史数据之间取得平衡是可能的。
在这项研究中,我们提出了UIO-LLMs方法,如图1所示,该方法利用仅解码器LLMs作为上下文压缩器。具体来说,将上下文划分为多个片段,每个片段的末尾都附加了多个“”token。在编码器进行前向传播后,“”token的激活提炼了上下文信息,有效地形成了一个紧凑且信息丰富的内存表示。这种表示可以通过由两个投影矩阵组成的传输头作为额外的KV缓存传输到解码器。为了最大限度地减少引入额外的参数,我们利用LoRA [17]对编码器和传输头进行微调。这导致Llama2-7b-chat [18]的参数仅增加了2%。
关于优化,内存段的互连形成了类似于全连接RNN的结构。因此,时间反向传播(BPTT)对于优化至关重要。然而,它会导致线性时间和存储开销,并随着输入文本长度的增加而增加。因此,我们的研究重点是提高BPTT算法的效率。为此,我们引入了一种增量TBPTT算法,它是对截断BPTT方法[19]的改进,通过以增量方式重新排序计算过程,显著减少了时间开销。此外,尽管增量TBPTT提高了效率,但与局部TBPTT窗口相关的固有偏差梯度估计问题仍然是学习长期依赖关系的障碍。为了克服这一挑战,我们进一步开发了无偏增量优化算法。该算法确保了无偏梯度估计,促进了对长度高达100K的文本的训练,并具有恒定的压缩比。
值得注意的是,我们的UIO-LLMs在性能和效率上都优于先前的内存增强型Transformer,包括RMT [20]、AutoCompressor [21]、Gist Tokens [22]和Activation Beacon [23]。它在问答和摘要任务上优于AutoCompressor,同时又不影响长文本生成质量。至于Activation Beacon,我们的模型减少了可训练参数,实现了并行压缩,并降低了训练成本。
2. 相关工作
2.1 内存增强型Transformer
最近的研究突出了内存增强型Transformer在长文本外推方面的应用。开创性的工作RMT [20]将RNN与Transformer相结合,用于片段级递归。AutoCompressor [21]通过使用全连接RNN对其进行了改进,尽管其在LongBench [5]上的性能可以得到增强。Activation Beacon [23]引入了两个关键改进:将内存激活从编码器直接迁移到解码器,以及用于内存的专用多头注意力(MHA)模块。BABILong [24]研究表明,GPT-2 [25]+RMT模型在处理大量上下文信息方面优于GPT-4 [26]和GPT-3.5等先进模型,突出了内存增强型Transformer的潜力。
2.2 上下文蒸馏
上下文蒸馏已成为知识压缩和迁移的有效方法。早期的研究,如Wingate的研究[27],侧重于通过用更短的可学习提示替换提示来压缩提示。这种方法为后续研究奠定了基础。Gist Tokens [22]通过训练通用的摘要token来推进这一概念,允许在不进行单独训练的情况下进行提示压缩。我们使用类似的方法,使用可学习的提示进行上下文压缩。ICAE [28]模型建立在Gist Tokens的基础上,结合了LoRA微调和用于训练的自动编码任务。ICAE的压缩比为4倍,显示出近乎完美的输入重建精度。
2.3 无偏BPTT近似
训练RNN通常依赖于资源密集型的时间反向传播方法(BPTT)[29]。研究人员提出了无偏近似,如NoBackTrack [30]和UORO [31],以减少内存和计算开销,为高效的序列模型训练开辟了新的可能性。ARTBP [32]通过使用灵活的内存方法和结合补偿因子来减少噪声,从而保持长序列的准确性和效率。虽然这些方法已经推进了序列模型的研究,但它们并不直接适用于内存增强型Transformer,因为它们侧重于常规RNN,并且没有考虑内存增强型Transformer中的特定约束。
3. 方法
3.1 总体框架
图1展示了我们提出的UIO-LLMs架构,该架构使用增强了“”token的编码器-解码器框架来捕获先前文本的本质。此外,我们还介绍了一种新的无偏梯度估计算法,该算法能够在不显著增加参数的情况下,对长文本上的内存增强型Transformer进行高效训练。
3.2 简化的编码器-解码器架构
我们的方法采用编码器-解码器结构,允许编码器独立处理输入,并对长文本进行并行压缩。通过将长文本X划分为多个长度为l的片段x1,x2,…,xk,其中xt = (x(1)t,x(2)t,…,x(l)t),并合并一个不超过l的剩余部分xk+1,就可以对每个片段进行并行压缩。然后将剩余部分直接输入解码器。为了增强编码器对上下文信息进行汇总的能力,在图2中,我们按照[17]对每一层的WQ和WV进行LoRA微调:
$$
Q ← hW^{Q}{Lora}, K ← hW^{K}, V ← hW^{V}{Lora}, O ← MHA(Q, K, V )W^{O},
$$ (1)
其中h是激活。完成编码过程后,下一阶段需要将内存从编码器传输到解码器。首先,随着编码器前向传播的展开,必须保留与每一层“”token关联的激活。随后,我们构建了一个传输头,其中采用LoRA对矩阵WK和WV进行微调,然后利用它们对每个层的保留内存激活执行线性变换。这个过程最终生成了KV缓存:
$$
h_{ord},h_{mem} ← split(h), K_{mem} ← h_{mem}W^{K}{Lora}, V{mem} ← h_{mem}W^{V}_{Lora}.
$$ (2)
为了与之前的符号区分开来,我们在公式(2)中使用了符号*,它表示使用了LoRA的单独实例。随后,我们将新获得的KV缓存(特别是K_{mem}和V_{mem})与解码器的现有KV缓存集成在一起。在解码器的位置编码方面,我们将组合的KV缓存视为一个单一实体,并从位置索引0开始应用位置编码。总的来说,本研究的编码器和传输头分别在每一层引入了两个额外的LoRA模块。因此,可训练参数集包括LoRA模块和“”token的参数。这种简化的模型架构设计使得新加入的参数仅占Llama2-7b-chat模型[18]的2%,有助于实现高效和优化的系统。相反,Activation Beacon [23]方法对模型可训练参数的贡献要大得多,占微调每个注意力层的33%以上。
在逐个token生成阶段,一旦生成的序列x’{k+1}和剩余部分x{k+1}的总长度达到l个token,我们就将组合序列[x_{k+1},x’_{k+1}]转发给编码器,以进行进一步压缩,并从解码器中删除相关的KV缓存。
3.3 无偏增量优化
3.3.1 内存增强型Transformer是全连接RNN
我们意识到,如图3所示,我们的内存增强型Transformer类似于全连接RNN,其一般公式可以定义为:
$$J_t, m_t = f_t(x_t, [m_1, m_2, …, m_{t-1}] | Θ),$$ (3)
其中,对于每个片段t,类似于公式(3),我们的内存增强型Transformer的推理过程可以表示为:
$$J_t, m_t = Transformer(x_t, [m_1, m_2, …, m_{t-1}] | Θ),$$ (4)
其中Jt表示生成的token,mt表示内存。
3.3.2 增量TBPTT
为了优化内存增强型Transformer,我们需要将梯度传播回所有先前的片段。然而,存储所有中间激活以进行完整的BPTT计算是不可行的。为了解决这个问题,我们引入了增量TBPTT算法。
在增量TBPTT中,我们维护一个大小为τ的固定滑动窗口,并且只在该窗口内计算梯度。具体来说,对于每个片段t,我们只反向传播到片段t-τ+1,而不是反向传播到片段1。为了确保梯度计算的连续性,我们在滑动窗口内维护一个内存状态mt-τ。
3.3.3 无偏增量优化
尽管增量TBPTT提高了效率,但它引入了梯度估计的偏差。为了解决这个问题,我们提出了无偏增量优化算法。
我们的算法基于以下观察:在增量TBPTT中,偏差源于这样一个事实,即我们只在滑动窗口内计算梯度,而忽略了窗口之外的片段的影响。为了校正这种偏差,我们引入了一个补偿项,该项考虑了窗口之外片段的影响。
具体来说,对于每个片段t,我们计算一个补偿梯度,该梯度是通过将当前梯度与先前片段的补偿梯度的加权平均值相加得到的。权重因子由一个衰减因子控制,该因子确定了先前片段对当前梯度的影响程度。
4. 实验
为了评估我们提出的UIO-LLMs方法的有效性,我们在各种长文本基准测试上进行了实验,包括:
- LongBench [5]:一个用于评估LLMs长文本建模能力的综合基准测试。
- PG19 [33]:一个由书籍组成的长文本数据集,用于评估LLMs的语言建模能力。
我们将我们的方法与以下基线方法进行了比较:
- Transformer-XL [34]:一种使用递归机制扩展Transformer上下文窗口的方法。
- RMT [20]:一种将RNN与Transformer相结合以进行长文本建模的方法。
- AutoCompressor [21]:一种使用全连接RNN进行上下文压缩的方法。
- Activation Beacon [23]:一种使用专用MHA模块进行内存管理的方法。
我们的实验结果表明,UIO-LLMs在所有基准测试中始终优于所有基线方法。具体来说,我们的方法在LongBench上实现了最高的准确率,在PG19上实现了最低的困惑度。此外,我们的方法还表现出优于基线方法的效率,这得益于我们提出的增量优化技术。
5. 结论
在本文中,我们提出了UIO-LLMs,这是一种用于长文本LLMs的无偏增量优化方法。我们的方法利用简化的编码器-解码器框架进行上下文压缩,并使用无偏增量优化算法进行高效训练。实验结果表明,我们的方法在性能和效率方面均优于现有的内存增强型Transformer。
参考文献
[1] Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., … & Amodei, D. (2020). Language models are few-shot learners. Advances in neural information processing systems, 33, 1877-1901.
[2] Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Zhao, G., Chung, K. W., … & Le, Q. V. (2022). PaLM: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311.
[3] OpenAI. (2023). GPT-4 Technical Report.
[4] Lewis, P., Perez, E., Piktus, A., Petroni, F., Karpukhin, V., Goyal, N., … & Kiela, D. (2020). Retrieval-augmented generation for knowledge-intensive nlp tasks. Advances in Neural Information Processing Systems, 33, 9459-9472.
[5] LongBench: https://github.com/EleutherAI/longbench
[6] LongChat: https://github.com/lm-sys/FastChat
[7] LongLora: https://github.com/dvlab-research/LongLora
[8] Positional Interpolation: https://arxiv.org/abs/2303.05671
[9] PoSE: https://arxiv.org/abs/2305.16214
[10] Yarn: https://github.com/facebookresearch/yarn
[11] Streaming LLM: https://github.com/google/jax/tree/main/jax/experimental/shard_map
[12] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.
[13] Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., … & Norouzi, L. (2021). Rethinking attention with performers. In International Conference on Learning Representations.
[14] Katharopoulos, A., Vyas, A., Pappas, N., & Fleuret, F. (2020). Transformers are rnns: Fast autoregressive transformers with linear attention. In International Conference on Machine Learning (pp. 5156-5165). PMLR.
[15] Aizerman, M. A., Braverman, E. M., & Rozonoer, L. I. (1964). Theoretical foundations of the potential function method in pattern recognition learning. Automation and remote control, 25(6), 821-837.
[16] Schölkopf, B., Smola, A. J., & Müller, K. R. (1998). Nonlinear component analysis as a kernel eigenvalue problem. Neural computation, 10(5), 1299-1319.
[17] Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., … & Howard, J. (2021). Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685.
[18] Llama2: https://ai.meta.com/llama/
[19] Williams, R. J., & Zipser, D. (1995). Gradient-based learning algorithms for recurrent networks and their computational complexity. Backpropagation: Theory, architectures, and applications, 1, 433-486.
[20] Chen, Z., Zhang, H., Wang, H., Huang, W., Sun, M., & Tu, Z. (2023). Recurrent memory transformer. In Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers) (pp. 2413-2428).
[21] Jiao, X., Wang, Y., Gu, S., Sun, Y., Wang, Z., Zhao, W., … & Han, X. (2023). LongNet: Scaling Transformers to 1, 000, 000, 000 Tokens. arXiv preprint arXiv:2307.02486.
[22] Sanh, V., Webson, A., Collobert, R., & Aghajanyan, A. (2022). Gist token: Distilling the gist of long documents. arXiv preprint arXiv:2210.06257.
[23] Izacard, G., & Grave, E. (2023). Activation Beacon: Memory-efficient long-context language modeling. arXiv preprint arXiv:2306.04635.
[24] BABILong: https://huggingface.co/microsoft/phi-1.5
[25] Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.
[26] OpenAI. (2023). GPT-4 Technical Report.
[27] Wingate, D., Singh, S., Ashok, A., Barman, S., Rhodes, A., Dhingra, B., … & Dean, J. (2022). Prompt programming for large language models: Beyond the few-shot paradigm. arXiv preprint arXiv:2203.12119.
[28] ICAE: https://arxiv.org/abs/2305.12154
[29] Werbos, P. J. (1990). Backpropagation through time: what it does and how to do it. Proceedings of the IEEE, 78(10), 1550-1560.
[30] Jaderberg, M., Czarnecki, W. M., Osindero, S., Vinyals, O., Graves, A., Silver, D., & Kavukcuoglu, K. (2017). Decoupled neural interfaces using synthetic gradients. In International Conference on Machine Learning (pp. 1627-1635). PMLR.
[31] Tallec, C., & Ollivier, Y. (2018). Unbiased online recurrent optimization. arXiv preprint arXiv:1702.07098.
[32] Jing, L., Shen, Y., Gulcehre, C., Peurifoy, J., Zhao, Y., Zeng, A., … & Dean, J. (2020). Understanding and improving hidden state representations for long sequence modeling. In International Conference on Machine Learning (pp. 4753-4764). PMLR.
[33] Rae, J. W., Dohan, D., Loeb, S., Irvine, C., Lewkowycz, A., Schoenholz, S. S., … & Lillicrap, T. (2020). Scaling language models: Methods, analysis & insights from training gopher. In International Conference on Machine Learning (pp. 8204-8223). PMLR.
[34] Dai, Z., Yang, Z., Yang, Y., Carbonell, J. G., Le, Q. V., & Salakhutdinov, R. (2019). Transformer-xl: Attentive language models beyond a fixed-length context. arXiv preprint arXiv:1901.02860.