6.4.0 头文件
import torch
from d2l import torch as d2l
6.4.1 多输入通道单输出通道卷积运算
# 定义多输入通道单输出通道卷积运算:将输入图像的第一个通道与卷积核第一个通道进行卷积运算,将输入图像的第二个通道与卷积核第二个通道进行卷积运算,依此类推,最后将所有卷积结果对应元素相加
def corr2d_multi_in(X, K):
return sum(d2l.corr2d(x, k) for x, k in zip(X, K))
# 定义输入图像X(2通道,3行,3列)
X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
# 定义卷积核K(2通道,2行,2列)
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])
Y = corr2d_multi_in(X, K)
print(Y)
# 输出:
# tensor([[ 56., 72.],
# [104., 120.]])
print(Y.shape)
# 输出:
# torch.Size([2, 2])
6.4.2 多输入通道多输出通道卷积运算
# 定义多输入多输出通道卷积运算:每个批量的卷积核与图像X进行一次多输入单输出卷积运算,每次卷积运算结果作为一个通道输出
def corr2d_multi_in_out(X, K):
return torch.stack([corr2d_multi_in(X, k) for k in K], 0)
# 定义卷积核K(3批量,2通道,2行,2列)
K = torch.stack((K, K + 1, K + 2), 0)
Y = corr2d_multi_in_out(X, K)
print(Y)
# 输出:
# tensor([[[ 56., 72.],
# [104., 120.]],
#
# [[ 76., 100.],
# [148., 172.]],
#
# [[ 96., 128.],
# [192., 224.]]])
print(Y.shape)
# 输出:
# torch.Size([3, 2, 2])
本小节完整代码如下
import torch
from d2l import torch as d2l
# ------------------------------多输入通道单输出通道卷积运算------------------------------------
# 定义多输入通道单输出通道卷积运算:将输入图像的第一个通道与卷积核第一个通道进行卷积运算,将输入图像的第二个通道与卷积核第二个通道进行卷积运算,依此类推,最后将所有卷积结果对应元素相加
def corr2d_multi_in(X, K):
return sum(d2l.corr2d(x, k) for x, k in zip(X, K))
# 定义输入图像X(2通道,3行,3列)
X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
# 定义卷积核K(2通道,2行,2列)
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])
Y = corr2d_multi_in(X, K)
print(Y)
# 输出:
# tensor([[ 56., 72.],
# [104., 120.]])
print(Y.shape)
# 输出:
# torch.Size([2, 2])
# ------------------------------多输入通道多输出通道卷积运算------------------------------------
# 定义多输入多输出通道卷积运算:每个批量的卷积核与图像X进行一次多输入单输出卷积运算,每次卷积运算结果作为一个通道输出
def corr2d_multi_in_out(X, K):
return torch.stack([corr2d_multi_in(X, k) for k in K], 0)
# 定义卷积核K(3批量,2通道,2行,2列)
K = torch.stack((K, K + 1, K + 2), 0)
Y = corr2d_multi_in_out(X, K)
print(Y)
# 输出:
# tensor([[[ 56., 72.],
# [104., 120.]],
#
# [[ 76., 100.],
# [148., 172.]],
#
# [[ 96., 128.],
# [192., 224.]]])
print(Y.shape)
# 输出:
# torch.Size([3, 2, 2])
浙公网安备 33010602011771号