tensorflow2——warmup+Cos衰减

class WarmUpCos(keras.callbacks.Callback):
    def __init__(self, lr_max,lr_min, warm_step,sum_step,bat):
        super(WarmUpCos, self).__init__()
        self.lr_max      = lr_max
        self.lr_min    = lr_min
        self.warm_step    = warm_step
        self.sum_step = sum_step
        self.bat = bat

    def on_train_begin(self, batch, logs=None):
        self.init_lr = self.lr_max
        self.step = 0

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch
    def on_batch_end(self,batch, logs=None):
        self.step += 1
        print('step:',self.step)
        # learning_decay_steps = 1
        # learning_decay_rate = 0.999
        warm_lr = self.lr_max * (self.step / self.warm_step)
        # decay_lr = max(self.init_lr * tf.pow(learning_decay_rate , ((step-self.warm_step) / learning_decay_steps)),self.lr_min)
        decay_lr = self.lr_max * (
                1 + math.cos(
            (self.step - self.warm_step) * math.pi / ( self.sum_step - self.warm_step)
        )
        ) / 2
        if  self.step < self.warm_step:
            lr = warm_lr
        else:
            lr =decay_lr
        K.set_value(self.model.optimizer.lr, lr)
warm_up =  WarmUpCos(lr_rate, lr_min, warm_step=warm_epoch*int(train_x.shape[0]//bat),bat=bat,sum_step=epochs*int(train_x.shape[0]//bat))

s_model.fit(train_db, epochs=epochs, validation_data=test_db, callbacks=[warm_up])

 

posted @ 2022-07-20 00:55  山…隹  阅读(147)  评论(0编辑  收藏  举报