线性回归

  1   
  2 # -*- coding: UTF-8 -*-
  3 """
  4 此脚本用于展示使用sklearn搭建线性回归模型
  5 """
  6 
  7 
  8 import os
  9 import sys
 10 
 11 import numpy as np
 12 import matplotlib.pyplot as plt
 13 import pandas as pd
 14 from sklearn import linear_model
 15 
 16 
 17 def evaluateModel(model, testData, features, labels):
 18     """
 19     计算线性模型的均方差和决定系数
 20     参数
 21     ----
 22     model : LinearRegression, 训练完成的线性模型
 23     testData : DataFrame,测试数据
 24     features : list[str],特征名列表
 25     labels : list[str],标签名列表
 26     返回
 27     ----
 28     error : np.float64,均方差
 29     score : np.float64,决定系数
 30     """
 31     # 均方差(The mean squared error),均方差越小越好
 32     error = np.mean(
 33         (model.predict(testData[features]) - testData[labels]) ** 2)
 34     # 决定系数(Coefficient of determination),决定系数越接近1越好
 35     score = model.score(testData[features], testData[labels])
 36     return error, score
 37 
 38 
 39 def visualizeModel(model, data, features, labels, error, score):
 40     """
 41     模型可视化
 42     """
 43     # 为在Matplotlib中显示中文,设置特殊字体
 44     plt.rcParams['font.sans-serif']=['SimHei']
 45     # 创建一个图形框
 46     fig = plt.figure(figsize=(6, 6), dpi=80)
 47     # 在图形框里只画一幅图
 48     ax = fig.add_subplot(111)
 49     # 在Matplotlib中显示中文,需要使用unicode
 50     # 在Python3中,str不需要decode
 51     if sys.version_info[0] == 3:
 52         ax.set_title(u'%s' % "线性回归示例")
 53     else:
 54         ax.set_title(u'%s' % "线性回归示例".decode("utf-8"))
 55     ax.set_xlabel('$x$')
 56     ax.set_ylabel('$y$')
 57     # 画点图,用蓝色圆点表示原始数据
 58     # 在Python3中,str不需要decode
 59     if sys.version_info[0] == 3:
 60         ax.scatter(data[features], data[labels], color='b',
 61             label=u'%s: $y = x + \epsilon$' % "真实值")
 62     else:
 63         ax.scatter(data[features], data[labels], color='b',
 64             label=u'%s: $y = x + \epsilon$' % "真实值".decode("utf-8"))
 65     # 根据截距的正负,打印不同的标签
 66     if model.intercept_ > 0:
 67         # 画线图,用红色线条表示模型结果
 68         # 在Python3中,str不需要decode
 69         if sys.version_info[0] == 3:
 70             ax.plot(data[features], model.predict(data[features]), color='r',
 71                 label=u'%s: $y = %.3fx$ + %.3f'\
 72                 % ("预测值", model.coef_, model.intercept_))
 73         else:
 74             ax.plot(data[features], model.predict(data[features]), color='r',
 75                 label=u'%s: $y = %.3fx$ + %.3f'\
 76                 % ("预测值".decode("utf-8"), model.coef_, model.intercept_))
 77     else:
 78         # 在Python3中,str不需要decode
 79         if sys.version_info[0] == 3:
 80             ax.plot(data[features], model.predict(data[features]), color='r',
 81                 label=u'%s: $y = %.3fx$ - %.3f'\
 82                 % ("预测值", model.coef_, abs(model.intercept_)))
 83         else:
 84             ax.plot(data[features], model.predict(data[features]), color='r',
 85                 label=u'%s: $y = %.3fx$ - %.3f'\
 86                 % ("预测值".decode("utf-8"), model.coef_, abs(model.intercept_)))
 87     legend = plt.legend(shadow=True)
 88     legend.get_frame().set_facecolor('#6F93AE')
 89     # 显示均方差和决定系数
 90     # 在Python3中,str不需要decode
 91     if sys.version_info[0] == 3:
 92         ax.text(0.99, 0.01, 
 93             u'%s%.3f\n%s%.3f'\
 94             % ("均方差:", error, "决定系数:", score),
 95             style='italic', verticalalignment='bottom', horizontalalignment='right',
 96             transform=ax.transAxes, color='m', fontsize=13)
 97     else:
 98          ax.text(0.99, 0.01, 
 99             u'%s%.3f\n%s%.3f'\
100             % ("均方差:".decode("utf-8"), error, "决定系数:".decode("utf-8"), score),
101             style='italic', verticalalignment='bottom', horizontalalignment='right',
102             transform=ax.transAxes, color='m', fontsize=13)
103     # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭
104     # 在Python shell里面,可以设置参数"block=False",使阻断失效。
105     plt.show()
106 
107 
108 def trainModel(trainData, features, labels):
109     """
110     利用训练数据,估计模型参数
111     参数
112     ----
113     trainData : DataFrame,训练数据集,包含特征和标签
114     features : 特征名列表
115     labels : 标签名列表
116     返回
117     ----
118     model : LinearRegression, 训练好的线性模型
119     """
120     # 创建一个线性回归模型
121     model = linear_model.LinearRegression()
122     # 训练模型,估计模型参数
123     model.fit(trainData[features], trainData[labels])
124     return model
125 
126 
127 def linearModel(data):
128     """
129     线性回归模型建模步骤展示
130     参数
131     ----
132     data : DataFrame,建模数据
133     """
134     features = ["x"]
135     labels = ["y"]
136     # 划分训练集和测试集
137     trainData = data[:15]
138     testData = data[15:]
139     # 产生并训练模型
140     model = trainModel(trainData, features, labels)
141     # 评价模型效果
142     error, score = evaluateModel(model, testData, features, labels)
143     # 图形化模型结果
144     visualizeModel(model, data, features, labels, error, score)
145 
146 
147 def readData(path):
148     """
149     使用pandas读取数据
150     """
151     data = pd.read_csv(path)
152     return data
153 
154 
155 if __name__ == "__main__":    #主模块的名字是__main__,import的模块名字是自己
156     homePath = os.path.dirname(os.path.abspath(__file__))  #os.path.dirname 是去掉文件名的路径 ,abspath获取当前文件路径
157     # Windows下的存储路径与Linux并不相同
158     if os.name == "nt":   #判断当前使用的平台,nt为windows
159         dataPath = "%s\\data\\simple_example.csv" % homePath
160     else:
161         dataPath = "%s/data/simple_example.csv" % homePath
162     data = readData(dataPath)
163     linearModel(data)
164 © 2019 GitHub, Inc.
165 Terms
166 Privacy
167 Security
168 Status
169 Help
170 Contact GitHub
171 Pricing
172 API
173 Training
174 Blog
175 About

 

posted @ 2019-04-29 23:28  bbgoal  阅读(197)  评论(0编辑  收藏  举报