23-vgg

import torch
import torch.nn as nn
from d2l import torch as d2l

# 定义vgg块
def vgg_block(num, in_channels, out_channels):
    layers = []
    for i in range(num):
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_channels = out_channels
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)  # 用*拆成一个个元素

conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

# VGG-11
def vgg_11(conv_arch):
    conv_blocks = []
    in_channels = 1

    for (num, out_channels) in conv_arch:
        conv_blocks.append(vgg_block(num, in_channels, out_channels))
        in_channels = out_channels

    return nn.Sequential(*conv_blocks,
                         nn.Flatten(),
                         nn.Linear(in_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),
                         nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
                         nn.Linear(4096, 10)
                         )

net = vgg_11(conv_arch)


posted @ 2024-08-26 16:10  不是孩子了  阅读(15)  评论(0)    收藏  举报