pytorch导入pytorch的预训练模型

import torch.utils.model_zoo
from torchvision import models


class PoseEstimationWithMobileNet3(nn.Module):
def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
super().__init__()
self.backbone=models.mobilenet_v3_small(pretrained=True)
self.backbone=nn.Sequential(*list(self.backbone.features.children())[0:4])#输出通道是24,接个24-128的点积

posted @ 2022-05-28 19:36  祥瑞哈哈哈  阅读(137)  评论(0)    收藏  举报