python语言理解

python是一门面向对象的语言,强调的是对象,当我们创建一个类时,必然要给这个类赋予对应的属性去描述它,例如一个动物的类,那么这个类应该有动物种类,颜色,年龄,体重,习性等属性,代码如下:

class Animal:
    def __init__(self, species, color, age, weight, habitat):
        self.species = species
        self.color = color
        self.age = age
        self.weight = weight
        self.habitat = habitat

    def __str__(self):
        return f"{self.species} | Color: {self.color} | Age: {self.age} years | Weight: {self.weight} kg | Habitat: {self.habitat}"

以上的代码非常易懂,但我让AI根据要求{ 自定义dataset,该类可以自义训练和测试的比例 }生成以下代码:
代码的self.images属性初始化委托给了load_images()方法,也就是说python的类初始化__init__()可以调用该类的其他方法

    
import torch
from PIL import Image
from torchvision import transforms
import os

class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, train_ratio=0.8, transform=None):
        super(CustomImageDataset, self).__init__()
        self.data_dir = data_dir
        self.train_ratio = train_ratio
        self.transform = transform
        self.images = self.load_images()
        self.train_images, self.test_images = self.split_images()

    def load_images(self):
        # 加载所有图像文件
        image_files = os.listdir(self.data_dir)
        image_files = [os.path.join(self.data_dir, file) for file in image_files if file.endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
        return image_files

    def split_images(self):
        # 根据训练比例分割图像
        train_images = self.images[:int(len(self.images) * self.train_ratio)]
        test_images = self.images[int(len(self.images) * self.train_ratio):]
        return train_images, test_images

    def __len__(self):
        return len(self.train_images)

    def __getitem__(self, idx):
        image_path = self.train_images[idx]
        image = Image.open(image_path).convert('RGB')  # 假设图像是以RGB格式打开的
        if self.transform:
            image = self.transform(image)
        return image

# 使用自定义数据集
custom_dataset = CustomImageDataset('/path/to/your/image/data', train_ratio=0.8)
dataloader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)

# 创建一个数据加载器,用于迭代训练集
for images in dataloader:
    # 处理图像
    pass




    
posted @ 2024-01-27 19:50  seekwhale13  阅读(16)  评论(0)    收藏  举报