小土堆pytorch学习—P21-线性层

一些结构没有讲解,包括正则化层(一篇论文提到采用正则化可以加快训练速度)、Recurrent Layers (看需要,平常用不到),Transformer Layers、Dropout Layers(主要是为了防止过拟合)

线性层用的比较多,所以讲这个层。

线性层

image-20230705161557447

image-20230705161719243

图片中的in_features就是\(x_1,x_2,...,x_d\),output是\(o_1,...,o_m\)

以下是将5$\times\(5的图片转换为1\)\times$25的案例代码👇

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader

test_data = torchvision.datasets.CIFAR10(root = "hymenoptera_data/val/CIFAR10" , train = False
                                        ,transform = torchvision.transforms.ToTensor())
data_loader = DataLoader(dataset = test_data , batch_size = 64)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui , self).__init__()
        self.linear1 = nn.Linear(196608 , 10)

    def forward(self , input):
        output = self.linear1(input)
        return output

tudui = Tudui()

for data in data_loader:
    imgs , targets = data
    print(f"imgs shape is {imgs.shape}")
    print("\n")
    #以下这两条语句作用一致
    # output = torch.reshape(imgs,(1,1,1,-1))
    output = torch.flatten(imgs)
    print(F"output_1 shape is {output.shape}")
    print("\n")

    output = tudui(output)
    print(F"output_2 shape is {output.shape}")
    print("\n")

    print("="*111)

torch.reshape的效果可以用torch.flattern来做。都是展成1维。

posted @ 2023-07-10 16:04  西红柿爆炒鸡蛋  阅读(48)  评论(0)    收藏  举报