Pytorch基础入门——线性回归Linear Regression
Linear Regression线性回归虽然看上去简单,但是其是最重要的数学模型之一,其他很多模型都建立在它的基础之上。
Linear Regression的表达式子如下:
y = Ax + B.
A = slope of curve
B = bias (point that intersect y-axis)
在本次例子中使用一组汽车价格和销量数据来进行模拟研究。
第一步:创建数据,构造成Tensor
import numpy as np # linear algebra import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) # As a car company we collect this data from previous selling # lets define car prices # import variable from pytorch library from torch.autograd import Variable import torch car_prices_array = [3,4,5,6,7,8,9] car_price_np = np.array(car_prices_array,dtype=np.float32) car_price_np = car_price_np.reshape(-1,1) car_price_tensor = Variable(torch.from_numpy(car_price_np)) # lets define number of car sell number_of_car_sell_array = [ 7.5, 7, 6.5, 6.0, 5.5, 5.0, 4.5] number_of_car_sell_np = np.array(number_of_car_sell_array,dtype=np.float32) number_of_car_sell_np = number_of_car_sell_np.reshape(-1,1) number_of_car_sell_tensor = Variable(torch.from_numpy(number_of_car_sell_np))
第二步:数据可视化(使用Numpy)
# lets visualize our data
import matplotlib.pyplot as plt
plt.scatter(car_prices_array,number_of_car_sell_array)
plt.xlabel("Car Price $")
plt.ylabel("Number of Car Sell")
plt.title("Car Price$ VS Number of Car Sell")
plt.show()
结果:
第三步:构建LinearRegression类
# Linear Regression with Pytorch
# libraries
import torch
from torch.autograd import Variable
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")
# create class
class LinearRegression(nn.Module):
def __init__(self,input_size,output_size):
# super function. It inherits from nn.Module and we can access everythink in nn.Module
super(LinearRegression,self).__init__()
# Linear function.
self.linear = nn.Linear(input_dim,output_dim)
def forward(self,x):
return self.linear(x)
第四步:定义LinearRegression模型,误差函数使用MSE,使用SGD
# define model
input_dim = 1
output_dim = 1
model = LinearRegression(input_dim,output_dim) # input and output size are 1
# MSE
mse = nn.MSELoss()
# Optimization (find parameters that minimize error)
learning_rate = 0.02 # how fast we reach best parameters
optimizer = torch.optim.SGD(model.parameters(),lr = learning_rate)
# train model
loss_list = []
iteration_number = 1001
for iteration in range(iteration_number):
# optimization
optimizer.zero_grad()
# Forward to get output
results = model(car_price_tensor)
# Calculate Loss
loss = mse(results, number_of_car_sell_tensor)
# backward propagation
loss.backward()
# Updating parameters
optimizer.step()
# store loss
loss_list.append(loss.data)
# print loss
if(iteration % 50 == 0):
print('epoch {}, loss {}'.format(iteration, loss.data))
plt.plot(range(iteration_number),loss_list)
plt.xlabel("Number of Iterations")
plt.ylabel("Loss")
plt.show()
结果:
epoch 0, loss 47.31208801269531 epoch 50, loss 4.446558952331543 epoch 100, loss 3.004725456237793 epoch 150, loss 2.0304176807403564 epoch 200, loss 1.372040033340454 epoch 250, loss 0.9271451830863953 epoch 300, loss 0.6265109777450562 epoch 350, loss 0.42335933446884155 epoch 400, loss 0.28608179092407227 epoch 450, loss 0.19331695139408112 epoch 500, loss 0.13063208758831024 epoch 550, loss 0.0882737785577774 epoch 600, loss 0.05965011566877365 epoch 650, loss 0.040308013558387756 epoch 700, loss 0.02723775990307331 epoch 750, loss 0.018405580893158913 epoch 800, loss 0.01243758574128151 epoch 850, loss 0.008404688909649849 epoch 900, loss 0.005679319612681866 epoch 950, loss 0.0038377337623387575 epoch 1000, loss 0.0025931724812835455
第五步:模型预测结果
# predict our car price
predicted = model(car_price_tensor).data.numpy()
plt.scatter(car_prices_array,number_of_car_sell_array,label = "original data",color ="red")
plt.scatter(car_prices_array,predicted,label = "predicted data",color ="blue")
# predict if car price is 10$, what will be the number of car sell
#predicted_10 = model(torch.from_numpy(np.array([10]))).data.numpy()
#plt.scatter(10,predicted_10.data,label = "car price 10$",color ="green")
plt.legend()
plt.xlabel("Car Price $")
plt.ylabel("Number of Car Sell")
plt.title("Original vs Predicted values")
plt.show()
结果为:
本次内容来源于https://www.kaggle.com/kanncaa1/pytorch-tutorial-for-deep-learning-lovers,使用的环境来自矩池云,感兴趣的小伙伴可以去看作者的原文。
浙公网安备 33010602011771号