微信扫一扫打赏支持

Tensorflow2(预课程)---9.1、循环神经网络实现输入一个字母,预测下一个字母

Tensorflow2(预课程)---9.1、循环神经网络实现输入一个字母,预测下一个字母

一、总结

一句话总结:

model = tf.keras.Sequential([SimpleRNN(3),Dense(5, activation='softmax')])

 

 

二、循环神经网络实现输入一个字母,预测下一个字母

博客对应课程的视频位置:

 

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, SimpleRNN
import matplotlib.pyplot as plt
import os
In [2]:
input_word = "abcde"
w_to_id = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4}  # 单词映射到数值id的词典
id_to_onehot = {0: [1., 0., 0., 0., 0.], 1: [0., 1., 0., 0., 0.], 2: [0., 0., 1., 0., 0.], 3: [0., 0., 0., 1., 0.],
                4: [0., 0., 0., 0., 1.]}  # id编码为one-hot
In [3]:
x_train = [id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']],
           id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']]]
y_train = [w_to_id['b'], w_to_id['c'], w_to_id['d'], w_to_id['e'], w_to_id['a']]
In [4]:
print(x_train)
print(y_train)
[[1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]]
[1, 2, 3, 4, 0]
In [5]:
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
In [6]:
# 使x_train符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。
# 此处整个数据集送入,送入样本数为len(x_train);输入1个字母出结果,循环核时间展开步数为1; 表示为独热码有5个输入特征,每个时间步输入特征个数为5
x_train = np.reshape(x_train, (len(x_train), 1, 5))
y_train = np.array(y_train)
In [7]:
# reshape成5,1,5
print(x_train)
print(y_train)
[[[1. 0. 0. 0. 0.]]

 [[0. 0. 0. 1. 0.]]

 [[0. 0. 1. 0. 0.]]

 [[0. 1. 0. 0. 0.]]

 [[0. 0. 0. 0. 1.]]]
[1 4 3 2 0]
In [8]:
model = tf.keras.Sequential([
    SimpleRNN(3),
    Dense(5, activation='softmax')
])
In [9]:
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/rnn_onehot_1pre1.ckpt"
In [10]:
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True,
                                                 monitor='loss')  # 由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型
-------------load the model-----------------
In [11]:
history = model.fit(x_train, y_train, batch_size=32, epochs=100, callbacks=[cp_callback])

