用Java玩转深度学习:DJL实战指南

深度学习模型大多用Python开发,而服务端却多用Java,导致许多开发者不得不使用Java调用Python接口,效率低下且不够优雅。更糟糕的是,如果想在Android上进行推理,就必须使用Java。

别担心!现在,我们可以用Java直接进行深度学习了!DJL(Deep Java Library)是一个强大的开源深度学习框架,它支持模型构建、训练、推理,甚至在Android上运行。本文将带你深入了解DJL,并通过一个实战案例,教你用Java加载PyTorch模型进行图片分类。

DJL:Java深度学习的利器

DJL 的出现,为Java开发者打开了深度学习的大门。它提供了一套简洁易用的API,让Java开发者能够轻松地构建、训练和部署深度学习模型。

DJL 的优势:

  • Java 开发: 使用熟悉的 Java 语言进行深度学习开发,无需学习其他语言。
  • 跨平台支持: 支持 Windows、Linux、macOS 和 Android 等多种平台。
  • GPU 加速: 支持 GPU 加速,提升模型训练和推理速度。
  • 模型兼容性: 支持多种深度学习框架,包括 PyTorch、TensorFlow 和 MXNet。

DJL 核心 API 解密

DJL 的核心 API 包括 Criteria、Translator 和 NDArray,它们共同构成了深度学习模型的构建和操作基础。

1. Criteria:模型的定义

Criteria 类对象定义了模型的属性,例如模型路径、输入和输出类型等。

Criteria<Input, Output> criteria = Criteria.builder()
        .setTypes(Input.class, Output.class) // 定义输入和输出数据类型
        .optTranslator(new InputOutputTranslator()) // 设置输入输出转换器
        .optModelPath(Paths.get("/var/models/my_resnet50")) // 指定模型路径
        .optModelName("model/resnet50") // 指定模型文件前缀
        .build();

ZooModel<Image, Classifications> model = criteria.loadModel();

这段代码定义了一个名为 “resnet50” 的模型,并加载了它。

2. Translator:数据转换桥梁

Translator 接口定义了如何将自定义的输入输出类转换为 Tensor 类型。

private Translator<Input, Output> translator = new Translator<Input, Output>() {

    @Override
    public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
        return null;
    }

    @Override
    public Output processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
        return null;
    }
};

Translator 接口包含两个方法:

  • processInput: 将输入类对象转换为 Tensor。
  • processOutput: 将模型输出的 Tensor 转换为自定义类。

3. NDArray:Tensor 操作的利器

NDArray 类类似于 Python 中的 NumPy 数组,它提供了丰富的 Tensor 操作功能。

NDManager ndManager = NDManager.newBaseManager(); // 创建 NDManager 对象
NDArray ndArray = ndManager.create(new Shape(1, 2, 3, 4)); // 创建一个 Shape 为 (1, 2, 3, 4) 的 Tensor

DJL 提供了多种 NDArray 操作,例如:

  • 创建 NDArray
  • 变更数据类型
  • 运算(加减乘除)
  • 切片
  • 赋值
  • 翻转

实战:用 DJL 加载 PyTorch 模型进行图片分类

下面,我们将使用 PyTorch 提供的 ResNet18 模型进行图片分类。

步骤:

  1. 引入依赖: 在项目的 pom.xml 文件中添加 DJL 的依赖。
  2. 导出 PyTorch 模型: 使用 Python 将 ResNet18 模型保存为 TorchScript 模型。
  3. 创建 Translator: 定义输入为图片路径,输出为类别。
  4. 定义 Criteria: 定义模型路径、输入输出类型和 Translator。
  5. 实例化模型: 使用 Criteria 加载模型。
  6. 创建 Predictor: 使用模型创建 Predictor 对象。
  7. 进行预测: 使用 Predictor 对图片进行分类。

代码示例:

// ... (引入依赖)

// 创建 Translator
Translator<String, String> translator = new Translator<String, String>() {

    @Override
    public NDList processInput(TranslatorContext ctx, String input) throws Exception {
        // ... (读取图片,进行预处理)
        return new NDList(ndArray);
    }

    @Override
    public String processOutput(TranslatorContext ctx, NDList list) throws Exception {
        // ... (获取预测结果)
        return index + "";
    }
};

// 定义 Criteria
Criteria<String, String> criteria = Criteria.builder()
        .setTypes(String.class, String.class)
        .optModelPath(Paths.get("model/traced_resnet_model.pt"))
        .optOption("mapLocation", "true")
        .optTranslator(translator)
        .build();

// 实例化模型
ZooModel model = criteria.loadModel();

// 创建 Predictor
Predictor predictor = model.newPredictor();

// 进行预测
System.out.println(predictor.predict("test/test.jpg"));

最终输出:

258

258 对应的类别为 Samoyed(萨摩耶),说明预测成功。

总结

DJL 为 Java 开发者提供了强大的深度学习能力,让我们能够使用 Java 语言进行模型构建、训练和推理。本文通过一个简单的图片分类案例,展示了如何使用 DJL 加载 PyTorch 模型进行预测。

参考文献:

希望本文能够帮助你快速入门 DJL,并开始你的 Java 深度学习之旅!

Leave a Comment