JoyBeanRobber

导航

李沐动手学深度学习9——多层感知机

见代码:

import torch
from torch import nn
import p4_practice


def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)


if __name__ == "__main__":
    net = nn.Sequential(nn.Flatten(),
                        nn.Linear(784, 256),
                        nn.ReLU(),
                        nn.Linear(256, 10))
    net.apply(init_weights)

    _batch_size, lr, num_epochs = 256, 0.1, 10
    # reduction='none'表示对每个样本的损失值保留原始计算结果,不进行任何聚合(求和或平均)
    loss = nn.CrossEntropyLoss(reduction='none')
    trainer = torch.optim.SGD(net.parameters(), lr=lr)

    train_iter, test_iter = p4_practice.load_data_fashion_mnist(_batch_size)
    p4_practice.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

 

posted on 2025-04-29 16:53  欢乐豆掠夺者  阅读(12)  评论(0)    收藏  举报