小心你的字典和样板代码

小心你的字典和样板代码

原始文档:https://www.yuque.com/lart/blog/gbd39h

这篇文章主要讨论了编码过程中,使用字典和样板代码时,所犯的一些低级错误。

字典键值对不匹配

今天在修改代码的时候发现了之前一个非常低级的错误,函数返回字典参数的时候,字典的键值对对应错了。这导致我后面的程序中使用的数据字典实际上与我预期的并不一致。

关键的是后面数据字典使用时,这个对应关系上的错误并不会导致程序出现明显的异常或者直接抛出错误,也就导致了最终过了这么久才被我发现。

在紧急修复后,重新启动程序开始了验证性的运行,看看之前的方法的偏差有多少。

这样的错误检查起来确实困难。

错误案例:

data={
    "image1.5": image_0_5,
    "image1.0": image_1_0,
    "image0.5": image_1_5,
}

正确形式:

data={
    "image1.5": image_1_5,
    "image1.0": image_1_0,
    "image0.5": image_0_5,
}

这一切于缘起于对于IDE补全功能的过度依赖。

实际上由于这些值对应的变量名称上的相似性,输入“image”之后,补全提示列表中出现在选项处的可能并不是最适合的内容。

为了更快编码验证程序,这种情况下,潜意识的影响更加明显。我快速的按下了应用当前选项的“Tab”键。这样的潜意识行为,让“明显”的错误付诸到纸面。也如房间里的大象一样,虽然确实存在,但就是被忽略了。仅能在下次重新实现新功能并需要对原始的代码直接逐行修改时,才可能被怼到眼前,再也无法避开了。

房间里的大象(Elephant in the room)是一个英语熟语,用来隐喻某件虽然明显却被集体视而不见、不做讨论(英语:conflict management)的事情或者风险,抑或是一种不敢反抗争辩(英语:conflict resolution)某些明显的问题的集体迷思。尽管这是一句英文熟语,中文中近年来也有使用或者提及。
这个短语指的是在法律规范相当清楚、像大象一样显眼的事或物,不知为何却好像被忽视了。抑或是指特定的社会背景、社会心理作用于更为宏观的环境中,使得人们对问题故意选择视而不见。
维基百科

那么这样的错误能否避免呢?

直观来看,这种错误就类似于是我们考试时的所谓“失误”一样,只能尽可能减少,很难有着彻底根治的办法。

这一问题目前感觉很难从工具层面进行优化。

编码习惯上需要保持足够的严谨。这种存在多个并行的对应关系在创建时就得万分小心。写完应该仔细检查一下,这不光包括人工阅读形式的检查,还包含防御性编程的形式。

防御性编程(Defensive programming)是防御式设计的一种具体体现,它是为了保证,对程序的不可预见的使用,不会造成程序功能上的损坏。它可以被看作是为了减少或消除墨菲定律效力的想法。防御式编程主要用于可能被滥用,恶作剧或无意地造成灾难性影响的程序上。
百度百科

字典不同的键值对,在实际被使用时,它们之间的顺序是非常重要的,例如为了传递参数时,其各个值需要被对应赋予到不同的参数上。

尽管特殊情况下,例如a+b+c这种不同键值对之间是无序组合的形式,即使是对应关系错误也不会导致什么问题,但是这毕竟是少数,而且这种情况也可以被有序形式正常表示。

如果键值之间的对应关系在后面的代码中有着明显的依赖,我们有必要在代码层面对这种关系进行检查。可以利用断言语句(**assert**)或者对非预期情况抛出异常(**if**+**raise**)来进行约束。

比如前面列举的数据,在返回或者使用这一数据的时候,考虑到三个对应的键值对之间的主要差异在于图像尺寸,所以可以对三者的尺寸进行约束,简单例子为assert data['image1.5'].shape[-1] > data['image1.0'].shape[-1] > data['image0.5'].shape[-1]。这种方式就需要我们对代码的逻辑有着全面的理解,只有这样才能抽离出更一般的关系约束形式。

样板代码(boilerplate code)的遗漏

之前的项目中犯的另外一个错误更是让我哭笑不得。pytorch中存在一些样板代码,典型的是如下一段:

