pytorch 学习笔记——创建一个简单的神经网络(以2分网络为例)

Part1:导入必要的库

    import torch
    import torch.nn as nn
    import matplotlib.pyplot as plt

Part2:定义输入层大小、隐藏层大小、输出层大小和批量大小

    n_in, n_h, n_out, batch_size = 10, 5, 1, 10

Part3:创建虚拟输入数据和目标数据

    x = torch.randn(batch_size, n_in)  # 随机生成输入数据
    y = torch.tensor([[1.0], [0.0], [0.0],
                      [1.0], [1.0], [1.0], [0.0], [0.0], [1.0], [1.0]])  # 目标输出数据

Part4:定义模型

    class Net(nn.Module):
        def __init__(self, ins, h, out):
            super(Net, self).__init__()
            self.layer=nn.Sequential(
                nn.Linear(ins, h),  # 输入层到隐藏层的线性变换
                nn.ReLU(),  # 隐藏层的ReLU激活函数
                nn.Linear(h, out),  # 隐藏层到输出层的线性变换
                nn.Sigmoid()
            )

        def forward(self, x):
            x=self.layer(x)
            return x


Part5:实例化模型,定义损失函数和优化器

    model = Net(n_in, n_h, n_out)

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 学习率为0.01

    # 用于存储每轮的损失值
    losses = []

Part6:执行梯度下降算法进行模型训练

    for epoch in range(5000):  # 迭代50次
        y_pred = model(x)  # 前向传播,计算预测值
        loss = criterion(y_pred, y)  # 计算损失
        losses.append(loss.item())  # 记录损失值
        print(f'Epoch [{epoch+1}/50], Loss: {loss.item():.4f}')  # 打印损失值

        optimizer.zero_grad()  # 清零梯度
        loss.backward()  # 反向传播,计算梯度
        optimizer.step()  # 更新模型参数


Part7:模型训练结束,可视化结果

    # 可视化损失变化曲线
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, 5001), losses, label='Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Epochs')
    plt.legend()
    plt.grid()
    plt.show()

    # 可视化预测结果与实际目标值对比
    y_pred_final = model(x).detach().numpy()  # 最终预测值
    y_actual = y.numpy()  # 实际值

    plt.figure(figsize=(8, 5))
    plt.plot(range(1, batch_size + 1), y_actual, 'o-', label='Actual', color='blue')
    plt.plot(range(1, batch_size + 1), y_pred_final, 'x--', label='Predicted', color='red')
    plt.xlabel('Sample Index')
    plt.ylabel('Value')
    plt.title('Actual vs Predicted Values')
    plt.legend()
    plt.grid()
    plt.show()
posted @ 2025-09-23 17:04  Oaths  阅读(21)  评论(0)    收藏  举报