1
2 # -*- coding: UTF-8 -*-
3 """
4 此脚本用于展示使用sklearn搭建线性回归模型
5 """
6
7
8 import os
9 import sys
10
11 import numpy as np
12 import matplotlib.pyplot as plt
13 import pandas as pd
14 from sklearn import linear_model
15
16
17 def evaluateModel(model, testData, features, labels):
18 """
19 计算线性模型的均方差和决定系数
20 参数
21 ----
22 model : LinearRegression, 训练完成的线性模型
23 testData : DataFrame,测试数据
24 features : list[str],特征名列表
25 labels : list[str],标签名列表
26 返回
27 ----
28 error : np.float64,均方差
29 score : np.float64,决定系数
30 """
31 # 均方差(The mean squared error),均方差越小越好
32 error = np.mean(
33 (model.predict(testData[features]) - testData[labels]) ** 2)
34 # 决定系数(Coefficient of determination),决定系数越接近1越好
35 score = model.score(testData[features], testData[labels])
36 return error, score
37
38
39 def visualizeModel(model, data, features, labels, error, score):
40 """
41 模型可视化
42 """
43 # 为在Matplotlib中显示中文,设置特殊字体
44 plt.rcParams['font.sans-serif']=['SimHei']
45 # 创建一个图形框
46 fig = plt.figure(figsize=(6, 6), dpi=80)
47 # 在图形框里只画一幅图
48 ax = fig.add_subplot(111)
49 # 在Matplotlib中显示中文,需要使用unicode
50 # 在Python3中,str不需要decode
51 if sys.version_info[0] == 3:
52 ax.set_title(u'%s' % "线性回归示例")
53 else:
54 ax.set_title(u'%s' % "线性回归示例".decode("utf-8"))
55 ax.set_xlabel('$x$')
56 ax.set_ylabel('$y$')
57 # 画点图,用蓝色圆点表示原始数据
58 # 在Python3中,str不需要decode
59 if sys.version_info[0] == 3:
60 ax.scatter(data[features], data[labels], color='b',
61 label=u'%s: $y = x + \epsilon$' % "真实值")
62 else:
63 ax.scatter(data[features], data[labels], color='b',
64 label=u'%s: $y = x + \epsilon$' % "真实值".decode("utf-8"))
65 # 根据截距的正负,打印不同的标签
66 if model.intercept_ > 0:
67 # 画线图,用红色线条表示模型结果
68 # 在Python3中,str不需要decode
69 if sys.version_info[0] == 3:
70 ax.plot(data[features], model.predict(data[features]), color='r',
71 label=u'%s: $y = %.3fx$ + %.3f'\
72 % ("预测值", model.coef_, model.intercept_))
73 else:
74 ax.plot(data[features], model.predict(data[features]), color='r',
75 label=u'%s: $y = %.3fx$ + %.3f'\
76 % ("预测值".decode("utf-8"), model.coef_, model.intercept_))
77 else:
78 # 在Python3中,str不需要decode
79 if sys.version_info[0] == 3:
80 ax.plot(data[features], model.predict(data[features]), color='r',
81 label=u'%s: $y = %.3fx$ - %.3f'\
82 % ("预测值", model.coef_, abs(model.intercept_)))
83 else:
84 ax.plot(data[features], model.predict(data[features]), color='r',
85 label=u'%s: $y = %.3fx$ - %.3f'\
86 % ("预测值".decode("utf-8"), model.coef_, abs(model.intercept_)))
87 legend = plt.legend(shadow=True)
88 legend.get_frame().set_facecolor('#6F93AE')
89 # 显示均方差和决定系数
90 # 在Python3中,str不需要decode
91 if sys.version_info[0] == 3:
92 ax.text(0.99, 0.01,
93 u'%s%.3f\n%s%.3f'\
94 % ("均方差:", error, "决定系数:", score),
95 style='italic', verticalalignment='bottom', horizontalalignment='right',
96 transform=ax.transAxes, color='m', fontsize=13)
97 else:
98 ax.text(0.99, 0.01,
99 u'%s%.3f\n%s%.3f'\
100 % ("均方差:".decode("utf-8"), error, "决定系数:".decode("utf-8"), score),
101 style='italic', verticalalignment='bottom', horizontalalignment='right',
102 transform=ax.transAxes, color='m', fontsize=13)
103 # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭
104 # 在Python shell里面,可以设置参数"block=False",使阻断失效。
105 plt.show()
106
107
108 def trainModel(trainData, features, labels):
109 """
110 利用训练数据,估计模型参数
111 参数
112 ----
113 trainData : DataFrame,训练数据集,包含特征和标签
114 features : 特征名列表
115 labels : 标签名列表
116 返回
117 ----
118 model : LinearRegression, 训练好的线性模型
119 """
120 # 创建一个线性回归模型
121 model = linear_model.LinearRegression()
122 # 训练模型,估计模型参数
123 model.fit(trainData[features], trainData[labels])
124 return model
125
126
127 def linearModel(data):
128 """
129 线性回归模型建模步骤展示
130 参数
131 ----
132 data : DataFrame,建模数据
133 """
134 features = ["x"]
135 labels = ["y"]
136 # 划分训练集和测试集
137 trainData = data[:15]
138 testData = data[15:]
139 # 产生并训练模型
140 model = trainModel(trainData, features, labels)
141 # 评价模型效果
142 error, score = evaluateModel(model, testData, features, labels)
143 # 图形化模型结果
144 visualizeModel(model, data, features, labels, error, score)
145
146
147 def readData(path):
148 """
149 使用pandas读取数据
150 """
151 data = pd.read_csv(path)
152 return data
153
154
155 if __name__ == "__main__": #主模块的名字是__main__,import的模块名字是自己
156 homePath = os.path.dirname(os.path.abspath(__file__)) #os.path.dirname 是去掉文件名的路径 ,abspath获取当前文件路径
157 # Windows下的存储路径与Linux并不相同
158 if os.name == "nt": #判断当前使用的平台,nt为windows
159 dataPath = "%s\\data\\simple_example.csv" % homePath
160 else:
161 dataPath = "%s/data/simple_example.csv" % homePath
162 data = readData(dataPath)
163 linearModel(data)
164 © 2019 GitHub, Inc.
165 Terms
166 Privacy
167 Security
168 Status
169 Help
170 Contact GitHub
171 Pricing
172 API
173 Training
174 Blog
175 About