Tensorflow2.0含有LSTM的模型保存

在做LSTM相关网络模型保存的时候,报了WARNING:absl:Found untraced functions such as lstm_cell_1_layer_call_fn, lstm_cell_1_layer_call_and_return_conditional_losses, lstm_cell_2_layer_call_fn, lstm_cell_2_layer_call_and_return_conditional_losses, lstm_cell_4_layer_call_fn while saving (showing 5 of 5). These functions will not be directly callable after loading.的警告,我没注意,但是在验证的时候,加载模型还原不出训练的精度了。

试了特别多不同的保存模型方法,最后用了.h5 文件方式来保存,这样出来的模型精度就没有影响了。最后还碰见了一个比较奇怪的事情,使用model.fit训练模型时,保存文件的回调函数ModelCheckpoint,其中save_ferq的参数不起作用,不知道是不是因为我的batchsize太大了,我设了2048,每过一个epoch就会保存一次。

cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        verbose=1,
        save_weights_only=False,
        monitor='val_loss',
        mode='min',
        save_best_only=False,
        save_ferq=20
    )

还好有一个自定义回调函数可以写一下。on_epoch_end 这个方法是在epoch结束的时候调用,里面可以自由发挥,比较好。

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        
        y_ = y[len_d:]
        pred = self.model.predict(x[len_d:])
        y_pred = (pred>0.5).astype(int)
        confusion_mtx = confusion_matrix(y_, y_pred)
        print(confusion_mtx)
        print(f'acc{accuracy_score(y_, y_pred):.3f} ,'
              f'precision{precision_score(y_, y_pred, average="macro"):.4f} ,'
              f'recal{recall_score(y_, y_pred, average="macro"):.4f} ,'
              f'f1{f1_score(y_, y_pred, average="macro"):.4f}')
        self.model.save(f'ck_pt/cp-{epoch:04d}-{f1_score(y_, y_pred, average="macro"):.4f}.h5')

mycallback = CustomCallback()

model.fit(x[:len_d], y[:len_d],
          epochs=501,
          batch_size=2048,
          shuffle=True,
          validation_data=(x[len_d:], y[len_d:]),
          callbacks= mycallback
         )
posted @ 2022-09-06 17:37  赫凯  阅读(186)  评论(0)    收藏  举报