用 Rust 和 tch-rs 实现图像验证码识别
本项目介绍如何使用 Rust 调用已经用 PyTorch 训练好的验证码识别模型,实现命令行下的图像识别。
一、模型准备
我们使用 Python + PyTorch 训练模型并导出 .pt 文件:
保存模型
torch.save(model.state_dict(), "captcha_model.pt")
建议模型结构(CNN + LSTM)写为 class,然后在 Rust 中重构结构一致的网络。
二、项目依赖(Cargo.toml)
[dependencies]
tch = "0.13.0"
image = "0.24"
三、加载模型和图像(main.rs)
更多内容访问ttocr.com或联系1436423940
use tch::{nn, vision, Tensor, Device, Kind, nn::ModuleT};
use image::GenericImageView;
use std::env;
const WIDTH: i64 = 160;
const HEIGHT: i64 = 60;
const CHARSET: &str = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const N_CLASSES: i64 = 36;
const N_LEN: usize = 4;
fn preprocess_image(path: &str) -> Tensor {
let img = image::open(path).expect("Cannot open image");
let resized = img.resize_exact(WIDTH as u32, HEIGHT as u32, image::imageops::FilterType::Nearest);
let rgb = resized.to_rgb8();
let vec: Vec<f32> = rgb.pixels()
.flat_map(|p| p.0.iter().map(|&x| (x as f32 / 255.0 - 0.5) / 0.5))
.collect();
Tensor::of_slice(&vec)
.reshape(&[1, HEIGHT, WIDTH, 3])
.permute(&[0, 3, 1, 2]) // NCHW
}
fn decode_output(output: Tensor) -> String {
let output = output.softmax(-1, Kind::Float); // optional
let indices = output.argmax(-1, false);
let indices = indices.squeeze();
let chars: Vec
(0..N_LEN)
.map(|i| chars[indices.i(i as i64).int64_value(&[] as &[i64]) as usize])
.collect()
}
fn main() {
let args: Vec
let image_path = &args[1];
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let net = build_model(&vs.root()); // 定义网络结构
vs.load("captcha_model.pt").unwrap();
let input = preprocess_image(image_path);
let output = net.forward_t(&input, false);
let pred = decode_output(output);
println!("Predicted CAPTCHA: {}", pred);
}
四、构建网络结构(build_model)
你需手动在 Rust 中实现与 PyTorch 一致的模型结构:
fn build_model(p: &nn::Path) -> impl nn::ModuleT {
let cnn = nn::seq_t()
.add(nn::conv2d(p / "c1", 3, 32, 3, Default::default()))
.add_fn(|x| x.relu())
.add_fn(|x| x.max_pool2d_default(2))
.add(nn::conv2d(p / "c2", 32, 64, 3, Default::default()))
.add_fn(|x| x.relu())
.add_fn(|x| x.max_pool2d_default(2));
// 添加线性层代替 LSTM(tch-rs 无 LSTM),或使用绑定 TorchScript 模型
let fc = nn::linear(p / "fc", 64 * 15 * 10, N_LEN as i64 * N_CLASSES);
nn::seq_t().add(cnn).add_fn(|x| {
let b = x.size()[0];
let x = x.view([b, -1]);
let x = x.apply(&fc);
x.view([b, N_LEN as i64, N_CLASSES])
})
}
五、运行
cargo run --release -- path/to/captcha.png
浙公网安备 33010602011771号