使用八股搭建手写数据集神经网络

写在前面

今天是初五,好好的玩了几天后还是回归到了学习的正轨上。今天主要学习了神经网络的搭建八股,使用这种模型搭建了一个训练手写数据集的神经网络

搭建网络八股

六步法:

import

train,test

model = tf.keras.models.Sequential

model.compile

model.fit

model.summary

总的来说,首先导包,然后指定出训练集和测试集。使用tensorflow提供的API搭建好每层神经网络结构,进行compile,指定优化器损失函数和衡量标准。使用fit函数来训练神经网络,最后使用summary来输出训练结果。

训练手写数据集

先来看代码:

import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)

model.summary()

代码不长,我们是严格按照六步法来搭建神经网络,可以看到十分简单。核心部分就是指定神经网络结构。

总结

总的来说,使用这种方法搭建神经网络还是十分简单的,但其中的原理一定要好好理解清楚。

posted @ 2021-02-16 19:23  武神酱丶  阅读(128)  评论(0编辑  收藏  举报