用 Java 和 DL4J 实现图像验证码识别系统
本教程介绍如何使用 Java 和开源深度学习框架 DL4J 构建一个图像验证码识别模型,从数据准备到模型训练与测试,完整实现全过程。
- 环境准备
首先配置 Maven 工程,添加 DL4J 和 ND4J 的依赖:
- 构建数据加载器
使用 ImageRecordReader + ParentPathLabelGenerator 加载图像和标签。
int height = 60;
int width = 160;
int channels = 1;
int outputNum = 36; // 字符集长度(0-9A-Z)
File trainData = new File("data/train");
FileSplit fileSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, new Random(123));
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
recordReader.initialize(fileSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(recordReader, 32, 1, outputNum);
你可以扩展 RecordReader 来按图片名中的字符位置提取四个标签。
- 构建神经网络模型
这里使用简单的 CNN 架构处理验证码图像。
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Adam(0.001))
.list()
.layer(new ConvolutionLayer.Builder(5, 5)
.nIn(channels)
.nOut(32)
.activation(Activation.RELU)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.build())
.layer(new DenseLayer.Builder().nOut(128).activation(Activation.RELU).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nOut(outputNum * 4) // 每个验证码有4个字符
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(height, width, channels))
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(10));
5. 模型训练
for (int i = 0; i < 10; i++) {
model.fit(trainIter);
}
你可以使用 EarlyStoppingTrainer 进一步控制训练策略。
- 模型测试与预测
File testImage = new File("data/test/A3B9_10.png");
NativeImageLoader loader = new NativeImageLoader(height, width, channels);
INDArray image = loader.asMatrix(testImage);
NormalizerStandardize scaler = new NormalizerStandardize();
scaler.transform(image);
INDArray output = model.output(image);
int[] predictions = new int[4];
for (int i = 0; i < 4; i++) {
INDArray slice = output.get(NDArrayIndex.interval(i * outputNum, (i + 1) * outputNum));
predictions[i] = Nd4j.argMax(slice, 1).getInt(0);
}
System.out.println("预测结果:" + Arrays.stream(predictions).mapToObj(i -> characters[i]).collect(Collectors.joining()));
浙公网安备 33010602011771号