Java 通过 onnxruntime 调用 .h5 模型 / Java via onnxruntime call .h5 model

正文索引 [隐藏]

前言 / Preface

前几天用Mirai复现了之前的鉴黄机器人(TensorFlow 鉴黄 API),因为Mirai的原生接口支持Java,就用Java复现了鉴黄Bot。但是有点拉胯,因为是曲线救国... 需要再跑一个WebAPI让Java来调用(接收图片的时候发起POST请求)再让Nodejs(或者Python)来进行预测。

然后昨天我觉得这样太拉胯了,就找了找让Java实现模型推理的办法。试过了TensorFlow for Java,但是例子太少了... 然后又看向了微软发布的onnxruntime,通过DJL实现跨平台的onnx模型推理库,虽然Java的例子依然很少(只想复制)...但是还是不要做一个只会CV的程序猿吧抠鼻(当然,主要是我比较菜)

然后,然后就有了这篇文章。

模型转换

你可以通过keras2onnx库很方便的将keras下的h5模型转换成onnx模型

For example:

import keras2onnx
import onnx
from keras.models import load_model
model = load_model('path to h5 model')
onnx_model = keras2onnx.convert_keras(model, model.name)
temp_model_file = 'path to save onnx model'
onnx.save_model(onnx_model, temp_model_file)

你可能需要安装keras2onnx:

pip install keras2onnx

如果遇到以下报错:

AttributeError: 'KerasTensor' object has no attribute 'graph'

解决方法:

1.安装 tf2onnx

pip install tf2onnx

2.将onnx模型保存为 tf的pb格式(会在项目目录下生产tmp_model文件夹)

from keras.models import load_model
import tensorflow as tf
model = load_model("path to h5 model")
tf.saved_model.save(model, "tmp_model")

3.在控制台用以下命令转换为onnx模型

python -m tf2onnx.convert --saved-model .\tmp_model\ --output "model.onnx"

OK,转换完成,你会在项目目录下收到一个modex.onnx。

代码 / Code

懒得解释了,反正也是记录用的,直接贴上代码就睡觉了,太困了。

我这里是调用的五分类的鉴黄模型,mobilenet_v2网络的,参考之前的文章:TensorFlow 鉴黄 API ,只不过当时是用resnet50网络训练的。

//here's a 5 classes nsfw model based on mobilenet_v2.

每个模型的input都会有点不一样,参数名可能不一样,要求的维度也有可能不一样,我这是input_1,要求是一个4维矩阵(alpha, width, height, channel),需要根据模型更改(你也可以改模型)。

//the model's input may be different, also the dimension may be different.

package com.julym;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.awt.image.PixelGrabber;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Map;

class AntiModel {
    public byte[] model;

    AntiModel() {
        try {
            model = getClass().getClassLoader().getResourceAsStream("model.onnx").readAllBytes(); // loading onnx model from ResourcePath
            //model = new FileInputStream("M:/IdeaProjects/nsfw/src/main/resources/model.onnx").readAllBytes();
        } catch (IOException | NullPointerException e) {
            e.printStackTrace();
        }
    }


    public static void main(String[] args) {
        try {
            BufferedImage image = ImageIO.read(new FileInputStream("D:/Desktop/Fun/test.jpg"));
            AntiModel antiModel = new AntiModel();
            System.out.println(antiModel.detector(image));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    // BufferedImage image = ImageIO.read(new FileInputStream("D:/Desktop/Fun/test.jpg"));
    public String detector(BufferedImage content){
        BufferedImage scaled = scaleImg(content);
        float[][][][] inputArray = imageToMatrix(scaled);
        try {
            OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession session = env.createSession(model);
            OnnxTensor onnxTensor =  OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), inputArray);
            Result result = session.run(
                    Map.of("input_1", onnxTensor)
            );
            session.close();
            onnxTensor.close();
            return pred(result, 1);

        } catch (OrtException e) {
            e.printStackTrace();
        }

        return null;
    }

