使用Firefly在单卡V100上对Qwen1.5进行SFT和DPO训练
引言 大语言模型(LLM)的训练一直是AI领域的热点话题。随着开源模型的不断涌现,如何对这些基础模型进行进一步优化和定制化训练成为了很多研究者和开发者关注的焦点。本文将介绍如何使用Firefly框架在单张V100 GPU上对Qwen1.5-7B模型进行SFT(Supervised Fine-tuning)和DPO(Direct Preference Optimization)训练,并探讨训练过程中的关键技术点和实验结果。 Firefly简介 Firefly是一个开源的大模型一站式训练框架,支持对各种主流大模型进行预训练、指令微调和DPO等训练。它支持全量参数、LoRA、QLoRA等多种训练方式,可以适应不同的硬件条件和训练需求。Firefly框架兼容包括Gemma、Qwen1.5、MiniCPM、Mixtral-8x7B、Mistral、Llama等在内的绝大多数主流大模型。 Qwen1.5模型介绍 Qwen1.5是阿里巴巴在2024年春节前开源的大语言模型,支持32K的上下文长度。该模型可以看作是Qwen2的beta版本,未来还会有Qwen2的正式版本发布。从各项评测结果来看,Qwen1.5各个尺寸的模型都显著优于同量级的Llama2。在2024年2月的SuperCLUE大模型榜单中,Qwen1.5也展现出了非常优秀的表现,在开源模型中处于领先地位。 大模型训练的三个阶段 大模型的训练通常可以分为以下三个主要阶段: DPO简介 在RLHF阶段,传统的方法如PPO(Proximal Policy Optimization)存在流程繁琐、显存需求大等问题。相比之下,DPO(Direct Preference Optimization)方法绕过了奖励模型的构建,可以直接使用人类偏好数据对模型进行训练,且在训练时仅需加载策略网络和参考网络,极大地节省了显存占用。 DPO的训练数据包含三个字段:prompt、chosen和rejected。其损失函数计算过程具有对称性,公式如下: 其中,r_θ表示策略网络,r_θ_ref表示参考网络,β是温度系数,σ是sigmoid函数。 在代码实现中,DPO损失函数的计算过程大致如下: 实验设置 本实验在Qwen1.5-7B的基础上,使用Firefly框架进行了SFT和DPO两阶段的训练。整个训练流程仅使用一张V100 GPU,采用QLoRA技术,在所有Linear层都添加adapter以提升训练效果。两个阶段均使用英文数据进行训练。 对话模板 Firefly与Qwen1.5官方的对话模板保持一致: SFT阶段设置 使用Firefly对Qwen1.5进行SFT的启动命令: SFT阶段的主要参数设置如下: DPO阶段设置 使用Firefly对Qwen1.5进行DPO的启动命令: DPO阶段采用ultrafeedback数据集,主要参数设置如下: 实验结果与分析 模型评测 在Open LLM Leaderboard上对模型进行评测,Firefly训练的模型表现显著优于官方的Qwen1.5-7B-Chat、Gemma-7B-it等模型。具体来说: 经过DPO之后,模型的平均分还有接近1分左右的提升。这说明Firefly框架在单卡V100上通过SFT和DPO训练,成功地提升了Qwen1.5模型的性能。 DPO训练指标分析 在DPO训练过程中,我们关注了几个重要的训练指标: 这些指标的变化趋势都表明,DPO训练确实帮助模型学习到了人类的偏好,提升了模型输出的质量。 结论与展望 通过使用Firefly框架在单卡V100上对Qwen1.5-7B模型进行SFT和DPO训练,我们成功地提升了模型的性能,在Open LLM Leaderboard上取得了优于原始Qwen1.5-7B-Chat和Gemma-7B-it等模型的成绩。这个实验结果表明: 未来的研究方向可以包括: 总的来说,Firefly框架为大模型的定制化训练提供了一个强大而灵活的工具,为AI研究者和开发者开辟了新的可能性。我们期待看到更多基于Firefly的创新应用和研究成果。 参考文献