Java 通过 onnxruntime 调用 .h5 模型 / Java via onnxruntime call .h5 model
前言 / Preface
前几天用Mirai复现了之前的鉴黄机器人(TensorFlow 鉴黄 API),因为Mirai的原生接口支持Java,就用Java复现了鉴黄Bot。但是有点拉胯,因为是曲线救国... 需要再跑一个WebAPI让Java来调用(接收图片的时候发起POST请求)再让Nodejs(或者Python)来进行预测。
然后昨天我觉得这样太拉胯了,就找了找让Java实现模型推理的办法。试过了TensorFlow for Java,但是例子太少了... 然后又看向了微软发布的onnxruntime,通过DJL实现跨平台的onnx模型推理库,虽然Java的例子依然很少(只想复制)...但是还是不要做一个只会CV的程序猿吧(当然,主要是我比较菜)
然后,然后就有了这篇文章。
模型转换
你可以通过keras2onnx库很方便的将keras下的h5模型转换成onnx模型
For example:
import keras2onnx import onnx from keras.models import load_model model = load_model('path to h5 model') onnx_model = keras2onnx.convert_keras(model, model.name) temp_model_file = 'path to save onnx model' onnx.save_model(onnx_model, temp_model_file)
你可能需要安装keras2onnx:
pip install keras2onnx
如果遇到以下报错:
AttributeError: 'KerasTensor' object has no attribute 'graph'
解决方法:
1.安装 tf2onnx
pip install tf2onnx
2.将onnx模型保存为 tf的pb格式(会在项目目录下生产tmp_model文件夹)
from keras.models import load_model import tensorflow as tf model = load_model("path to h5 model") tf.saved_model.save(model, "tmp_model")
3.在控制台用以下命令转换为onnx模型
python -m tf2onnx.convert --saved-model .\tmp_model\ --output "model.onnx"
OK,转换完成,你会在项目目录下收到一个modex.onnx。
代码 / Code
懒得解释了,反正也是记录用的,直接贴上代码就睡觉了,太困了。
我这里是调用的五分类的鉴黄模型,mobilenet_v2网络的,参考之前的文章:TensorFlow 鉴黄 API ,只不过当时是用resnet50网络训练的。
//here's a 5 classes nsfw model based on mobilenet_v2.
每个模型的input都会有点不一样,参数名可能不一样,要求的维度也有可能不一样,我这是input_1,要求是一个4维矩阵(alpha, width, height, channel),需要根据模型更改(你也可以改模型)。
//the model's input may be different, also the dimension may be different.
package com.julym; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import ai.onnxruntime.OrtSession.Result; import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; import java.awt.image.PixelGrabber; import java.io.FileInputStream; import java.io.IOException; import java.util.Map; class AntiModel { public byte[] model; AntiModel() { try { model = getClass().getClassLoader().getResourceAsStream("model.onnx").readAllBytes(); // loading onnx model from ResourcePath //model = new FileInputStream("M:/IdeaProjects/nsfw/src/main/resources/model.onnx").readAllBytes(); } catch (IOException | NullPointerException e) { e.printStackTrace(); } } public static void main(String[] args) { try { BufferedImage image = ImageIO.read(new FileInputStream("D:/Desktop/Fun/test.jpg")); AntiModel antiModel = new AntiModel(); System.out.println(antiModel.detector(image)); } catch (IOException e) { e.printStackTrace(); } } // BufferedImage image = ImageIO.read(new FileInputStream("D:/Desktop/Fun/test.jpg")); public String detector(BufferedImage content){ BufferedImage scaled = scaleImg(content); float[][][][] inputArray = imageToMatrix(scaled); try { OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession session = env.createSession(model); OnnxTensor onnxTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), inputArray); Result result = session.run( Map.of("input_1", onnxTensor) ); session.close(); onnxTensor.close(); return pred(result, 1); } catch (OrtException e) { e.printStackTrace(); } return null; } public static BufferedImage scaleImg(BufferedImage image){ Image scaledImg = image.getScaledInstance(224,224, Image.SCALE_FAST); BufferedImage img = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB); img.getGraphics().drawImage(scaledImg, 0, 0, null); return img; } // retType = 0 return JsonObj = Other return JsonArr public static String pred(Result tensorResult, int retType) { try { float[][] outputProbs = (float[][]) tensorResult.get(0).getValue(); float[] probabilities = outputProbs[0]; float maxVal = Float.NEGATIVE_INFINITY; String[] label = {"drawings", "hentai", "neutral", "porn", "sexy"}; String result = null; for (int i = 0; i < probabilities.length; i++) { // doing bubble sort if (probabilities[i] > maxVal) { maxVal = probabilities[i]; } for(int j=0;j<probabilities.length-1-i;j++) { if(probabilities[j]>probabilities[j+1]) { float proTemp = probabilities[j]; probabilities[j]=probabilities[j+1]; probabilities[j+1]=proTemp; String labelTemp = label[j]; label[j] = label[j+1]; label[j+1] = labelTemp; } } // result.setValue(i, probabilities[i]); } if (retType == 0) { result = "{\"" + label[4] + "\":" + String.valueOf(probabilities[4]) + ",\"" + label[3] + "\":" + String.valueOf(probabilities[3]) + ", \"" + label[2] + "\":" + String.valueOf(probabilities[2]) + ", \"" + label[1] + "\":" + String.valueOf(probabilities[1]) + ", \"" + label[0] + "\":" + String.valueOf(probabilities[0]) + "}"; }else{ result = "[{\"className\":\"" + label[4] + "\", \"probability\":" + String.valueOf(probabilities[4]) + "}, {\"className\":\"" + label[3] + "\", \"probability\":" + String.valueOf(probabilities[3]) + "}, {\"className\":\"" + label[2] + "\", \"probability\":" + String.valueOf(probabilities[2]) + "}, {\"className\":\"" + label[1] + "\", \"probability\":" + String.valueOf(probabilities[1]) + "}, {\"className\":\"" + label[0] + "\", \"probability\":" + String.valueOf(probabilities[0]) + "}]"; } tensorResult.close(); return result; } catch (OrtException e) { e.printStackTrace(); } return null; } public static float[][][][] imageToMatrix(BufferedImage image){ int width = image.getWidth(); int height = image.getHeight(); int pixels[] = new int[width * height]; PixelGrabber pg = new PixelGrabber(image, 0, 0, width, height, pixels, 0, width); try { pg.grabPixels(); } catch (InterruptedException e) { e.printStackTrace(); } float[][][][] ret = new float[1][pg.getHeight()][pg.getWidth()][3]; int pixel = 0; int row = 0; int col = 0; while (row * width + col < pixels.length){ pixel = row * width + col; ret[0][row][col][2] = (pixels[pixel] & 0x000000FF) / 255f; // blue ret[0][row][col][1] = (pixels[pixel]>> 8 & 0x000000FF) / 255f; // green ret[0][row][col][0] = (pixels[pixel]>> 16 & 0x000000FF) / 255f; // red col++; if (col == width - 1) { col = 0; row++; } } return ret; // another writing // for(int i=0;i<image.getHeight();i++) // { // for(int j=0;j<image.getWidth();j++) // { // ret[i][j][2]= (image.getRGB(i, j)>> 16) & 0x000000FF; // red // ret[i][j][1]= (image.getRGB(i, j)>> 8 ) & 0x000000FF; // green // ret[i][j][0]= (image.getRGB(i, j)) & 0x000000FF; // blue // } // } } }
你需要安装一下onnxruntime依赖库噢,Maven Repository: com.microsoft.onnxruntime » onnxruntime » 1.9.0 (mvnrepository.com)。
gralde(kts):
// https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime implementation("com.microsoft.onnxruntime:onnxruntime:1.9.0")
Result: