class WarmUpCos(keras.callbacks.Callback):
def __init__(self, lr_max,lr_min, warm_step,sum_step,bat):
super(WarmUpCos, self).__init__()
self.lr_max = lr_max
self.lr_min = lr_min
self.warm_step = warm_step
self.sum_step = sum_step
self.bat = bat
def on_train_begin(self, batch, logs=None):
self.init_lr = self.lr_max
self.step = 0
def on_epoch_begin(self, epoch, logs=None):
self.epoch = epoch
def on_batch_end(self,batch, logs=None):
self.step += 1
print('step:',self.step)
# learning_decay_steps = 1
# learning_decay_rate = 0.999
warm_lr = self.lr_max * (self.step / self.warm_step)
# decay_lr = max(self.init_lr * tf.pow(learning_decay_rate , ((step-self.warm_step) / learning_decay_steps)),self.lr_min)
decay_lr = self.lr_max * (
1 + math.cos(
(self.step - self.warm_step) * math.pi / ( self.sum_step - self.warm_step)
)
) / 2
if self.step < self.warm_step:
lr = warm_lr
else:
lr =decay_lr
K.set_value(self.model.optimizer.lr, lr)
warm_up = WarmUpCos(lr_rate, lr_min, warm_step=warm_epoch*int(train_x.shape[0]//bat),bat=bat,sum_step=epochs*int(train_x.shape[0]//bat))
s_model.fit(train_db, epochs=epochs, validation_data=test_db, callbacks=[warm_up])