pytorch-day05(过拟合)

1、过拟合与欠拟合

 

2、如何检测过拟合(Train-Val-Test-交叉验证)

 这里的test(validation)用于挑选模型的参数,提前终止train,防止过拟合。

 validation set是用来挑选模型参数的。test set是客户在验收的时候,使用的数据集。

 把数据集分为train set、val set、test set。

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 import torch.optim as optim
  5 from torchvision import datasets, transforms
  6 
  7 batch_size = 200
  8 learning_rate = 0.01
  9 epochs = 10
 10 
 11 train_db = datasets.MNIST('../data', train=True, download=True,
 12                           transform=transforms.Compose([
 13                               transforms.ToTensor(),
 14                               transforms.Normalize((0.1307,), (0.3081,))
 15                           ]))
 16 train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
 17 
 18 test_db = datasets.MNIST('../data', train=False, transform=transforms.Compose([
 19     transforms.ToTensor(),
 20     transforms.Normalize((0.1307,), (0.3081,))
 21 ]))
 22 test_loader = torch.utils.data.DataLoader(test_db, batch_size=batch_size, shuffle=True)
 23 
 24 print('train:', len(train_db), 'test:', len(test_db))
 25 train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
 26 print('db1:', len(train_db), 'db2:', len(val_db))
 27 
 28 train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
 29 
 30 val_loader = torch.utils.data.DataLoader(val_db, batch_size=batch_size, shuffle=True)
 31 
 32 class MLP(nn.Module):
 33     def __init__(self):
 34         super(MLP, self).__init__()
 35 
 36         self.model = nn.Sequential(
 37             nn.Linear(784, 200),
 38             nn.LeakyReLU(inplace=True),
 39             nn.Linear(200, 200),
 40             nn.LeakyReLU(inplace=True),
 41             nn.Linear(200, 10),
 42             nn.LeakyReLU(inplace=True),
 43         )
 44 
 45     def forward(self, x):
 46         x = self.model(x)
 47         return x
 48 
 49 
 50 device = torch.device('cuda:0')
 51 net = MLP().to(device)
 52 optimizer = optim.SGD(net.parameters(), lr=learning_rate)
 53 criteon = nn.CrossEntropyLoss().to(device)
 54 
 55 for epoch in range(epochs):
 56 
 57     for batch_idx, (data, target) in enumerate(train_loader):
 58         data = data.view(-1, 28 * 28)
 59         data, target = data.to(device), target.cuda()
 60 
 61         logits = net(data)
 62         loss = criteon(logits, target)
 63 
 64         optimizer.zero_grad()
 65         loss.backward()
 66         # print(w1.grad.norm(), w2.grad.norm())
 67         optimizer.step()
 68 
 69         if batch_idx % 100 == 0:
 70             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
 71                 epoch, batch_idx * len(data), len(train_loader.dataset),
 72                        100. * batch_idx / len(train_loader), loss.item()))
 73 
 74     test_loss = 0
 75     correct = 0
 76     for data, target in val_loader:
 77         data = data.view(-1, 28 * 28)
 78         data, target = data.to(device), target.cuda()
 79         logits = net(data)
 80         test_loss += criteon(logits, target).item()
 81 
 82         pred = logits.data.max(1)[1]
 83         correct += pred.eq(target.data).sum()
 84 
 85     test_loss /= len(val_loader.dataset)
 86     print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
 87         test_loss, correct, len(val_loader.dataset),
 88         100. * correct / len(val_loader.dataset)))
 89 
 90 test_loss = 0
 91 correct = 0
 92 for data, target in test_loader:
 93     data = data.view(-1, 28 * 28)
 94     data, target = data.to(device), target.cuda()
 95     logits = net(data)
 96     test_loss += criteon(logits, target).item()
 97 
 98     pred = logits.data.max(1)[1]
 99     correct += pred.eq(target.data).sum()
100 
101 test_loss /= len(test_loader.dataset)
102 print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
103     test_loss, correct, len(test_loader.dataset),
104     100. * correct / len(test_loader.dataset)))

 k-fold cross validation:test set不能动,每次使用k-1份作为train set,1份作为val set。

3、如何防止过拟合(正则化)

 例如:

  二分类的交叉熵损失函数:

        

  添加一项:

      

  即,当模型训练(拟合)得比较好时,添加项(参数的范数)逐渐趋近于0(Enforce Weights close to 0)。参数的范数趋近于0意味着参数特征的个数少(参数维数低)。从而得到一个低复杂的网络,从而达到放在过拟合。有的地方也称为weight decay

    


 

    

 

 


 

4、动量(momentum:惯性)与学习率衰减( learning rate decay)

 4.1、momentum:

    

    


 

 4.2、learning rate decay:

  学习率逐渐减小:

  方法1:

      

  方法2(比较暴力):

      


5、Early stopping & Dropout

 5.1、Early stopping

    

 5.2、Dropout

    

 

    

 

    

    

    

 5.3、Stochastic Gradient Descent

    

 

    

 

    


 

posted @ 2020-07-28 17:42  小吴的日常  阅读(423)  评论(0)    收藏  举报