NeverDelay

 

3.2使用PyTorch搭建AlexNet并训练花分类数据集

 

1、搭建AlexNet网络

2、如何使用自己的数据集使用网络

 

3、pytorch卷积操作详解:https://blog.csdn.net/qq_37541097/article/details/102926037?spm=1001.2014.3001.5501 ——————————————————————————————————————————————————————

1、搭建AlexNet网络

     网络结构:

 

 定义网络类:

  • 定义初始化函数:定义网络在正向传播中所需要使用的层结构。

features:神经网络层

使用nn.Sequential模块:网络层次比较多可以使用nn.Sequential简洁代码

kernel_size:卷积核个数,stride:卷积核步长,padding:边缘补0的个数,如果是padding=(1,2)(上下,左右)则为上下方各补一行0,左右两侧各补两列零;如果上下左右补零行数都不同,则使用nn.ZeroPad2d((1,2,1,2))(左右上下)

classifier:全连接层

  使用nn.Sequential模块

  使用nn.Dropout

  初始化判断init_weights(nn.Sequential中的参数)(pytorch会自动初始化)

  • 定义前向传播

     传入神经网络层

使用torch.flatten(x,start_dim=1):展平处理,索引从1开始(由于神经网络层与全连接层结构差异,所以需要展平处理)

传入全连接层

  • 定义初始化判断函数

2、训练模型train.py

  导入头文件

  定义使用CPU设备:

    device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")

    print(device)

  数据预处理函数:

    训练集操作:torchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起。

      • 随机裁剪(224*224大小)
      • 随机翻转
      • 转化为tensor
      • 标准化处理

    验证集操作

      • Resize为标准大小(224,224)
      • 转化为tensor
      • 标准化处理

    获取数据集所在根目录

3、pytorch卷积操作详解:

          pytorch中Tensor通道排列顺序:[batch,channel,height,width]

          常用的卷积函数:

          torch.nn.Conv2d(

        • in_channels,     #代表输入特征矩阵的深度即channel,比如输入一张RGB彩色图像,那in_channels=3
        • out_channels,  #代表卷积核的个数,使用n个卷积核输出的特征矩阵深度即channel就是n
        • kernel_size,      #代表卷积核的尺寸,输入可以是int类型如3 代表卷积核的height=width=3,也可以是tuple类型如(3, 5)代表卷积核的height=3,width=5
        • stride=1,           #代表卷积核的步距默认为1,和kernel_size一样输入可以是int类型,也可以是tuple类型
        • padding=0,        #参数代表在输入特征矩阵四周补零的情况默认为0,同样输入可以为int型如1 代表上下方向各补一行0元素,左右方向各补一列0像素(即补一圈0)
        • dilation=1,
        • groups=1,
        • bias=True,       #表示是否使用偏置(默认使用)
        • padding_mode='zeros'

                                    )

 

           在卷积操作中,矩阵经卷积后的尺寸由4个因数决定:

    •  输入图片大小W*W
    •  Filter大小F*F
    •  步长S
    • padding的像素数

    经卷积后的矩阵尺寸大小计算公式为:

 

      N = (W − F + 2P ) / S + 1

 

                 

 

N为非整数的情况:在卷积过程中会直接删除多余的行和列来保证卷积的输出尺寸为整数,以保证N为整数

 

posted on 2021-05-30 14:54  NeverDelay  阅读(332)  评论(0编辑  收藏  举报

导航