前言 / Preface
前几天用Mirai复现了之前的鉴黄机器人(TensorFlow 鉴黄 API),因为Mirai的原生接口支持Java,就用Java复现了鉴黄Bot。但是有点拉胯,因为是曲线救国... 需要再跑一个WebAPI让Java来调用(接收图片的时候发起POST请求)再让Nodejs(或者Python)来进行预测。
然后昨天我觉得这样太拉胯了,就找了找让Java实现模型推理的办法。试过了TensorFlow for Java,但是例子太少了... 然后又看向了微软发布的onnxruntime,通过DJL实现跨平台的onnx模型推理库,虽然Java的例子依然很少(只想复制)...但是还是不要做一个只会CV的程序猿吧@(抠鼻)(当然,主要是我比较菜)
然后,然后就有了这篇文章。
模型转换
你可以通过keras2onnx库很方便的将keras下的h5模型转换成onnx模型
For example:
import keras2onnx
import onnx
from keras.models import load_model
model = load_model('path to h5 model')
onnx_model = keras2onnx.convert_keras(model, model.name)
temp_model_file = 'path to save onnx model'
onnx.save_model(onnx_model, temp_model_file)
你可能需要安装keras2onnx:
pip install keras2onnx
如果遇到以下报错:
AttributeError: 'KerasTensor' object has no attribute 'graph'
解决方法:
1.安装 tf2onnx
pip install tf2onnx
2.将onnx模型保存为 tf的pb格式(会在项目目录下生产tmp_model文件夹)
from keras.models import load_model
import tensorflow as tf
model = load_model("path to h5 model")
tf.saved_model.save(model, "tmp_model")
3.在控制台用以下命令转换为onnx模型
python -m tf2onnx.convert --saved-model .\tmp_model\ --output "model.onnx"
OK,转换完成,你会在项目目录下收到一个modex.onnx。
代码 / Code
懒得解释了,反正也是记录用的,直接贴上代码就睡觉了,太困了。
我这里是调用的五分类的鉴黄模型,mobilenet_v2网络的,参考之前的文章:TensorFlow 鉴黄 API ,只不过当时是用resnet50网络训练的。
//here's a 5 classes nsfw model based on mobilenet_v2.
每个模型的input都会有点不一样,参数名可能不一样,要求的维度也有可能不一样,我这是input_1,要求是一个4维矩阵(alpha, width, height, channel),需要根据模型更改(你也可以改模型)。
//the model's input may be different, also the dimension may be different.
package com.julym;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.awt.image.PixelGrabber;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Map;
class AntiModel {
public byte[] model;
AntiModel() {
try {
model = getClass().getClassLoader().getResourceAsStream("model.onnx").readAllBytes(); // loading onnx model from ResourcePath
//model = new FileInputStream("M:/IdeaProjects/nsfw/src/main/resources/model.onnx").readAllBytes();
} catch (IOException | NullPointerException e) {
e.printStackTrace();
}
}
public static void main(String[] args) {
try {
BufferedImage image = ImageIO.read(new FileInputStream("D:/Desktop/Fun/test.jpg"));
AntiModel antiModel = new AntiModel();
System.out.println(antiModel.detector(image));
} catch (IOException e) {
e.printStackTrace();
}
}
// BufferedImage image = ImageIO.read(new FileInputStream("D:/Desktop/Fun/test.jpg"));
public String detector(BufferedImage content){
BufferedImage scaled = scaleImg(content);
float[][][][] inputArray = imageToMatrix(scaled);
try {
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(model);
OnnxTensor onnxTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), inputArray);
Result result = session.run(
Map.of("input_1", onnxTensor)
);
session.close();
onnxTensor.close();
return pred(result, 1);
} catch (OrtException e) {
e.printStackTrace();
}
return null;
}
public static BufferedImage scaleImg(BufferedImage image){
Image scaledImg = image.getScaledInstance(224,224, Image.SCALE_FAST);
BufferedImage img = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB);
img.getGraphics().drawImage(scaledImg, 0, 0, null);
return img;
}
// retType = 0 return JsonObj = Other return JsonArr
public static String pred(Result tensorResult, int retType) {
try {
float[][] outputProbs = (float[][]) tensorResult.get(0).getValue();
float[] probabilities = outputProbs[0];
float maxVal = Float.NEGATIVE_INFINITY;
String[] label = {"drawings", "hentai", "neutral", "porn", "sexy"};
String result = null;
for (int i = 0; i < probabilities.length; i++) { // doing bubble sort
if (probabilities[i] > maxVal) {
maxVal = probabilities[i];
}
for(int j=0;j<probabilities.length-1-i;j++)
{
if(probabilities[j]>probabilities[j+1])
{
float proTemp = probabilities[j];
probabilities[j]=probabilities[j+1];
probabilities[j+1]=proTemp;
String labelTemp = label[j];
label[j] = label[j+1];
label[j+1] = labelTemp;
}
}
// result.setValue(i, probabilities[i]);
}
if (retType == 0) {
result = "{\"" + label[4] + "\":" + String.valueOf(probabilities[4]) +
",\"" + label[3] + "\":" + String.valueOf(probabilities[3]) +
", \"" + label[2] + "\":" + String.valueOf(probabilities[2]) +
", \"" + label[1] + "\":" + String.valueOf(probabilities[1]) +
", \"" + label[0] + "\":" + String.valueOf(probabilities[0]) + "}";
}else{
result = "[{\"className\":\"" + label[4] + "\", \"probability\":" + String.valueOf(probabilities[4]) +
"}, {\"className\":\"" + label[3] + "\", \"probability\":" + String.valueOf(probabilities[3]) +
"}, {\"className\":\"" + label[2] + "\", \"probability\":" + String.valueOf(probabilities[2]) +
"}, {\"className\":\"" + label[1] + "\", \"probability\":" + String.valueOf(probabilities[1]) +
"}, {\"className\":\"" + label[0] + "\", \"probability\":" + String.valueOf(probabilities[0]) + "}]";
}
tensorResult.close();
return result;
} catch (OrtException e) {
e.printStackTrace();
}
return null;
}
public static float[][][][] imageToMatrix(BufferedImage image){
int width = image.getWidth();
int height = image.getHeight();
int pixels[] = new int[width * height];
PixelGrabber pg = new PixelGrabber(image, 0, 0, width, height, pixels, 0, width);
try {
pg.grabPixels();
} catch (InterruptedException e) {
e.printStackTrace();
}
float[][][][] ret = new float[1][pg.getHeight()][pg.getWidth()][3];
int pixel = 0;
int row = 0;
int col = 0;
while (row * width + col < pixels.length){
pixel = row * width + col;
ret[0][row][col][2] = (pixels[pixel] & 0x000000FF) / 255f; // blue
ret[0][row][col][1] = (pixels[pixel]>> 8 & 0x000000FF) / 255f; // green
ret[0][row][col][0] = (pixels[pixel]>> 16 & 0x000000FF) / 255f; // red
col++;
if (col == width - 1) {
col = 0;
row++;
}
}
return ret;
// another writing
// for(int i=0;i<image.getHeight();i++)
// {
// for(int j=0;j<image.getWidth();j++)
// {
// ret[i][j][2]= (image.getRGB(i, j)>> 16) & 0x000000FF; // red
// ret[i][j][1]= (image.getRGB(i, j)>> 8 ) & 0x000000FF; // green
// ret[i][j][0]= (image.getRGB(i, j)) & 0x000000FF; // blue
// }
// }
}
}
你需要安装一下onnxruntime依赖库噢,Maven Repository: com.microsoft.onnxruntime » onnxruntime » 1.9.0 (mvnrepository.com)。
gralde(kts):
// https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime
implementation("com.microsoft.onnxruntime:onnxruntime:1.9.0")
Result: