pytorch-day05(过拟合)
1、过拟合与欠拟合
2、如何检测过拟合(Train-Val-Test-交叉验证)

这里的test(validation)用于挑选模型的参数,提前终止train,防止过拟合。
validation set是用来挑选模型参数的。test set是客户在验收的时候,使用的数据集。
把数据集分为train set、val set、test set。
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 import torch.optim as optim 5 from torchvision import datasets, transforms 6 7 batch_size = 200 8 learning_rate = 0.01 9 epochs = 10 10 11 train_db = datasets.MNIST('../data', train=True, download=True, 12 transform=transforms.Compose([ 13 transforms.ToTensor(), 14 transforms.Normalize((0.1307,), (0.3081,)) 15 ])) 16 train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True) 17 18 test_db = datasets.MNIST('../data', train=False, transform=transforms.Compose([ 19 transforms.ToTensor(), 20 transforms.Normalize((0.1307,), (0.3081,)) 21 ])) 22 test_loader = torch.utils.data.DataLoader(test_db, batch_size=batch_size, shuffle=True) 23 24 print('train:', len(train_db), 'test:', len(test_db)) 25 train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000]) 26 print('db1:', len(train_db), 'db2:', len(val_db)) 27 28 train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True) 29 30 val_loader = torch.utils.data.DataLoader(val_db, batch_size=batch_size, shuffle=True) 31 32 class MLP(nn.Module): 33 def __init__(self): 34 super(MLP, self).__init__() 35 36 self.model = nn.Sequential( 37 nn.Linear(784, 200), 38 nn.LeakyReLU(inplace=True), 39 nn.Linear(200, 200), 40 nn.LeakyReLU(inplace=True), 41 nn.Linear(200, 10), 42 nn.LeakyReLU(inplace=True), 43 ) 44 45 def forward(self, x): 46 x = self.model(x) 47 return x 48 49 50 device = torch.device('cuda:0') 51 net = MLP().to(device) 52 optimizer = optim.SGD(net.parameters(), lr=learning_rate) 53 criteon = nn.CrossEntropyLoss().to(device) 54 55 for epoch in range(epochs): 56 57 for batch_idx, (data, target) in enumerate(train_loader): 58 data = data.view(-1, 28 * 28) 59 data, target = data.to(device), target.cuda() 60 61 logits = net(data) 62 loss = criteon(logits, target) 63 64 optimizer.zero_grad() 65 loss.backward() 66 # print(w1.grad.norm(), w2.grad.norm()) 67 optimizer.step() 68 69 if batch_idx % 100 == 0: 70 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 71 epoch, batch_idx * len(data), len(train_loader.dataset), 72 100. * batch_idx / len(train_loader), loss.item())) 73 74 test_loss = 0 75 correct = 0 76 for data, target in val_loader: 77 data = data.view(-1, 28 * 28) 78 data, target = data.to(device), target.cuda() 79 logits = net(data) 80 test_loss += criteon(logits, target).item() 81 82 pred = logits.data.max(1)[1] 83 correct += pred.eq(target.data).sum() 84 85 test_loss /= len(val_loader.dataset) 86 print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 87 test_loss, correct, len(val_loader.dataset), 88 100. * correct / len(val_loader.dataset))) 89 90 test_loss = 0 91 correct = 0 92 for data, target in test_loader: 93 data = data.view(-1, 28 * 28) 94 data, target = data.to(device), target.cuda() 95 logits = net(data) 96 test_loss += criteon(logits, target).item() 97 98 pred = logits.data.max(1)[1] 99 correct += pred.eq(target.data).sum() 100 101 test_loss /= len(test_loader.dataset) 102 print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 103 test_loss, correct, len(test_loader.dataset), 104 100. * correct / len(test_loader.dataset)))
k-fold cross validation:test set不能动,每次使用k-1份作为train set,1份作为val set。
3、如何防止过拟合(正则化)

例如:
二分类的交叉熵损失函数:

添加一项:

即,当模型训练(拟合)得比较好时,添加项(参数的范数)逐渐趋近于0(Enforce Weights close to 0)。参数的范数趋近于0意味着参数特征的个数少(参数维数低)。从而得到一个低复杂的网络,从而达到放在过拟合。有的地方也称为weight decay:




4、动量(momentum:惯性)与学习率衰减( learning rate decay)
4.1、momentum:


4.2、learning rate decay:
学习率逐渐减小:
方法1:

方法2(比较暴力):

5、Early stopping & Dropout
5.1、Early stopping

5.2、Dropout




5.3、Stochastic Gradient Descent




浙公网安备 33010602011771号