model.summary()
Epoch 1/100
1/1 [==============================] - 0s 203ms/step - loss: 0.0397 - sparse_categorical_accuracy: 1.0000
Epoch 2/100
1/1 [==============================] - 0s 175ms/step - loss: 0.0395 - sparse_categorical_accuracy: 1.0000
Epoch 3/100
1/1 [==============================] - 0s 185ms/step - loss: 0.0393 - sparse_categorical_accuracy: 1.0000
Epoch 4/100
1/1 [==============================] - 0s 154ms/step - loss: 0.0391 - sparse_categorical_accuracy: 1.0000
Epoch 5/100
1/1 [==============================] - 0s 177ms/step - loss: 0.0388 - sparse_categorical_accuracy: 1.0000
Epoch 6/100
1/1 [==============================] - 0s 201ms/step - loss: 0.0386 - sparse_categorical_accuracy: 1.0000
Epoch 7/100
1/1 [==============================] - 0s 204ms/step - loss: 0.0384 - sparse_categorical_accuracy: 1.0000
Epoch 8/100
1/1 [==============================] - 0s 193ms/step - loss: 0.0382 - sparse_categorical_accuracy: 1.0000
Epoch 9/100
1/1 [==============================] - 0s 193ms/step - loss: 0.0380 - sparse_categorical_accuracy: 1.0000
Epoch 10/100
1/1 [==============================] - 0s 193ms/step - loss: 0.0378 - sparse_categorical_accuracy: 1.0000
Epoch 11/100
1/1 [==============================] - 0s 151ms/step - loss: 0.0376 - sparse_categorical_accuracy: 1.0000
Epoch 12/100
1/1 [==============================] - 0s 161ms/step - loss: 0.0374 - sparse_categorical_accuracy: 1.0000
Epoch 13/100
1/1 [==============================] - 0s 162ms/step - loss: 0.0372 - sparse_categorical_accuracy: 1.0000
Epoch 14/100
1/1 [==============================] - 0s 187ms/step - loss: 0.0370 - sparse_categorical_accuracy: 1.0000
Epoch 15/100
1/1 [==============================] - 0s 183ms/step - loss: 0.0368 - sparse_categorical_accuracy: 1.0000
Epoch 16/100
1/1 [==============================] - 0s 172ms/step - loss: 0.0366 - sparse_categorical_accuracy: 1.0000
Epoch 17/100
1/1 [==============================] - 0s 317ms/step - loss: 0.0364 - sparse_categorical_accuracy: 1.0000
Epoch 18/100
1/1 [==============================] - 0s 186ms/step - loss: 0.0362 - sparse_categorical_accuracy: 1.0000
Epoch 19/100
1/1 [==============================] - 0s 183ms/step - loss: 0.0360 - sparse_categorical_accuracy: 1.0000
Epoch 20/100
1/1 [==============================] - 0s 153ms/step - loss: 0.0358 - sparse_categorical_accuracy: 1.0000
Epoch 21/100
1/1 [==============================] - 0s 167ms/step - loss: 0.0356 - sparse_categorical_accuracy: 1.0000
Epoch 22/100
1/1 [==============================] - 0s 179ms/step - loss: 0.0354 - sparse_categorical_accuracy: 1.0000
Epoch 23/100
1/1 [==============================] - 0s 170ms/step - loss: 0.0352 - sparse_categorical_accuracy: 1.0000
Epoch 24/100
1/1 [==============================] - 0s 191ms/step - loss: 0.0350 - sparse_categorical_accuracy: 1.0000
Epoch 25/100
1/1 [==============================] - 0s 169ms/step - loss: 0.0348 - sparse_categorical_accuracy: 1.0000
Epoch 26/100
1/1 [==============================] - 0s 169ms/step - loss: 0.0347 - sparse_categorical_accuracy: 1.0000
Epoch 27/100
1/1 [==============================] - 0s 176ms/step - loss: 0.0345 - sparse_categorical_accuracy: 1.0000
Epoch 28/100
1/1 [==============================] - 0s 176ms/step - loss: 0.0343 - sparse_categorical_accuracy: 1.0000
Epoch 29/100
1/1 [==============================] - 0s 151ms/step - loss: 0.0341 - sparse_categorical_accuracy: 1.0000
Epoch 30/100
1/1 [==============================] - 0s 171ms/step - loss: 0.0340 - sparse_categorical_accuracy: 1.0000
Epoch 31/100
1/1 [==============================] - 0s 176ms/step - loss: 0.0338 - sparse_categorical_accuracy: 1.0000
Epoch 32/100
1/1 [==============================] - 0s 184ms/step - loss: 0.0336 - sparse_categorical_accuracy: 1.0000
Epoch 33/100
1/1 [==============================] - 0s 171ms/step - loss: 0.0334 - sparse_categorical_accuracy: 1.0000
Epoch 34/100
1/1 [==============================] - 0s 177ms/step - loss: 0.0333 - sparse_categorical_accuracy: 1.0000
Epoch 35/100
1/1 [==============================] - 0s 192ms/step - loss: 0.0331 - sparse_categorical_accuracy: 1.0000
Epoch 36/100
1/1 [==============================] - 0s 175ms/step - loss: 0.0329 - sparse_categorical_accuracy: 1.0000
Epoch 37/100
1/1 [==============================] - 0s 176ms/step - loss: 0.0328 - sparse_categorical_accuracy: 1.0000
Epoch 38/100
1/1 [==============================] - 0s 200ms/step - loss: 0.0326 - sparse_categorical_accuracy: 1.0000
Epoch 39/100
1/1 [==============================] - 0s 321ms/step - loss: 0.0324 - sparse_categorical_accuracy: 1.0000
Epoch 40/100
1/1 [==============================] - 0s 184ms/step - loss: 0.0323 - sparse_categorical_accuracy: 1.0000
Epoch 41/100
1/1 [==============================] - 0s 211ms/step - loss: 0.0321 - sparse_categorical_accuracy: 1.0000
Epoch 42/100
1/1 [==============================] - 0s 203ms/step - loss: 0.0320 - sparse_categorical_accuracy: 1.0000
Epoch 43/100
1/1 [==============================] - 0s 209ms/step - loss: 0.0318 - sparse_categorical_accuracy: 1.0000
Epoch 44/100
1/1 [==============================] - 0s 193ms/step - loss: 0.0317 - sparse_categorical_accuracy: 1.0000
Epoch 45/100
1/1 [==============================] - 0s 209ms/step - loss: 0.0315 - sparse_categorical_accuracy: 1.0000
Epoch 46/100
1/1 [==============================] - 0s 185ms/step - loss: 0.0313 - sparse_categorical_accuracy: 1.0000
Epoch 47/100
1/1 [==============================] - 0s 187ms/step - loss: 0.0312 - sparse_categorical_accuracy: 1.0000
Epoch 48/100
1/1 [==============================] - 0s 150ms/step - loss: 0.0310 - sparse_categorical_accuracy: 1.0000
Epoch 49/100
1/1 [==============================] - 0s 179ms/step - loss: 0.0309 - sparse_categorical_accuracy: 1.0000
Epoch 50/100
1/1 [==============================] - 0s 179ms/step - loss: 0.0307 - sparse_categorical_accuracy: 1.0000
Epoch 51/100
1/1 [==============================] - 0s 221ms/step - loss: 0.0306 - sparse_categorical_accuracy: 1.0000
Epoch 52/100
1/1 [==============================] - 0s 167ms/step - loss: 0.0304 - sparse_categorical_accuracy: 1.0000
Epoch 53/100
1/1 [==============================] - 0s 178ms/step - loss: 0.0303 - sparse_categorical_accuracy: 1.0000
Epoch 54/100
1/1 [==============================] - 0s 162ms/step - loss: 0.0302 - sparse_categorical_accuracy: 1.0000
Epoch 55/100
1/1 [==============================] - 0s 171ms/step - loss: 0.0300 - sparse_categorical_accuracy: 1.0000
Epoch 56/100
1/1 [==============================] - 0s 325ms/step - loss: 0.0299 - sparse_categorical_accuracy: 1.0000
Epoch 57/100
1/1 [==============================] - 0s 193ms/step - loss: 0.0297 - sparse_categorical_accuracy: 1.0000
Epoch 58/100
1/1 [==============================] - 0s 158ms/step - loss: 0.0296 - sparse_categorical_accuracy: 1.0000
Epoch 59/100
1/1 [==============================] - 0s 192ms/step - loss: 0.0295 - sparse_categorical_accuracy: 1.0000
Epoch 60/100
1/1 [==============================] - 0s 166ms/step - loss: 0.0293 - sparse_categorical_accuracy: 1.0000
Epoch 61/100
1/1 [==============================] - 0s 170ms/step - loss: 0.0292 - sparse_categorical_accuracy: 1.0000
Epoch 62/100
1/1 [==============================] - 0s 175ms/step - loss: 0.0291 - sparse_categorical_accuracy: 1.0000
Epoch 63/100
1/1 [==============================] - 0s 191ms/step - loss: 0.0289 - sparse_categorical_accuracy: 1.0000
Epoch 64/100
1/1 [==============================] - 0s 158ms/step - loss: 0.0288 - sparse_categorical_accuracy: 1.0000
Epoch 65/100
1/1 [==============================] - 0s 191ms/step - loss: 0.0287 - sparse_categorical_accuracy: 1.0000
Epoch 66/100
1/1 [==============================] - 0s 186ms/step - loss: 0.0285 - sparse_categorical_accuracy: 1.0000
Epoch 67/100
1/1 [==============================] - 0s 178ms/step - loss: 0.0284 - sparse_categorical_accuracy: 1.0000
Epoch 68/100
1/1 [==============================] - 0s 187ms/step - loss: 0.0283 - sparse_categorical_accuracy: 1.0000
Epoch 69/100
1/1 [==============================] - 0s 167ms/step - loss: 0.0281 - sparse_categorical_accuracy: 1.0000
Epoch 70/100
1/1 [==============================] - 0s 178ms/step - loss: 0.0280 - sparse_categorical_accuracy: 1.0000
Epoch 71/100
1/1 [==============================] - 0s 177ms/step - loss: 0.0279 - sparse_categorical_accuracy: 1.0000
Epoch 72/100
1/1 [==============================] - 0s 138ms/step - loss: 0.0278 - sparse_categorical_accuracy: 1.0000
Epoch 73/100
1/1 [==============================] - 0s 177ms/step - loss: 0.0276 - sparse_categorical_accuracy: 1.0000
Epoch 74/100
1/1 [==============================] - 0s 178ms/step - loss: 0.0275 - sparse_categorical_accuracy: 1.0000
Epoch 75/100
1/1 [==============================] - 0s 168ms/step - loss: 0.0274 - sparse_categorical_accuracy: 1.0000
Epoch 76/100
1/1 [==============================] - 0s 153ms/step - loss: 0.0273 - sparse_categorical_accuracy: 1.0000
Epoch 77/100
1/1 [==============================] - 0s 237ms/step - loss: 0.0271 - sparse_categorical_accuracy: 1.0000
Epoch 78/100
1/1 [==============================] - 0s 201ms/step - loss: 0.0270 - sparse_categorical_accuracy: 1.0000
Epoch 79/100
1/1 [==============================] - 0s 191ms/step - loss: 0.0269 - sparse_categorical_accuracy: 1.0000
Epoch 80/100
1/1 [==============================] - 0s 192ms/step - loss: 0.0268 - sparse_categorical_accuracy: 1.0000
Epoch 81/100
1/1 [==============================] - 0s 185ms/step - loss: 0.0267 - sparse_categorical_accuracy: 1.0000
Epoch 82/100
1/1 [==============================] - 0s 178ms/step - loss: 0.0266 - sparse_categorical_accuracy: 1.0000
Epoch 83/100
1/1 [==============================] - 0s 178ms/step - loss: 0.0264 - sparse_categorical_accuracy: 1.0000
Epoch 84/100
1/1 [==============================] - 0s 179ms/step - loss: 0.0263 - sparse_categorical_accuracy: 1.0000
Epoch 85/100
1/1 [==============================] - 0s 187ms/step - loss: 0.0262 - sparse_categorical_accuracy: 1.0000
Epoch 86/100
1/1 [==============================] - 0s 207ms/step - loss: 0.0261 - sparse_categorical_accuracy: 1.0000
Epoch 87/100
1/1 [==============================] - 0s 167ms/step - loss: 0.0260 - sparse_categorical_accuracy: 1.0000
Epoch 88/100
1/1 [==============================] - 0s 162ms/step - loss: 0.0259 - sparse_categorical_accuracy: 1.0000
Epoch 89/100
1/1 [==============================] - 0s 178ms/step - loss: 0.0258 - sparse_categorical_accuracy: 1.0000
Epoch 90/100
1/1 [==============================] - 0s 162ms/step - loss: 0.0257 - sparse_categorical_accuracy: 1.0000
Epoch 91/100
1/1 [==============================] - 0s 187ms/step - loss: 0.0256 - sparse_categorical_accuracy: 1.0000
Epoch 92/100
1/1 [==============================] - 0s 209ms/step - loss: 0.0254 - sparse_categorical_accuracy: 1.0000
Epoch 93/100
1/1 [==============================] - 0s 166ms/step - loss: 0.0253 - sparse_categorical_accuracy: 1.0000
Epoch 94/100
1/1 [==============================] - 0s 162ms/step - loss: 0.0252 - sparse_categorical_accuracy: 1.0000
Epoch 95/100
1/1 [==============================] - 0s 154ms/step - loss: 0.0251 - sparse_categorical_accuracy: 1.0000
Epoch 96/100
1/1 [==============================] - 0s 168ms/step - loss: 0.0250 - sparse_categorical_accuracy: 1.0000
Epoch 97/100
1/1 [==============================] - 0s 171ms/step - loss: 0.0249 - sparse_categorical_accuracy: 1.0000
Epoch 98/100
1/1 [==============================] - 0s 152ms/step - loss: 0.0248 - sparse_categorical_accuracy: 1.0000
Epoch 99/100
1/1 [==============================] - 0s 153ms/step - loss: 0.0247 - sparse_categorical_accuracy: 1.0000
Epoch 100/100
1/1 [==============================] - 0s 175ms/step - loss: 0.0246 - sparse_categorical_accuracy: 1.0000
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn (SimpleRNN)       (None, 3)                 27        
_________________________________________________________________
dense (Dense)                (None, 5)                 20        
=================================================================
Total params: 47
Trainable params: 47
Non-trainable params: 0
_________________________________________________________________
In [12]:
# print(model.trainable_variables)
file = open('./weights.txt', 'w')  # 参数提取
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()
In [14]:
###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
loss = history.history['loss']

