机器学习树模型第一节

决策树

import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import joblib                         # 用于保存和加载模型
import os                             # 仅用于检查文件是否存在(可选)

# ---------- 1. 加载数据并添加中文标签 ----------
iris = load_iris()
feature_names_cn = ['花萼长度(cm)', '花萼宽度(cm)', '花瓣长度(cm)', '花瓣宽度(cm)']
target_names_cn = ['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾']

df = pd.DataFrame(iris.data, columns=feature_names_cn)
df['最终分类结果'] = [target_names_cn[i] for i in iris.target]

print("========= 1. 数据集前5行预览 ==========")
print(df.head())

# ---------- 2. 划分训练集(70%)和测试集(30%) ----------
X_train, X_test, y_train, y_test = train_test_split(
    iris.data, iris.target, test_size=0.3, random_state=42
)

# ---------- 3. 构建并训练决策树(基尼系数,最大深度3) ----------
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
clf.fit(X_train, y_train)

# ---------- 4. 在测试集上预测并输出准确率 ----------
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred) * 100
print(f"\n决策树模型的预测准确率为: {acc:.2f}%")

# ---------- 5. 绘制决策树(使用中文字段名和类别名) ----------
plt.figure(figsize=(16, 12))
plot_tree(clf,
          feature_names=feature_names_cn,
          class_names=target_names_cn,
          filled=True,
          rounded=True,
          impurity=False)
plt.title("鸢尾花数据集决策树(最大深度=3,基尼系数)")
plt.show()   # 展示图形(若在无图形界面的环境,可注释此行)

# ========== 新增:模型导出(保存) ==========
model_filename = 'iris_tree_model.pkl'
joblib.dump(clf, model_filename)
print(f"\n模型已保存为:{model_filename}")

# ========== 新增:加载模型并进行预测 ==========
# 演示从文件加载模型(可另起一个脚本,或在这里直接演示)
loaded_model = joblib.load(model_filename)
print(f"模型加载成功,类型:{type(loaded_model)}")

# 使用加载的模型对测试集第一个样本进行预测
sample = X_test[0].reshape(1, -1)   # 取出第一个测试样本,并转为二维
true_label = y_test[0]
pred_label = loaded_model.predict(sample)[0]

print("\n===== 使用加载的模型进行单样本预测 =====")
print(f"测试样本特征:{X_test[0]}")
print(f"真实类别:{target_names_cn[true_label]} (编码{true_label})")
print(f"预测类别:{target_names_cn[pred_label]} (编码{pred_label})")

# 可选:对整个测试集再用加载的模型评估一遍(验证一致性)
y_pred_loaded = loaded_model.predict(X_test)
acc_loaded = accuracy_score(y_test, y_pred_loaded) * 100
print(f"\n加载模型在测试集上的准确率:{acc_loaded:.2f}% (与之前一致)")
posted @ 2026-06-24 22:20  cerofang  阅读(8)  评论(0)    收藏  举报