GBDT code

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn import ensemble
from sklearn import datasets
from sklearn.utils import shuffle
from sklearn.metrics import mean_squared_error

###############################################################################
df = pd.read_csv('/home/caijiao/code/kaijuan/data.csv') # 读取文件
offsetX = int(df.shape[0] * 0.5) #格式转换
offsetY = df.shape[1]-1
xTmp=df[df.columns[:offsetY]]
X= np.array(xTmp)
yTmp=df[df.columns[offsetY:]]
y_Tmp = np.array(yTmp)
y1 = [i[0] for i in y_Tmp]
y = np.array(y1)
X_train, y_train = X[:offsetX], y[:offsetX] #获取训练集和测试集
X_test, y_test = X[offsetX:], y[offsetX:]
# print X_test
# print y_test
###############################################################################
# Fit regression model
params = {'n_estimators': 1500, 'max_depth': 4, 'min_samples_split': 2, # 设置参数
'learning_rate': 0.01, 'loss': 'ls'}
clf = ensemble.GradientBoostingRegressor(**params)

clf.fit(X_train, y_train)
mse = mean_squared_error(y_test, clf.predict(X_test)) # 获取结果mse值
length = len(y_test) #获取测试集长度
res = clf.predict(X_test) # 获取测试结果
for i in xrange(length):# 循环输出
print y_test[i],res[i]
print("MSE: %.4f" % mse)#输出测试结果标准
###############################################################################
# Plot training deviance #绘制错误图

# compute test set deviance
test_score = np.zeros((params['n_estimators'],), dtype=np.float64)

for i, y_pred in enumerate(clf.staged_predict(X_test)):
test_score[i] = clf.loss_(y_test, y_pred)

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title('Deviance')
plt.plot(np.arange(params['n_estimators']) + 1, clf.train_score_, 'b-',
label='Training Set Deviance')
plt.plot(np.arange(params['n_estimators']) + 1, test_score, 'r-',
label='Test Set Deviance')
plt.legend(loc='upper right')
plt.xlabel('Boosting Iterations')
plt.ylabel('Deviance')
plt.show()
###############################################################################
i = 0
csvfile = open('/home/caijiao/code/kaijuan/data.csv', 'r')
data = []
for line in csvfile:
data.append(list(line.strip().split(',')))
i = i+1
if i==1:
break

data= data[0]
###############################################################################
plt.show()

posted @ 2019-01-14 16:23  superxiaoying  阅读(268)  评论(0编辑  收藏  举报