Fork me on GitHub

学习率设置&&训练模型之loss曲线滑动平均

分段常数衰减

分段常数衰减是在事先定义好的训练次数区间上,设置不同的学习率常数。刚开始学习率大一些,之后越来越小,区间的设置需要根据样本量调整,一般样本量越大区间间隔应该越小。tf中定义了tf.train.piecewise_constant 函数,实现了学习率的分段常数衰减功能。

指数衰减

指数衰减是比较常用的衰减方法,学习率是跟当前的训练轮次指数相关的。tf中实现指数衰减的函数是 tf.train.exponential_decay()。

- decayed_learning_rate = learning_rate *decay_rate ^ (global_step / decay_steps)

TensorFlow提供了一种非常灵活的学习率设置方法,指数衰减法。通过这种方式可以很好的解决上面的问题,先用一个较大的学习率来快速得到一个比较优的参数值,然后通过迭代次数的增加逐渐减少学习率,使得保证参数极优的同时迭代次数也少。TensorFlow提供了一个exponential_decay函数会指数极的逐渐减少学习率,函数的功能有下面的公式可以表示:

decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)

公式中的参数,其中decayed_learning_rate表示每轮迭代所使用的学习率,learning_rate为初始化学习率,decay_rate为衰减系数,随着迭代次数的增加,学习率会逐步降低。

tf.train.exponential_decay(learning_rate,global_step,decay_step,staircase=False,name=None)

learning_rate:一个标量类型为float32或floate64、张量或一个python数字,代表初始化的学习率

global_step:一个标量类型为int32或int64,张量或一个python数字,用于衰减计算中,不能是负数。

decay_steps:一个标量类型为int32或int64,张量或一个python数字,必须是正数,用于衰减计算中。

decay_rate:一个标量类型为float32或floate64,张量或一个python数字,表示衰减的比率。

staircase:Boolean类型,默认是False表示衰减的学习率是连续的,如果是True代表衰减的学习率是一个离散的间隔。

自然指数衰减

自然指数衰减是指数衰减的一种特殊情况,学习率也是跟当前的训练轮次指数相关,只不过以 e 为底数。tf中实现自然指数衰减的函数是 tf.train.natural_exp_decay()

多项式衰减

多项式衰减是这样一种衰减机制:定义一个初始的学习率,一个最低的学习率,按照设置的衰减规则,学习率从初始学习率逐渐降低到最低的学习率,并且可以定义学习率降低到最低的学习率之后,是一直保持使用这个最低的学习率,还是到达最低的学习率之后再升高学习率到一定值,然后再降低到最低的学习率(反复这个过程)。tf中实现多项式衰减的函数是 tf.train.polynomial_decay()

global_step = min(global_step, decay_steps)
decayed_learning_rate = (learning_rate - end_learning_rate) *
                        (1 - global_step / decay_steps) ^ (power) +
                        end_learning_rate

余弦衰减

余弦衰减的衰减机制跟余弦函数相关,形状也大体上是余弦形状。tf中的实现函数是:tf.train.cosine_decay()

改进的余弦衰减方法还有:
线性余弦衰减,对应函数 tf.train.linear_cosine_decay()
噪声线性余弦衰减,对应函数 tf.train.noisy_linear_cosine_decay()

倒数衰减

倒数衰减指的是一个变量的大小与另一个变量的大小成反比的关系,具体到神经网络中就是学习率的大小跟训练次数有一定的反比关系。

tf中实现倒数衰减的函数是 tf.train.inverse_time_decay()。

训练模型之loss曲线滑动平均

- 只依赖python

def print_loss(config, title, loss_dict, epoch, iters, current_iter, need_plot=False):
    data_str = ''
    for k, v in loss_dict.items():
        if data_str != '':
            data_str += ', '
        data_str += '{}: {:.10f}'.format(k, v)

        if need_plot and config.vis is not None:
            plot_line(config, title, k, (epoch-1)*iters+current_iter, v)

    # step is the progress rate of the whole dataset (split by batchsize)
    print('[{}] [{}] Epoch [{}/{}], Iter [{}/{}]'.format(title, config.experiment_name, epoch, config.epochs, current_iter, iters))
    print('        {}'.format(data_str))

class AverageWithinWindow():
    def __init__(self, win_size):
        self.win_size = win_size
        self.cache = []
        self.average = 0
        self.count = 0

    def update(self, v):
        if self.count < self.win_size:
            self.cache.append(v)
            self.count += 1
            self.average = (self.average * (self.count - 1) + v) / self.count
        else:
            idx = self.count % self.win_size
            self.average += (v - self.cache[idx]) / self.win_size
            self.cache[idx] = v
            self.count += 1


