Java + DL4J 构建图像验证码识别系统
本教程展示如何在 Java 中使用 DeepLearning4J 构建、训练并测试一个用于识别图像验证码的深度学习模型。
一、环境准备
确保你已安装以下工具:
JDK 8 或以上
Maven
IDE(如 IntelliJ IDEA)
Python(用于生成验证码数据)
Maven 依赖:
更多内容访问ttocr.com或联系1436423940
二、生成训练数据(用 Python)
用 Python 的 captcha 库生成训练数据:
from captcha.image import ImageCaptcha
import random, string, os
characters = string.digits + string.ascii_uppercase
generator = ImageCaptcha(width=160, height=60)
os.makedirs("java_captcha", exist_ok=True)
for i in range(10000):
label = ''.join(random.choices(characters, k=4))
img = generator.generate_image(label)
img.save(f"java_captcha/{label}_{i}.png")
三、加载数据集
使用 ImageRecordReader 读取数据,标签来自文件名:
int height = 60, width = 160, channels = 1;
int outputNum = 36; // 0-9 + A-Z
int batchSize = 64;
int seed = 123;
File trainData = new File("java_captcha");
FileSplit fileSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, new Random(seed));
BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(seed), null);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
recordReader.initialize(fileSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum);
四、构建模型(CNN)
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.updater(new Adam(0.001))
.list()
.layer(new ConvolutionLayer.Builder(3, 3).nIn(channels).nOut(32).activation(Activation.RELU).build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).build())
.layer(new ConvolutionLayer.Builder(3, 3).nOut(64).activation(Activation.RELU).build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).build())
.layer(new DenseLayer.Builder().nOut(256).activation(Activation.RELU).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(outputNum * 4) // 4个字符
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels))
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
五、训练模型
for (int i = 0; i < 10; i++) {
model.fit(trainIter);
}
ModelSerializer.writeModel(model, new File("captcha_model.zip"), true);
六、预测验证码
NativeImageLoader loader = new NativeImageLoader(height, width, channels);
INDArray image = loader.asMatrix(new File("java_captcha/A3B9_0.png"));
INDArray output = model.output(image);
int[] predicted = new int[4];
for (int i = 0; i < 4; i++) {
INDArray slice = output.get(NDArrayIndex.point(0), NDArrayIndex.interval(i * 36, (i + 1) * 36));
predicted[i] = Nd4j.argMax(slice, 1).getInt(0);
}
String chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
StringBuilder result = new StringBuilder();
for (int idx : predicted) result.append(chars.charAt(idx));
System.out.println("预测结果: " + result.toString());
浙公网安备 33010602011771号