11-ViT图像分类

 

 参考:https://blog.csdn.net/weixin_42392454/article/details/122667271

 :

ViT 的 pytorch实现代码:

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)
View Code

 

classfy_main.py

  1 import torch
  2 from torch.utils.data import DataLoader
  3 from torch import nn, optim
  4 from torchvision import datasets, transforms
  5 from torchvision.transforms.functional import InterpolationMode
  6 
  7 from matplotlib import pyplot as plt
  8 
  9 import time
 10 
 11 from Lenet5 import Lenet5_new
 12 # from Resnet18 import ResNet18,ResNet18_new
 13 # from AlexNet import AlexNet
 14 # from Vgg16 import VGGNet16
 15 # from Densenet import DenseNet121, DenseNet169, DenseNet201, DenseNet264
 16 
 17 # from NIN import NIN_Net
 18 # from GoogleNet import GoogLeNet
 19 # from MobileNet_v3 import mobilenet_v3
 20 from shuffleNet import shuffleNet_g3_
 21 
 22 from vit import ViT
 23 
 24 def main():
 25 
 26     print("Load datasets...")
 27 
 28     # transforms.RandomHorizontalFlip(p=0.5)---以0.5的概率对图片做水平横向翻转
 29     # transforms.ToTensor()---shape从(H,W,C)->(C,H,W), 每个像素点从(0-255)映射到(0-1):直接除以255
 30     # transforms.Normalize---先将输入归一化到(0,1),像素点通过"(x-mean)/std",将每个元素分布到(-1,1)
 31     transform_train = transforms.Compose([
 32                         transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
 33                         # transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
 34                         transforms.RandomHorizontalFlip(p=0.5),
 35                         transforms.ToTensor(),
 36                         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
 37                     ])
 38 
 39     transform_test = transforms.Compose([
 40                         transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
 41                         # transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
 42                         transforms.ToTensor(),
 43                         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
 44                     ])
 45     
 46     dataset_path = "/big-data/person/zhaopengpeng/deepfake_zpp/code/Transformer_code/Coco_train_code"
 47 
 48     # 内置函数下载数据集
 49     train_dataset = datasets.CIFAR10(root= dataset_path +"/data/Cifar10/", train=True,
 50                                      transform = transform_train,
 51                                      download=True)
 52     test_dataset = datasets.CIFAR10(root = dataset_path +"/data/Cifar10/",
 53                                     train = False,
 54                                     transform = transform_test,
 55                                     download=True)
 56 
 57     print(len(train_dataset), len(test_dataset))
 58 
 59     Batch_size = 64
 60     train_loader = DataLoader(train_dataset, batch_size=Batch_size,  shuffle = True, num_workers=4)
 61     test_loader = DataLoader(test_dataset, batch_size = Batch_size, shuffle = False, num_workers=4)
 62 
 63     # 设置CUDA
 64     device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
 65 
 66     # 初始化模型
 67     # 直接更换模型就行,其他无需操作
 68     # model = Lenet5_new().to(device)
 69     # model = ResNet18().to(device)
 70     # model = ResNet18_new().to(device)
 71     # model = VGGNet16().to(device)
 72     # model = DenseNet121().to(device)
 73     # model  = DenseNet169().to(device)
 74 
 75     # model = NIN_Net().to(device)
 76 
 77     # model = GoogLeNet().to(device)
 78     # model = mobilenet_v3().to(device)
 79     
 80     # model = ViT(image_size=(32, 32), patch_size=(4, 4), num_classes=10, dim=256, depth=6, heads=8, mlp_dim=512, dropout=0.1)
 81     model = ViT(image_size=(224, 224), patch_size=(16, 16), num_classes=10, dim=256, depth=6, heads=8, mlp_dim=512, dropout=0.1).to(device)
 82     
 83     # model = shuffleNet_g3_().to(device)
 84     # model = AlexNet(num_classes=10, init_weights=True).to(device)
 85     print(" ViTViT train...")
 86 
 87     # 构造损失函数和优化器
 88     criterion = nn.CrossEntropyLoss() # 多分类softmax构造损失
 89     # opt = optim.SGD(model.parameters(), lr=0.01, momentum=0.8, weight_decay=0.001)
 90     opt = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
 91 
 92     # 动态更新学习率 ------每隔step_size : lr = lr * gamma
 93     schedule = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.6, last_epoch=-1)
 94 
 95     # 开始训练
 96     print("Start Train...")
 97 
 98     epochs = 100
 99 
100     loss_list = []
101     train_acc_list =[]
102     test_acc_list = []
103     epochs_list = []
104 
105     for epoch in range(0, epochs):
106 
107         start = time.time()
108 
109         model.train()
110 
111         running_loss = 0.0
112         batch_num = 0
113 
114         for i, (inputs, labels) in enumerate(train_loader):
115 
116             inputs, labels = inputs.to(device), labels.to(device)
117 
118             # 将数据送入模型训练
119             outputs = model(inputs)
120             # 计算损失
121             loss = criterion(outputs, labels).to(device)
122 
123             # 重置梯度
124             opt.zero_grad()
125             # 计算梯度,反向传播
126             loss.backward()
127             # 根据反向传播的梯度值优化更新参数
128             opt.step()
129 
130             # 100个batch的 loss 之和
131             running_loss += loss.item()
132             # loss_list.append(loss.item())
133             batch_num+=1
134 
135 
136         epochs_list.append(epoch)
137 
138         # 每一轮结束输出一下当前的学习率 lr
139         lr_1 = opt.param_groups[0]['lr']
140         print("learn_rate:%.15f" % lr_1)
141         schedule.step()
142 
143         end = time.time()
144         print('epoch = %d/100, batch_num = %d, loss = %.6f, time = %.3f' % (epoch+1, batch_num, running_loss/batch_num, end-start))
145         running_loss=0.0
146 
147         # 每个epoch训练结束,都进行一次测试验证
148         model.eval()
149         train_correct = 0.0
150         train_total = 0
151 
152         test_correct = 0.0
153         test_total = 0
154 
155          # 训练模式不需要反向传播更新梯度
156         with torch.no_grad():
157 
158             # print("=======================train=======================")
159             for inputs, labels in train_loader:
160                 inputs, labels = inputs.to(device), labels.to(device)
161                 outputs = model(inputs)
162 
163                 pred = outputs.argmax(dim=1)  # 返回每一行中最大值元素索引
164                 train_total += inputs.size(0)
165                 train_correct += torch.eq(pred, labels).sum().item()
166 
167 
168             # print("=======================test=======================")
169             for inputs, labels in test_loader:
170                 inputs, labels = inputs.to(device), labels.to(device)
171                 outputs = model(inputs)
172 
173                 pred = outputs.argmax(dim=1)  # 返回每一行中最大值元素索引
174                 test_total += inputs.size(0)
175                 test_correct += torch.eq(pred, labels).sum().item()
176 
177             print("train_total = %d, Accuracy = %.5f %%,  test_total= %d, Accuracy = %.5f %%" %(train_total, 100 * train_correct / train_total, test_total, 100 * test_correct / test_total))
178 
179             train_acc_list.append(100 * train_correct / train_total)
180             test_acc_list.append(100 * test_correct / test_total)
181 
182         # print("Accuracy of the network on the 10000 test images:%.5f %%" % (100 * test_correct / test_total))
183         # print("===============================================")
184 
185     fig = plt.figure(figsize=(4, 4))
186 
187     plt.plot(epochs_list, train_acc_list, label='train_acc_list')
188     plt.plot(epochs_list, test_acc_list, label='test_acc_list')
189     plt.legend()
190     plt.title("train_test_acc")
191     # plt.savefig('shuffleNet_g3_acc_epoch_{:04d}.png'.format(epochs))
192     plt.savefig('ViT_acc_epoch_{:04d}.png'.format(epochs))
193     plt.close()
194 
195 if __name__ == "__main__":
196 
197     main()
View Code

 

posted @ 2024-05-10 10:53  赵家小伙儿  阅读(34)  评论(0)    收藏  举报