2024.11.28(周四)

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from sklearn.model_selection import train_test_split

# (1)利用 pandas 库从本地读取 iris 数据集
# 假设数据集已经保存为 CSV 格式,或者你可以直接使用 sklearn 加载数据
# iris_data = pd.read_csv('iris.csv')  # 示例代码,从本地读取数据

# (2)从 scikit-learn 库中直接加载 iris 数据集
iris = load_iris()
X = iris.data  # 特征
y = iris.target  # 标签


# 手动实现五折交叉验证
def manual_k_fold_cross_validation(X, y, model, k=5):
    # 划分数据集
    n_samples = X.shape[0]
    indices = np.arange(n_samples)
    np.random.shuffle(indices)

    fold_size = n_samples // k
    accuracy_list = []
    precision_list = []
    recall_list = []
    f1_list = []

    for fold in range(k):
        # 确定本次交叉验证的训练集和测试集
        test_indices = indices[fold * fold_size: (fold + 1) * fold_size]
        train_indices = np.concatenate([indices[:fold * fold_size], indices[(fold + 1) * fold_size:]])

        X_train, X_test = X[train_indices], X[test_indices]
        y_train, y_test = y[train_indices], y[test_indices]

        # 训练模型
        model.fit(X_train, y_train)

        # 预测测试集
        y_pred = model.predict(X_test)

        # 计算评估指标
        accuracy = accuracy_score(y_test, y_pred)
        precision = precision_score(y_test, y_pred, average='macro')
        recall = recall_score(y_test, y_pred, average='macro')
        f1 = f1_score(y_test, y_pred, average='macro')

        # 将每一折的结果保存到列表中
        accuracy_list.append(accuracy)
        precision_list.append(precision)
        recall_list.append(recall)
        f1_list.append(f1)

    # 返回平均值
    return np.mean(accuracy_list), np.mean(precision_list), np.mean(recall_list), np.mean(f1_list)


# 使用随机森林分类器
model = RandomForestClassifier(random_state=42)

# 进行五折交叉验证
accuracy, precision, recall, f1 = manual_k_fold_cross_validation(X, y, model, k=5)

# 输出评估结果
print(f"五折交叉验证结果:")
print(f"准确度 (Accuracy): {accuracy:.4f}")
print(f"精度 (Precision): {precision:.4f}")
print(f"召回率 (Recall): {recall:.4f}")
print(f"F1 值 (F1 Score): {f1:.4f}")

 

posted @ 2024-12-02 16:36  记得关月亮  阅读(2)  评论(0编辑  收藏  举报