class LossHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.losses = []
self.min_val_loss = 999
def on_batch_end(self, batch, logs={}):
pass
# self.losses.append(logs.get('loss'))
# print('callback调用!!!!!!!!!')
def on_epoch_end(self, epoch, logs=None):
test_train_data = train_high0_img[:5]
test_train_label = train_rain[:5]
pre = my_model(test_train_data,training = False)
print('pre:',pre)
print('test_train_label:', test_train_label)
# self.losses.append(logs.get('loss'))
print_loss = logs.get('loss')
val_loss = logs.get('val_loss')
print('\ncallback调用!!!!!!!!!\nloss:{}'.format(print_loss))
print_loss = str(print_loss)
val_loss =str(val_loss)
with open(log_path,'a+') as f:
f.write(print_loss+'\t'+val_loss)
f.write('\n')
if float(val_loss) <= self.min_val_loss:
print('保存模型,val_loss:{}'.format(val_loss))
self.min_val_loss = float(val_loss)
min_model_save_path = './model_save'
if not os.path.exists(min_model_save_path):
os.mkdir(min_model_save_path)
model_name = 'minloss_val_model'
callback_savemodel = os.path.join(min_model_save_path,model_name)
if not os.path.exists(callback_savemodel):
os.mkdir(callback_savemodel)
minval_log_path = os.path.join(callback_savemodel, 'minval_loss.txt')
callback_savemodel = os.path.join(callback_savemodel,model_name)
my_model.save_weights(callback_savemodel)
with open(minval_log_path,'w+') as f:
f.write('min_val_loss:{:.2f}'.format(float(val_loss)))
history2 = LossHistory()
history=my_model.fit(train_high0_img,train_rain,validation_data=(test_high0_img,test_rain),epochs=epochs, validation_freq=1,batch_size=bat,callbacks=[history2])