学习进度笔记22

今天通过观看老师分享的TensorFlow教学视频,学习了训练神经网络的简单代码:import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data/',one_hot=True)

h1 = 256
h2 = 128
input = 784
n_class = 10

x = tf.placeholder("float",[None,input])
y = tf.placeholder("float",[None,n_class])

stddev = 0.1
weights = {
'w1':tf.Variable(tf.random_normal([input,h1],stddev=stddev)),
'w2':tf.Variable(tf.random_normal([h1,h2],stddev=stddev)),
'out':tf.Variable(tf.random_normal([h2,n_class],stddev=stddev))
}
biases = {
'b1':tf.Variable(tf.random_normal([h1])),
'b2':tf.Variable(tf.random_normal([h2])),
'out':tf.Variable(tf.random_normal([n_class]))
}
print("NETWORK READY")

def multilayer_perceptron(_X,_weights,_biases):
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,_weights['w2']),_biases['b2']))
return (tf.matmul(layer_2,_weights['out']) + _biases['out'])

pred = multilayer_perceptron(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))

init = tf.global_variables_initializer()
print("FUNCTIONS READY")

training_epochs = 20
batch_size = 100
display_step = 4
sess = tf.Session()
sess.run(init)
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_x,batch_y = mnist.train.next_batch(batch_size)
feeds = {x:batch_x,y:batch_y}
sess.run(optm,feed_dict=feeds)
avg_cost += sess.run(cost,feed_dict=feeds)
avg_cost = avg_cost / total_batch
if (epoch + 1) % display_step == 0:
print("Epoch: %03d/%03d cost: %.9f" % (epoch,training_epochs,avg_cost))
feeds = {x:batch_x,y:batch_y}
train_acc = sess.run(accr,feed_dict=feeds)
print("TRAIN ACCURACY: %.3f" % (train_acc))
feeds = {x:mnist.test.images,y:mnist.test.labels}
test_acc = sess.run(accr,feed_dict=feeds)
print("TEST ACCURACY: %.3f" % (test_acc))
print("FINISHED")
训练结果:

 

 

 
posted @ 2021-02-02 12:22  20183602  阅读(34)  评论(0编辑  收藏  举报