import tensorflow as tf
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32) / 255 - 0.5
y = tf.cast(y, dtype=tf.int32)
return x, y
batchsz = 128
# [50k,32,32,3],[50k,1]
(x, y), (x_val, y_val) = tf.keras.datasets.cifar10.load_data()
y = tf.one_hot(y, depth=10) # [50k,10]
y_val = tf.one_hot(y_val, depth=10)
print(x.shape, y.shape)
y = tf.squeeze(y) # 去掉为1 的维度
y_val = tf.squeeze(y_val)
print('squeeze后:')
print(x.shape, y.shape, x.min(), x.max())
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(1000).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batchsz)
sample = next(iter((train_db))) # 测试下数据集shape是否符合要求 batch (128, 32, 32, 3) (128, 10)
print('batch:', sample[0].shape, sample[1].shape)
# 自定义层
# 代替标准的tf.keras.layers.Dense()
class MyDense(tf.keras.layers.Layer):
def __init__(self, inp_dim, oup_dim): # 参数为输入的维度和输出维度
super(MyDense, self).__init__()
self.kernel = self.add_variable('w', [inp_dim, oup_dim])
# self.bias = self.add_variable('b',[oup_dim])
def call(self, inputs, training=None): # 参数为数据
x = inputs @ self.kernel
return x
# 自定义网络
class MyNetwork(tf.keras.Model):
def __init__(self):
super(MyNetwork, self).__init__()
self.fc1 = MyDense(32 * 32 * 3, 256)
self.fc2 = MyDense(256, 256)
self.fc3 = MyDense(256, 256)
self.fc4 = MyDense(256, 32)
self.fc5 = MyDense(32, 10)
def call(self, inputs, training=None, mask=None):
'''
:param inputs:[b,32,32,3]
:param training:
:param mask:
:return:
'''
# [b,32,32,3] -> [b,32*32*3]
x = tf.reshape(inputs,[-1,32*32*3])
# [b,32*32*3] -> [b,256]
x = self.fc1(x)
x = tf.nn.relu(x)
# [b,256] -> [b,128]
x = self.fc2(x)
x = tf.nn.relu(x)
# [b,128] -> [b,64]
x = self.fc3(x)
x = tf.nn.relu(x)
# [b,64] -> [b,32]
x = self.fc4(x)
x = tf.nn.relu(x)
# [b,32] -> [b,10]
x = self.fc5(x)
# 最后一层不需要激活函数
return x
network = MyNetwork()
network.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
network.fit(train_db,epochs=13,validation_data=test_db,validation_freq=1)
network.evaluate(test_db)
network.save_weights('./save_w_model/test1')
# 加载仅有参数的model
network2 = MyNetwork()
network2.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
network2.load_weights('./save_w_model/test1')
print('加载仅有参数的模型')
network2.evaluate(test_db)