#基于CIFA10数据集,自定义网络模型,保存模型
import tensorflow as tf
import os
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
def prepeocess (x,y):
x=2*tf.cast(x,dtype=tf.float32)/255.-1. #[0~255]=>[-1~1]
y=tf.cast(y,dtype=tf.int32)
return x,y
(x,y),(x_val,y_val)=datasets.cifar10.load_data() #[32,32,3]
y=tf.squeeze(y) #初始的y:[50k,1],经过squeeze挤压掉中间的1.挤压维度,y_val同理
y_val=tf.squeeze(y_val)
y=tf.one_hot(y,depth=10) #x:[50k,32,32,3] y:[50k,10]
y_val=tf.one_hot(y_val,depth=10) #x:[10k,32,32,3] y:[10k,10]
print("datasets:",x.shape,y.shape,x_val.shape,y_val.shape,x.min(),x.max())
batchsz=100
train_db=tf.data.Dataset.from_tensor_slices((x,y))
train_db=train_db.map(prepeocess).shuffle(10000).batch(batchsz)
test_db=tf.data.Dataset.from_tensor_slices((x_val,y_val))
test_db=test_db.map(prepeocess).batch(batchsz)
sample = next(iter(train_db))
print("batch:",sample[0].shape,sample[1].shape)
class Mydense(layers.Layer): #自定义层
#To replace standard layers.Dense()
def __init__(self,inp_dim,outp_dim):
super(Mydense,self).__init__()
self.kernel = self.add_variable('w',[inp_dim,outp_dim])
# self.bias = self.add_variable('b',[outp_dim])
def __call__(self, inputs, training=None):
x=inputs@self.kernel
return x
class MyNetwork(keras.Model):
def __init__(self):
super(MyNetwork, self).__init__()
self.fc1 = Mydense(32 * 32 * 3, 256)
self.fc2 = Mydense(256,256) #增加参数量提高精度
self.fc3 = Mydense(256,256)
self.fc4 = Mydense(256,256)
self.fc5 = Mydense(256,10)
def __call__(self, inputs, training=None):
"""
:param inputs: [b,32,32,3]
:param training:
:return:
"""
x=tf.reshape(inputs,[-1,32*32*3])
x = self.fc1(x)
x = tf.nn.relu(x)
x = self.fc2(x)
x = tf.nn.relu(x)
x = self.fc3(x)
x = tf.nn.relu(x)
x = self.fc4(x)
x = tf.nn.relu(x)
x = self.fc5(x)
x = tf.nn.relu(x)
return x
network=MyNetwork()
network.compile(optimizer=optimizers.Adam(learning_rate=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
network.fit(train_db,epochs=20,validation_data=test_db,validation_freq=1)
#模型保存
network.evaluate(test_db)
network.save_weights('ckpt/weights.ckpt') #保存权重的方法,而不是全部状态都保存,需要重新创建网络模型
del network
print('saved to ckpt/weights.ckpt')
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(learning_rate=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
network.load_weights('ckpt/weights.ckpt') #可在文件名为ckpt的文件夹下查看
print('loaded weights from file.')
network.evaluate(test_db)