Loading

Pytorch学习笔记 3 实现简单的神经网络

无API实现神经网络

假装一个数据集

a = torch.Tensor([[1,2,3],[4,5,6]])
b = torch.Tensor([[1],[2],[3]])

设置参数

w1 = nn.Parameter(torch.randn(2, 4)*0.01)
b1 = nn.Parameter(torch.zeros(4))
w2 = nn.Parameter(torch.randn(4,1)*0.01)
b2 = nn.Parameter(torch.zeros(1))

设置模型

def two_network(x):
    x1 = torch.mm(x, w1) + b1
    x1 = F.tanh(x1)
    x2 = torch.mm(x1, w2) + b2
    return x2

设置优化器和参数

optimizer = torch.optim.SGD([w1, w2, b1, b2], 0.01)    #定义优化器
criterion = nn.MSELoss()     #定义损失函数
def two_network(x):
    x1 = torch.mm(x, w1) + b1
    x1 = nn.functional.tanh(x1)
    x2 = torch.mm(x1, w2) + b2
    return x2
for i in range(10000):
    out = two_network(Variable(a))
    loss = criterion(out, Variable(b))    #注意这里:这里
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (i+1)%1000 == 0:
        print('epoch:{}, loss:{}]'.format(i+1,loss.item()))

网络的API实现

nn.Sequential可以定义序列化的模型,上面的可以由这种进行定义

seq_net = nn.Sequential(
    nn.Linear(3, 4),  #就是wx+b的那部分
    nn.Tanh(),
    nn.Linear(4, 1)
)

此时训练代码变成这样

optimizer = torch.optim.SGD(seq_net.parameters(), 0.01)    #定义优化器
criterion = nn.MSELoss()      #定义损失函数
for i in range(10000):
    out = seq_net(Variable(a))
    loss = criterion(out, Variable(b))    #注意这里:这里
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (i+1)%1000 == 0:
        print('epoch:{}, loss:{}]'.format(i+1,loss.item()))

或者Module的使用

class module_net(nn.Module):
    def __init__(self):
        super(module_net, self).__init__()
        self.layer1 = nn.Linear(3,4)
        self.layer2 = nn.Tanh()
        self.layer3 = nn.Linear(4,1)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x
model = module_net()    #实例化
optimizer = torch.optim.SGD(model.parameters(), 0.01)    #定义优化器
criterion = nn.MSELoss()      #定义损失函数
for i in range(10000):
    out = model(Variable(a))
    loss = criterion(out, Variable(b))    #注意这里:这里
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (i+1)%1000 == 0:
        print('epoch:{}, loss:{}]'.format(i+1,loss.item()))
posted @ 2021-09-20 00:46  笑云博文  阅读(65)  评论(0)    收藏  举报