pytorch onnx模型转换java调用示例
python torch训练一个神经网络,用来进行简单的mnist数字预测!并将训练后的模型存为onnx格式:
代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
import onnxruntime as ort
import numpy as np
import torch
from torchvision import datasets, transforms
# 定义超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 1
# 数据加载和预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
# 定义神经网络模型
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = SimpleNN()
def train():
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(num_epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item()}')
# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Accuracy of the model on the test images: {100 * correct / total}%')
# 推理示例
sample_data, sample_target = next(iter(test_loader))
sample_output = model(sample_data)
_, sample_predicted = torch.max(sample_output.data, 1)
print(f'Predicted: {sample_predicted[:10]}')
print(f'Actual: {sample_target[:10]}')
# 保存模型到本地
torch.save(model.state_dict(), 'simple_nn.pth')
print('Model saved to simple_nn.pth')
# 转换为ONNX格式并保存
dummy_input = torch.randn(1, 1, 28, 28) # 创建一个dummy输入
torch.onnx.export(model, dummy_input, 'simple_nn.onnx', input_names=['input'], output_names=['output'])
print('Model converted to ONNX and saved to simple_nn.onnx')
def inference():
# 数据加载和预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
# 加载ONNX模型
onnx_model_path = 'simple_nn.onnx'
ort_session = ort.InferenceSession(onnx_model_path)
# 定义一个函数来进行推理
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def infer_onnx_model(ort_session, data):
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs
# 推理示例
sample_data, sample_target = next(iter(test_loader))
sample_data = sample_data.view(1, 1, 28, 28) # 调整输入形状
onnx_output = infer_onnx_model(ort_session, sample_data)
onnx_predicted = np.argmax(onnx_output[0], axis=1)
print(f'Predicted: {onnx_predicted[0]}')
print(f'Actual: {sample_target.item()}')
if __name__ == '__main__':
train()
inference()
代码输出:
Epoch 1/1, Batch 0, Loss: 2.3018319606781006 Epoch 1/1, Batch 100, Loss: 1.8857725858688354 Epoch 1/1, Batch 200, Loss: 1.0029046535491943 Epoch 1/1, Batch 300, Loss: 0.6656786203384399 Epoch 1/1, Batch 400, Loss: 0.641338586807251 Epoch 1/1, Batch 500, Loss: 0.5250198841094971 Epoch 1/1, Batch 600, Loss: 0.5605880618095398 Epoch 1/1, Batch 700, Loss: 0.5747233629226685 Epoch 1/1, Batch 800, Loss: 0.49430033564567566 Epoch 1/1, Batch 900, Loss: 0.28630945086479187 Accuracy of the model on the test images: 90.38% Predicted: tensor([7, 2, 1, 0, 4, 1, 4, 9, 6, 9]) Actual: tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9]) Model saved to simple_nn.pth Model converted to ONNX and saved to simple_nn.onnx Predicted: 7 Actual: 7
接下来,我们在java中使用该onnx模型进行预测:
代码如下:
import ai.onnxruntime.*;
import java.nio.FloatBuffer;
import java.util.Collections;
import java.util.Map;
public class App {
public static void main(String[] args) throws OrtException {
// 加载ONNX模型
String modelPath = "D:\\source\\pythonProject\\simple_nn.onnx"; // "simple_nn.onnx";
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
OrtSession session = env.createSession(modelPath, opts);
// 构造随机输入
float[] inputData = new float[1 * 1 * 28 * 28];
for (int i = 0; i < inputData.length; i++) {
inputData[i] = (float) Math.random();
}
// 创建输入张量
OnnxTensor inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), new long[]{1, 1, 28, 28});
// 运行推理
Map<String, OnnxTensor> inputs = Collections.singletonMap(session.getInputNames().iterator().next(), inputTensor);
OrtSession.Result result = session.run(inputs);
// 获取输出
float[][] output = (float[][]) result.get(0).getValue();
int predictedLabel = argMax(output[0]);
System.out.println("Predicted Label: " + predictedLabel);
// 释放资源
inputTensor.close();
session.close();
env.close();
}
// 获取最大值的索引
private static int argMax(float[] array) {
int maxIndex = 0;
for (int i = 1; i < array.length; i++) {
if (array[i] > array[maxIndex]) {
maxIndex = i;
}
}
return maxIndex;
}
}
运行输出:
Predicted Label: 8
附:
pom.xml文件
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>test_onnx</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>test_onnx</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<!-- ONNX Runtime dependency -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.15.1</version>
</dependency>
</dependencies>
</project>

浙公网安备 33010602011771号