Pytorch在损失函数中为权重添加L1正则化

L1正则化可以使权重变稀疏,应用场景:对one-hot词袋模型中的词表进行裁剪时,根据权重weight筛选,此时需要权重越稀疏越好;

L1_Weight为超参数,可设定为1e-4

 1 def train(model, iterator, optimizer, criteon):
 2     avg_acc, avg_loss = [], []
 3     model.train()     
 4 
 5     for batch in tqdm(iterator):
 6         text, label = batch[0].cuda(), batch[1].cuda()         
 7 
 8         pred = model(text)    
 9         l1_penalty = L1_Weight * sum([p.abs().sum() for p in model.fc.parameters()])
10         loss = criteon(pred, label.long())        
11         loss_with_penalty = loss + l1_penalty
12 
13         acc = utils.binary_acc(torch.argmax(pred.cpu(), dim=1), label.cpu().long())  
14         avg_acc.append(acc)
15         avg_loss.append(loss.item())
16 
17         optimizer.zero_grad()
18         loss_with_penalty.backward()
19         #loss.backward()
20         optimizer.step()
21 
22     avg_acc = np.array(avg_acc).mean()
23     avg_loss = np.array(avg_loss).mean()
24     train_metrics = {'train_acc': avg_acc,
25                      'train_loss': avg_loss
26                      }
27     logging.info(train_metrics)
28     return avg_acc, avg_loss

 

posted @ 2021-06-29 16:48  最咸的鱼  阅读(3196)  评论(0编辑  收藏  举报