1.Pytorch_导入图片

需要导入的库

import torch
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets,transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
image_dir = 'data/turtle'

transform = transforms.Compose([
      transforms.Resize((244,244)),     #长款
      transforms.CenterCrop(224),
    # transforms.RandomRotation(30),    --翻转
    # transforms.RandomHorizontalFlip(),--水平翻转
      transforms.ToTensor(),            #格式转换
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))    标准化处理
])

设置空数组,用来存数据和标签

data = []
labels = []
for class_name in os.listdir(image_dir):
    class_dir = os.path.join(image_dir, class_name)
    for file_name in os.listdir(class_dir):
        if file_name.endswith('.jpg'):
            file_path = os.path.join(class_dir, file_name)
            image = Image.open(file_path)
            # image = image.resize((int(image.width * 0.05), int(image.height * 0.05)), Image.LANCZOS)
            image_array = np.array(image)
            image_tensor = transform(image)
            data.append(image_tensor)
            labels.append(class_name)

加载dataset和迭代器

dataset = torch.utils.data.TensorDataset(torch.stack(data), torch.Tensor(labels_int).type(torch.LongTensor))
loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
figure = plt.figure(figsize=(8,8))
cols, rows = 2,4
label_encoder = LabelEncoder()
labels_int = label_encoder.fit_transform(labels)

展示图片和标签

train_features, train_labels = next(iter(loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img[0])
plt.axis("off")
plt.show()

print(f"Label: {label}")
posted @ 2023-05-17 16:08  paopaocha  阅读(69)  评论(0)    收藏  举报