1 import tensorflow as tf
2 import os
3 import numpy as np
4 from matplotlib import pyplot as plt
5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
6 from tensorflow.keras import Model
7
8
9 np.set_printoptions(threshold=np.inf)
10
11 cifar10 = tf.keras.datasets.cifar10
12 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
13 x_train, x_test = x_train/25.0, x_test/255.0
14
15
16 class BaseLine(Model):
17 def __init__(self):
18 super(BaseLine, self).__init__()
19 self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding='same') #卷积层
20 self.b1 = BatchNormalization() #BN层
21 self.a1 = Activation('relu') #激活层
22 self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same') #池化层
23 self.d1 = Dropout(0.2) #dropou层
24
25 self.flatten = Flatten()
26 self.f1 = Dense(128, activation='relu')
27 self.d2 = Dropout(0.2)
28 self.f2 = Dense(10, activation='softmax')
29
30 def call(self, x):
31 x = self.c1(x)
32 x = self.b1(x)
33 x = self.a1(x)
34 x = self.p1(x)
35 x = self.d1(x)
36
37 x = self.flatten(x)
38 x = self.f1(x)
39 x = self.d2(x)
40 y = self.f2(x)
41 return y
42
43
44
45 model = BaseLine()
46
47 model.compile(optimizer='adam',
48 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
49 metrics = ['sparse_categorical_accuracy'])
50
51 checkpoint_save_path = "./checkpoint/Baseline.ckpt"
52 if os.path.exists(checkpoint_save_path + ".index"):
53 print("--------------------load the model-----------------")
54 model.load_weights(checkpoint_save_path)
55
56 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True)
57
58 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback])
59
60 model.summary()
61
62
63 with open('./weights.txt', 'w') as file:
64 for v in model.trainable_variables:
65 file.write(str(v.name) + '\n')
66 file.write(str(v.shape) + '\n')
67 file.write(str(v.numpy()) + '\n')
68
69
70 def plot_acc_loss_curve(history):
71 # 显示训练集和验证集的acc和loss曲线
72 from matplotlib import pyplot as plt
73 acc = history.history['sparse_categorical_accuracy']
74 val_acc = history.history['val_sparse_categorical_accuracy']
75 loss = history.history['loss']
76 val_loss = history.history['val_loss']
77
78 plt.figure(figsize=(15, 5))
79 plt.subplot(1, 2, 1)
80 plt.plot(acc, label='Training Accuracy')
81 plt.plot(val_acc, label='Validation Accuracy')
82 plt.title('Training and Validation Accuracy')
83 plt.legend()
84 #plt.grid()
85
86 plt.subplot(1, 2, 2)
87 plt.plot(loss, label='Training Loss')
88 plt.plot(val_loss, label='Validation Loss')
89 plt.title('Training and Validation Loss')
90 plt.legend()
91 #plt.grid()
92 plt.show()
93
94 plot_acc_loss_curve(history)