class DictAccumulator():
    def __init__(self, win_size=None):
        self.accumulator = OrderedDict()
        self.total_num = 0 
        self.win_size = win_size

    def update(self, d):
        self.total_num += 1
        for k, v in d.items():
            if not self.win_size:
                self.accumulator[k] = v + self.accumulator.get(k,0)
            else:
                self.accumulator.setdefault(k, AverageWithinWindow(self.win_size)).update(v)

    def get_average(self):
        average = OrderedDict()
        for k, v in self.accumulator.items():
            if not self.win_size:
                average[k] = v*1.0/self.total_num 
            else:
                average[k] = v.average 
        return average

def train(epoch,  train_loader, model):
    loss_accumulator = utils.DictAccumulator(config.loss_average_win_size)
    grad_accumulator = utils.DictAccumulator(config.loss_average_win_size)
    score_accumulator = utils.DictAccumulator(config.loss_average_win_size)
    iters = len(train_loader)

    for i, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.cuda()
        print (inputs.shape)
        targets = targets.cuda()
        inputs = Variable(inputs)
        targets = Variable(targets)

        net_outputs, loss, grad, lr_dict, score = model.fit(inputs, targets, update=True, epoch=epoch,
                                                            cur_iter=i+1, iter_one_epoch=iters)
        loss_accumulator.update(loss)
        grad_accumulator.update(grad)
        score_accumulator.update(score)

        if (i+1) % config.loss_average_win_size == 0:
            need_plot = True
            if hasattr(config, 'plot_loss_start_iter'):
                need_plot = (i + 1 + (epoch - 1) * iters >= config.plot_loss_start_iter)
            elif hasattr(config, 'plot_loss_start_epoch'):
                need_plot = (epoch >= config.plot_loss_start_epoch)

            utils.print_loss(config, "train_loss", loss_accumulator.get_average(), epoch=epoch, iters=iters, current_iter=i+1, need_plot=need_plot)
            utils.print_loss(config, "grad", grad_accumulator.get_average(), epoch=epoch, iters=iters, current_iter=i+1, need_plot=need_plot)
            utils.print_loss(config, "learning rate", lr_dict, epoch=epoch, iters=iters, current_iter=i+1, need_plot=need_plot)

            utils.print_loss(config, "train_score", score_accumulator.get_average(), epoch=epoch, iters=iters, current_iter=i+1, need_plot=need_plot)

    if epoch % config.save_train_hr_interval_epoch == 0:
        k = random.randint(0, net_outputs['output'].size(0) - 1)
        for name, out in net_outputs.items():
            utils.save_tensor(out.data[k], os.path.join(config.TRAIN_OUT_FOLDER, 'epoch_%d_k_%d_%s.png' % (epoch, k, name)))


def validate(valid_loader, model):
    loss_accumulator = utils.DictAccumulator()
    score_accumulator = utils.DictAccumulator()

    # loss of the whole validation dataset
    for i, (inputs, targets) in enumerate(valid_loader):
        inputs = inputs.cuda()
        targets = targets.cuda()

        inputs = Variable(inputs, volatile=True)
        targets = Variable(targets)

        loss, score = model.fit(inputs, targets, update=False)

        loss_accumulator.update(loss)
        score_accumulator.update(score)

    return loss_accumulator.get_average(), score_accumulator.get_average()

- 依赖torch

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import time
from collections import defaultdict
from collections import deque
from datetime import datetime

import torch

from .comm import is_main_process


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20):
        self.deque = deque(maxlen=window_size)
        self.series = []
        self.total = 0.0
        self.count = 0

    def update(self, value):
        self.deque.append(value)
        self.series.append(value)
        self.count += 1
        self.total += value

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque))
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        return object.__getattr__(self, attr)

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
            )
        return self.delimiter.join(loss_str)


class TensorboardLogger(MetricLogger):
    def __init__(self,
                 log_dir='logs',
                 exp_name='maskrcnn-benchmark',
                 start_iter=0,
                 delimiter='\t'):

        super(TensorboardLogger, self).__init__(delimiter)
        self.iteration = start_iter
        self.writer = self._get_tensorboard_writer(log_dir, exp_name)

    @staticmethod
    def _get_tensorboard_writer(log_dir, exp_name):
        try:
            from tensorboardX import SummaryWriter
        except ImportError:
            raise ImportError(
                'To use tensorboard please install tensorboardX '
                '[ pip install tensorflow tensorboardX ].'
            )

        if is_main_process():
            timestamp = datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H:%M')
            tb_logger = SummaryWriter('{}/{}-{}'.format(log_dir, exp_name, timestamp))
            return tb_logger
        else:
            return None

    def update(self, ** kwargs):
        super(TensorboardLogger, self).update(**kwargs)
        if self.writer:
            for k, v in kwargs.items():
                if isinstance(v, torch.Tensor):
                    v = v.item()
                assert isinstance(v, (float, int))
                self.writer.add_scalar(k, v, self.iteration)
            self.iteration += 1

