Gokix

一言(ヒトコト)

关注我

用pytorch训练模型-入门

python基础

列表 list

可以理解为变量数组。

  • 定义:name = ['Alan', 'Bishop', 'He', 'Xavier']

  • 修改元素:name[2] = 'Srivastava'

  • 末尾添加元素:name.append('Yann')

  • 任意位置添加元素:name.insert(2, 'Ren')

  • 删除元素:del name[0]

元组 tuple

可以理解为常量数组。

  • 定义:bet = ('alpha', 'beta', 'gamma')

  • 不允许增删改

字典 dictionary

可以理解为key-value pair组,或者说hashmap。

  • 定义:tech = {'name':'ResNet', 'author':'He', 'year':2016}

  • 修改元素:tech['author'] = 'He et al.'

  • 添加元素:tech['category'] = 'CNN'

  • 删除元素:del tech['year']

切片 slice

取出list/tuple的一部分。

有三个参数 x[start:end:stride]。其中startend是左闭右开的,默认为0n,可以是负数,表示倒数第几个;stride默认为1且可省略,也可以是负数,表示倒着选取。

类 class

class A(B):
	# initialization method
	def __init__ (self, a, b, c, d):
		super().__init__()
		# do something
	# other method
	def foo:
		#do something
# instantiate sample
x = A(a, b, c, d)

其中,__init__是初始化方法,在实例化时会自动调用。__init__的第一个参数固定为self,代表Student类本身,其余的参数则是自己定义的在初始化时需要用到的参数。

class A(B)表示A继承自B,在A的初始化方法中使用super().__init__()执行父类B的初始化函数(或者按照python2的习惯写作super(A, self).__init__())。需要指出的是,python处理菱形继承问题的方法是将所有类置于MRO链上,super()指代的是MRO链的下一个(不一定是父类),比如继承关系A(B, C) B(D) C(D)的MRO链可能是:A->B->C->D->object,此时super().__init__()super(A, self).__init__()会沿着MRO链自然地执行上述父类的初始化,但B.__init__(),C.__init__()是错误的行为,会导致D被重复初始化。

torch基础

tensor操作参见 d2l学习笔记

所有网络继承自nn.Module,必须实现两个方法:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
		# 初始化方法 声明所需的层
        super().__init__()

        self.conv1 = nn.Conv2d(1, 6, 5, 1, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
		# 前向传播
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2))
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2))
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)

        return x
		
# instantiation
net = Net().to(DEVICE)

# using instance
outputs = net(inputs)

其中nn.Conv2d是卷积层,前五个参数依次是in_channels, out_channels, kernel_size, stride, padding,满足关系 \(O = \frac{I - K + 2P}{S} +1\)(注:这里的O和I是输入和输出的高宽,不是通道)。nn.Linear是全连接层,前两个参数依次是in_features, out_features

在应用示例时,前向传播方法forward会自动使用(这一效果在nn.Module中实现)。

torchvision基础

transforms

transform是传入数据的预处理,包括归一化、数据增强等操作都在这一部分实现。

使用transforms.Compose()接收一个操作list,可以串联依次执行多个操作。

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

Dataset

Dataset是数据集类,继承自torch.utils.data.Dataset,必须实现三个方法:

from torchvision import datasets
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data_path, transform):
        # 初始化方法
		super().__init__()
        self.data = ...
        self.transform = ...

    def __getitem__(self, idx):
        # 返回第 idx 个样本
		# 比如分类问题的返回格式应为 data, label
        image = load_image(self.data[idx])
        if self.transform:
            image = self.transform(image)
        return image, label

    def __len__(self):
        # 返回数据集总长度
        return len(self.data)

trainset = MyDataset(data_path=train_path, transform=train_transform)
testset = MyDataset(data_path=test_path, transform=test_transform)

部分经典数据集在Dataset中已经内置,如MNIST数据集:

trainset = datasets.MNIST(
    root = data_path,
    train = True,
    download = False,
    transform = transform
)

testset = datasets.MNIST(
    root = data_path,
    train = False,
    download = False,
    transform = transform
)

Dataloader

DataLoaderDataset的可迭代对象,每次调用时返回一个batch的数据。

trainloader = DataLoader(
    dataset = trainset,
    batch_size = 64,
    shuffle = True,
    num_workers = 2
)

testloader = DataLoader(
    dataset = testset,
    batch_size = 64,
    shuffle = False,
    num_workers = 2
)

DataLoader使用collate_fn将多个样本堆叠成一个batch,比如在batch_size=64的情况下,64张[1, 28, 28]的图片组[(img_0, label_0), (img_1, label_1), ..., (img_63, label_63)]collate_fn堆叠后输出形式为(tensor([64, 1, 28, 28]), tensor([64]))

训练时只需按如下方式解析即可得到一个batch的图片数据和类别数据。

for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
		# do something

当每个样本格式不统一时(比如一张图片目标检测有多个锚框和类别),需要自己写collate函数,将多个样本堆叠为一个batch。

def detection_collate_fn(batch):
    """
    batch: [(img_0, target_0), (img_1, target_1), ...]
    每个 target 包含 bbox[N, 4] 和 labels[N],N对于每张图可能不同
    """
    images = []
    targets = []

    for img, target in batch:
        images.append(img)
        targets.append(target)

    # images: list of [3, H, W] with different H, W
    # targets: list of dict {'boxes': [N, 4], 'labels': [N]}

    return images, targets
	
loader = DataLoader(dataset, ..., collate_fn=detection_collate_fn)
posted @ 2026-01-31 15:40  Gokix  阅读(5)  评论(0)    收藏  举报