    public static BufferedImage scaleImg(BufferedImage image){
        Image scaledImg = image.getScaledInstance(224,224, Image.SCALE_FAST);
        BufferedImage img = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB);
        img.getGraphics().drawImage(scaledImg, 0, 0, null);
        return img;
    }

    // retType = 0 return JsonObj = Other return JsonArr
    public static String pred(Result tensorResult, int retType) {
        try {
            float[][] outputProbs = (float[][]) tensorResult.get(0).getValue();
            float[] probabilities = outputProbs[0];
            float maxVal = Float.NEGATIVE_INFINITY;
            String[] label = {"drawings", "hentai", "neutral", "porn", "sexy"};
            String result = null;
            for (int i = 0; i < probabilities.length; i++) { // doing bubble sort
                if (probabilities[i] > maxVal) {
                    maxVal = probabilities[i];

                }
                for(int j=0;j<probabilities.length-1-i;j++)
                {
                    if(probabilities[j]>probabilities[j+1])
                    {
                        float proTemp = probabilities[j];
                        probabilities[j]=probabilities[j+1];
                        probabilities[j+1]=proTemp;
                        String labelTemp = label[j];
                        label[j] = label[j+1];
                        label[j+1] = labelTemp;
                    }
                }
                // result.setValue(i, probabilities[i]);
            }
            if (retType == 0) {
                result = "{\"" + label[4] + "\":" + String.valueOf(probabilities[4]) +
                        ",\"" + label[3] + "\":" + String.valueOf(probabilities[3]) +
                        ", \"" + label[2] + "\":" + String.valueOf(probabilities[2]) +
                        ", \"" + label[1] + "\":" + String.valueOf(probabilities[1]) +
                        ", \"" + label[0] + "\":" + String.valueOf(probabilities[0]) + "}";
            }else{
                result = "[{\"className\":\"" + label[4] + "\", \"probability\":" + String.valueOf(probabilities[4]) +
                        "}, {\"className\":\"" + label[3] + "\", \"probability\":" + String.valueOf(probabilities[3]) +
                        "}, {\"className\":\"" + label[2] + "\", \"probability\":" + String.valueOf(probabilities[2]) +
                        "}, {\"className\":\"" + label[1] + "\", \"probability\":" + String.valueOf(probabilities[1]) +
                        "}, {\"className\":\"" + label[0] + "\", \"probability\":" + String.valueOf(probabilities[0]) + "}]";
            }
            tensorResult.close();
            return result;
        } catch (OrtException e) {
            e.printStackTrace();
        }
        return null;
    }
    public static float[][][][] imageToMatrix(BufferedImage image){
        int width = image.getWidth();
        int height = image.getHeight();
        int pixels[] = new int[width * height];
        PixelGrabber pg = new PixelGrabber(image, 0, 0, width, height, pixels, 0, width);
        try {
            pg.grabPixels();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        float[][][][] ret = new float[1][pg.getHeight()][pg.getWidth()][3];
        int pixel = 0;
        int row = 0;
        int col = 0;
        while (row * width + col < pixels.length){
            pixel = row * width + col;
            ret[0][row][col][2] = (pixels[pixel] & 0x000000FF) / 255f; // blue
            ret[0][row][col][1] = (pixels[pixel]>> 8 & 0x000000FF) / 255f; // green
            ret[0][row][col][0] = (pixels[pixel]>> 16 & 0x000000FF) / 255f; // red
            col++;
            if (col == width - 1) {
                col = 0;
                row++;
            }
        }
        return ret;
    // another writing
//        for(int i=0;i<image.getHeight();i++)
//        {
//            for(int j=0;j<image.getWidth();j++)
//            {
//                ret[i][j][2]= (image.getRGB(i, j)>> 16) & 0x000000FF; // red
//                ret[i][j][1]= (image.getRGB(i, j)>> 8 ) & 0x000000FF; // green
//                ret[i][j][0]= (image.getRGB(i, j)) & 0x000000FF; // blue
//            }
//        }
    }
}

你需要安装一下onnxruntime依赖库噢,Maven Repository: com.microsoft.onnxruntime » onnxruntime » 1.9.0 (mvnrepository.com)

gralde(kts):

// https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime
implementation("com.microsoft.onnxruntime:onnxruntime:1.9.0")

Result: