Vision Transformer

ViT 实现了将 Transformer 应用到 CV 领域。

image

模型定义

import torch as th
import torch.nn as nn

class SimpleViT(nn.Module):
    def __init__(
        self,
        image_size=28,
        image_dim=1,
        patch_size=7,
        num_classes=10,
        dim=128,
        depth=4,
        heads=4,
        mlp_dim=256
    ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        self.patch_size = patch_size
        self.dim = dim
        # 切割图片并映射到嵌入维度
        self.to_patch = nn.Conv2d(
            in_channels=image_dim,
            out_channels=dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        self.cls_embedding = nn.Parameter(th.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(th.randn(1, 1+num_patches, dim))
        # Transformer 编码器
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim, batch_first=True),
            num_layers=depth
        )
        # 分类头
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        B = x.shape[0]
        x = self.to_patch(x).reshape(B, -1, self.dim)  # [B,   N, D]
        cls = self.cls_embedding.expand(B, -1, -1)     # [B,   1, D]
        x = th.cat((cls, x), dim=1)                    # [B, 1+N, D]
        x = x + self.pos_embedding
        x = self.transformer(x)
        x = x[:, 0, :]                                 # 取出 cls
        x = self.mlp_head(x)                           # 对 cls 分类
        return x

数据准备

import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

训练与测试

device = th.device('cuda' if th.cuda.is_available() else 'cpu')
model = SimpleViT().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 训练
for epoch in range(5):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{5} - Loss: {loss.item():.4f}')

# 测试
model.eval()
correct = 0
total = 0
with th.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, pred = th.max(outputs, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
print(f'Accuracy: {100 * correct / total:.2f}%')

与 CNN 的比较

特性 CNN ViT
归纳偏置 强(局部性、平移不变性) 弱(主要靠数据学习)
对小数据集表现 很好,易于训练,收敛快 容易过拟合,需要更多数据和正则化
参数量 少,计算高效 多,计算量大
适合场景 小图片、结构简单、样本量少的任务 大图片、结构复杂、样本量大的任务
posted @ 2025-06-17 22:44  Undefined443  阅读(17)  评论(0)    收藏  举报