# 判断某个文件是否是图像
# enswith判断是否以指定的.png,.jpg,.jpeg结尾的字符串
# 可以根据情况扩充图像类型,加入.bmp、.tif等
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
# 读取图像转为YCbCr模式,得到Y通道
def load_img(filepath):
img = Image.open(filepath).convert('YCbCr')
y, _, _ = img.split()
return y
# 裁剪大小,宽高一致为300
# 如果想训练自己的数据集,请根据情况修改裁剪大小
CROP_SIZE = 300
# 封装数据集,适配后面的torch.utils.data.DataLoader中的dataset,定义成类似形式
# 类参数为图像文件夹路径和放大倍数
# __len__(self) 定义当被len()函数调用时的行为(返回容器中元素的个数)
#__getitem__(self) 定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。
#__iter__(self) 定义当迭代容器中的元素的行为
# 返回输入图像和标签,传入DataLoader的dataset参数
class DatasetFromFolder(Dataset):
def __init__(self, image_dir, zoom_factor):
super(DatasetFromFolder, self).__init__()
self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)] # 图像路径列表
crop_size = CROP_SIZE - (CROP_SIZE % zoom_factor) # 处理放大倍数,防止用户瞎设置,本例只能设置为2,3,4,大小不变
# 数据集变换
# 还有一些其他的变换操作,如归一化等,遇到一个积累一个
self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # 从图片中心裁剪成300*300
transforms.Resize(
crop_size // zoom_factor), # Resize, 输入应该是缩放倍数后的图像,因为先缩小后放大
transforms.Resize(
crop_size, interpolation=Image.BICUBIC), # 双三次插值
transforms.ToTensor()]) # 图像转成tensor
# label标签,超分不是分类问题,定义成一样的就行
self.target_transform = transforms.Compose(
[transforms.CenterCrop(crop_size), transforms.ToTensor()])
def __getitem__(self, index):
input = load_img(self.image_filenames[index]) # 输入是图像的Y通道,即亮度通道
target = input.copy()
input = self.input_transform(input)
target = self.target_transform(target)
return input, target
def __len__(self):
return len(self.image_filenames) # 图像个数