微信扫一扫打赏支持

Tensorflow2(预课程)---9.2、循环神经网络实现输入一个字母,预测下一个字母(embedding)

Tensorflow2(预课程)---9.2、循环神经网络实现输入一个字母,预测下一个字母(embedding)

一、总结

一句话总结:

加上Embedding层即可,数据的输入维度变一下
print(x_train)
print(y_train)
[[0]
 [3]
 [2]
 [1]
 [4]]
[1 4 3 2 0]
model = tf.keras.Sequential([
    Embedding(5, 2), 
    SimpleRNN(3),
    Dense(5, activation='softmax')
])
# Embedding(5, 2):5个输入,2维来表示

 

 

二、循环神经网络实现输入一个字母,预测下一个字母(embedding)

博客对应课程的视频位置:

 

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, SimpleRNN, Embedding
import matplotlib.pyplot as plt
import os
In [2]:
input_word = "abcde"
w_to_id = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4}  # 单词映射到数值id的词典

x_train = [w_to_id['a'], w_to_id['b'], w_to_id['c'], w_to_id['d'], w_to_id['e']]
y_train = [w_to_id['b'], w_to_id['c'], w_to_id['d'], w_to_id['e'], w_to_id['a']]
In [3]:
print(x_train)
print(y_train)
[0, 1, 2, 3, 4]
[1, 2, 3, 4, 0]
In [4]:
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
In [5]:
# 使x_train符合Embedding输入要求:[送入样本数, 循环核时间展开步数] ,
# 此处整个数据集送入所以送入,送入样本数为len(x_train);输入1个字母出结果,循环核时间展开步数为1。
x_train = np.reshape(x_train, (len(x_train), 1))
y_train = np.array(y_train)
In [6]:
print(x_train)
print(y_train)
[[0]
 [3]
 [2]
 [1]
 [4]]
[1 4 3 2 0]
In [7]:
model = tf.keras.Sequential([
    Embedding(5, 2), 
    SimpleRNN(3),
    Dense(5, activation='softmax')
])
# Embedding(5, 2):5个输入,2维来表示
In [8]:
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/run_embedding_1pre1.ckpt"

if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True,
                                                 monitor='loss')  # 由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型
In [9]:
history = model.fit(x_train, y_train, batch_size=32, epochs=100, callbacks=[cp_callback])

