CNN网络感受野计算matplotlib显示

参考资料ResNet网络感受野计算:_dekiang的博客-CSDN博客_residual 感受野

感受野计算工具失效了?https://fomoro.com/research/article/receptive-field-calculator.

下面为自己写的代码实现:

import pandas as pd
import numpy as np

def reception_cal(*args, start_reception=1, start_strides=1, show_bar=False):
    '''
    all input must be formed into tuple (kernel_size, stride, padding)
    '''
    reception = start_reception
    standard_params = 0
    record = {'kernel_size':['', ], 'stride':['', ], 'padding':['', ], 'stride_cummul':[start_strides, ], 'reception':[reception]}
    for kernel_size, stride, padding in args:
        record['kernel_size'].append(kernel_size)
        record['stride'].append(stride)
        record['padding'].append(padding)
        reception += (kernel_size - 1) * record['stride_cummul'][-1]
        stride_cumsum = record['stride_cummul'][-1] * stride
        record['stride_cummul'].append(stride_cumsum)
        record['reception'].append(reception)
        standard_params += kernel_size ** 2
    out_table = pd.DataFrame(record)
    pd.set_option('colheader_justify', 'center')
    print(out_table)
    print('standard conv params:{}'.format(standard_params))
    if show_bar:
        import matplotlib.pyplot as plt
        receptions_show = np.array(record['reception'][1:]).astype(np.float64)
        plt.bar(np.arange(len(receptions_show)), receptions_show, tick_label=['layer{}'.format(i) for i in range(len(receptions_show))])
        plt.xticks(rotation = 30)
        # plt.plot(np.arange(len(receptions_show)), receptions_show, "b", marker='o', ms=10, linewidth=2)
        # plt.fill_between(np.arange(len(receptions_show)), receptions_show, 0, color='blue', alpha=.3)
        for i in range(len(receptions_show)):
            plt.text(np.arange(len(receptions_show))[i], receptions_show[i] + 1, '{}'.format(int(receptions_show[i])), size=12, ha='center', family="Times new roman")
        plt.get_current_fig_manager().window.state('zoomed')
        plt.show()

def judge_dwconv_replace(kernel_size=3, in_channels=256, out_channels=256):
    standard_params = in_channels * out_channels * kernel_size ** 2
    deepconv_params = in_channels * kernel_size ** 2
    pointconv_params = in_channels * out_channels * 1**2
    dp_params = deepconv_params + pointconv_params
    print('standrad conv params:{}'.format(standard_params))
    print('deepwise separable params:{}'.format(dp_params))
    print('d : s : {:.5}, params_reduce:{}'.format(dp_params / standard_params, dp_params / standard_params < 1.0))


if __name__ == '__main__':
    # resnet-50 test
    kernel_size_list = [(7, 2, 3), (3, 2, 1)] + [(1, 1, 0), (3, 1, 1), (1, 1, 0)] * 3
    reception_cal(*kernel_size_list)

    # conv3_x
    kernel_size_list += [(1, 1, 0), (3, 2, 1), (1, 1, 0)] + [(1, 1, 0), (3, 1, 1), (1, 1, 0)] * 3
    reception_cal(*kernel_size_list)

    # conv4_x
    kernel_size_list += [(1, 1, 0), (3, 2, 1), (1, 1, 0)] + [(1, 1, 0), (3, 1, 1), (1, 1, 0)] * 5
    reception_cal(*kernel_size_list)

    # conv5_x
    kernel_size_list += [(1, 1, 0), (3, 2, 1), (1, 1, 0)] + [(1, 1, 0), (3, 1, 1), (1, 1, 0)]* 2
    reception_cal(*kernel_size_list, show_bar=True)

    # hornet
    kernel_size_list = [(4, 4, 0), (7, 1, 3), (7, 1, 3)] * 2
    reception_cal(*kernel_size_list)

    # judge_dwconv_replace(kernel_size=7, in_channels=256, out_channels=256)

resnet50测试结果正常,如下

 

posted @ 2023-02-09 14:56  Anm半夏  阅读(54)  评论(0)    收藏  举报