PyTorch 深度学习实践 第5讲:用PyTorch实现线性回归
1. prepare dataset
import torch
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])
#注意:x,y是3*1的张量(矩阵)
2. Design model using class (inherit from nn.Module)
class LinearModel(torch.nn.Module):#继承自nn.Moudle
def __init__(self):#构造函数:初始化对象
super (LinearModel,self).__init__()#just do it
self.linear = torch.nn.Linear(1,1)#构造对象,Class nn.Linear 包含两个成员Tensor:w,b
#并说明输入输出的维数(特征数量)
def forward(self,x):
y_pred = self.linear(x)#调用对象,计算y=x*w+b
return y_pred
model = LinearModel()#实例化模型
3.construct loss and optimizer(using pytorch API)
criterion = torch.nn.MSELoss(size_average = False)
#计算mse
optimizer = torch.optim.SGD(model.parameters(),lr = 0.01)
#lr为学习率
#model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
4.Training cycle(forward ,backward, updata)
for epoch in range(1000):
y_pred = model(x_data)#1.forwar:predict
loss = criterion(y_pred,y_data)#2.forward:Loss
print(epoch,loss.item())
optimizer.zero_grad()#注意;backward的梯度也会被计算出,因此,在backward 的之前,记住让grad设置为0!!!
loss.backward()#3.自动计算backward
optimizer.step()#4.更新w,类似于04节手动更新w
print ('w=', model.linear.weight.item())
print ('b=',model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=',y_test.data)
0 55.598304748535156
1 24.751205444335938
2 11.018953323364258
3 4.905735969543457
4 2.184299945831299
5 0.9727876782417297
6 0.4334515333175659
7 0.1933487355709076
8 0.08645600825548172
9 0.038864899426698685
10 0.01767338253557682
......
983 5.315996531862766e-10
984 5.209130904404446e-10
985 5.116476131661329e-10
986 5.014157977711875e-10
987 5.024958227295429e-10
988 4.832259037357289e-10
989 4.832259037357289e-10
990 4.743014869745821e-10
991 4.743014869745821e-10
992 4.5571368900709786e-10
993 4.5571368900709786e-10
994 4.376943252282217e-10
995 4.3456793719087727e-10
996 4.3456793719087727e-10
997 4.260982677806169e-10
998 4.1677594708744437e-10
999 4.177422852080781e-10
w= 1.9999865293502808
b= 3.079011366935447e-05
y_pred= tensor([[8.0000]])
补充:
在类的声明的时候定义一个call()函数:要使用一个可调用对象
class Foobar:
def __init__(self):
pass
def __call__(self,*args,**kwargs):
pass
参数args代表前面n个参数变成n元组,*kwargs会把参数变成一个字典,ex:
def func(*args,**kwargs):
print(args)
print(kwargs)
func(1,2,3,4,x=3,y=4)
结果:
(1, 2, 3, 4)
{'x': 3, 'y': 4}
参考