搭建卷积

x=torch.randn(1, 1, 30, 30)
weight=torch.Tensor(
             3, 1, *(3,3))
b=torch.Tensor(
             1)
print(F.conv2d(x,weight,stride=1,padding=0).shape)
class CondConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1):
        super(CondConv2D,self).__init__()
        self.weight=Parameter(torch.Tensor(
              out_channels, in_channels, *(3,3))
                              )    
    def forward(self, inputs):
        return F.conv2d(inputs,self.weight,stride=1,padding=0)
posted @ 2022-02-17 09:49  祥瑞哈哈哈  阅读(22)  评论(0)    收藏  举报