微信扫一扫打赏支持

Tensorflow2(预课程)---10.1、循环神经网络实现4个字母预测1个字母

Tensorflow2(预课程)---10.1、循环神经网络实现4个字母预测1个字母

一、总结

一句话总结:

网络还是一样的网络,输入数据变了而已:model = tf.keras.Sequential([SimpleRNN(3),Dense(5, activation='softmax')])
print(x_train)
print(y_train)
[[[1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]]

 [[0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 1.]
  [1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]]

 [[0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 1.]
  [1. 0. 0. 0. 0.]]

 [[0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 1.]]

 [[0. 0. 0. 0. 1.]
  [1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]]]
[4 2 1 0 3]
model = tf.keras.Sequential([
    SimpleRNN(3),
    Dense(5, activation='softmax')
])

 

 

二、循环神经网络实现4个字母预测1个字母

博客对应课程的视频位置:

 

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, SimpleRNN
import matplotlib.pyplot as plt
import os
In [2]:
input_word = "abcde"
w_to_id = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4}  # 单词映射到数值id的词典
id_to_onehot = {0: [1., 0., 0., 0., 0.], 1: [0., 1., 0., 0., 0.], 2: [0., 0., 1., 0., 0.], 3: [0., 0., 0., 1., 0.],
                4: [0., 0., 0., 0., 1.]}  # id编码为one-hot
