1 #!/usr/bin/python
2 # coding=utf-8
3 from sklearn.datasets import load_iris
4 from sklearn.model_selection import train_test_split
5 from sklearn.tree import DecisionTreeClassifier, export_graphviz
6 def dectree_demo():
7 #决策树对鸢尾花数据集进行分类
8
9 #获取数据
10 iris = load_iris()
11
12 #划分数据
13 x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target)
14
15 #决策树算法预估器
16 estimator = DecisionTreeClassifier(criterion="entropy")
17 estimator.fit(x_train, y_train)
18
19 #模型评估
20 y_predict = estimator.predict(x_test)
21 print "y_predict:\n", y_predict
22 print "对比真实值和预测值:\n", y_test == y_predict
23
24 # 方法二:计算正确率
25 score = estimator.score(x_test, y_test)
26 print "准确率:\n", score
27
28 #可视化决策树
29 export_graphviz(estimator, out_file="iris_tree.dot", feature_names=iris.feature_names)
30 return None
31
32 dectree_demo()