1 import tensorflow as tf
2 import os
3 import numpy as np
4 from matplotlib import pyplot as plt
5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dropout, Flatten, Dense, GlobalAveragePooling2D
6 from tensorflow.keras import Model
7
8 np.set_printoptions(threshold=np.inf)
9
10 ciar10 = tf.keras.datasets.cifar10
11 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
12 x_train, x_test = x_train/255.0, x_test/255.0
13
14 class ConvBNRelu(Model):
15 def __init__(self, ch, kernelsz=3, strides=1, padding='same'):
16 super(ConvBNRelu, self).__init__()
17 self.model = tf.keras.models.Sequential([
18 Conv2D(ch, kernelsz, strides=strides, padding=padding),
19 BatchNormalization(),
20 Activation('relu')
21 ])
22
23 def call(self, x):
24 x = self.model(x, training=False)
25 #在training=False时,BN通过整个训练集计算均值、方差去做批归一化,training=True时,通过当前batch的均值、方差去做批归一化。推理时 training=False效果好
26 return x
27
28
29
30 class InceptionBlk(Model):
31 def __init__(self, ch, strides=1):
32 super(InceptionBlk, self).__init__()
33 self.ch = ch
34 self.strides = strides
35 self.c1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
36 self.c2_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
37 self.c2_2 = ConvBNRelu(ch, kernelsz=3, strides=1)
38 self.c3_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
39 self.c3_2 = ConvBNRelu(ch, kernelsz=5, strides=1)
40 self.p4_1 = MaxPooling2D(3, strides=1, padding='same')
41 self.c4_2 = ConvBNRelu(ch, kernelsz=1, strides=strides)
42
43 def call(self, x):
44 x1 = self.c1(x)
45 x2_1 = self.c2_1(x)
46 x2_2 = self.c2_2(x2_1)
47 x3_1 = self.c3_1(x)
48 x3_2 = self.c3_2(x3_1)
49 x4_1 = self.p4_1(x)
50 x4_2 = self.c4_2(x4_1)
51 # concat along axis=channel
52 x = tf.concat([x1, x2_2, x3_2, x4_2], axis=1)
53 return x
54
55 class Inception10(Model):
56 def __init__(self, num_blocks, num_classes, init_ch=16, **kwargs):
57 super(Inception10, self).__init__(**kwargs)
58 self.in_channels = init_ch
59 self.out_channels = init_ch
60 self.num_blocks = num_blocks
61 self.init_ch = init_ch
62 self.c1 = ConvBNRelu(init_ch)
63 self.blocks = tf.keras.models.Sequential()
64 for block_id in range(num_blocks):
65 for layer_id in range(2):
66 if layer_id == 0:
67 block = InceptionBlk(self.out_channels, strides=1)
68 else:
69 block = InceptionBlk(self.out_channels, strides=1)
70 self.blocks.add(block)
71 # enlarger out_channels per block
72 self.out_channels *=2
73 self.p1 = GlobalAveragePooling2D()
74 self.f1 = Dense(num_classes, activation='softmax')
75
76 def call(self, x):
77 x = self.c1(x)
78 x = self.blocks(x)
79 x = self.p1(x)
80 y = self.f1(x)
81 return y
82
83 model = Inception10(num_blocks=2, num_classes=10)
84
85 model.compile(optimizer='adam',
86 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
87 metrics=['sparse_categorical_accuracy'])
88
89
90 checkpoint_save_path = "./checkpoint/Inception10.ckpt"
91 if os.path.exists(checkpoint_save_path + '.index'):
92 print('-------------load the model---------------')
93 model.load_weights(checkpoint_save_path)
94
95 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_save_path,
96 save_weights_only = True,
97 save_best_only = True)
98
99 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),validation_freq=1,
100 callbacks=[cp_callback])
101 model.summary()
102
103
104 with open('./weights.txt', 'w') as f:
105 for v in model.trainable_variables:
106 f.write(str(v.name) + '\n')
107 f.write(str(v.shape) + '\n')
108 f.write(str(v.numpy()) + '\n')
109
110
111
112 def plot_acc_loss_curve(history):
113 # 显示训练集和验证集的acc和loss曲线
114 from matplotlib import pyplot as plt
115 acc = history.history['sparse_categorical_accuracy']
116 val_acc = history.history['val_sparse_categorical_accuracy']
117 loss = history.history['loss']
118 val_loss = history.history['val_loss']
119
120 plt.figure(figsize=(15, 5))
121 plt.subplot(1, 2, 1)
122 plt.plot(acc, label='Training Accuracy')
123 plt.plot(val_acc, label='Validation Accuracy')
124 plt.title('Training and Validation Accuracy')
125 #plt.legend()
126 plt.grid()
127
128 plt.subplot(1, 2, 2)
129 plt.plot(loss, label='Training Loss')
130 plt.plot(val_loss, label='Validation Loss')
131 plt.title('Training and Validation Loss')
132 plt.legend()
133 #plt.grid()
134 plt.show()
135
136 plot_acc_loss_curve(history)
137