用 Go 和 TensorFlow 实现图像验证码识别

本教程介绍如何在 Go 中加载一个训练好的 TensorFlow 模型,对验证码图像进行识别,适用于构建轻量级识别服务或 CLI 工具。

一、环境准备

  1. 安装 TensorFlow Go
    TensorFlow 官方提供 Go API,但需安装 C 语言版本的 TensorFlow 依赖:

安装 TensorFlow C 库(Ubuntu 示例)

sudo apt-get install libtensorflow
然后安装 Go 包:
更多内容访问ttocr.com或联系1436423940
go get github.com/tensorflow/tensorflow/tensorflow/go
二、模型准备
使用 Python(如 TensorFlow/Keras)训练模型并导出 SavedModel 格式:

model.save("captcha_model")
结构建议输出 [1, 4, 36](4 个字符,每个位置36分类),输入为 [1, 60, 160, 3]。

三、图像预处理(使用标准 Go 图像库)

import (
"image"
"image/jpeg"
_ "image/png"
"os"
"github.com/nfnt/resize"
)

func loadAndPreprocessImage(path string) ([]float32, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()

imgRaw, _, err := image.Decode(file)
if err != nil {
	return nil, err
}

// Resize to 160x60
img := resize.Resize(160, 60, imgRaw, resize.Lanczos3)
bounds := img.Bounds()

data := make([]float32, 160*60*3)
i := 0
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
	for x := bounds.Min.X; x < bounds.Max.X; x++ {
		r, g, b, _ := img.At(x, y).RGBA()
		data[i] = (float32(r>>8)/255.0 - 0.5) / 0.5
		data[i+1] = (float32(g>>8)/255.0 - 0.5) / 0.5
		data[i+2] = (float32(b>>8)/255.0 - 0.5) / 0.5
		i += 3
	}
}

return data, nil

}
四、加载模型并预测

import (
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)

func predict(imagePath string) {
model, err := tf.LoadSavedModel("captcha_model", []string{"serve"}, nil)
if err != nil {
panic(err)
}
defer model.Session.Close()

inputData, err := loadAndPreprocessImage(imagePath)
if err != nil {
	panic(err)
}

// 创建 Tensor
tensor, err := tf.NewTensor([1][60][160][3]float32{}) // 此处需填充数据,略
copy(tensor.Value().([1][60][160][3]float32)[0][:][0][:], inputData)

// 执行模型
results, err := model.Session.Run(
	map[tf.Output]*tf.Tensor{
		model.Graph.Operation("serving_default_input").Output(0): tensor,
	},
	[]tf.Output{
		model.Graph.Operation("StatefulPartitionedCall").Output(0),
	},
	nil,
)

if err != nil {
	panic(err)
}

output := results[0].Value().([][][]float32)[0] // shape: [4][36]
chars := "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
prediction := ""
for _, pos := range output {
	maxIdx := 0
	maxVal := pos[0]
	for i, val := range pos {
		if val > maxVal {
			maxVal = val
			maxIdx = i
		}
	}
	prediction += string(chars[maxIdx])
}

println("预测结果:", prediction)

}

posted @ 2025-05-27 17:27  ttocr、com  阅读(10)  评论(0)    收藏  举报