1 import tensorflow as tf
2 from tensorflow.keras import layers, activations
3
4
5 class Residual(tf.keras.Model):
6 def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
7 super(Residual, self).__init__(**kwargs)
8 self.conv1 = layers.Conv2D(num_channels,
9 padding='same',
10 kernel_size=3,
11 strides=strides)
12 self.conv2 = layers.Conv2D(num_channels, kernel_size=3, padding='same')
13 if use_1x1conv:
14 self.conv3 = layers.Conv2D(num_channels,
15 kernel_size=1,
16 strides=strides)
17 else:
18 self.conv3 = None
19 self.bn1 = layers.BatchNormalization()
20 self.bn2 = layers.BatchNormalization()
21
22 def call(self, X):
23 Y = activations.relu(self.bn1(self.conv1(X)))
24 Y = self.bn2(self.conv2(Y))
25 if self.conv3:
26 X = self.conv3(X)
27 return activations.relu(Y + X)
28
29
30 blk = Residual(3)
31 #tensorflow input shape (n_images, x_shape, y_shape, channels)
32 X = tf.random.uniform((4, 6, 6, 3))
33 blk(X).shape
34
35
36 blk = Residual(6, use_1x1conv=True, strides=2)
37 blk(X).shape
38
39
40 net = tf.keras.models.Sequential(
41 [layers.Conv2D(64, kernel_size=7, strides=2, padding='same'),
42 layers.BatchNormalization(),
43 layers.Activation('relu'),
44 layers.MaxPool2D(pool_size=3, strides=2, padding='same')]
45 )
46
47
48 class ResnetBlock(tf.keras.layers.Layer):
49 def __init__(self, num_channels, num_residuals, first_block=False, **kwargs):
50 super(ResnetBlock, self).__init__(**kwargs)
51 self.listLayers = []
52 for i in range(num_residuals):
53 if i==0 and not first_block:
54 self.listLayers.append(Residual(num_channels, use_1x1conv=True, strides=2))
55 else:
56 self.listLayers.append(Residual(num_channels))
57
58 def call(self, X):
59 for layer in self.listLayers.layers:
60 X = layer(X)
61 return X
62
63
64 # 为ResNet加入所有残差块。这里每个模块使用两个残差块
65
66 class ResNet(tf.keras.Model):
67 def __init__(self, num_blocks, **kwargs):
68 super(ResNet, self).__init__(**kwargs)
69 self.conv=layers.Conv2D(64, kernel_size=7, strides=2, padding='same')
70 self.bn = layers.BatchNormalization()
71 self.relu = layers.Activation('relu')
72 self.mp = layers.MaxPool2D(pool_size=3, strides=2, padding='same')
73 self.resnet_block1 = ResnetBlock(64, num_blocks[0], first_block=True)
74 self.resnet_block2 = ResnetBlock(128, num_blocks[1])
75 self.resnet_block3 = ResnetBlock(256, num_blocks[2])
76 self.resnet_block4 = ResnetBlock(512, num_blocks[3])
77 self.gap = layers.GlobalAvgPool2D()
78 self.fc = layers.Dense(units=10, activation=tf.keras.activations.softmax)
79
80 def call(self, x):
81 x = self.conv(x)
82 x = self.bn(x)
83 x = self.relu(x)
84 x = self.mp(x)
85 x = self.resnet_block1(x)
86 x = self.resnet_block2(x)
87 x = self.resnet_block3(x)
88 x = self.resnet_block4(x)
89 x = self.gap(x)
90 x = self.fc(x)
91 return x
92
93 mynet = ResNet([2, 2, 2, 2])
94
95
96 X = tf.random.uniform(shape=(1, 224, 224, 1))
97 for layer in mynet.layers:
98 X = layer(X)
99 print(layer.name, 'output shape:\t', X.shape)
100
101
102 # 在Fashion-MNIST数据集上训练ResNet
103 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
104 x_train = x_train.reshape((60000, 28, 28, 1)).astype('float32') / 255
105 x_test = x_test.reshape((10000, 28, 28, 1)).astype('float32') / 255
106
107 mynet.compile(loss='sparse_categorical_crossentropy',
108 optimizer=tf.keras.optimizers.Adam(),
109 metrics=['accuracy'])
110
111 history = mynet.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.2)
112 test_scores = mynet.evaluate(x_test, y_test, verbose=2)