pytorch实现VGG

一。网络结构和参数

特点:堆叠多个小尺寸的卷积核来做到和大卷积核一样的感受野。减少网络参数的同时加深了网络深度。

 

 

 

 二。模型定义和训练代码

model.py

 1 import torch.nn as nn
 2 import torch
 3 
 4 
 5 class VGG(nn.Module):
 6     def __init__(self, features, num_classes=1000, init_weights=False):
 7         super(VGG, self).__init__()
 8         self.features = features
 9         self.classifier = nn.Sequential(
10             nn.Dropout(p=0.5),
11             nn.Linear(512*7*7, 2048),
12             nn.ReLU(True),
13             nn.Dropout(p=0.5),
14             nn.Linear(2048, 2048),
15             nn.ReLU(True),
16             nn.Linear(2048, num_classes)
17         )
18         if init_weights:
19             self._initialize_weights()
20 
21     def forward(self, x):
22         # N x 3 x 224 x 224
23         x = self.features(x)
24         # N x 512 x 7 x 7
25         x = torch.flatten(x, start_dim=1)
26         # N x 512*7*7
27         x = self.classifier(x)
28         return x
29 
30     def _initialize_weights(self):
31         for m in self.modules():
32             if isinstance(m, nn.Conv2d):
33                 # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
34                 nn.init.xavier_uniform_(m.weight)
35                 if m.bias is not None:
36                     nn.init.constant_(m.bias, 0)
37             elif isinstance(m, nn.Linear):
38                 nn.init.xavier_uniform_(m.weight)
39                 # nn.init.normal_(m.weight, 0, 0.01)
40                 nn.init.constant_(m.bias, 0)
41 
42 
43 def make_features(cfg: list):
44     layers = []
45     in_channels = 3
46     for v in cfg:
47         if v == "M":
48             layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
49         else:
50             conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
51             layers += [conv2d, nn.ReLU(True)]
52             in_channels = v
53     return nn.Sequential(*layers)
54 
55 
56 cfgs = {
57     'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
58     'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
59     'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
60     'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
61 }
62 
63 
64 def vgg(model_name="vgg16", **kwargs):
65     try:
66         cfg = cfgs[model_name]
67     except:
68         print("Warning: model number {} not in cfgs dict!".format(model_name))
69         exit(-1)
70     model = VGG(make_features(cfg), **kwargs)
71     return model

train.py

  1 import torch.nn as nn
  2 from torchvision import transforms, datasets
  3 import json
  4 import os
  5 import torch.optim as optim
  6 from model import vgg
  7 import torch
  8 
  9 
 10 def main():
 11     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 12     print("using {} device.".format(device))
 13 
 14     data_transform = {
 15         "train": transforms.Compose([transforms.RandomResizedCrop(224),
 16                                      transforms.RandomHorizontalFlip(),
 17                                      transforms.ToTensor(),
 18                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
 19         "val": transforms.Compose([transforms.Resize((224, 224)),
 20                                    transforms.ToTensor(),
 21                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
 22 
 23     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
 24     image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
 25     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
 26     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
 27                                          transform=data_transform["train"])
 28     train_num = len(train_dataset)
 29 
 30     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
 31     flower_list = train_dataset.class_to_idx
 32     cla_dict = dict((val, key) for key, val in flower_list.items())
 33     # write dict into json file
 34     json_str = json.dumps(cla_dict, indent=4)
 35     with open('class_indices.json', 'w') as json_file:
 36         json_file.write(json_str)
 37 
 38     batch_size = 32
 39     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
 40     print('Using {} dataloader workers every process'.format(nw))
 41 
 42     train_loader = torch.utils.data.DataLoader(train_dataset,
 43                                                batch_size=batch_size, shuffle=True,
 44                                                num_workers=0)
 45 
 46     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
 47                                             transform=data_transform["val"])
 48     val_num = len(validate_dataset)
 49     validate_loader = torch.utils.data.DataLoader(validate_dataset,
 50                                                   batch_size=batch_size, shuffle=False,
 51                                                   num_workers=0)
 52     print("using {} images for training, {} images fot validation.".format(train_num,
 53                                                                            val_num))
 54 
 55     # test_data_iter = iter(validate_loader)
 56     # test_image, test_label = test_data_iter.next()
 57 
 58     model_name = "vgg16"
 59     net = vgg(model_name=model_name, num_classes=5, init_weights=True)
 60     net.to(device)
 61     loss_function = nn.CrossEntropyLoss()
 62     optimizer = optim.Adam(net.parameters(), lr=0.0001)
 63 
 64     best_acc = 0.0
 65     save_path = './{}Net.pth'.format(model_name)
 66     for epoch in range(30):
 67         # train
 68         net.train()
 69         running_loss = 0.0
 70         for step, data in enumerate(train_loader, start=0):
 71             images, labels = data
 72             optimizer.zero_grad()
 73             outputs = net(images.to(device))
 74             loss = loss_function(outputs, labels.to(device))
 75             loss.backward()
 76             optimizer.step()
 77 
 78             # print statistics
 79             running_loss += loss.item()
 80             # print train process
 81             rate = (step + 1) / len(train_loader)
 82             a = "*" * int(rate * 50)
 83             b = "." * int((1 - rate) * 50)
 84             print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
 85         print()
 86 
 87         # validate
 88         net.eval()
 89         acc = 0.0  # accumulate accurate number / epoch
 90         with torch.no_grad():
 91             for val_data in validate_loader:
 92                 val_images, val_labels = val_data
 93                 optimizer.zero_grad()
 94                 outputs = net(val_images.to(device))
 95                 predict_y = torch.max(outputs, dim=1)[1]
 96                 acc += (predict_y == val_labels.to(device)).sum().item()
 97             val_accurate = acc / val_num
 98             if val_accurate > best_acc:
 99                 best_acc = val_accurate
100                 torch.save(net.state_dict(), save_path)
101             print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
102                   (epoch + 1, running_loss / step, val_accurate))
103 
104     print('Finished Training')
105 
106 
107 if __name__ == '__main__':
108     main()

 

posted @ 2020-12-20 16:37  荼离伤花  阅读(581)  评论(0编辑  收藏  举报