def do_train(
    model,
    data_loader,
    optimizer,
    scheduler,
    checkpointer,
    device,
    checkpoint_period,
    arguments,
    tb_log_dir,
    tb_exp_name,
    use_tensorboard=False
):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")

    meters = TensorboardLogger(log_dir=tb_log_dir,
                               exp_name=tb_exp_name,
                               start_iter=arguments['iteration'],
                               delimiter="  ") \
        if use_tensorboard else MetricLogger(delimiter="  ")

    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                )
            )
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter)
        )
    )

- 依赖torch

import math
from . import meter
import torch


class MovingAverageValueMeter(meter.Meter):
    def __init__(self, windowsize):
        super(MovingAverageValueMeter, self).__init__()
        self.windowsize = windowsize
        self.valuequeue = torch.Tensor(windowsize)
        self.reset()

    def reset(self):
        self.sum = 0.0
        self.n = 0
        self.var = 0.0
        self.valuequeue.fill_(0)

    def add(self, value):
        queueid = (self.n % self.windowsize)
        oldvalue = self.valuequeue[queueid]
        self.sum += value - oldvalue
        self.var += value * value - oldvalue * oldvalue
        self.valuequeue[queueid] = value
        self.n += 1

    def value(self):
        n = min(self.n, self.windowsize)
        mean = self.sum / max(1, n)
        std = math.sqrt(max((self.var - n * mean * mean) / max(1, n - 1), 0))
        return mean, std

def main():
    .....
    # TensorBoard Logger
    writer = SummaryWriter(CONFIG.LOG_DIR)
    loss_meter = MovingAverageValueMeter(20)

    model.train()
    model.module.scale.freeze_bn()

    for iteration in tqdm(
        range(1, CONFIG.ITER_MAX + 1),
        total=CONFIG.ITER_MAX,
        leave=False,
        dynamic_ncols=True,
    ):

        # Set a learning rate
        poly_lr_scheduler(
            optimizer=optimizer,
            init_lr=CONFIG.LR,
            iter=iteration - 1,
            lr_decay_iter=CONFIG.LR_DECAY,
            max_iter=CONFIG.ITER_MAX,
            power=CONFIG.POLY_POWER,
        )

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        iter_loss = 0
        for i in range(1, CONFIG.ITER_SIZE + 1):
            try:
                images, labels = next(loader_iter)
            except:
                loader_iter = iter(loader)
                images, labels = next(loader_iter)

            images = images.to(device)
            labels = labels.to(device).unsqueeze(1).float()

            # Propagate forward
            logits = model(images)

            # Loss
            loss = 0
            for logit in logits:
                # Resize labels for {100%, 75%, 50%, Max} logits
                labels_ = F.interpolate(labels, logit.shape[2:], mode="nearest")
                labels_ = labels_.squeeze(1).long()
                # Compute crossentropy loss
                loss += criterion(logit, labels_)

            # Backpropagate (just compute gradients wrt the loss)
            loss /= float(CONFIG.ITER_SIZE)
            loss.backward()

            iter_loss += float(loss)

        loss_meter.add(iter_loss)

        # Update weights with accumulated gradients
        optimizer.step()

        # TensorBoard
        if iteration % CONFIG.ITER_TB == 0:
            writer.add_scalar("train_loss", loss_meter.value()[0], iteration)
            for i, o in enumerate(optimizer.param_groups):
                writer.add_scalar("train_lr_group{}".format(i), o["lr"], iteration)
            if False:  # This produces a large log file
                for name, param in model.named_parameters():
                    name = name.replace(".", "/")
                    writer.add_histogram(name, param, iteration, bins="auto")
                    if param.requires_grad:
                        writer.add_histogram(
                            name + "/grad", param.grad, iteration, bins="auto"
                        )

        # Save a model
        if iteration % CONFIG.ITER_SAVE == 0:
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.SAVE_DIR, "checkpoint_{}.pth".format(iteration)),
            )

        # Save a model (short term)
        if iteration % 100 == 0:
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.SAVE_DIR, "checkpoint_current.pth"),
            )

    torch.save(
        model.module.state_dict(), osp.join(CONFIG.SAVE_DIR, "checkpoint_final.pth")
    )

 

posted @ 2019-01-01 20:07  ranjiewen  阅读(3023)  评论(0编辑  收藏  举报