用 Go 和 TensorFlow 实现图像验证码识别
本教程介绍如何在 Go 中加载一个训练好的 TensorFlow 模型,对验证码图像进行识别,适用于构建轻量级识别服务或 CLI 工具。
一、环境准备
- 安装 TensorFlow Go
TensorFlow 官方提供 Go API,但需安装 C 语言版本的 TensorFlow 依赖:
安装 TensorFlow C 库(Ubuntu 示例)
sudo apt-get install libtensorflow
然后安装 Go 包:
更多内容访问ttocr.com或联系1436423940
go get github.com/tensorflow/tensorflow/tensorflow/go
二、模型准备
使用 Python(如 TensorFlow/Keras)训练模型并导出 SavedModel 格式:
model.save("captcha_model")
结构建议输出 [1, 4, 36](4 个字符,每个位置36分类),输入为 [1, 60, 160, 3]。
三、图像预处理(使用标准 Go 图像库)
import (
"image"
"image/jpeg"
_ "image/png"
"os"
"github.com/nfnt/resize"
)
func loadAndPreprocessImage(path string) ([]float32, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
imgRaw, _, err := image.Decode(file)
if err != nil {
return nil, err
}
// Resize to 160x60
img := resize.Resize(160, 60, imgRaw, resize.Lanczos3)
bounds := img.Bounds()
data := make([]float32, 160*60*3)
i := 0
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
r, g, b, _ := img.At(x, y).RGBA()
data[i] = (float32(r>>8)/255.0 - 0.5) / 0.5
data[i+1] = (float32(g>>8)/255.0 - 0.5) / 0.5
data[i+2] = (float32(b>>8)/255.0 - 0.5) / 0.5
i += 3
}
}
return data, nil
}
四、加载模型并预测
import (
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
func predict(imagePath string) {
model, err := tf.LoadSavedModel("captcha_model", []string{"serve"}, nil)
if err != nil {
panic(err)
}
defer model.Session.Close()
inputData, err := loadAndPreprocessImage(imagePath)
if err != nil {
panic(err)
}
// 创建 Tensor
tensor, err := tf.NewTensor([1][60][160][3]float32{}) // 此处需填充数据,略
copy(tensor.Value().([1][60][160][3]float32)[0][:][0][:], inputData)
// 执行模型
results, err := model.Session.Run(
map[tf.Output]*tf.Tensor{
model.Graph.Operation("serving_default_input").Output(0): tensor,
},
[]tf.Output{
model.Graph.Operation("StatefulPartitionedCall").Output(0),
},
nil,
)
if err != nil {
panic(err)
}
output := results[0].Value().([][][]float32)[0] // shape: [4][36]
chars := "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
prediction := ""
for _, pos := range output {
maxIdx := 0
maxVal := pos[0]
for i, val := range pos {
if val > maxVal {
maxVal = val
maxIdx = i
}
}
prediction += string(chars[maxIdx])
}
println("预测结果:", prediction)
}
浙公网安备 33010602011771号