1 import numpy as np
2 import tensorflow as tf
3 from tensorflow.keras.layers import Dense, SimpleRNN
4 import matplotlib.pyplot as plt
5 import os
6
7 input_word = "abcde"
8 w_to_id = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4} # 单词映射到数值id的词典
9 id_to_onehot = {0: [1., 0., 0., 0., 0.], 1: [0., 1., 0., 0., 0.], 2: [0., 0., 1., 0., 0.], 3: [0., 0., 0., 1., 0.],
10 4: [0., 0., 0., 0., 1.]} # id编码为one-hot
11
12 x_train = [id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']],
13 id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']]]
14 y_train = [w_to_id['b'], w_to_id['c'], w_to_id['d'], w_to_id['e'], w_to_id['a']]
15
16 np.random.seed(7)
17 np.random.shuffle(x_train)
18 np.random.seed(7)
19 np.random.shuffle(y_train)
20 tf.random.set_seed(7)
21
22 # 使x_train符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。
23 # 此处整个数据集送入,送入样本数为len(x_train);输入1个字母出结果,循环核时间展开步数为1; 表示为独热码有5个输入特征,每个时间步输入特征个数为5
24 x_train = np.reshape(x_train, (len(x_train), 1, 5))
25 y_train = np.array(y_train)
26
27 model = tf.keras.models.Sequential([
28 SimpleRNN(3),
29 Dense(5, activation='softmax')
30 ])
31
32 model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
33 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
34 metrics=['sparse_categorical_accuracy'])
35
36 checkpoint_save_path = "./checkpoint/rnn_onehot_1pre1.ckpt"
37
38 if os.path.exists(checkpoint_save_path + '.index'):
39 print('-------------load the model-----------------')
40 model.load_weights(checkpoint_save_path)
41
42 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
43 save_weights_only=True,
44 save_best_only=True,
45 monitor='loss') # 由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型
46
47 history = model.fit(x_train, y_train, batch_size=32, epochs=100, callbacks=[cp_callback])
48
49 model.summary()
50
51 # print(model.trainable_variables)
52 file = open('./weights.txt', 'w') # 参数提取
53 for v in model.trainable_variables:
54 file.write(str(v.name) + '\n')
55 file.write(str(v.shape) + '\n')
56 file.write(str(v.numpy()) + '\n')
57 file.close()
58
59 ############################################### show ###############################################
60
61 # 显示训练集和验证集的acc和loss曲线
62 acc = history.history['sparse_categorical_accuracy']
63 loss = history.history['loss']
64
65 plt.subplot(1, 2, 1)
66 plt.plot(acc, label='Training Accuracy')
67 plt.title('Training Accuracy')
68 plt.legend()
69
70 plt.subplot(1, 2, 2)
71 plt.plot(loss, label='Training Loss')
72 plt.title('Training Loss')
73 plt.legend()
74 plt.show()
75
76 ############### predict #############
77
78 preNum = int(input("input the number of test alphabet:"))
79 for i in range(preNum):
80 alphabet1 = input("input test alphabet:")
81 alphabet = [id_to_onehot[w_to_id[alphabet1]]]
82 # 使alphabet符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,所以循环核时间展开步数为1; 表示为独热码有5个输入特征,每个时间步输入特征个数为5
83 alphabet = np.reshape(alphabet, (1, 1, 5))
84 result = model.predict([alphabet])
85 pred = tf.argmax(result, axis=1)
86 pred = int(pred)
87 tf.print(alphabet1 + '->' + input_word[pred])