在现代人工智能领域,推理大型语言模型 (LLM) 已经成为一个重要的应用场景。 GitHub 上的项目 mukel/llama2.java 提供了一种使用纯 Java 代码进行 Llama 2 推理的简洁实现。本文将详细介绍该项目的背景、构建方法及性能表现。
背景介绍
Llama 2 是由 Andrej Karpathy 开发的一个非常简单的 LLM 推理实现。该项目的 Java 版本旨在提供教育价值,并用于在 JVM 上测试和调整编译器优化,特别是针对 Graal 编译器的优化。这一 Java 移植版本最初参考了 llama2.scala 。
构建与运行
要构建和运行该项目,您需要 Java 21+,特别是其中的 MemorySegment mmap-ing 功能。以下是具体的构建步骤:
- 下载必要的文件:
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M. bin✅
- 手动构建与运行:
javac --enable-preview -source 21 --add-modules=jdk.incubator.vector Llama2.java java --enable-preview --add-modules=jdk.incubator.vector Llama2 stories15M. bin✅
- 使用 JBang 直接运行:
jbang Llama2.java stories15M. bin✅
- 使用 Makefile 和 run.sh 脚本:
make # 可选,run.sh 已经包含了 make JAVA_HOME=$GRAALVM_HOME \ JAVA_RUNTIME_OPTIONS=-Djava.util.concurrent.ForkJoinPool.common.parallelism=8 \ ./run.sh stories15M. bin✅
生成本地镜像
使用 GraalVM 可以创建一个独立的本地镜像:
JAVA_HOME=$GRAALVM_HOME NATIVE_IMAGE_OPTIONS="-march=native" make native-image
./llama2 stories15M. bin✅
或者使用 Profile-Guided Optimizations (PGO):
JAVA_HOME=$GRAALVM_HOME \
NATIVE_IMAGE_OPTIONS="--pgo-instrument -march=native --initialize-at-build-time=Llama2 -Dllama2.VectorAPI=false" \
make native-image
# 生成默认的 iprof 配置文件
./llama2 -Djava.util.concurrent.ForkJoinPool.common.parallelism=0 stories15M. bin✅
# 构建优化后的镜像
JAVA_HOME=$GRAALVM_HOME \
NATIVE_IMAGE_OPTIONS="--pgo -march=native --initialize-at-build-time=Llama2 -Dllama2.VectorAPI=false" \
make native-image
# 优化后的运行速度应该比普通镜像快约 2 倍
./llama2 stories15M. bin✅
性能表现
以下是该项目在不同配置下的性能测试结果 (基于 AMD Ryzen 3950X 64GB,Arch Linux):
单线程测试
模型 | 每秒处理 Token | 相对于 llama2.c 的加速 | 实现 |
---|---|---|---|
stories15M. bin✅ | 363 | 1.0 | llama2.c |
stories15M. bin✅ | 237 | 0.65 | llama2.java |
stories110M. bin✅ | 51.71 | 1.0 | llama2.c |
stories110M. bin✅ | 42.20 | 0.81 | llama2.java |
llama2_7B. bin✅ | 0.92 | 1.0 | llama2.c |
llama2_7B. bin✅ | 0.88 | 0.95 | llama2.java |
多线程测试
模型 | 每秒处理 Token | 相对于 llama2.c 的加速 | 实现 |
---|---|---|---|
stories15M. bin✅ | 1233 | 1.0 | llama2.c |
stories15M. bin✅ | 438 | 0.35 | llama2.java |
stories110M. bin✅ | 90 | 1.0 | llama2.c |
stories110M. bin✅ | 80 | 0.88 | llama2.java |
llama2_7B. bin✅ | 1.68 | 1.0 | llama2.c |
llama2_7B. bin✅ | 1.65 | 0.98 | llama2.java |
需要注意的是,Java 版本在多线程情况下的性能提升并不显著,这主要是由于内存带宽限制所致。
结论
mukel/llama2.java 项目展示了如何使用纯 Java 代码实现 Llama 2 推理,并在一定程度上达到了与原始 C 实现相当的性能。尽管当前版本的性能尚未完全优化,但其作为教育工具和编译器优化测试平台已经展现出巨大潜力。