import tensorflow as tf
from numpy.random import RandomState
# 定义训练数据batch的大小
batch_size = 8
# 定义神经网络的参数
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
# 训练集的输入数据和对应的标签
x = tf.placeholder(tf.float32, shape=(None,2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(None,1), name='y-input')
# 定义神经网络前向传播的过程
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)
# 定义损失函数和反向传播的算法
cross_entropy = -tf.reduce_mean(
y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_start = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
# 通过随机数生成一个模拟数据集
rdm = RandomState(1) # 1为随机种子
dataset_size = 128
X = rdm.rand(dataset_size, 2)
# X = rdm.uniform(3, 4, (2, 3))
'''
X的输出结果
array([[3.26865012, 3.80827801, 3.29528879],
[3.54412138, 3.48792149, 3.85535641]])
'''
Y = [[int(x1+x2 < 1)] for (x1, x2) in X]
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print(sess.run(w1))
print(sess.run(w2))
'''
在训练之前神经网络的参数的值:
w1 = [[-0.8113182 1.4845988 0.06532937]
[-2.4427042 0.0992484 0.5912243 ]]
w2 = [[-0.8113182 ], [ 1.4845988 ], [ 0.06532937]]
'''
trainStep = 5000
for i in range(trainStep):
# 每次选取 batch_size 个样本进行训练。
start = (i * batch_size) % dataset_size
end = min(start+batch_size, dataset_size)
# 通过选取的样本训练神经网络并更新参数
sess.run(train_start, feed_dict={x: X[start: end], y_: Y[start: end]})
if i % 1000 == 0:
total_cross_entropy = sess.run(
cross_entropy, feed_dict= {x: X, y_: Y})
print("After %d training step(s), cross entropy on all data is %g" %(i, total_cross_entropy))
'''
输出结果:
After 0 training step(s), cross entropy on all data is 0.0674925
After 1000 training step(s), cross entropy on all data is 0.0163385
After 2000 training step(s), cross entropy on all data is 0.00907547
After 3000 training step(s), cross entropy on all data is 0.00714436
After 4000 training step(s), cross entropy on all data is 0.00578471
'''
print(sess.run(w1))
print(sess.run(w2))
'''
在训练之后神经网络的参数的值:
w1 = [[-1.9618274 2.582354 1.6820377]
[-3.4681718 1.0698233 2.11789 ]]
w2 = [[-1.8247149], [ 2.6854665], [ 1.418195 ]]
'''