深度学习(使用torchvision预训练权重)

在做深度学习时,很多任务backbone是不需要完全重新训练的,使用预训练权重能够加快收敛,并且性能更好。

下面使用了torchvision中的efficientnet_b0网络,网络默认会生成原图1/32大小特征图后进分类器,这里截断到生成1/16大小的特征图部分,后面可以根据需求增加新的结构。

import torch
from torchvision.models import mobilenet_v3_small,resnet18,efficientnet_b0

#extract features
class FeatureExtractor(torch.nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        # model = mobilenet_v3_small(pretrained=True)
        # self.features = torch.nn.Sequential(*list(model.children())[:-2])[0] 
        # self.features = torch.nn.Sequential(*list(self.features.children())[:-4])  # H/16,W/16
  
        # model = resnet18(pretrained=True)
        # self.features = torch.nn.Sequential(*list(model.children())[:-3])     # H/16,W/16

        model = efficientnet_b0(pretrained=True)
        self.features = torch.nn.Sequential(*list(model.children())[:-2])[0]  
        self.features = torch.nn.Sequential(*list(self.features.children())[:-3])  # H/16,W/16

    def forward(self, x):
        return self.features(x)


x = torch.randn(1, 3, 224, 224)
model = FeatureExtractor()
model.eval()  # Set the model to evaluation mode
y = model(x)
print(y.shape)  # Print the shape of the output

torch.onnx.export(model, x, "efficientnet_b0_1_16.onnx", export_params=True, opset_version=11,
                  do_constant_folding=True, input_names=['input'], output_names=['output'])

下表是torchvision提供的可用分类权重的表格,准确率是使用 ImageNet-1K 单裁剪图像。

可以作为参考:

权重

Acc@1

Acc@5

Params

GFLOPS

训练方法

AlexNet_Weights.IMAGENET1K_V1

56.522

79.066

61.1M

0.71

链接

ConvNeXt_Base_Weights.IMAGENET1K_V1

84.062

96.87

88.6M

15.36

链接

ConvNeXt_Large_Weights.IMAGENET1K_V1

84.414

96.976

197.8M

34.36

链接

ConvNeXt_Small_Weights.IMAGENET1K_V1

83.616

96.65

50.2M

8.68

链接

ConvNeXt_Tiny_Weights.IMAGENET1K_V1

82.52

96.146

28.6M

4.46

链接

DenseNet121_Weights.IMAGENET1K_V1

74.434

91.972

8.0M

2.83

链接

DenseNet161_Weights.IMAGENET1K_V1

77.138

93.56

28.7M

7.73

链接

DenseNet169_Weights.IMAGENET1K_V1

75.6

92.806

14.1M

3.36

链接

DenseNet201_Weights.IMAGENET1K_V1

76.896

93.37

20.0M

4.29

链接

EfficientNet_B0_Weights.IMAGENET1K_V1

77.692

93.532

5.3M

0.39

链接

EfficientNet_B1_Weights.IMAGENET1K_V1

78.642

94.186

7.8M

0.69

链接

EfficientNet_B1_Weights.IMAGENET1K_V2

79.838

94.934

7.8M

0.69

链接

EfficientNet_B2_Weights.IMAGENET1K_V1

80.608

95.31

9.1M

1.09

链接

EfficientNet_B3_Weights.IMAGENET1K_V1

82.008

96.054

12.2M

1.83

链接

EfficientNet_B4_Weights.IMAGENET1K_V1

83.384

96.594

19.3M

4.39

链接

EfficientNet_B5_Weights.IMAGENET1K_V1

83.444

96.628

30.4M

10.27

链接

EfficientNet_B6_Weights.IMAGENET1K_V1

84.008

96.916

43.0M

19.07

链接

EfficientNet_B7_Weights.IMAGENET1K_V1

84.122

96.908

66.3M

37.75

链接

EfficientNet_V2_L_Weights.IMAGENET1K_V1

85.808

97.788

118.5M

56.08

链接

EfficientNet_V2_M_Weights.IMAGENET1K_V1

85.112

97.156

54.1M

24.58

链接

EfficientNet_V2_S_Weights.IMAGENET1K_V1

84.228

96.878

21.5M

8.37

链接

GoogLeNet_Weights.IMAGENET1K_V1

69.778

89.53

6.6M

1.5

链接

Inception_V3_Weights.IMAGENET1K_V1

77.294

93.45

27.2M

5.71

链接

MNASNet0_5_Weights.IMAGENET1K_V1

67.734

87.49

2.2M

0.1

链接

MNASNet0_75_Weights.IMAGENET1K_V1

71.18

90.496

3.2M

0.21

链接

MNASNet1_0_Weights.IMAGENET1K_V1

73.456

91.51

4.4M

0.31

链接

MNASNet1_3_Weights.IMAGENET1K_V1

76.506

93.522

6.3M

0.53

链接

MaxVit_T_Weights.IMAGENET1K_V1

83.7

96.722

30.9M

5.56

链接

MobileNet_V2_Weights.IMAGENET1K_V1

71.878

90.286

3.5M

0.3

链接

MobileNet_V2_Weights.IMAGENET1K_V2

72.154

90.822

3.5M

0.3

链接

MobileNet_V3_Large_Weights.IMAGENET1K_V1

74.042

91.34

5.5M

0.22

链接

MobileNet_V3_Large_Weights.IMAGENET1K_V2

75.274

92.566

5.5M

0.22

链接

MobileNet_V3_Small_Weights.IMAGENET1K_V1

67.668

87.402

2.5M

0.06

链接

RegNet_X_16GF_Weights.IMAGENET1K_V1

80.058

94.944

54.3M

15.94

链接

RegNet_X_16GF_Weights.IMAGENET1K_V2

82.716

96.196

54.3M

15.94

链接

RegNet_X_1_6GF_Weights.IMAGENET1K_V1

77.04

93.44

9.2M

1.6

链接

RegNet_X_1_6GF_Weights.IMAGENET1K_V2

79.668

94.922

9.2M

1.6

链接

RegNet_X_32GF_Weights.IMAGENET1K_V1

80.622

95.248

107.8M

31.74

链接

RegNet_X_32GF_Weights.IMAGENET1K_V2

83.014

96.288

107.8M

31.74

链接

RegNet_X_3_2GF_Weights.IMAGENET1K_V1

78.364

93.992

15.3M

3.18

链接

RegNet_X_3_2GF_Weights.IMAGENET1K_V2

81.196

95.43

15.3M

3.18

链接

RegNet_X_400MF_Weights.IMAGENET1K_V1

72.834

90.95

5.5M

0.41

链接

RegNet_X_400MF_Weights.IMAGENET1K_V2

74.864

92.322

5.5M

0.41

链接

RegNet_X_800MF_Weights.IMAGENET1K_V1

75.212

92.348

7.3M

0.8

链接

RegNet_X_800MF_Weights.IMAGENET1K_V2

77.522

93.826

7.3M

0.8

链接

RegNet_X_8GF_Weights.IMAGENET1K_V1

79.344

94.686

39.6M

8

链接

RegNet_X_8GF_Weights.IMAGENET1K_V2

81.682

95.678

39.6M

8

链接

RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1

88.228

98.682

644.8M

374.57

链接

RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

86.068

97.844

644.8M

127.52

链接

RegNet_Y_16GF_Weights.IMAGENET1K_V1

80.424

95.24

83.6M

15.91

链接

RegNet_Y_16GF_Weights.IMAGENET1K_V2

82.886

96.328

83.6M

15.91

链接

RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1

86.012

98.054

83.6M

46.73

链接

RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

83.976

97.244

83.6M

15.91

链接

RegNet_Y_1_6GF_Weights.IMAGENET1K_V1

77.95

93.966

11.2M

1.61

链接

RegNet_Y_1_6GF_Weights.IMAGENET1K_V2

80.876

95.444

11.2M

1.61

链接

RegNet_Y_32GF_Weights.IMAGENET1K_V1

80.878

95.34

145.0M

32.28

链接

RegNet_Y_32GF_Weights.IMAGENET1K_V2

83.368

96.498

145.0M

32.28

链接

RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1

86.838

98.362

145.0M

94.83

链接

RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

84.622

97.48

145.0M

32.28

链接

RegNet_Y_3_2GF_Weights.IMAGENET1K_V1

78.948

94.576

19.4M

3.18

链接

RegNet_Y_3_2GF_Weights.IMAGENET1K_V2

81.982

95.972

19.4M

3.18

链接

RegNet_Y_400MF_Weights.IMAGENET1K_V1

74.046

91.716

4.3M

0.4

链接

RegNet_Y_400MF_Weights.IMAGENET1K_V2

75.804

92.742

4.3M

0.4

链接

RegNet_Y_800MF_Weights.IMAGENET1K_V1

76.42

93.136

6.4M

0.83

链接

RegNet_Y_800MF_Weights.IMAGENET1K_V2

78.828

94.502

6.4M

0.83

链接

RegNet_Y_8GF_Weights.IMAGENET1K_V1

80.032

95.048

39.4M

8.47

链接

RegNet_Y_8GF_Weights.IMAGENET1K_V2

82.828

96.33

39.4M

8.47

链接

ResNeXt101_32X8D_Weights.IMAGENET1K_V1

79.312

94.526

88.8M

16.41

链接

ResNeXt101_32X8D_Weights.IMAGENET1K_V2

82.834

96.228

88.8M

16.41

链接

ResNeXt101_64X4D_Weights.IMAGENET1K_V1

83.246

96.454

83.5M

15.46

链接

ResNeXt50_32X4D_Weights.IMAGENET1K_V1

77.618

93.698

25.0M

4.23

链接

ResNeXt50_32X4D_Weights.IMAGENET1K_V2

81.198

95.34

25.0M

4.23

链接

ResNet101_Weights.IMAGENET1K_V1

77.374

93.546

44.5M

7.8

链接

ResNet101_Weights.IMAGENET1K_V2

81.886

95.78

44.5M

7.8

链接

ResNet152_Weights.IMAGENET1K_V1

78.312

94.046

60.2M

11.51

链接

ResNet152_Weights.IMAGENET1K_V2

82.284

96.002

60.2M

11.51

链接

ResNet18_Weights.IMAGENET1K_V1

69.758

89.078

11.7M

1.81

链接

ResNet34_Weights.IMAGENET1K_V1

73.314

91.42

21.8M

3.66

链接

ResNet50_Weights.IMAGENET1K_V1

76.13

92.862

25.6M

4.09

链接

ResNet50_Weights.IMAGENET1K_V2

80.858

95.434

25.6M

4.09

链接

ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1

60.552

81.746

1.4M

0.04

链接

ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1

69.362

88.316

2.3M

0.14

链接

ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1

72.996

91.086

3.5M

0.3

链接

ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1

76.23

93.006

7.4M

0.58

链接

SqueezeNet1_0_Weights.IMAGENET1K_V1

58.092

80.42

1.2M

0.82

链接

SqueezeNet1_1_Weights.IMAGENET1K_V1

58.178

80.624

1.2M

0.35

链接

Swin_B_Weights.IMAGENET1K_V1

83.582

96.64

87.8M

15.43

链接

Swin_S_Weights.IMAGENET1K_V1

83.196

96.36

49.6M

8.74

链接

Swin_T_Weights.IMAGENET1K_V1

81.474

95.776

28.3M

4.49

链接

Swin_V2_B_Weights.IMAGENET1K_V1

84.112

96.864

87.9M

20.32

链接

Swin_V2_S_Weights.IMAGENET1K_V1

83.712

96.816

49.7M

11.55

链接

Swin_V2_T_Weights.IMAGENET1K_V1

82.072

96.132

28.4M

5.94

链接

VGG11_BN_Weights.IMAGENET1K_V1

70.37

89.81

132.9M

7.61

链接

VGG11_Weights.IMAGENET1K_V1

69.02

88.628

132.9M

7.61

链接

VGG13_BN_Weights.IMAGENET1K_V1

71.586

90.374

133.1M

11.31

链接

VGG13_Weights.IMAGENET1K_V1

69.928

89.246

133.0M

11.31

链接

VGG16_BN_Weights.IMAGENET1K_V1

73.36

91.516

138.4M

15.47

链接

VGG16_Weights.IMAGENET1K_V1

71.592

90.382

138.4M

15.47

链接

VGG16_Weights.IMAGENET1K_FEATURES

nan

nan

138.4M

15.47

链接

VGG19_BN_Weights.IMAGENET1K_V1

74.218

91.842

143.7M

19.63

链接

VGG19_Weights.IMAGENET1K_V1

72.376

90.876

143.7M

19.63

链接

ViT_B_16_Weights.IMAGENET1K_V1

81.072

95.318

86.6M

17.56

链接

ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1

85.304

97.65

86.9M

55.48

链接

ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1

81.886

96.18

86.6M

17.56

链接

ViT_B_32_Weights.IMAGENET1K_V1

75.912

92.466

88.2M

4.41

链接

ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1

88.552

98.694

633.5M

1016.72

链接

ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1

85.708

97.73

632.0M

167.29

链接

ViT_L_16_Weights.IMAGENET1K_V1

79.662

94.638

304.3M

61.55

链接

ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1

88.064

98.512

305.2M

361.99

链接

ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1

85.146

97.422

304.3M

61.55

链接

ViT_L_32_Weights.IMAGENET1K_V1

76.972

93.07

306.5M

15.38

链接

Wide_ResNet101_2_Weights.IMAGENET1K_V1

78.848

94.284

126.9M

22.75

链接

Wide_ResNet101_2_Weights.IMAGENET1K_V2

82.51

96.02

126.9M

22.75

链接

Wide_ResNet50_2_Weights.IMAGENET1K_V1

78.468

94.086

68.9M

11.4

链接

Wide_ResNet50_2_Weights.IMAGENET1K_V2

81.602

95.758

68.9M

11.4

链接

参考:模型和预训练权重 — Torchvision 0.22 文档 - PyTorch 深度学习库

posted @ 2025-06-23 20:40  Dsp Tian  阅读(65)  评论(0)    收藏  举报