pytorch深度学习:非线性模型

1. 这篇博客使用深度学习框架搭建了一个预测三次函数的模型

2. 正则化很重要,一定要normalize,否则神经网络就是垃圾

 1 import torch
 2 from torch import nn,optim
 3 import torch.nn.functional as F
 4 from matplotlib import pyplot as plt
 5 
 6 class unLinear(nn.Module):
 7     def __init__(self,input_feature,num_hidden,output_size):
 8         super(unLinear,self).__init__()
 9         self.hidden=nn.Linear(input_feature,num_hidden)#一个层就是一个函数
10         self.out=nn.Linear(num_hidden,output_size)#可以把层理解成函数的右值引用
11 
12     def forward(self,x):
13         # x=F.relu(self.hidden(x))
14         # x = torch.sigmoid(self.hidden(x))
15         x=torch.tanh(self.hidden(x))
16         x=self.out(x)
17         return x
18 
19     def train(self,inputs,target,criterion,optimizer,epoches):
20         print(inputs.size())
21         print(target.size())
22         loss=0
23         for epoch in range(epoches):
24             output = model.forward(inputs)
25             # if epoch%1000==0:
26             #     plt.scatter(inputs.detach().numpy(), output.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
27             #     plt.show()
28             loss = criterion(output, target)
29             optimizer.zero_grad()
30             loss.backward()
31             optimizer.step()
32         return self, loss
33 
34 model = unLinear(input_feature=1,num_hidden=20,output_size=1)
35 x=torch.torch.arange(-2,2,0.1)
36 y=x.pow(3)+0.1*torch.rand(x.size())
37 # print(x)
38 # print(y)
39 plt.scatter(x.detach().numpy(), y.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
40 plt.show()
41 
42 inputs=torch.unsqueeze(x,dim=1)
43 target=torch.unsqueeze(y,dim=1)
44 criterion=nn.MSELoss()
45 optimizer = optim.SGD(model.parameters(), lr=1e-2)
46 
47 new_model=model.train(inputs=inputs,target=target,criterion=criterion,optimizer=optimizer,epoches=10000)
48 
49 # plt.scatter(x.numpy(),y.numpy(),c='#00CED1',s=10,alpha=0.8,label="test")
50 # plt.show()
51 
52 x_predict=torch.unsqueeze(torch.arange(-2,2,0.05),dim=1)
53 y_predict=model.forward(x_predict)
54 # y_predict=model.forward(inputs)
55 # print(inputs.size())
56 # print(x_predict.size())
57 # print(y_predict.detach().numpy())
58 x_predict=torch.squeeze(x_predict)
59 y_predict=torch.squeeze(y_predict)
60 x_predict=x_predict.detach().numpy()
61 y_predict=y_predict.detach().numpy()
62 # print(y_predict)
63 plt.scatter(x_predict,y_predict,s=10,alpha=0.8,label="test")
64 plt.show()

 

posted @ 2020-09-19 15:04  Lovaer  阅读(1092)  评论(0编辑  收藏  举报