机器学习-鸢尾花分类
记得先通过命令安装 scikit-learn 环境!
pip install scikit-learn -i https://pypi.tuna.tsinghua.edu.cn/simple/
先上代码!
import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import accuracy_score # 加载鸢尾花数据集 iris = load_iris() # 将数据转化为 pandas DataFrame X = pd.DataFrame(iris.data, columns=iris.feature_names) # 特征数据 y = pd.Series(iris.target) # 标签数据 # 显示前五行数据 print(X.head()) # 划分训练集和测试集(80% 训练集,20% 测试集) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 标准化特征 scaler = StandardScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) # 创建 KNN 分类器 knn = KNeighborsClassifier(n_neighbors=3) # 训练模型 knn.fit(X_train, y_train) # 预测测试集 y_pred = knn.predict(X_test) # 计算准确率 accuracy = accuracy_score(y_test, y_pred) print(f'模型准确率: {accuracy:.2f}') # 可视化 - 这里只是一个简单示例,具体可根据实际情况选择绘图方式 plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred, cmap='viridis', marker='o') plt.title("KNN Classification Results") plt.xlabel("Feature 1") plt.ylabel("Feature 2") plt.show()
一、 代码结构分析 :
- 这是一个完整的机器学习工作流,从数据加载到模型评估
- 使用了经典的鸢尾花数据集(Iris dataset)
- 实现了K最近邻(KNN)分类算法
二、主要功能模块 :
# 数据准备部分 iris = load_iris() # 加载数据集 X = pd.DataFrame(iris.data, columns=iris.feature_names) # 特征 y = pd.Series(iris.target) # 标签 # 数据预处理 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) # 拆分 scaler = StandardScaler() # 标准化 X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) # 模型训练与评估 knn = KNeighborsClassifier(n_neighbors=3) # KNN模型 knn.fit(X_train, y_train) # 训练 y_pred = knn.predict(X_test) # 预测 accuracy = accuracy_score(y_test, y_pred) # 评估
三、可视化部分 :
当前代码只绘制了前两个特征的散点图,可以进一步优化:
# 已经存在的代码 # 可视化 - 改进建议 plt.figure(figsize=(10, 6)) plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred, cmap='viridis', alpha=0.7) plt.colorbar(label='Class') plt.title("KNN Classification Results (Sepal Length vs Sepal Width)") plt.xlabel(iris.feature_names[0]) plt.ylabel(iris.feature_names[1]) plt.show()
改进建议如下 :
- 可以添加交叉验证来更可靠地评估模型性能
- 可以尝试不同的K值并通过网格搜索找到最优参数
- 可以添加混淆矩阵来查看分类细节
- 可以扩展可视化部分,展示更多特征组合

浙公网安备 33010602011771号