PyTorch torch.nn 模块 Layer 总结

Pytorch 中,对 Layer 的总结,即 torch.nn 模块,官网.

0. 概述

0.1 torch.nn 模块中的类

0.1.1 torch.nn 模块中类的划分

为了理清 torch.nn 模块中的内容,可以将 torch.nn 模块中的类分为 4 种:

容器类,nn.modules.container 类,是一系列的组合

  • 主要为 nn.Sequential()

模块类,nn.modules 类下除了容器类(nn.modules.container)和模型类(nn.modules.module.Module)均被划为此类,也就是所谓的

  • 诸如含有学习参数的类 nn.Linear();不含学习参数的类 nn.Sigmoid()nn.MSELoss()

  • 使用 type() 函数查看实例化对象,可以得到 <class 'torch.nn.modules....'> 的信息

  • 容器类(nn.modules.container)和模型类(nn.modules.module.Module)也是 nn.modules 类,这里为了方便解释,所以做这样的划分。

学习参数类,nn.parameter

  • 这里为了与函数或实例的传入参数区分,加上了限定词学习,即指网络模型需要学习(即,训练,或优化)的参数

  • 一般为 nn.Parameter() 类初始化的实例

  • 使用 type() 函数查看实例化对象,可以得到 <class 'torch.nn.parameter.Parameter'> 的信息

模型类,继承 nn.Module 类自行构建的神经网络模型,具体为 nn.modules.module.Module

0.1.2 获取模型的学习参数值

学习参数类,nn.parameter 类:

  • 直接通过 .data 属性(或 .detach() 方法),得到学习参数值

  • 返回 tensor 数据类型,且 autograd 启动

模块类,nn.modules 类:

  • 通过特定的属性,返回 nn.parameter 类,再学习参数类的方法(即 .data 属性)得到

  • 如:对于 nn.Linear() 层,包括 2 种属性:.weight.bias

  • 可以在对应官方说明文档中 "Variables" 一栏中找到

  • 如果不包含参数的模块类,则没有这些属性

容器类,nn.modules.container 类:

  • 先通过索引先获取模块类,再按照模块类的方法

模型类:

  • 通过 .modules() 方法得到一个迭代器,获得模型中的所有模块类(也可能有容器类)

    • .modules()方法的第一个模型本身。对于嵌套的模块(即容器类),会以此往其子模块遍历

0.2 模块类与函数

对于一些函数层(如:激活函数层,维度和尺寸改变层,损失函数层),可以通过 torch 模块中或 torch.nn.functional 模块中的函数实现;不需要先定义层,直接当做函数使用。

如:

  • 激活函数层 nn.Sigmoid()

    • torch.nn.functional.sigmoid()torch.sigmoid(),但是 PyTorch 官方推荐使用后者
  • 展平层 torch.nn.Flatten()

    • torch.flatten()
  • 损失函数层 nn.MSELoss():

    • nn.functional.mse_loss()

实例:以下两种使用激活函数的方法等价

import torch.nn as nn
import torch.nn.functional as F
input = torch.rand((3, 5))
# 方法 1:
layer = nn.Sigmoid()
output1 = layer(input)
# 方法 2:
output2 = F.sigmoid(input)
# 方法 3:
output3 = torch.sigmoid(input)

# 输出
print(ouput1.size())
print(torch.sum(torch.abs(output2 - output1)))

0.3 引入与使用

0.3.1 引入

一般使用如下的代码引入 torch.nn 模块和 torch.nn.functional 函数模块

import torch.nn as nn
import torch.nn.functional as F

0.3.2 使用

torch.nn 模块中类(即,类)的使用:先定义,再输入

layer = nn.Linear(params)        # 先定义,设置层的参数
# 一般在 nn.Module 模型的初始化函数 __init__() 中定义
output_data = layer(input_data)  # 输入数据,得到输出数据
# 一般在函数 forward() 中定义

1. 通用层和容器层(Containers)

1.1 torch.nn.Parameter()

nn.Parameter():用于 wrap 待优化的模型参数 Tensor

参数

  • data
  • requires_grad:默认 True

常用属性方法

  • .data 属性:返回 nn.Parameter() wrap 的 tensor 的数值;返回的 tensor 与原 tensor 共用内存

  • .grad 属性:返回 nn.Parameter() wrap 的 tensor 的梯度值;返回的 tensor 与原 tensor 共用内存

  • .detach() 方法:返回一个新的 tensor,将其从计算图分离

实例:

w = nn.Parameter(torch.Tensor([1., 1.,])) 
print(w.detach())
print(w.data)
print(w.grad)

类似的函数有 nn.ParameterList()nn.ParameterDict()

1.2 torch.nn.Sequential()

将多个网络层,按照先后顺序串联起来

# 构建输入层为5,隐含层为3,输出层为1,激活函数为sigmoid的神经网络
model = nn.Sequential(
    nn.Linear(5, 3), 
    nn.Sigmoid(),
    nn.Linear(3, 1), 
    nn.Sigmoid(),
    nn.Flatten(0, -1)
)

1.3 torch.nn.Module()

torch.nn.Module() :用来管理模型待优化的参数。具体介绍参考这篇博客:PyTorch torch.nn.Module 类 构建模型

实例: nn.Module() 的基本框架

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
		# 放入需要学习的参数
    # 正向传播
    def forward(self, x):
        return y
    
    # 损失函数
    def loss_func(self, y_pred, y_true):
        return loss
    
    # 评估函数(准确率)
    def metric_func(self, y_pred, y_true):
        return metric
    
    # 优化器
    @property
    def optimizer(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

2. 线性层

2.1 torch.nn.Linear()

torch.nn.Linear():对输入数据最后的一个维度上实现 \(y = xA^{\top} + b\)

  • 主要参数:

    • in_features:输入维度
    • out_features:输出维度
    • bias:bool 类型,默认为 True,表示是否使用偏置
  • 输入与输出:

    • 输入\((*, H_{\text{in}})\)
    • 输出\((*, H_{\text{out}})\)
  • 训练参数:

    • 训练参数的数量为:\(H_{\text{in}} \cdot H_{\text{out}} + H_{\text{out}}\)

例子

layer = nn.Linear(20, 30)  # 输入维度为 30,输出维度为 20
input = torch.randn(128, 20)  # 输入数据的尺寸 (128,20)
output = layer(input)
print(output.size())  # 输出数据的尺寸 (128,30)

2.2 torch.nn.LazyLinear()

torch.nn.Linear(),与 Linear() 层的功能一样,只不过不需要输入维度参数 in_features。输入维度的参数由输入数据的最后一个维度推测的得到,即 input.shape[-1]

3. 稀疏层 Sparse Layers

3.1 torch.nn.Embedding()

torch.nn.Embedding:用于词向量的嵌入,参见 site

4. 其他

  • 激活函数层:参见博文:PyTorch 激活函数,site

  • 损失函数层:参见博文:PyTorch 损失函数,site

参考资料

文中代码:Colab, Github

posted @ 2022-05-20 12:24  veager  阅读(1982)  评论(0编辑  收藏  举报