Eric的新家

导航

scikit-learn实现简单的决策树

#encoding=utf-8
import numpy as np
import pandas as pd

def main():
#Pre-processing
from sklearn.datasets import load_iris
iris = load_iris()
print(iris)
print(len(iris["data"]))
# from sklearn.cross_validation import train_test_split
from sklearn.model_selection import train_test_split
train_data,test_data,train_traget,test_target=train_test_split(iris.data,iris.target,test_size=0.2,random_state=1)

#Model
from sklearn import tree
clf = tree.DecisionTreeClassifier(criterion="entropy")
clf.fit(train_data,train_traget)
y_pred = clf.predict(test_data)

#Verify
from sklearn import metrics
print(metrics.accuracy_score(y_true=test_target,y_pred=y_pred))#分类准确率分数是指所有分类正确的百分比
print(metrics.confusion_matrix(y_true=test_target,y_pred=y_pred))#混淆矩阵

#文件目录写自己的
with open("./python_source/tree.doc","w") as fw:
tree.export_graphviz(clf,out_file=fw)

if __name__ == '__main__':
main()

posted on 2017-12-04 15:46  Eric的新家  阅读(367)  评论(0编辑  收藏  举报