pytorch构建项目

  • 搭建model
  1. 继承于nn.Module,关键在于def __init__()的初始化传参,这里面也可以用来定义一些layer,实例化model的时候被调用,以及def forward()用来传入input,函数内真正将input传入model的layer;
  2. nn.Sequential()的用法;
  3. nn.ReLU()用来增加model的非线性连接;
  4. nn.Softmax()用来scale to[0,1]之间,更加符合0,1分类任务;
  • 指定device
  1. 语句一
# 单卡
device=("cuda" if torch.cuda.is_available() else "cpu")
  • 制作自己的dataset
  1. 继承于torchvision.utils.data里面的Dataset
  2. 包括三个函数:def __init__()传一些文件路径之类的参数,包括transform参数,里面定义好这些参数,def __len__()获取到input的个数,def __getitem__()通过iter的方式利用传入的index返回对应的input和label,返回的input和label是以tuple的形式存储的;
  • 传入Dataloader
  1. 引入Dataloader库,from torch.utils.data import Dataloader
  2. 一行代码
train_dataloader=Dataloader(train_dataset, batch_size=64, shuffle=True) # 遍历完全部的batch再进行shuffle

后期用到的时候

for batch, (data, label) in enumerate(train_dataloader):
  pass

源于这么一个内置的函数,可以取batchfeatures, lables=next(iter(train_dataloader))

  • transform可以自定义
  1. Lambda()
Lambda(lambda y: 操作) # y是对象
  • 反传Backpropagation⭐
  1. 设置参数requires_grad=True,nn.Module里面似乎默认设置为True;
  2. 如果不需要反传,两种方法:Ⅰ. 用with torch.no_grad():进行包裹;Ⅱ. detach掉(不推荐).detach()
  3. 停止反传场景有二:Ⅰ. 冻住部分参数(To mark some parameters in your neural network as frozen parameters.);Ⅱ. 加快forward速度(To speed up computations when you are only doing forward pass, because computations on tensors that do not track gradients would be more efficient.);
  • 最后,save and load model
  1. 两个命令:torch.save()以及torch.load()
  2. 如果是save/load参数,关键字state_dict(),但是要依附于model,即model.state_dict(),如此导致model必须完全一致;
  • 总结
  1. 其实就是四步: 放进dataloader(dataset),实例化model并计算predict=model(input),计算loss(pred,label)并梯度反传.backward(), 优化参数optimizer.step()并紧跟optimizer.zero_grad()
  2. 这其中很重要的就是预先定义一些超参(一般发生在train.py文件):dataloader 要传batch_size, optimizer 要传learning_rate
  3. 对于每个epoch来说,就是执行一个train_per_epoch()以及test_per_epoch() 的过程;
  4. 对于每个train/test_per_epoch(),里面做的事情无非就是:Ⅰ. 从dataloader里面拿出batch,遍历batch;Ⅱ. 计算每一个batchpredict结果与label之间的loss或者评价指标;Ⅲ. loss反传,optimizer优化参数(test无);Ⅳ. print loss(评价指标) 结果。
posted @ 2023-07-05 17:42  Elina-Chang  阅读(41)  评论(0)    收藏  举报