卷积helloword
from PIL import Image
import torchvision.transforms as transforms
import torch
# 读取图像
image = Image.open('./data/cat.png')
# 转换为灰度图像
gray_transform = transforms.Grayscale()
gray_image = transforms.Grayscale()(image)
# 转换为tensor变量
tensor_transform = transforms.ToTensor()
tensor_image = tensor_transform(gray_image)
# 打印tensor_image的形状和数据类型
print(tensor_image.shape)
print(tensor_image.dtype)
torch.Size([1, 561, 915])
torch.float32
import torch.nn as nn
# 定义卷积层
conv_layer = nn.Conv2d(1, 4, kernel_size=3, padding=1)
# nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
# 输入通道为3,输出通道为16,卷积核3*3,步长为1,填充(0)长度为1
# 将tensor_image的维度调整为(batch_size, channels, height, width)
tensor_image = tensor_image.unsqueeze(0)
# 在第0维度增加一维度,其他大小不变
# 应用卷积操作
output = conv_layer(tensor_image)
# 打印输出的形状和数据类型
print(output.shape)
print(output.dtype)
torch.Size([1, 4, 561, 915])
torch.float32
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
# 定义转换为PIL图像的操作
to_pil = transforms.ToPILImage()
# 将输出的每个通道转换为PIL图像并打印
for i in range(output.shape[1]):
channel_image = output[0, i, :, :]
pil_image = to_pil(channel_image)
plt.imshow(pil_image, cmap='gray')
plt.show()





浙公网安备 33010602011771号