第五节下,图像分类半监督

class semiDataset(Dataset):
    def __init__(self, no_label_loder, model, device, thres=0.99):
        x, y = self.get_label(no_label_loder, model, device, thres)
        if x == []:
            self.flag = False

        else:
            self.flag = True
            self.X = np.array(x)
            self.Y = torch.LongTensor(y)
            self.transform = train_transform
    def get_label(self, no_label_loder, model, device, thres):
        model = model.to(device)
        pred_prob = []
        labels = []
        x = []
        y = []
        soft = nn.Softmax()
        with torch.no_grad():
            for bat_x, _ in no_label_loder:#经过transform的x和原始的x
                bat_x = bat_x.to(device)
                pred = model(bat_x)
                pred_soft = soft(pred)
                pred_max, pred_value = pred_soft.max(1)
                pred_prob.extend(pred_max.cpu().numpy().tolist())#extend才能合并两个列表  和append区分
                labels.extend(pred_value.cpu().numpy().tolist())

        for index, prob in enumerate(pred_prob):
            if prob > thres:
                x.append(no_label_loder.dataset[index][1])   #调用到原始的getitem
                y.append(labels[index])
        return x, y

    def __getitem__(self, item):
        return self.transform(self.X[item]), self.Y[item]
    def __len__(self):
        return len(self.X)
class noLabDataset(Dataset):
    def __init__(self,dataloader, model, device, thres=0.85):
        super(noLabDataset, self).__init__()
        self.model = model      #模型也要传入进来
        self.device = device
        self.thres = thres      #这里置信度阈值 我设置的 0.99
        x, y = self._model_pred(dataloader)        #核心, 获得新的训练数据
        if x == []:                            # 如果没有, 就不启用这个数据集
            self.flag = False
        else:
            self.flag = True
            self.x = np.array(x)
            self.y = torch.LongTensor(y)
        # self.x = np.concatenate((np.array(x), train_dataset.x),axis=0)
        # self.y = torch.cat(((torch.LongTensor(y),train_dataset.y)),dim=0)
        self.transformers = train_transform

    def _model_pred(self, dataloader):
        model = self.model
        device = self.device
        thres = self.thres
        pred_probs = []
        labels = []
        x = []
        y = []
        with torch.no_grad():                                  # 不训练, 要关掉梯度
            for data in dataloader:                            # 取数据
                imgs = data[0].to(device)
                pred = model(imgs)                              #预测
                soft = torch.nn.Softmax(dim=1)             #softmax 可以返回一个概率分布
                pred_p = soft(pred)
                pred_max, preds = pred_p.max(1)          #得到最大值 ,和最大值的位置 。 就是置信度和标签。
                pred_probs.extend(pred_max.cpu().numpy().tolist())
                labels.extend(preds.cpu().numpy().tolist())        #把置信度和标签装起来

        for index, prob in enumerate(pred_probs):
            if prob > thres:                                  #如果置信度超过阈值, 就转化为可信的训练数据
                x.append(dataloader.dataset[index][1])
                y.append(labels[index])
        return x, y

    def __getitem__(self, index):                          # getitem 和len
        x = self.x[index]
        x= self.transformers(x)
        y = self.y[index]
        return x, y

    def __len__(self):
        return len(self.x)

no_label_Loader = getDataLoader(filepath,'train_unl', batchSize)

def get_semi_loader(dataloader,model, device, thres):
    semi_set = noLabDataset(dataloader, model, device, thres)
    if semi_set.flag:   #不可用时返回空
        dataloader = DataLoader(semi_set, batch_size=dataloader.batch_size,shuffle=True)
        return dataloader
    else:
        return None
posted @ 2025-02-01 11:20  JYP0222  阅读(26)  评论(0)    收藏  举报