loss = loss_fn(model(X), Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

如果考虑上混合精度的应用,那就是这样的一个形式:

optimizer.zero_grad()
with autocast(enabled=args.use_fp16):
    loss = loss_fn(model(X), Y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

在我的错误中是把用于梯度置0的操作optimizer.zero_grad()给漏了😅。这会导致梯度不断的累加,非常容易出现梯度爆炸。再搭配使用混合精度的时候,由于自带了一步梯度放缩的操作scaler.scale(loss).backward(),所以不会出现异常的日志,但是却导致模型无法正常训练。

这一错误也是非常难以察觉的,尤其是当训练主脚本代码极其多的时候,在代码的浏览过程中,非常容易被忽视掉。

那么如何避免这一问题呢?

我想出了两种策略,自定义code snippets的自动补全或者抽离出作为固定结构,例如类或者函数来调用。总之就是想方设法避免自己去手动一行一行的敲。

自定义snippet

自定义code snippets是目前主流的编辑器中非常常见且强大的功能。有的编辑器甚至提供了在snippets中调用外部程序的机制。不过由于我们只考虑这些样板代码的书写,所以实际上并不会用到太多复杂的机制,只需要考虑光标跳转位置的定义即可。snippets设定好后,直接敲击快捷键,根据提示补全即可。

代码模块化

另一种则是抽离出作为固定的代码结构,例如将这些样板代码包装成独立的类或者是独立的函数来调用,每次不是直接写,而是直接导入需要的内容即可。

一个对上面内容包装的如下代码片段所示,这里功能更加复杂一些,额外包含了梯度剪裁、混合精度,以及对导出和载入参数的设定。由于这里的的优化器不是这个类的核心,所以导入导出仅针对于梯度scaler。

def clip_grad(params, mode, clip_cfg: dict):
    if mode == "norm":
        if "max_norm" not in clip_cfg:
            raise ValueError(f"`clip_cfg` must contain `max_norm`.")
        torch.nn.utils.clip_grad_norm_(
            params, max_norm=clip_cfg.get("max_norm"), norm_type=clip_cfg.get("norm_type", 2.0)
        )
    elif mode == "value":
        if "clip_value" not in clip_cfg:
            raise ValueError(f"`clip_cfg` must contain `clip_value`.")
        torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value"))
    else:
        raise NotImplementedError


class Scaler:
    def __init__(
        self, optimizer, use_fp16=False, *, set_to_none=False, clip_grad=False, clip_mode=None, clip_cfg=None
    ) -> None:
        self.optimizer = optimizer
        self.set_to_none = set_to_none
        self.autocast = autocast(enabled=use_fp16)
        self.scaler = GradScaler(enabled=use_fp16)

        if clip_grad:
            self.grad_clip_ops = partial(ops.clip_grad, mode=clip_mode, clip_cfg=clip_cfg)
        else:
            self.grad_clip_ops = None

    def calculate_grad(self, loss):
        self.scaler.scale(loss).backward()
        if self.grad_clip_ops is not None:
            self.scaler.unscale_(self.optimizer)
            self.grad_clip_ops(chain(*[group["params"] for group in self.optimizer.param_groups]))

    def update_grad(self):
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad(set_to_none=self.set_to_none)

    def state_dict(self):
        r"""
        Returns the state of the scaler as a :class:`dict`.  It contains five entries:

        * ``"scale"`` - a Python float containing the current scale
        * ``"growth_factor"`` - a Python float containing the current growth factor
        * ``"backoff_factor"`` - a Python float containing the current backoff factor
        * ``"growth_interval"`` - a Python int containing the current growth interval
        * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.

        If this instance is not enabled, returns an empty dict.

        .. note::
           If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
           should be called after :meth:`update`.
        """
        return self.scaler.state_dict()

    def load_state_dict(self, state_dict):
        r"""
        Loads the scaler state.  If this instance is disabled, :meth:`load_state_dict` is a no-op.

        Args:
           state_dict(dict): scaler state.  Should be an object returned from a call to :meth:`state_dict`.
        """
        self.scaler.load_state_dict(state_dict)

在实际使用中,可以使按照如下形式使用。由于在过程中执行的操作缩减为关键的几步操作,所以丢失语句的概率也降低了。

scaler = pipeline.Scaler(
    optimizer=optimizer,
    use_fp16=cfg.train.use_amp,
    set_to_none=cfg.train.optimizer.set_to_none,
    clip_grad=cfg.train.grad_clip.enable,
    clip_mode=cfg.train.grad_clip.mode,
    clip_cfg=cfg.train.grad_clip.cfg,
)

with torch.cuda.amp.autocast(enabled=cfg.train.use_amp):
    probs, loss, loss_str = model(
        data=batch_data, iter_percentage=counter.curr_iter / counter.num_total_iters
    )
    loss = loss / cfg.train.grad_acc_step
scaler.calculate_grad(loss=loss)
if counter.every_n_iters(cfg.train.grad_acc_step):  # Accumulates scaled gradients.
    scaler.update_grad()
posted @ 2022-07-30 12:43  lart  阅读(54)  评论(0编辑  收藏  举报