plt.figure()
plt.plot(acc, label='Training Accuracy')
plt.title('Training Accuracy')
plt.legend()

plt.figure()
plt.plot(loss, label='Training Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

预测的话就是用model.predict预测就好了

In [15]:
############### predict #############

preNum = int(input("input the number of test alphabet:"))
for i in range(preNum):
    alphabet1 = input("input test alphabet:")
    alphabet = [id_to_onehot[w_to_id[alphabet1]]]
    # 使alphabet符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,所以循环核时间展开步数为1; 表示为独热码有5个输入特征,每个时间步输入特征个数为5
    alphabet = np.reshape(alphabet, (1, 1, 5))
    result = model.predict([alphabet])
    pred = tf.argmax(result, axis=1)
    pred = int(pred)
    tf.print(alphabet1 + '->' + input_word[pred])
input the number of test alphabet:5
input test alphabet:a
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'IteratorGetNext:0' shape=(None, 1, 5) dtype=float32>,)
Consider rewriting this model with the Functional API.
a->b
input test alphabet:b
b->c
input test alphabet:c
c->d
input test alphabet:d
d->e
input test alphabet:e
e->a
In [ ]:
 

对应权重信息:

simple_rnn/simple_rnn_cell/kernel:0
(5, 3)
[[ 0.06645484 -2.1043136   1.8370255 ]
 [-2.5618525   1.0372659   1.0856687 ]
 [ 1.462596    1.962184    1.8522223 ]
 [ 0.45230418  0.807459   -1.2672907 ]
 [-1.9447203  -2.139078   -1.4054455 ]]
simple_rnn/simple_rnn_cell/recurrent_kernel:0
(3, 3)
[[-0.87483037 -0.18346582 -0.44834363]
 [-0.311412    0.9219252   0.23038337]
 [ 0.37107185  0.34116596 -0.86366165]]
simple_rnn/simple_rnn_cell/bias:0
(3,)
[ 0.5690928   0.33914855 -0.24334261]
dense/kernel:0
(3, 5)
[[-1.8691967  0.7345443 -2.1057754  2.1691592  1.2987496]
 [-1.8359054 -2.2348397  1.0776519  1.7597938  1.5372584]
 [-2.0739214  2.0273297  1.3114356  2.4741762 -2.2947042]]
dense/bias:0
(5,)
[ 0.07764927  0.34175265  0.44885948 -1.1139487  -0.02294221]

 
posted @ 2020-09-23 10:44  范仁义  阅读(491)  评论(0)    收藏  举报