用pytorch训练模型-入门
python基础
列表 list
可以理解为变量数组。
-
定义:
name = ['Alan', 'Bishop', 'He', 'Xavier'] -
修改元素:
name[2] = 'Srivastava' -
末尾添加元素:
name.append('Yann') -
任意位置添加元素:
name.insert(2, 'Ren') -
删除元素:
del name[0]
元组 tuple
可以理解为常量数组。
-
定义:
bet = ('alpha', 'beta', 'gamma') -
不允许增删改
字典 dictionary
可以理解为key-value pair组,或者说hashmap。
-
定义:
tech = {'name':'ResNet', 'author':'He', 'year':2016} -
修改元素:
tech['author'] = 'He et al.' -
添加元素:
tech['category'] = 'CNN' -
删除元素:
del tech['year']
切片 slice
取出list/tuple的一部分。
有三个参数 x[start:end:stride]。其中start和end是左闭右开的,默认为0和n,可以是负数,表示倒数第几个;stride默认为1且可省略,也可以是负数,表示倒着选取。
类 class
class A(B):
# initialization method
def __init__ (self, a, b, c, d):
super().__init__()
# do something
# other method
def foo:
#do something
# instantiate sample
x = A(a, b, c, d)
其中,__init__是初始化方法,在实例化时会自动调用。__init__的第一个参数固定为self,代表Student类本身,其余的参数则是自己定义的在初始化时需要用到的参数。
class A(B)表示A继承自B,在A的初始化方法中使用super().__init__()执行父类B的初始化函数(或者按照python2的习惯写作super(A, self).__init__())。需要指出的是,python处理菱形继承问题的方法是将所有类置于MRO链上,super()指代的是MRO链的下一个(不一定是父类),比如继承关系A(B, C) B(D) C(D)的MRO链可能是:A->B->C->D->object,此时super().__init__()或super(A, self).__init__()会沿着MRO链自然地执行上述父类的初始化,但B.__init__(),C.__init__()是错误的行为,会导致D被重复初始化。
torch基础
tensor操作参见 d2l学习笔记。
所有网络继承自nn.Module,必须实现两个方法:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
# 初始化方法 声明所需的层
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5, 1, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# 前向传播
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
# instantiation
net = Net().to(DEVICE)
# using instance
outputs = net(inputs)
其中nn.Conv2d是卷积层,前五个参数依次是in_channels, out_channels, kernel_size, stride, padding,满足关系 \(O = \frac{I - K + 2P}{S} +1\)(注:这里的O和I是输入和输出的高宽,不是通道)。nn.Linear是全连接层,前两个参数依次是in_features, out_features。
在应用示例时,前向传播方法forward会自动使用(这一效果在nn.Module中实现)。
torchvision基础
transforms
transform是传入数据的预处理,包括归一化、数据增强等操作都在这一部分实现。
使用transforms.Compose()接收一个操作list,可以串联依次执行多个操作。
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
Dataset
Dataset是数据集类,继承自torch.utils.data.Dataset,必须实现三个方法:
from torchvision import datasets
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data_path, transform):
# 初始化方法
super().__init__()
self.data = ...
self.transform = ...
def __getitem__(self, idx):
# 返回第 idx 个样本
# 比如分类问题的返回格式应为 data, label
image = load_image(self.data[idx])
if self.transform:
image = self.transform(image)
return image, label
def __len__(self):
# 返回数据集总长度
return len(self.data)
trainset = MyDataset(data_path=train_path, transform=train_transform)
testset = MyDataset(data_path=test_path, transform=test_transform)
部分经典数据集在Dataset中已经内置,如MNIST数据集:
trainset = datasets.MNIST(
root = data_path,
train = True,
download = False,
transform = transform
)
testset = datasets.MNIST(
root = data_path,
train = False,
download = False,
transform = transform
)
Dataloader
DataLoader是Dataset的可迭代对象,每次调用时返回一个batch的数据。
trainloader = DataLoader(
dataset = trainset,
batch_size = 64,
shuffle = True,
num_workers = 2
)
testloader = DataLoader(
dataset = testset,
batch_size = 64,
shuffle = False,
num_workers = 2
)
DataLoader使用collate_fn将多个样本堆叠成一个batch,比如在batch_size=64的情况下,64张[1, 28, 28]的图片组[(img_0, label_0), (img_1, label_1), ..., (img_63, label_63)]经collate_fn堆叠后输出形式为(tensor([64, 1, 28, 28]), tensor([64]))。
训练时只需按如下方式解析即可得到一个batch的图片数据和类别数据。
for i, data in enumerate(trainloader):
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
# do something
当每个样本格式不统一时(比如一张图片目标检测有多个锚框和类别),需要自己写collate函数,将多个样本堆叠为一个batch。
def detection_collate_fn(batch):
"""
batch: [(img_0, target_0), (img_1, target_1), ...]
每个 target 包含 bbox[N, 4] 和 labels[N],N对于每张图可能不同
"""
images = []
targets = []
for img, target in batch:
images.append(img)
targets.append(target)
# images: list of [3, H, W] with different H, W
# targets: list of dict {'boxes': [N, 4], 'labels': [N]}
return images, targets
loader = DataLoader(dataset, ..., collate_fn=detection_collate_fn)

如何使用torch, torchvision等库写代码、训练模型。适合对象:初学者,预先学习过基础的理论知识,有阅读代码或LLM辅助写过简单模型代码的经历,想要训练自己模型代码能力
浙公网安备 33010602011771号