CNN网络感受野计算matplotlib显示
参考资料ResNet网络感受野计算:_dekiang的博客-CSDN博客_residual 感受野
感受野计算工具失效了?https://fomoro.com/research/article/receptive-field-calculator.
下面为自己写的代码实现:
import pandas as pd import numpy as np def reception_cal(*args, start_reception=1, start_strides=1, show_bar=False): ''' all input must be formed into tuple (kernel_size, stride, padding) ''' reception = start_reception standard_params = 0 record = {'kernel_size':['', ], 'stride':['', ], 'padding':['', ], 'stride_cummul':[start_strides, ], 'reception':[reception]} for kernel_size, stride, padding in args: record['kernel_size'].append(kernel_size) record['stride'].append(stride) record['padding'].append(padding) reception += (kernel_size - 1) * record['stride_cummul'][-1] stride_cumsum = record['stride_cummul'][-1] * stride record['stride_cummul'].append(stride_cumsum) record['reception'].append(reception) standard_params += kernel_size ** 2 out_table = pd.DataFrame(record) pd.set_option('colheader_justify', 'center') print(out_table) print('standard conv params:{}'.format(standard_params)) if show_bar: import matplotlib.pyplot as plt receptions_show = np.array(record['reception'][1:]).astype(np.float64) plt.bar(np.arange(len(receptions_show)), receptions_show, tick_label=['layer{}'.format(i) for i in range(len(receptions_show))]) plt.xticks(rotation = 30) # plt.plot(np.arange(len(receptions_show)), receptions_show, "b", marker='o', ms=10, linewidth=2) # plt.fill_between(np.arange(len(receptions_show)), receptions_show, 0, color='blue', alpha=.3) for i in range(len(receptions_show)): plt.text(np.arange(len(receptions_show))[i], receptions_show[i] + 1, '{}'.format(int(receptions_show[i])), size=12, ha='center', family="Times new roman") plt.get_current_fig_manager().window.state('zoomed') plt.show() def judge_dwconv_replace(kernel_size=3, in_channels=256, out_channels=256): standard_params = in_channels * out_channels * kernel_size ** 2 deepconv_params = in_channels * kernel_size ** 2 pointconv_params = in_channels * out_channels * 1**2 dp_params = deepconv_params + pointconv_params print('standrad conv params:{}'.format(standard_params)) print('deepwise separable params:{}'.format(dp_params)) print('d : s : {:.5}, params_reduce:{}'.format(dp_params / standard_params, dp_params / standard_params < 1.0)) if __name__ == '__main__': # resnet-50 test kernel_size_list = [(7, 2, 3), (3, 2, 1)] + [(1, 1, 0), (3, 1, 1), (1, 1, 0)] * 3 reception_cal(*kernel_size_list) # conv3_x kernel_size_list += [(1, 1, 0), (3, 2, 1), (1, 1, 0)] + [(1, 1, 0), (3, 1, 1), (1, 1, 0)] * 3 reception_cal(*kernel_size_list) # conv4_x kernel_size_list += [(1, 1, 0), (3, 2, 1), (1, 1, 0)] + [(1, 1, 0), (3, 1, 1), (1, 1, 0)] * 5 reception_cal(*kernel_size_list) # conv5_x kernel_size_list += [(1, 1, 0), (3, 2, 1), (1, 1, 0)] + [(1, 1, 0), (3, 1, 1), (1, 1, 0)]* 2 reception_cal(*kernel_size_list, show_bar=True) # hornet kernel_size_list = [(4, 4, 0), (7, 1, 3), (7, 1, 3)] * 2 reception_cal(*kernel_size_list) # judge_dwconv_replace(kernel_size=7, in_channels=256, out_channels=256)
resnet50测试结果正常,如下


浙公网安备 33010602011771号