LeNet(PyTorch)实现

LeNet

LeNet,它是最早发布的卷积神经网络之一,因其在计算机视觉任务中的高效性能而受到广泛关注。 这个模型是由AT&T贝尔实验室的研究员Yann LeCun在1989年提出的(并以其命名),目的是识别图像 (LeCun et al., 1998)中的手写数字。 当时,Yann LeCun发表了第一篇通过反向传播成功训练卷积神经网络的研究,这项工作代表了十多年来神经网络研究开发的成果。
当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。 时至今日,一些自动取款机仍在运行Yann LeCun和他的同事Leon Bottou在上世纪90年代写的代码。

模型结构

image
总体来看,LeNet(LeNet-5)由两个部分组成:

  • 卷积编码器:由两个卷积层组成;
  • 全连接层密集块:由三个全连接层组成。

PyTorch实现

import torch
import torch.nn as nn


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=16 * 5 * 5, out_features=120),
            nn.Sigmoid(),
            nn.Linear(in_features=120, out_features=84),
            nn.Sigmoid(),
            nn.Linear(in_features=84, out_features=10),
        )

    def forward(self, x):
        x = self.features(x)
        return self.fc(x)

模型测试

net = LeNet()
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
print(f"Initial input shape: {X.shape}\n")
X_features = X
print("--- Features Block ---")
for i, layer in enumerate(net.features):
    X_features = layer(X_features)
    print(
        f"Features[{i}] ({layer.__class__.__name__}): output shape {X_features.shape}"
    )

print("\n--- Fully Connected Block ---")
X_fc = X_features
for i, layer in enumerate(net.fc):
    X_fc = layer(X_fc)
    print(f"FC[{i}] ({layer.__class__.__name__}): output shape {X_fc.shape}")

输出

Initial input shape: torch.Size([1, 1, 28, 28])

--- Features Block ---
Features[0] (Conv2d): output shape torch.Size([1, 6, 28, 28])
Features[1] (Sigmoid): output shape torch.Size([1, 6, 28, 28])
Features[2] (AvgPool2d): output shape torch.Size([1, 6, 14, 14])
Features[3] (Conv2d): output shape torch.Size([1, 16, 10, 10])
Features[4] (Sigmoid): output shape torch.Size([1, 16, 10, 10])
Features[5] (AvgPool2d): output shape torch.Size([1, 16, 5, 5])

--- Fully Connected Block ---
FC[0] (Flatten): output shape torch.Size([1, 400])
FC[1] (Linear): output shape torch.Size([1, 120])
FC[2] (Sigmoid): output shape torch.Size([1, 120])
FC[3] (Linear): output shape torch.Size([1, 84])
FC[4] (Sigmoid): output shape torch.Size([1, 84])
FC[5] (Linear): output shape torch.Size([1, 10])
posted @ 2025-08-15 15:59  里列昂遗失的记事本  阅读(18)  评论(0)    收藏  举报