tensorflow2——自定义回调函数

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.min_val_loss = 999

    def on_batch_end(self, batch, logs={}):
        pass
        # self.losses.append(logs.get('loss'))
        # print('callback调用!!!!!!!!!')

    def on_epoch_end(self, epoch, logs=None):

        test_train_data = train_high0_img[:5]
        test_train_label = train_rain[:5]
        pre = my_model(test_train_data,training = False)
        print('pre:',pre)
        print('test_train_label:', test_train_label)
        # self.losses.append(logs.get('loss'))
        print_loss = logs.get('loss')
        val_loss = logs.get('val_loss')
        print('\ncallback调用!!!!!!!!!\nloss:{}'.format(print_loss))
        print_loss = str(print_loss)
        val_loss =str(val_loss)
        with open(log_path,'a+') as f:
            f.write(print_loss+'\t'+val_loss)
            f.write('\n')
        if float(val_loss) <= self.min_val_loss:
            print('保存模型,val_loss:{}'.format(val_loss))
            self.min_val_loss = float(val_loss)
            min_model_save_path = './model_save'
            if not os.path.exists(min_model_save_path):
                os.mkdir(min_model_save_path)
            model_name = 'minloss_val_model'
            callback_savemodel = os.path.join(min_model_save_path,model_name)
            if not os.path.exists(callback_savemodel):
                os.mkdir(callback_savemodel)
            minval_log_path = os.path.join(callback_savemodel, 'minval_loss.txt')
            callback_savemodel = os.path.join(callback_savemodel,model_name)
            my_model.save_weights(callback_savemodel)

            with open(minval_log_path,'w+') as f:
                f.write('min_val_loss:{:.2f}'.format(float(val_loss)))
history2 = LossHistory()

history=my_model.fit(train_high0_img,train_rain,validation_data=(test_high0_img,test_rain),epochs=epochs, validation_freq=1,batch_size=bat,callbacks=[history2])

 

posted @ 2022-04-26 18:05  山…隹  阅读(53)  评论(0)    收藏  举报