pytorch模型层的增、删、改

print打印下模型,可以看到各层的名字(小括号里的是层名)

如果层名是数字,写成 model.model[0].act 来定位到act层

import torch
import ultralytics
from ultralytics import YOLO

yolo = YOLO("yolo11n-cls.pt")
print(yolo)  #小括号里的是层名

,删除act层

del yolo.model.model[0].act #删除act层
print(yolo) #再查看下是否删除了

 

 改,改linear层

将输出由1000类改为6类

in_features=yolo.model.model[10].linear.in_features
yolo.model.model[10].linear=nn.Linear(in_features,out_features=6)
# in_features=yolo.get_submodule("model.model.10.linear").in_features
# yolo.set_submodule("model.model.10.linear",nn.Linear(in_features,out_features=6))
print(yolo)

两种方式:直接 yolo.model.model[10].linear,或者用set_submodule()方法(10如果在字符串里就直接写10即可)

 

 

怎么改某个class的forward???

YOLOv5改进 | Head | 将yolov5的检测头替换为ASFF_Detect_yolov5head改进-CSDN博客

CV党福音:YOLOv8实现分类_51CTO博客_yolov1实现

posted @ 2024-12-09 15:44  夕西行  阅读(122)  评论(0)    收藏  举报