F8NET Fixed-point 8-bit only multiplication for network quantization论文代码梳理
简单记录下论文中某些要点,对应的实现方式。主要函数都在fix_quant_ops.py中实现。
1. 激活值fractional lengthlen grid search实现步骤
def fraclen_gridsearch(input, wl, align_dim, signed): err_list = [] for fl in range(wl + 1 - int(signed)): # 遍历fl范围 res, _ = fix_quant( input, wl, torch.ones(input.shape[align_dim]).to(input.device) * fl * 1.0, align_dim, signed) err = torch.mean((input - res)**2)**0.5 err_list.append(err) opt_fl = torch.argmin(torch.tensor(err_list)).to(input.device) * 1.0 # 选择使mse最小的fl return opt_fl
2. 激活值fractional length更新
new_input_fraclen = self.momentum * input_fraclen + (
1 - self.momentum) * self.get_input_fraclen()
3. fix_quant实现方式
def fix_quant(input, wl=8, fl=0, align_dim=0, signed=True, floating=False): assert wl >= 0 assert torch.all(fl >= 0) if signed: assert torch.all(fl <= wl - 1) else: assert torch.all(fl <= wl) assert type(wl) == int assert torch.all(torch.round(fl) == fl) expand_dim = input.dim() - align_dim - 1 fl = fl[(..., ) + (None, ) * expand_dim] res = input * (2**fl) if not floating: res.round_() if signed: bound = 2**(wl - 1) - 1 grad_scale = torch.abs(res) < bound res.clamp_(max=bound, min=-bound) else: bound = 2**wl - 1 grad_scale = (res > 0) * (res < bound) res.clamp_(max=bound, min=0) res.div_(2**fl) return res, grad_scale
4. double forward for BN fusion,估计BN running statics
## estimate running stats if self.floating and self.floating_wo_clip: y0 = nn.functional.conv2d(input_val, weight, bias=self.conv.bias, stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups) else: y0 = nn.functional.conv2d( self.fix_scaling[(..., ) + (None, None)] * input_val, weight, bias=self.conv.bias, stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups) self.bn(y0) bn_mean = y0.mean([0, 2, 3]) bn_std = torch.sqrt( y0.var([0, 2, 3], unbiased=False) + self.bn.eps)

浙公网安备 33010602011771号