第五节下,图像分类半监督
class semiDataset(Dataset):
def __init__(self, no_label_loder, model, device, thres=0.99):
x, y = self.get_label(no_label_loder, model, device, thres)
if x == []:
self.flag = False
else:
self.flag = True
self.X = np.array(x)
self.Y = torch.LongTensor(y)
self.transform = train_transform
def get_label(self, no_label_loder, model, device, thres):
model = model.to(device)
pred_prob = []
labels = []
x = []
y = []
soft = nn.Softmax()
with torch.no_grad():
for bat_x, _ in no_label_loder:#经过transform的x和原始的x
bat_x = bat_x.to(device)
pred = model(bat_x)
pred_soft = soft(pred)
pred_max, pred_value = pred_soft.max(1)
pred_prob.extend(pred_max.cpu().numpy().tolist())#extend才能合并两个列表 和append区分
labels.extend(pred_value.cpu().numpy().tolist())
for index, prob in enumerate(pred_prob):
if prob > thres:
x.append(no_label_loder.dataset[index][1]) #调用到原始的getitem
y.append(labels[index])
return x, y
def __getitem__(self, item):
return self.transform(self.X[item]), self.Y[item]
def __len__(self):
return len(self.X)
class noLabDataset(Dataset):
def __init__(self,dataloader, model, device, thres=0.85):
super(noLabDataset, self).__init__()
self.model = model #模型也要传入进来
self.device = device
self.thres = thres #这里置信度阈值 我设置的 0.99
x, y = self._model_pred(dataloader) #核心, 获得新的训练数据
if x == []: # 如果没有, 就不启用这个数据集
self.flag = False
else:
self.flag = True
self.x = np.array(x)
self.y = torch.LongTensor(y)
# self.x = np.concatenate((np.array(x), train_dataset.x),axis=0)
# self.y = torch.cat(((torch.LongTensor(y),train_dataset.y)),dim=0)
self.transformers = train_transform
def _model_pred(self, dataloader):
model = self.model
device = self.device
thres = self.thres
pred_probs = []
labels = []
x = []
y = []
with torch.no_grad(): # 不训练, 要关掉梯度
for data in dataloader: # 取数据
imgs = data[0].to(device)
pred = model(imgs) #预测
soft = torch.nn.Softmax(dim=1) #softmax 可以返回一个概率分布
pred_p = soft(pred)
pred_max, preds = pred_p.max(1) #得到最大值 ,和最大值的位置 。 就是置信度和标签。
pred_probs.extend(pred_max.cpu().numpy().tolist())
labels.extend(preds.cpu().numpy().tolist()) #把置信度和标签装起来
for index, prob in enumerate(pred_probs):
if prob > thres: #如果置信度超过阈值, 就转化为可信的训练数据
x.append(dataloader.dataset[index][1])
y.append(labels[index])
return x, y
def __getitem__(self, index): # getitem 和len
x = self.x[index]
x= self.transformers(x)
y = self.y[index]
return x, y
def __len__(self):
return len(self.x)
no_label_Loader = getDataLoader(filepath,'train_unl', batchSize)
def get_semi_loader(dataloader,model, device, thres):
semi_set = noLabDataset(dataloader, model, device, thres)
if semi_set.flag: #不可用时返回空
dataloader = DataLoader(semi_set, batch_size=dataloader.batch_size,shuffle=True)
return dataloader
else:
return None

浙公网安备 33010602011771号