代码学习

看到一部分这样的代码
# build models
    model = build_model(args)
    ema_model = build_model(args, ema=True)

  通过百度发现,似乎是一种叫mean teacher的半监督处理方法:(参考自

第一个model被称为student model

第二个ema_model被称为teacher model(EMA即Exponential Moving Average,指数移动平均)
  在半监督中,每个输入Batch包含一半已标注的图像与一般未标注的图像。首先,整个Batch会被送入Student Model中,得到一个预测结果。对于Batch中的已标注部分,利用结果与真值计算loss,进行梯度反向传播,从而更新Student Model的参数。而对于Batch中的未标注部分,其输入Student Model也会得到一个结果(记为A),未标注的图像在加入随机噪声后,也会被送入Teacher Model中,得到一个预测结果(记为B):

  我们希望A==B,这样的话说明模型的参数比较鲁棒泛化。

 

 

 

在TRSSL的代码中,我想加入一个磁瓦的数据集,把路径加进去之后发现bug频出,然后发现args是init里的形参,但是他在get_data里调用了.....

lass tinyimagenet_dataset():
    def __init__(self, args):
        # augmentations
        self.transform_train = transforms.Compose([
            transforms.RandomChoice([
                    transforms.RandomCrop(64, padding=8),
                    transforms.RandomResizedCrop(64, (0.5, 1.0)),
                ]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.5),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(tinyimagenet_mean, tinyimagenet_std),
        ])

        self.transform_val = transforms.Compose([
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize(mean=tinyimagenet_mean, std=tinyimagenet_std)
        ])

        base_dataset = datasets.ImageFolder(os.path.join(args.data_root, 'train'))
        base_dataset_targets = np.array(base_dataset.imgs)
        base_dataset_targets = base_dataset_targets[:,1]
        base_dataset_targets= list(map(int, base_dataset_targets.tolist()))
        train_labeled_idxs, train_unlabeled_idxs = x_u_split_seen_novel(base_dataset_targets, args.lbl_percent, args.no_class, list(range(0,args.no_seen)), list(range(args.no_seen, args.no_class)), args.imb_factor)

        self.train_labeled_idxs = train_labeled_idxs
        self.train_unlabeled_idxs = train_unlabeled_idxs
        self.temperature = args.temperature
        self.data_root = args.data_root
        self.no_seen = args.no_seen
        self.no_class = args.no_class

    def get_dataset(self, temp_uncr=None):
        train_labeled_idxs = self.train_labeled_idxs.copy()
        train_unlabeled_idxs = self.train_unlabeled_idxs.copy()

        train_labeled_dataset = GenericSSL(os.path.join(args.data_root, 'train'), train_labeled_idxs, transform=self.transform_train, temperature=self.temperature)
        train_unlabeled_dataset = GenericSSL(os.path.join(args.data_root, 'train'), train_unlabeled_idxs, transform=TransformTwice(self.transform_train), temperature=self.temperature, temp_uncr=temp_uncr)

        if temp_uncr is not None:
            return train_labeled_dataset, train_unlabeled_dataset

        train_uncr_dataset = GenericUNCR(os.path.join(args.data_root, 'train'), train_unlabeled_idxs, transform=self.transform_train)
        test_dataset_seen = GenericTEST(os.path.join(args.data_root, 'test'), no_class=args.no_class, transform=self.transform_val, labeled_set=list(range(0,args.no_seen)))
        test_dataset_novel = GenericTEST(os.path.join(args.data_root, 'test'), no_class=args.no_class, transform=self.transform_val, labeled_set=list(range(args.no_seen, args.no_class)))
        test_dataset_all = GenericTEST(os.path.join(args.data_root, 'test'), no_class=args.no_class, transform=self.transform_val)
        return train_labeled_dataset, train_unlabeled_dataset, train_uncr_dataset, test_dataset_all, test_dataset_seen, test_dataset_novel

 

posted @ 2023-04-07 21:08  浪矢-CL  阅读(15)  评论(0编辑  收藏  举报