pytorch将卷积展开成二维形式

import torch
import torch.nn as nn
def get_inp(inp,stride,kernel_size):
    channels = inp.shape[0]
    x_len = inp.shape[1]
    y_len = inp.shape[2]
    inp_list = []
    x_step = (x_len-kernel_size+1)//stride
    y_step = (y_len-kernel_size+1)//stride
    for i in range(y_step):
        inp_tmp_list = []
        for z in range(channels):
            inp_tmp = inp[z][i*stride:i*stride+kernel_size].reshape(1,-1)
            inp_tmp_list.append(inp_tmp)
        inp_tmp_cat = torch.cat(inp_tmp_list,dim=1)
        inp_list.append(inp_tmp_cat)
    inp_cat=torch.cat(inp_list,dim=0)
    return inp_cat
def get_weight(weight,stride,kernel_size,x_len):
    out_channels = weight.shape[0]
    inp_channels = weight.shape[1]
    weight_list = []
    x_step = (x_len-kernel_size+1)//stride
    for i in range(out_channels):
        for j in range(x_step):
            weight_tmp = torch.zeros(1,inp_channels*kernel_size*x_len)
            for z in range(inp_channels):
                for m in range(kernel_size):
                    for n in range(kernel_size):
                        weight_tmp[0][z*kernel_size*x_len+x_len*m+n+j] = weight[i][z][m][n]
            weight_list.append(weight_tmp)
    weight_cat=torch.cat(weight_list,dim=0)#[168,160]
    weight_cat=weight_cat.permute(1,0)#[160,168]
    return weight_cat
def test(in_channels=1, out_channels=6, kernel_size=5, stride=1, inp_xlen=32 ,inp_ylen=32):
    x_step = (inp_xlen-kernel_size+1)//stride
    y_step = (inp_ylen-kernel_size+1)//stride
    inp = torch.randn(in_channels ,inp_xlen, inp_ylen)
    conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, bias=False)
    out = conv(inp)
    inp_tensor = get_inp(inp,stride,kernel_size)
    weight_tensor = get_weight(conv.weight,stride,kernel_size,inp_xlen)
    out2 = (inp_tensor@weight_tensor).reshape(y_step,out_channels,x_step).permute(1,0,2)
    print(out-out2)
    print(out.shape)
    print(out2.shape)
test(in_channels=6, out_channels=16, kernel_size=5, stride=1, inp_xlen=14 ,inp_ylen=14)
#test1 = torch.randn(1,3,3)
#conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, bias=False)
#print(test1)
#print(conv.weight[0][0][0][0]*test1[0][0][1]+conv.weight[0][0][0][1]*test1[0][0][2]+conv.weight[0][0][1][0]*test1[0][1][1]+conv.weight[0][0][1][1]*test1[0][1][2])
#print(conv(test1))

posted @ 2025-02-25 15:52  心比天高xzh  阅读(12)  评论(0)    收藏  举报