JavaCnn项目注解

JavaCnn项目注解

该JavaCnn项目旨在用Java语言构造一个完整的卷积神经网络,实现训练一个手写字符识别模型,并预测。该项目可以帮助我们深入到Cnn的底层原理实现,通过阅读分析该项目代码,既可以提高对Java语言的掌握,也加深了对卷积神经网络的认识。

虽然项目的功能是“识别”,但其本质上,是一个分类的过程。

项目的入口是RunCnn类,Main()函数里开了个定时器,并根据CPU核数分了线程数。// Todo

项目分训练和预测两个模块:

训练步骤:

  1. 创建新模型
  2. 载入训练集
  3. 载入测试集
  4. 调用cnn对象的train()方法

预测步骤:

  1. 调用cnn对象的loadModels()方法载入模型,该方法返回Cnn对象
  2. 载入测试集
  3. 测试集初始化 // Todo
  4. 调用cnn对象的predict()方法

predict()方法详解

  1. 保存模型每一层的是否开启Dropout状态,存到save[]数组中;
  2. 关掉每一层的dropout,确保预测的时候所有权重都参与计算;
  3. 初始化batch的记录为0。 // Todo
  4. 对于每一张图片,都进行一次forward()正向传播计算。
  5. 对于每一张图片,正向传播的输出是 x个数(x是样本类别数),将这10个数字存入分类预测结果的数组中。
  6. 从分类预测结果的数组(样本一共有x类,数组的长度就是x)中取出数组最大值对应的下标,将其和图片对应的label对比,若值相等,则正确个数 +1 。
  7. 计算正常率: 正确个数/总测试图片数*100%。
  8. 将每一层的dropout状态变回原来的状态,使训练过程得以继续。

Record类注解

一张图片对应了一个record实例对象,Record类由两个属性组成,数组attrs[]保存一张图片的所有像素值,像素值进行了归一化处理,范围为0~1。

	public Record(double[] data)
        {
	  lable = data[data.length-1];
	  attrs = Arrays.copyOfRange(data, 0, data.length-1);
	}
posted @ 2021-05-28 19:03  梁君牧  阅读(155)  评论(0编辑  收藏  举报