【猫狗数据集】从命令行接收参数

数据集下载地址:

链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
提取码:2xq4

创建数据集:https://www.cnblogs.com/xiximayou/p/12398285.html

读取数据集:https://www.cnblogs.com/xiximayou/p/12422827.html

进行训练:https://www.cnblogs.com/xiximayou/p/12448300.html

保存模型并继续进行训练:https://www.cnblogs.com/xiximayou/p/12452624.html

加载保存的模型并测试:https://www.cnblogs.com/xiximayou/p/12459499.html

划分验证集并边训练边验证:https://www.cnblogs.com/xiximayou/p/12464738.html

使用学习率衰减策略并边训练边测试:https://www.cnblogs.com/xiximayou/p/12468010.html

利用tensorboard可视化训练和测试过程:https://www.cnblogs.com/xiximayou/p/12482573.html

epoch、batchsize、step之间的关系:https://www.cnblogs.com/xiximayou/p/12405485.html

 

本节我们要在命令行接收参数,包括batch_size的值以及网络的类型。

基本上我们只需要修改main.py就行了:

main.py

import sys
sys.path.append("/content/drive/My Drive/colab notebooks")
from utils import rdata
from model import resnet
import torch.nn as nn
import torch
import numpy as np
import torchvision
import train
import torch.optim as optim

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def main(batch_size,baseline):
  train_loader,val_loader,test_loader=rdata.load_dataset(batch_size)
  if baseline:
    model =torchvision.models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features,2,bias=False)
  if torch.cuda.is_available():
    model.cuda()

  #定义训练的epochs
  num_epochs=100
  #定义学习率
  learning_rate=0.1
  #定义损失函数
  criterion=nn.CrossEntropyLoss()
  #定义优化方法,简单起见,就是用带动量的随机梯度下降
  optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1, momentum=0.9,
                            weight_decay=1*1e-4)
  scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [40,80], 0.1)
  print("训练集有:",len(train_loader.dataset))
  #print("验证集有:",len(val_loader.dataset))
  print("测试集有:",len(test_loader.dataset))
  trainer=train.Trainer(criterion,optimizer,model)
  trainer.loop(num_epochs,train_loader,val_loader,test_loader,scheduler)

if __name__ == "__main__":
  import argparse
  p=argparse.ArgumentParser()
  p.add_argument("--batch_size",type=int,default=64)
  p.add_argument("--baseline",action="store_true")
  args=p.parse_args()
  main(args.batch_size,args.baseline)

说明:我们将读取数据集、定义损失、优化器等代码放入到main()函数中,然后给main传入batch_size和baseline。使用argparse可以从命令行接收参数。add_argument()函数中,第一个参数是参数的名称,第二个是参数的类型,default是默认值,即不在命令行输入--batch_size 具体值,则会使用默认值。需要关注的是action="store_true",该参数的意思是默认baseline为False,如果在命令行中加入了--baseline,则baseline的值就为True。

结果如图所示:

没有加--batch_size,则batch_size默认为64,也就是18255/64约等于286。然后我们使用了--baseline,即默认使用resnet18模型。

 

由于图像分类一般考虑的衡量指标是top1和top5,下一节就是加上计算top5的代码了。

posted @ 2020-03-13 20:06  西西嘛呦  阅读(347)  评论(0编辑  收藏