用 Rust 和 tch-rs 实现图像验证码识别

本项目介绍如何使用 Rust 调用已经用 PyTorch 训练好的验证码识别模型,实现命令行下的图像识别。

一、模型准备
我们使用 Python + PyTorch 训练模型并导出 .pt 文件:

保存模型

torch.save(model.state_dict(), "captcha_model.pt")
建议模型结构(CNN + LSTM)写为 class,然后在 Rust 中重构结构一致的网络。

二、项目依赖(Cargo.toml)

[dependencies]
tch = "0.13.0"
image = "0.24"
三、加载模型和图像(main.rs)
更多内容访问ttocr.com或联系1436423940
use tch::{nn, vision, Tensor, Device, Kind, nn::ModuleT};
use image::GenericImageView;
use std::env;

const WIDTH: i64 = 160;
const HEIGHT: i64 = 60;
const CHARSET: &str = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const N_CLASSES: i64 = 36;
const N_LEN: usize = 4;

fn preprocess_image(path: &str) -> Tensor {
let img = image::open(path).expect("Cannot open image");
let resized = img.resize_exact(WIDTH as u32, HEIGHT as u32, image::imageops::FilterType::Nearest);
let rgb = resized.to_rgb8();

let vec: Vec<f32> = rgb.pixels()
    .flat_map(|p| p.0.iter().map(|&x| (x as f32 / 255.0 - 0.5) / 0.5))
    .collect();

Tensor::of_slice(&vec)
    .reshape(&[1, HEIGHT, WIDTH, 3])
    .permute(&[0, 3, 1, 2])  // NCHW

}

fn decode_output(output: Tensor) -> String {
let output = output.softmax(-1, Kind::Float); // optional
let indices = output.argmax(-1, false);
let indices = indices.squeeze();
let chars: Vec = CHARSET.chars().collect();
(0..N_LEN)
.map(|i| chars[indices.i(i as i64).int64_value(&[] as &[i64]) as usize])
.collect()
}

fn main() {
let args: Vec = env::args().collect();
let image_path = &args[1];

let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let net = build_model(&vs.root()); // 定义网络结构
vs.load("captcha_model.pt").unwrap();

let input = preprocess_image(image_path);
let output = net.forward_t(&input, false);

let pred = decode_output(output);
println!("Predicted CAPTCHA: {}", pred);

}
四、构建网络结构(build_model)
你需手动在 Rust 中实现与 PyTorch 一致的模型结构:

fn build_model(p: &nn::Path) -> impl nn::ModuleT {
let cnn = nn::seq_t()
.add(nn::conv2d(p / "c1", 3, 32, 3, Default::default()))
.add_fn(|x| x.relu())
.add_fn(|x| x.max_pool2d_default(2))
.add(nn::conv2d(p / "c2", 32, 64, 3, Default::default()))
.add_fn(|x| x.relu())
.add_fn(|x| x.max_pool2d_default(2));

// 添加线性层代替 LSTM(tch-rs 无 LSTM),或使用绑定 TorchScript 模型

let fc = nn::linear(p / "fc", 64 * 15 * 10, N_LEN as i64 * N_CLASSES);

nn::seq_t().add(cnn).add_fn(|x| {
    let b = x.size()[0];
    let x = x.view([b, -1]);
    let x = x.apply(&fc);
    x.view([b, N_LEN as i64, N_CLASSES])
})

}
五、运行

cargo run --release -- path/to/captcha.png

posted @ 2025-05-27 19:17  ttocr、com  阅读(17)  评论(0)    收藏  举报