model.summary()
Epoch 1/100
1/1 [==============================] - 0s 220ms/step - loss: 1.6253 - sparse_categorical_accuracy: 0.0000e+00
Epoch 2/100
1/1 [==============================] - 0s 231ms/step - loss: 1.6180 - sparse_categorical_accuracy: 0.0000e+00
Epoch 3/100
1/1 [==============================] - 0s 192ms/step - loss: 1.6113 - sparse_categorical_accuracy: 0.0000e+00
Epoch 4/100
1/1 [==============================] - 0s 186ms/step - loss: 1.6049 - sparse_categorical_accuracy: 0.2000
Epoch 5/100
1/1 [==============================] - 0s 194ms/step - loss: 1.5985 - sparse_categorical_accuracy: 0.2000
Epoch 6/100
1/1 [==============================] - 0s 204ms/step - loss: 1.5919 - sparse_categorical_accuracy: 0.2000
Epoch 7/100
1/1 [==============================] - 0s 177ms/step - loss: 1.5851 - sparse_categorical_accuracy: 0.4000
Epoch 8/100
1/1 [==============================] - 0s 185ms/step - loss: 1.5778 - sparse_categorical_accuracy: 0.4000
Epoch 9/100
1/1 [==============================] - 0s 188ms/step - loss: 1.5699 - sparse_categorical_accuracy: 0.4000
Epoch 10/100
1/1 [==============================] - 0s 178ms/step - loss: 1.5615 - sparse_categorical_accuracy: 0.4000
Epoch 11/100
1/1 [==============================] - 0s 161ms/step - loss: 1.5524 - sparse_categorical_accuracy: 0.2000
Epoch 12/100
1/1 [==============================] - 0s 193ms/step - loss: 1.5426 - sparse_categorical_accuracy: 0.2000
Epoch 13/100
1/1 [==============================] - 0s 194ms/step - loss: 1.5320 - sparse_categorical_accuracy: 0.4000
Epoch 14/100
1/1 [==============================] - 0s 168ms/step - loss: 1.5207 - sparse_categorical_accuracy: 0.4000
Epoch 15/100
1/1 [==============================] - 0s 195ms/step - loss: 1.5087 - sparse_categorical_accuracy: 0.4000
Epoch 16/100
1/1 [==============================] - 0s 178ms/step - loss: 1.4958 - sparse_categorical_accuracy: 0.2000
Epoch 17/100
1/1 [==============================] - 0s 173ms/step - loss: 1.4821 - sparse_categorical_accuracy: 0.2000
Epoch 18/100
1/1 [==============================] - 0s 152ms/step - loss: 1.4676 - sparse_categorical_accuracy: 0.2000
Epoch 19/100
1/1 [==============================] - 0s 158ms/step - loss: 1.4524 - sparse_categorical_accuracy: 0.2000
Epoch 20/100
1/1 [==============================] - 0s 193ms/step - loss: 1.4363 - sparse_categorical_accuracy: 0.4000
Epoch 21/100
1/1 [==============================] - 0s 185ms/step - loss: 1.4194 - sparse_categorical_accuracy: 0.4000
Epoch 22/100
1/1 [==============================] - 0s 188ms/step - loss: 1.4019 - sparse_categorical_accuracy: 0.4000
Epoch 23/100
1/1 [==============================] - 0s 184ms/step - loss: 1.3837 - sparse_categorical_accuracy: 0.4000
Epoch 24/100
1/1 [==============================] - 0s 302ms/step - loss: 1.3649 - sparse_categorical_accuracy: 0.4000
Epoch 25/100
1/1 [==============================] - 0s 299ms/step - loss: 1.3457 - sparse_categorical_accuracy: 0.4000
Epoch 26/100
1/1 [==============================] - 0s 190ms/step - loss: 1.3260 - sparse_categorical_accuracy: 0.4000
Epoch 27/100
1/1 [==============================] - 0s 179ms/step - loss: 1.3061 - sparse_categorical_accuracy: 0.4000
Epoch 28/100
1/1 [==============================] - 0s 178ms/step - loss: 1.2861 - sparse_categorical_accuracy: 0.4000
Epoch 29/100
1/1 [==============================] - 0s 184ms/step - loss: 1.2660 - sparse_categorical_accuracy: 0.4000
Epoch 30/100
1/1 [==============================] - 0s 160ms/step - loss: 1.2460 - sparse_categorical_accuracy: 0.4000
Epoch 31/100
1/1 [==============================] - 0s 187ms/step - loss: 1.2262 - sparse_categorical_accuracy: 0.4000
Epoch 32/100
1/1 [==============================] - 0s 199ms/step - loss: 1.2068 - sparse_categorical_accuracy: 0.4000
Epoch 33/100
1/1 [==============================] - 0s 193ms/step - loss: 1.1878 - sparse_categorical_accuracy: 0.4000
Epoch 34/100
1/1 [==============================] - 0s 194ms/step - loss: 1.1694 - sparse_categorical_accuracy: 0.4000
Epoch 35/100
1/1 [==============================] - 0s 193ms/step - loss: 1.1517 - sparse_categorical_accuracy: 0.4000
Epoch 36/100
1/1 [==============================] - 0s 189ms/step - loss: 1.1347 - sparse_categorical_accuracy: 0.4000
Epoch 37/100
1/1 [==============================] - 0s 184ms/step - loss: 1.1185 - sparse_categorical_accuracy: 0.4000
Epoch 38/100
1/1 [==============================] - 0s 187ms/step - loss: 1.1031 - sparse_categorical_accuracy: 0.4000
Epoch 39/100
1/1 [==============================] - 0s 178ms/step - loss: 1.0885 - sparse_categorical_accuracy: 0.4000
Epoch 40/100
1/1 [==============================] - 0s 182ms/step - loss: 1.0748 - sparse_categorical_accuracy: 0.4000
Epoch 41/100
1/1 [==============================] - 0s 176ms/step - loss: 1.0618 - sparse_categorical_accuracy: 0.4000
Epoch 42/100
1/1 [==============================] - 0s 188ms/step - loss: 1.0497 - sparse_categorical_accuracy: 0.4000
Epoch 43/100
1/1 [==============================] - 0s 169ms/step - loss: 1.0383 - sparse_categorical_accuracy: 0.4000
Epoch 44/100
1/1 [==============================] - 0s 178ms/step - loss: 1.0276 - sparse_categorical_accuracy: 0.4000
Epoch 45/100
1/1 [==============================] - 0s 170ms/step - loss: 1.0176 - sparse_categorical_accuracy: 0.6000
Epoch 46/100
1/1 [==============================] - 0s 160ms/step - loss: 1.0082 - sparse_categorical_accuracy: 0.6000
Epoch 47/100
1/1 [==============================] - 0s 174ms/step - loss: 0.9994 - sparse_categorical_accuracy: 0.6000
Epoch 48/100
1/1 [==============================] - 0s 243ms/step - loss: 0.9911 - sparse_categorical_accuracy: 0.6000
Epoch 49/100
1/1 [==============================] - 0s 169ms/step - loss: 0.9833 - sparse_categorical_accuracy: 0.6000
Epoch 50/100
1/1 [==============================] - 0s 170ms/step - loss: 0.9759 - sparse_categorical_accuracy: 0.6000
Epoch 51/100
1/1 [==============================] - 0s 212ms/step - loss: 0.9688 - sparse_categorical_accuracy: 0.6000
Epoch 52/100
1/1 [==============================] - 0s 167ms/step - loss: 0.9620 - sparse_categorical_accuracy: 0.6000
Epoch 53/100
1/1 [==============================] - 0s 167ms/step - loss: 0.9555 - sparse_categorical_accuracy: 0.8000
Epoch 54/100
1/1 [==============================] - 0s 179ms/step - loss: 0.9492 - sparse_categorical_accuracy: 0.8000
Epoch 55/100
1/1 [==============================] - 0s 186ms/step - loss: 0.9431 - sparse_categorical_accuracy: 0.8000
Epoch 56/100
1/1 [==============================] - 0s 185ms/step - loss: 0.9371 - sparse_categorical_accuracy: 0.8000
Epoch 57/100
1/1 [==============================] - 0s 192ms/step - loss: 0.9312 - sparse_categorical_accuracy: 0.8000
Epoch 58/100
1/1 [==============================] - 0s 184ms/step - loss: 0.9254 - sparse_categorical_accuracy: 0.8000
Epoch 59/100
1/1 [==============================] - 0s 154ms/step - loss: 0.9197 - sparse_categorical_accuracy: 0.8000
Epoch 60/100
1/1 [==============================] - 0s 195ms/step - loss: 0.9139 - sparse_categorical_accuracy: 0.8000
Epoch 61/100
1/1 [==============================] - 0s 185ms/step - loss: 0.9082 - sparse_categorical_accuracy: 0.8000
Epoch 62/100
1/1 [==============================] - 0s 187ms/step - loss: 0.9025 - sparse_categorical_accuracy: 0.8000
Epoch 63/100
1/1 [==============================] - 0s 193ms/step - loss: 0.8967 - sparse_categorical_accuracy: 0.8000
Epoch 64/100
1/1 [==============================] - 0s 184ms/step - loss: 0.8909 - sparse_categorical_accuracy: 0.8000
Epoch 65/100
1/1 [==============================] - 0s 361ms/step - loss: 0.8851 - sparse_categorical_accuracy: 0.8000
Epoch 66/100
1/1 [==============================] - 0s 194ms/step - loss: 0.8792 - sparse_categorical_accuracy: 0.8000
Epoch 67/100
1/1 [==============================] - 0s 178ms/step - loss: 0.8733 - sparse_categorical_accuracy: 0.8000
Epoch 68/100
1/1 [==============================] - 0s 194ms/step - loss: 0.8673 - sparse_categorical_accuracy: 0.8000
Epoch 69/100
1/1 [==============================] - 0s 238ms/step - loss: 0.8612 - sparse_categorical_accuracy: 0.8000
Epoch 70/100
1/1 [==============================] - 0s 234ms/step - loss: 0.8551 - sparse_categorical_accuracy: 0.8000
Epoch 71/100
1/1 [==============================] - 0s 278ms/step - loss: 0.8490 - sparse_categorical_accuracy: 0.8000
Epoch 72/100
1/1 [==============================] - 0s 170ms/step - loss: 0.8428 - sparse_categorical_accuracy: 0.8000
Epoch 73/100
1/1 [==============================] - 0s 145ms/step - loss: 0.8365 - sparse_categorical_accuracy: 0.8000
Epoch 74/100
1/1 [==============================] - 0s 177ms/step - loss: 0.8301 - sparse_categorical_accuracy: 0.8000
Epoch 75/100
1/1 [==============================] - 0s 160ms/step - loss: 0.8237 - sparse_categorical_accuracy: 0.8000
Epoch 76/100
1/1 [==============================] - 0s 151ms/step - loss: 0.8172 - sparse_categorical_accuracy: 0.8000
Epoch 77/100
1/1 [==============================] - 0s 177ms/step - loss: 0.8106 - sparse_categorical_accuracy: 0.8000
Epoch 78/100
1/1 [==============================] - 0s 187ms/step - loss: 0.8039 - sparse_categorical_accuracy: 0.8000
Epoch 79/100
1/1 [==============================] - 0s 167ms/step - loss: 0.7971 - sparse_categorical_accuracy: 0.8000
Epoch 80/100
1/1 [==============================] - 0s 171ms/step - loss: 0.7901 - sparse_categorical_accuracy: 0.8000
Epoch 81/100
1/1 [==============================] - 0s 179ms/step - loss: 0.7831 - sparse_categorical_accuracy: 0.8000
Epoch 82/100
1/1 [==============================] - 0s 160ms/step - loss: 0.7760 - sparse_categorical_accuracy: 0.8000
Epoch 83/100
1/1 [==============================] - 0s 169ms/step - loss: 0.7687 - sparse_categorical_accuracy: 0.8000
Epoch 84/100
1/1 [==============================] - 0s 188ms/step - loss: 0.7613 - sparse_categorical_accuracy: 0.8000
Epoch 85/100
1/1 [==============================] - 0s 166ms/step - loss: 0.7538 - sparse_categorical_accuracy: 0.8000
Epoch 86/100
1/1 [==============================] - 0s 177ms/step - loss: 0.7461 - sparse_categorical_accuracy: 1.0000
Epoch 87/100
1/1 [==============================] - 0s 188ms/step - loss: 0.7384 - sparse_categorical_accuracy: 1.0000
Epoch 88/100
1/1 [==============================] - 0s 178ms/step - loss: 0.7305 - sparse_categorical_accuracy: 1.0000
Epoch 89/100
1/1 [==============================] - 0s 184ms/step - loss: 0.7224 - sparse_categorical_accuracy: 1.0000
Epoch 90/100
1/1 [==============================] - 0s 170ms/step - loss: 0.7143 - sparse_categorical_accuracy: 1.0000
Epoch 91/100
1/1 [==============================] - 0s 242ms/step - loss: 0.7061 - sparse_categorical_accuracy: 1.0000
Epoch 92/100
1/1 [==============================] - 0s 226ms/step - loss: 0.6978 - sparse_categorical_accuracy: 1.0000
Epoch 93/100
1/1 [==============================] - 0s 245ms/step - loss: 0.6895 - sparse_categorical_accuracy: 1.0000
Epoch 94/100
1/1 [==============================] - 0s 151ms/step - loss: 0.6811 - sparse_categorical_accuracy: 1.0000
Epoch 95/100
1/1 [==============================] - 0s 219ms/step - loss: 0.6726 - sparse_categorical_accuracy: 1.0000
Epoch 96/100
1/1 [==============================] - 0s 200ms/step - loss: 0.6641 - sparse_categorical_accuracy: 1.0000
Epoch 97/100
1/1 [==============================] - 0s 194ms/step - loss: 0.6557 - sparse_categorical_accuracy: 1.0000
Epoch 98/100
1/1 [==============================] - 0s 171ms/step - loss: 0.6472 - sparse_categorical_accuracy: 1.0000
Epoch 99/100
1/1 [==============================] - 0s 230ms/step - loss: 0.6388 - sparse_categorical_accuracy: 1.0000
Epoch 100/100
1/1 [==============================] - 0s 178ms/step - loss: 0.6305 - sparse_categorical_accuracy: 1.0000
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 2)           10        
_________________________________________________________________
simple_rnn (SimpleRNN)       (None, 3)                 18        
_________________________________________________________________
dense (Dense)                (None, 5)                 20        
=================================================================
Total params: 48
Trainable params: 48
Non-trainable params: 0
_________________________________________________________________
In [10]:
# print(model.trainable_variables)
file = open('./weights.txt', 'w')  # 参数提取
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()
In [11]:
###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
loss = history.history['loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.title('Training Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.title('Training Loss')
plt.legend()
plt.show()
In [12]:
############### predict #############

preNum = int(input("input the number of test alphabet:"))
for i in range(preNum):
    alphabet1 = input("input test alphabet:")
    alphabet = [w_to_id[alphabet1]]
    # 使alphabet符合Embedding输入要求:[送入样本数, 循环核时间展开步数]。
    # 此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,循环核时间展开步数为1。
    alphabet = np.reshape(alphabet, (1, 1))
    result = model.predict(alphabet)
    pred = tf.argmax(result, axis=1)
    pred = int(pred)
    tf.print(alphabet1 + '->' + input_word[pred])
input the number of test alphabet:3
input test alphabet:a
a->b
input test alphabet:b
b->c
input test alphabet:c
c->d
In [ ]:
 

 

 

 
posted @ 2020-09-24 11:05  范仁义  阅读(409)  评论(0)    收藏  举报