机器学习-鸢尾花分类

记得先通过命令安装 scikit-learn 环境!

pip install scikit-learn -i https://pypi.tuna.tsinghua.edu.cn/simple/

先上代码!

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# 加载鸢尾花数据集
iris = load_iris()

# 将数据转化为 pandas DataFrame
X = pd.DataFrame(iris.data, columns=iris.feature_names)  # 特征数据
y = pd.Series(iris.target)  # 标签数据

# 显示前五行数据
print(X.head())

# 划分训练集和测试集(80% 训练集,20% 测试集)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 标准化特征
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 创建 KNN 分类器
knn = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn.fit(X_train, y_train)

# 预测测试集
y_pred = knn.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f'模型准确率: {accuracy:.2f}')

# 可视化 - 这里只是一个简单示例,具体可根据实际情况选择绘图方式
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred, cmap='viridis', marker='o')
plt.title("KNN Classification Results")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()

一、 代码结构分析 :
- 这是一个完整的机器学习工作流,从数据加载到模型评估
- 使用了经典的鸢尾花数据集(Iris dataset)
- 实现了K最近邻(KNN)分类算法

二、主要功能模块 :

# 数据准备部分
iris = load_iris()  # 加载数据集
X = pd.DataFrame(iris.data, columns=iris.feature_names)  # 特征
y = pd.Series(iris.target)  # 标签

# 数据预处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)  # 拆分
scaler = StandardScaler()  # 标准化
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 模型训练与评估
knn = KNeighborsClassifier(n_neighbors=3)  # KNN模型
knn.fit(X_train, y_train)  # 训练
y_pred = knn.predict(X_test)  # 预测
accuracy = accuracy_score(y_test, y_pred)  # 评估

三、可视化部分 :

当前代码只绘制了前两个特征的散点图,可以进一步优化:

# 已经存在的代码
# 可视化 - 改进建议
plt.figure(figsize=(10, 6))
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred, cmap='viridis', alpha=0.7)
plt.colorbar(label='Class')
plt.title("KNN Classification Results (Sepal Length vs Sepal Width)")
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1])
plt.show()

改进建议如下 :
- 可以添加交叉验证来更可靠地评估模型性能
- 可以尝试不同的K值并通过网格搜索找到最优参数
- 可以添加混淆矩阵来查看分类细节
- 可以扩展可视化部分,展示更多特征组合

posted @ 2025-04-22 17:17  写个博客玩玩  阅读(61)  评论(0)    收藏  举报