In [3]:
x_train = [
    [id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']]],
    [id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']]],
    [id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']]],
    [id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']]],
    [id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']]],
]
y_train = [w_to_id['e'], w_to_id['a'], w_to_id['b'], w_to_id['c'], w_to_id['d']]
In [4]:
print(x_train)
print(y_train)
[[[1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0]], [[0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]], [[0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0]]]
[4, 0, 1, 2, 3]
In [5]:
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
In [6]:
# 使x_train符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。
# 此处整个数据集送入,送入样本数为len(x_train);输入4个字母出结果,循环核时间展开步数为4; 表示为独热码有5个输入特征,每个时间步输入特征个数为5
x_train = np.reshape(x_train, (len(x_train), 4, 5))
y_train = np.array(y_train)
In [7]:
print(x_train)
print(y_train)
[[[1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]]

 [[0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 1.]
  [1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]]

 [[0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 1.]
  [1. 0. 0. 0. 0.]]

 [[0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 1.]]

 [[0. 0. 0. 0. 1.]
  [1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]]]
[4 2 1 0 3]
In [8]:
model = tf.keras.Sequential([
    SimpleRNN(3),
    Dense(5, activation='softmax')
])
In [9]:
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/rnn_onehot_4pre1.ckpt"
In [10]:
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True,
                                                 monitor='loss')  # 由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型
-------------load the model-----------------
In [11]:
history = model.fit(x_train, y_train, batch_size=32, epochs=50, callbacks=[cp_callback])

model.summary()
Epoch 1/50
1/1 [==============================] - 0s 216ms/step - loss: 0.2844 - sparse_categorical_accuracy: 1.0000
Epoch 2/50
1/1 [==============================] - 0s 178ms/step - loss: 0.2795 - sparse_categorical_accuracy: 1.0000
Epoch 3/50
1/1 [==============================] - 0s 180ms/step - loss: 0.2748 - sparse_categorical_accuracy: 1.0000
Epoch 4/50
1/1 [==============================] - 0s 201ms/step - loss: 0.2703 - sparse_categorical_accuracy: 1.0000
Epoch 5/50
1/1 [==============================] - 0s 205ms/step - loss: 0.2658 - sparse_categorical_accuracy: 1.0000
Epoch 6/50
1/1 [==============================] - 0s 197ms/step - loss: 0.2614 - sparse_categorical_accuracy: 1.0000
Epoch 7/50
1/1 [==============================] - 0s 187ms/step - loss: 0.2572 - sparse_categorical_accuracy: 1.0000
Epoch 8/50
1/1 [==============================] - 0s 192ms/step - loss: 0.2530 - sparse_categorical_accuracy: 1.0000
Epoch 9/50
1/1 [==============================] - 0s 214ms/step - loss: 0.2490 - sparse_categorical_accuracy: 1.0000
Epoch 10/50
1/1 [==============================] - 0s 172ms/step - loss: 0.2451 - sparse_categorical_accuracy: 1.0000
Epoch 11/50
1/1 [==============================] - 0s 196ms/step - loss: 0.2412 - sparse_categorical_accuracy: 1.0000
Epoch 12/50
1/1 [==============================] - 0s 200ms/step - loss: 0.2375 - sparse_categorical_accuracy: 1.0000
Epoch 13/50
1/1 [==============================] - 0s 205ms/step - loss: 0.2338 - sparse_categorical_accuracy: 1.0000
Epoch 14/50
1/1 [==============================] - 0s 296ms/step - loss: 0.2302 - sparse_categorical_accuracy: 1.0000
Epoch 15/50
1/1 [==============================] - 0s 183ms/step - loss: 0.2268 - sparse_categorical_accuracy: 1.0000
Epoch 16/50
1/1 [==============================] - 0s 204ms/step - loss: 0.2234 - sparse_categorical_accuracy: 1.0000
Epoch 17/50
1/1 [==============================] - 0s 178ms/step - loss: 0.2200 - sparse_categorical_accuracy: 1.0000
Epoch 18/50
1/1 [==============================] - 0s 181ms/step - loss: 0.2168 - sparse_categorical_accuracy: 1.0000
Epoch 19/50
1/1 [==============================] - 0s 179ms/step - loss: 0.2136 - sparse_categorical_accuracy: 1.0000
Epoch 20/50
1/1 [==============================] - 0s 243ms/step - loss: 0.2106 - sparse_categorical_accuracy: 1.0000
Epoch 21/50
1/1 [==============================] - 0s 192ms/step - loss: 0.2075 - sparse_categorical_accuracy: 1.0000
Epoch 22/50
1/1 [==============================] - 0s 272ms/step - loss: 0.2046 - sparse_categorical_accuracy: 1.0000
Epoch 23/50
1/1 [==============================] - 0s 226ms/step - loss: 0.2017 - sparse_categorical_accuracy: 1.0000
Epoch 24/50
1/1 [==============================] - 0s 209ms/step - loss: 0.1989 - sparse_categorical_accuracy: 1.0000
Epoch 25/50
1/1 [==============================] - 0s 210ms/step - loss: 0.1962 - sparse_categorical_accuracy: 1.0000
Epoch 26/50
1/1 [==============================] - 0s 151ms/step - loss: 0.1935 - sparse_categorical_accuracy: 1.0000
Epoch 27/50
1/1 [==============================] - 0s 187ms/step - loss: 0.1909 - sparse_categorical_accuracy: 1.0000
Epoch 28/50
1/1 [==============================] - 0s 189ms/step - loss: 0.1883 - sparse_categorical_accuracy: 1.0000
Epoch 29/50
1/1 [==============================] - 0s 188ms/step - loss: 0.1858 - sparse_categorical_accuracy: 1.0000
Epoch 30/50
1/1 [==============================] - 0s 189ms/step - loss: 0.1833 - sparse_categorical_accuracy: 1.0000
Epoch 31/50
1/1 [==============================] - 0s 189ms/step - loss: 0.1809 - sparse_categorical_accuracy: 1.0000
Epoch 32/50
1/1 [==============================] - 0s 171ms/step - loss: 0.1786 - sparse_categorical_accuracy: 1.0000
Epoch 33/50
1/1 [==============================] - 0s 163ms/step - loss: 0.1763 - sparse_categorical_accuracy: 1.0000
Epoch 34/50
1/1 [==============================] - 0s 174ms/step - loss: 0.1740 - sparse_categorical_accuracy: 1.0000
Epoch 35/50
1/1 [==============================] - 0s 184ms/step - loss: 0.1718 - sparse_categorical_accuracy: 1.0000
Epoch 36/50
1/1 [==============================] - 0s 173ms/step - loss: 0.1697 - sparse_categorical_accuracy: 1.0000
Epoch 37/50
1/1 [==============================] - 0s 184ms/step - loss: 0.1676 - sparse_categorical_accuracy: 1.0000
Epoch 38/50
1/1 [==============================] - 0s 191ms/step - loss: 0.1655 - sparse_categorical_accuracy: 1.0000
Epoch 39/50
1/1 [==============================] - 0s 180ms/step - loss: 0.1635 - sparse_categorical_accuracy: 1.0000
Epoch 40/50
1/1 [==============================] - 0s 242ms/step - loss: 0.1615 - sparse_categorical_accuracy: 1.0000
Epoch 41/50
1/1 [==============================] - 0s 222ms/step - loss: 0.1595 - sparse_categorical_accuracy: 1.0000
Epoch 42/50
1/1 [==============================] - 0s 194ms/step - loss: 0.1576 - sparse_categorical_accuracy: 1.0000
Epoch 43/50
1/1 [==============================] - 0s 166ms/step - loss: 0.1558 - sparse_categorical_accuracy: 1.0000
Epoch 44/50
1/1 [==============================] - 0s 187ms/step - loss: 0.1540 - sparse_categorical_accuracy: 1.0000
Epoch 45/50
1/1 [==============================] - 0s 213ms/step - loss: 0.1522 - sparse_categorical_accuracy: 1.0000
Epoch 46/50
1/1 [==============================] - 0s 185ms/step - loss: 0.1504 - sparse_categorical_accuracy: 1.0000
Epoch 47/50
1/1 [==============================] - 0s 195ms/step - loss: 0.1487 - sparse_categorical_accuracy: 1.0000
Epoch 48/50
1/1 [==============================] - 0s 168ms/step - loss: 0.1470 - sparse_categorical_accuracy: 1.0000
Epoch 49/50
1/1 [==============================] - 0s 163ms/step - loss: 0.1453 - sparse_categorical_accuracy: 1.0000
Epoch 50/50
1/1 [==============================] - 0s 183ms/step - loss: 0.1437 - sparse_categorical_accuracy: 1.0000
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn (SimpleRNN)       (None, 3)                 27        
_________________________________________________________________
dense (Dense)                (None, 5)                 20        
=================================================================
Total params: 47
Trainable params: 47
Non-trainable params: 0
_________________________________________________________________
In [12]:
# print(model.trainable_variables)
file = open('./weights.txt', 'w')  # 参数提取
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()
In [13]:
# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
loss = history.history['loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.title('Training Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.title('Training Loss')
plt.legend()
plt.show()
In [14]:
############### predict #############

preNum = int(input("input the number of test alphabet:"))
for i in range(preNum):
    alphabet1 = input("input test alphabet:")
    alphabet = [id_to_onehot[w_to_id[a]] for a in alphabet1]
    # 使alphabet符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。此处验证效果送入了1个样本,送入样本数为1;输入4个字母出结果,所以循环核时间展开步数为4; 表示为独热码有5个输入特征,每个时间步输入特征个数为5
    alphabet = np.reshape(alphabet, (1, 4, 5))
    result = model.predict([alphabet])
    pred = tf.argmax(result, axis=1)
    pred = int(pred)
    tf.print(alphabet1 + '->' + input_word[pred])
input the number of test alphabet:5
input test alphabet:abcd
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'tuple'> input: (<tf.Tensor 'IteratorGetNext:0' shape=(None, 4, 5) dtype=float32>,)
Consider rewriting this model with the Functional API.
abcd->e
input test alphabet:bcde
bcde->a
input test alphabet:cdea
cdea->b
input test alphabet:deab
deab->c
input test alphabet:eabc
eabc->d
In [ ]:
 

 

 

 
posted @ 2020-09-24 08:27  范仁义  阅读(487)  评论(0)    收藏  举报