# coding:utf-8
import os
import torch
import numpy as np
import rarfile as rar
from PIL import Image
from torch.utils.data import Dataset, DataLoader
class myDataset(Dataset):
# 基于压缩文件rar, 定义自己的数据集
def __init__(self, inrar):
# inrar:.rar windows 压缩文件,这里为图片数据
self.inrar = inrar
def __len__(self):
orar = rar.RarFile(self.inrar)
fnames = orar.namelist()
orar.close()
return len(fnames)
def __getitem__(self, item):
orar = rar.RarFile(self.inrar)
fnames = orar.namelist()
fname = fnames[item]
fp = orar.extract(fname)
img = Image.open(fp)
if img.mode == "RGBA":
img = np.array(img)[:, :, :3] / 255
else:
img = np.array(img) / 255
os.remove(fp)
orar.close()
return torch.tensor(img, dtype=torch.float32)