import torch
import torch.nn as nn
def get_inp(inp,stride,kernel_size):
channels = inp.shape[0]
x_len = inp.shape[1]
y_len = inp.shape[2]
inp_list = []
x_step = (x_len-kernel_size+1)//stride
y_step = (y_len-kernel_size+1)//stride
for i in range(y_step):
inp_tmp_list = []
for z in range(channels):
inp_tmp = inp[z][i*stride:i*stride+kernel_size].reshape(1,-1)
inp_tmp_list.append(inp_tmp)
inp_tmp_cat = torch.cat(inp_tmp_list,dim=1)
inp_list.append(inp_tmp_cat)
inp_cat=torch.cat(inp_list,dim=0)
return inp_cat
def get_weight(weight,stride,kernel_size,x_len):
out_channels = weight.shape[0]
inp_channels = weight.shape[1]
weight_list = []
x_step = (x_len-kernel_size+1)//stride
for i in range(out_channels):
for j in range(x_step):
weight_tmp = torch.zeros(1,inp_channels*kernel_size*x_len)
for z in range(inp_channels):
for m in range(kernel_size):
for n in range(kernel_size):
weight_tmp[0][z*kernel_size*x_len+x_len*m+n+j] = weight[i][z][m][n]
weight_list.append(weight_tmp)
weight_cat=torch.cat(weight_list,dim=0)#[168,160]
weight_cat=weight_cat.permute(1,0)#[160,168]
return weight_cat
def test(in_channels=1, out_channels=6, kernel_size=5, stride=1, inp_xlen=32 ,inp_ylen=32):
x_step = (inp_xlen-kernel_size+1)//stride
y_step = (inp_ylen-kernel_size+1)//stride
inp = torch.randn(in_channels ,inp_xlen, inp_ylen)
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, bias=False)
out = conv(inp)
inp_tensor = get_inp(inp,stride,kernel_size)
weight_tensor = get_weight(conv.weight,stride,kernel_size,inp_xlen)
out2 = (inp_tensor@weight_tensor).reshape(y_step,out_channels,x_step).permute(1,0,2)
print(out-out2)
print(out.shape)
print(out2.shape)
test(in_channels=6, out_channels=16, kernel_size=5, stride=1, inp_xlen=14 ,inp_ylen=14)
#test1 = torch.randn(1,3,3)
#conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, bias=False)
#print(test1)
#print(conv.weight[0][0][0][0]*test1[0][0][1]+conv.weight[0][0][0][1]*test1[0][0][2]+conv.weight[0][0][1][0]*test1[0][1][1]+conv.weight[0][0][1][1]*test1[0][1][2])
#print(conv(test1))