【pytorch】关于深度学习模型是怎么使数据从头流动到尾的
问题描述
之前在看cycleGAN的代码时想到一个问题
代码里用类的方式定义cycleGAN模型,各个模块是以一个列表里的变量存在的
那么模型在进行forward时,是怎么知道各个模块之间的顺序的?或者说是怎么控制张量正确地从头走到尾的?
cycleGAN的方式
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
回头重新看了一下cycleGAN的代码,发现这里是用nn.Sequential方法把添加好的模块串起来的
因为cycleGAN的生成器结构比较简单,各个模块之间直接头尾相接就可以了
nn.Sequential方法
当一个模型较简单的时候,我们可以使用torch.nn.Sequential类来实现简单的顺序连接模型。这个模型也是继承自Module类的
参考文献:https://blog.csdn.net/qq_27825451/article/details/90551513
21. 最简单的序贯模型
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
print(model)
print(model[2]) # 通过索引获取第几个层
'''运行结果为:
Sequential(
(0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(3): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''
注意:这样做有一个问题,每一个层是没有名称,默认的是以0、1、2、3来命名,从上面的运行结果也可以看出。
2.2 给每一个层添加名称
import torch.nn as nn
from collections import OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
print(model)
print(model[2]) # 通过索引获取第几个层
'''运行结果为:
Sequential(
(conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(relu1): ReLU()
(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(relu2): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''
注意:从上面的结果中可以看出,这个时候每一个层都有了自己的名称,但是此时需要注意,我并不能够通过名称直接获取层,依然只能通过索引index,即
model[2] 是正确的
model["conv2"] 是错误的
这其实是由它的定义实现的,看上面的Sequenrial定义可知,只支持index访问。
2.3 Sequential的第三种实现
import torch.nn as nn
from collections import OrderedDict
model = nn.Sequential()
model.add_module("conv1",nn.Conv2d(1,20,5))
model.add_module('relu1', nn.ReLU())
model.add_module('conv2', nn.Conv2d(20,64,5))
model.add_module('relu2', nn.ReLU())
print(model)
print(model[2]) # 通过索引获取第几个层
稍复杂的模型的连接方式
以下是yolov10模型的定义代码
class DetectionModel(BaseModel):
"""YOLOv8检测模型。"""
def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # 模型配置、输入通道数、类别数量
"""使用给定的配置和参数初始化YOLOv8检测模型。"""
super().__init__()
# 加载配置:如果cfg是字典则直接使用,否则从yaml文件加载
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # 配置字典
# 定义模型输入通道数
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # 输入通道数(默认从配置获取,否则使用传入的ch)
# 如果指定了类别数且与配置中的不同,则覆盖配置中的类别数
if nc and nc != self.yaml["nc"]:
LOGGER.info(f"覆盖model.yaml中的nc={self.yaml['nc']}为nc={nc}")
self.yaml["nc"] = nc # 覆盖YAML中的类别数
# 解析模型配置,创建模型结构和需要保存的层列表
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # 模型结构和保存列表
# 初始化类别名称字典(默认用索引作为名称)
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # 默认类别名称字典
# 是否启用原地操作(in-place operation),从配置获取(默认True)
self.inplace = self.yaml.get("inplace", True)
# 计算模型的步长(stride)
m = self.model[-1] # 获取检测头(Detect层)
# 判断是否为检测类层(包括Segment、Pose等子类)
if isinstance(m, Detect): # 包括所有Detect子类,如Segment、Pose、OBB、WorldDetect
s = 256 # 最小步长的2倍(用于计算实际步长)
m.inplace = self.inplace # 设置检测头的原地操作属性
# 定义前向传播函数:根据不同检测头类型处理输出
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
# 针对v10Detect类型的特殊处理
if isinstance(m, v10Detect):
forward = lambda x: self.forward(x)["one2many"]
# 计算步长:通过输入一个随机张量,根据输出特征图尺寸反推
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # 前向计算获取步长
self.stride = m.stride # 保存步长到模型属性
m.bias_init() # 初始化偏置(仅运行一次)
else:
self.stride = torch.Tensor([32]) # 其他类型模型的默认步长(如RTDETR)
# 初始化权重和偏置
initialize_weights(self)
# 如果启用verbose模式,打印模型信息
if verbose:
self.info()
LOGGER.info("")
def _predict_augment(self, x):
"""对输入图像x执行数据增强,并返回增强后的推理结果和训练输出。"""
img_size = x.shape[-2:] # 获取图像尺寸(高,宽)
s = [1, 0.83, 0.67] # 缩放比例列表
f = [None, 3, None] # 翻转方式(2-上下翻转,3-左右翻转)
y = [] # 存储增强后的输出结果
# 遍历每个缩放比例和翻转方式
for si, fi in zip(s, f):
# 对图像进行翻转(如果需要)和缩放,保持步长对齐
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = super().predict(xi) # 前向推理获取结果
# 处理不同类型的输出(如yolov10的字典输出)
if isinstance(yi, dict):
yi = yi["one2one"] # yolov10的输出处理
if isinstance(yi, (list, tuple)):
yi = yi[0] # 取列表/元组的第一个元素
# 对预测结果进行逆缩放和逆翻转,恢复到原始图像尺度
yi = self._descale_pred(yi, fi, si, img_size)
y.append(yi) # 保存处理后的结果
y = self._clip_augmented(y) # 裁剪增强结果的冗余部分
return torch.cat(y, -1), None # 返回拼接的增强推理结果和None(训练输出)
@staticmethod
def _descale_pred(p, flips, scale, img_size, dim=1):
"""对增强推理后的预测结果进行逆缩放(增强的逆操作)。"""
p[:, :4] /= scale # 对坐标和宽高进行逆缩放(恢复到原始尺度)
# 将预测结果拆分:x坐标、y坐标、宽高、类别
x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
# 处理上下翻转的逆操作
if flips == 2:
y = img_size[0] - y # 上下翻转恢复:y坐标 = 图像高度 - y
# 处理左右翻转的逆操作
elif flips == 3:
x = img_size[1] - x # 左右翻转恢复:x坐标 = 图像宽度 - x
# 拼接处理后的结果
return torch.cat((x, y, wh, cls), dim)
def _clip_augmented(self, y):
"""裁剪YOLO增强推理结果的冗余尾部(去除增强带来的无效部分)。"""
nl = self.model[-1].nl # 检测层数量(如P3-P5对应3层)
g = sum(4**x for x in range(nl)) # 计算总网格点数量
e = 1 # 要排除的层数量
# 计算大尺度增强结果的裁剪索引
i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # 索引计算
y[0] = y[0][..., :-i] # 裁剪大尺度结果的尾部
# 计算小尺度增强结果的裁剪索引
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # 索引计算
y[-1] = y[-1][..., i:] # 裁剪小尺度结果的头部
return y
def init_criterion(self):
"""初始化DetectionModel的损失函数。"""
return v8DetectionLoss(self) # 返回YOLOv8的检测损失函数实例
在DetectionModel的BaseModel里,预测过程的前向传播是这个函数定义的
def _predict_once(self, x, profile=False, visualize=False, embed=None):
"""
Perform a forward pass through the network.
Args:
x (torch.Tensor): The input tensor to the model.
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False.
embed (list, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): The last output of the model.
"""
y, dt, embeddings = [], [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
# 非顺序结构是在这里实现的
# 通过m.f来判断本模块m的输入数据是否直接来自列表中的上一个模块,-1代表直接上一个,其他参数则代表其他层
# 如果不是的话,则定位到对应的模块,获取其输出
# 把问题简化为:要么直接来自上一个模块,要么来自之前的某个模块
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
if profile:
self._profile_one_layer(m, x, dt)
x = m(x) # run:使用m模块对数据x进行处理
y.append(x if m.i in self.save else None) # save output
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)
if embed and m.i in embed:
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
if m.i == max(embed):
return torch.unbind(torch.cat(embeddings, 1), dim=0)
return x

浙公网安备 33010602011771号