F8NET Fixed-point 8-bit only multiplication for network quantization论文代码梳理

简单记录下论文中某些要点,对应的实现方式。主要函数都在fix_quant_ops.py中实现。

 

1. 激活值fractional lengthlen grid search实现步骤

def fraclen_gridsearch(input, wl, align_dim, signed):
    err_list = []
    for fl in range(wl + 1 - int(signed)):  # 遍历fl范围
        res, _ = fix_quant(
            input, wl,
            torch.ones(input.shape[align_dim]).to(input.device) * fl * 1.0,
            align_dim, signed)
        err = torch.mean((input - res)**2)**0.5
        err_list.append(err)
    opt_fl = torch.argmin(torch.tensor(err_list)).to(input.device) * 1.0  # 选择使mse最小的fl
    return opt_fl

2. 激活值fractional length更新

new_input_fraclen = self.momentum * input_fraclen + (
                            1 - self.momentum) * self.get_input_fraclen()

3. fix_quant实现方式

def fix_quant(input, wl=8, fl=0, align_dim=0, signed=True, floating=False):
    assert wl >= 0
    assert torch.all(fl >= 0)
    if signed:
        assert torch.all(fl <= wl - 1)
    else:
        assert torch.all(fl <= wl)
    assert type(wl) == int
    assert torch.all(torch.round(fl) == fl)
    expand_dim = input.dim() - align_dim - 1
    fl = fl[(..., ) + (None, ) * expand_dim]
    res = input * (2**fl)
    if not floating:
        res.round_()
    if signed:
        bound = 2**(wl - 1) - 1 
        grad_scale = torch.abs(res) < bound
        res.clamp_(max=bound, min=-bound)
    else:
        bound = 2**wl - 1
        grad_scale = (res > 0) * (res < bound)
        res.clamp_(max=bound, min=0)
    res.div_(2**fl)
    return res, grad_scale

4. double forward for BN fusion,估计BN running statics

## estimate running stats
            if self.floating and self.floating_wo_clip:
                y0 = nn.functional.conv2d(input_val,
                                          weight,
                                          bias=self.conv.bias,
                                          stride=self.conv.stride,
                                          padding=self.conv.padding,
                                          dilation=self.conv.dilation,
                                          groups=self.conv.groups)
            else:
                y0 = nn.functional.conv2d(
                    self.fix_scaling[(..., ) + (None, None)] * input_val,
                    weight,
                    bias=self.conv.bias,
                    stride=self.conv.stride,
                    padding=self.conv.padding,
                    dilation=self.conv.dilation,
                    groups=self.conv.groups)
            self.bn(y0)
            bn_mean = y0.mean([0, 2, 3])
            bn_std = torch.sqrt(
                y0.var([0, 2, 3], unbiased=False) + self.bn.eps)

 

posted @ 2022-06-16 17:26  撬动地球的coder  阅读(116)  评论(0)    收藏  举报