毒蘑菇分类预测

使用3种模型实现毒蘑菇分类:

"""
毒蘑菇分类预测项目
整合expanded.txt和agaricus-lepiota.data.txt数据文件
实现多个机器学习算法进行分类预测
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

class MushroomClassifier:
    def __init__(self):
        self.data = None
        self.X = None
        self.y = None
        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None
        self.label_encoders = {}
        self.scaler = StandardScaler()
        
    def load_and_integrate_data(self):
        """加载并整合两个数据文件"""
        print("正在加载和整合数据文件...")
        
        # 定义属性名称(根据agaricus-lepiota.names.txt)
        column_names = [
            'class', 'cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor',
            'gill-attachment', 'gill-spacing', 'gill-size', 'gill-color',
            'stalk-shape', 'stalk-root', 'stalk-surface-above-ring', 
            'stalk-surface-below-ring', 'stalk-color-above-ring',
            'stalk-color-below-ring', 'veil-type', 'veil-color', 'ring-number',
            'ring-type', 'spore-print-color', 'population', 'habitat'
        ]
        
        # 加载原始数据文件
        try:
            # 加载agaricus-lepiota.data.txt
            original_data = pd.read_csv('data/agaricus-lepiota.data.txt', header=None, names=column_names)
            print(f"原始数据文件加载成功,共{len(original_data)}条记录")
            
            # 加载expanded.txt
            expanded_data = pd.read_csv('data/expanded.txt', header=None, names=column_names)
            print(f"扩展数据文件加载成功,共{len(expanded_data)}条记录")
            
            # 统一类别标签(将EDIBLE/POISONOUS转换为e/p)
            expanded_data['class'] = expanded_data['class'].map({'EDIBLE': 'e', 'POISONOUS': 'p'})
            
            # 整合两个数据集
            self.data = pd.concat([original_data, expanded_data], ignore_index=True)
            print(f"数据整合完成,总记录数: {len(self.data)}")
            
            return True
        except Exception as e:
            print(f"数据加载失败: {e}")
            return False
    
    def explore_data(self):
        """探索数据集并展示统计信息"""
        print("\n=== 数据集探索 ===")
        print(f"数据集形状: {self.data.shape}")
        print(f"特征数量: {len(self.data.columns) - 1}")
        print(f"样本数量: {len(self.data)}")
        
        print("\n前6行数据:")
        print(self.data.head(6))
        
        print("\n数据集信息:")
        print(self.data.info())
        
        print("\n类别分布:")
        class_dist = self.data['class'].value_counts()
        print(class_dist)
        
        print("\n缺失值统计:")
        missing_values = self.data.isnull().sum()
        print(missing_values[missing_values > 0])
        
        return self.data
    
    def handle_missing_values(self):
        """处理缺失值"""
        print("\n=== 缺失值处理 ===")
        
        # 保存处理前的数据用于展示
        data_before = self.data.copy()
        
        # 检查并处理"?"标记的缺失值
        question_mark_count = 0
        question_mark_cols = []
        
        for col in self.data.columns:
            if self.data[col].dtype == 'object':
                qm_count = (self.data[col] == '?').sum()
                if qm_count > 0:
                    question_mark_count += qm_count
                    question_mark_cols.append(col)
                    print(f"  列 {col}: {qm_count} 个'?'标记的缺失值 ({qm_count/len(self.data)*100:.2f}%)")
        
        print(f"处理前'?'标记缺失值总数: {question_mark_count}")
        
        # 处理"?"标记的缺失值 - 使用众数填充
        for col in question_mark_cols:
            # 获取非"?"的值来计算众数
            non_qm_data = self.data[self.data[col] != '?'][col]
            if len(non_qm_data) > 0:
                mode_value = non_qm_data.mode()[0] if not non_qm_data.mode().empty else 'unknown'
                self.data.loc[self.data[col] == '?', col] = mode_value
                print(f"  列 {col} 使用众数 '{mode_value}' 填充'?'标记缺失值")
        
        # 检查常规缺失值
        missing_before = self.data.isnull().sum().sum()
        print(f"处理前常规缺失值总数: {missing_before}")
        
        # 处理常规缺失值
        missing_cols = self.data.columns[self.data.isnull().any()].tolist()
        if missing_cols:
            for col in missing_cols:
                mode_value = self.data[col].mode()[0] if not self.data[col].mode().empty else 'unknown'
                self.data[col].fillna(mode_value, inplace=True)
                print(f"  列 {col} 使用众数 '{mode_value}' 填充常规缺失值")
        
        # 检查处理后的缺失值
        question_mark_after = 0
        for col in self.data.columns:
            if self.data[col].dtype == 'object':
                question_mark_after += (self.data[col] == '?').sum()
        
        missing_after = self.data.isnull().sum().sum()
        print(f"处理后'?'标记缺失值总数: {question_mark_after}")
        print(f"处理后常规缺失值总数: {missing_after}")
        
        # 展示处理前后对比
        print("\n缺失值处理前后对比:")
        if question_mark_cols:
            for col in question_mark_cols[:3]:  # 只展示前3列
                before_count = (data_before[col] == '?').sum()
                after_count = (self.data[col] == '?').sum()
                print(f"  列 {col}: 处理前'?'缺失值 {before_count} -> 处理后'?'缺失值 {after_count}")
        
        # 输出三条缺失值处理后的数据结果
        print("\n三条缺失值处理后的数据结果:")
        for i in range(3):
            print(f"\n样本 {i+1}:")
            sample = self.data.iloc[i]
            for col in self.data.columns[:6]:  # 只显示前6个特征
                original_value = data_before.iloc[i][col]
                processed_value = sample[col]
                print(f"  {col}: {original_value} -> {processed_value}")
        
        return self.data
    
    def encode_categorical_features(self):
        """对分类特征进行数值化编码"""
        print("\n=== 分类特征数值化编码 ===")
        
        # 创建编码前后的数据副本用于展示对比
        data_before = self.data.copy()
        
        # 对每个分类特征进行标签编码
        categorical_columns = self.data.columns.drop('class')
        
        for col in categorical_columns:
            le = LabelEncoder()
            self.data[col] = le.fit_transform(self.data[col].astype(str))
            self.label_encoders[col] = le
            
            # 显示编码信息
            unique_values = le.classes_
            print(f"  列 {col}: {len(unique_values)} 个类别编码完成")
        
        # 对目标变量进行编码
        le_class = LabelEncoder()
        self.data['class'] = le_class.fit_transform(self.data['class'])
        self.label_encoders['class'] = le_class
        
        print("\n编码前后对比示例:")
        for col in categorical_columns[:3]:  # 只展示前3列
            print(f"\n列 {col}:")
            original_values = data_before[col].head(3).values
            encoded_values = self.data[col].head(3).values
            for i, (orig, enc) in enumerate(zip(original_values, encoded_values)):
                print(f"  样本{i+1}: {orig} -> {enc}")
        
        return self.data
    
    def visualize_data(self):
        """数据可视化分析 - 只保留特征关系矩阵图"""
        print("\n=== 数据可视化分析 ===")
        
        # 特征与目标变量的相关性
        correlation_matrix = self.data.corr()
        target_correlation = correlation_matrix['class'].sort_values(ascending=False)
        print("\n特征与目标变量的相关性:")
        print(target_correlation)
        
        # 图4. 数据可视化 - 各特征之间关系的矩阵图
        plt.figure(figsize=(15, 12))
        sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', center=0,
                   square=True, cbar_kws={"shrink": .8})
        plt.title('图4. 数据可视化 - 特征相关性矩阵图')
        plt.tight_layout()
        plt.savefig('feature_correlation_matrix.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return True
    
    def prepare_data(self):
        """准备训练和测试数据"""
        print("\n=== 数据准备 ===")
        
        # 分离特征和目标变量
        self.X = self.data.drop('class', axis=1)
        self.y = self.data['class']
        
        print(f"特征矩阵形状: {self.X.shape}")
        print(f"目标变量形状: {self.y.shape}")
        
        # 划分训练集和测试集
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            self.X, self.y, test_size=0.3, random_state=42, stratify=self.y
        )
        
        print(f"训练集大小: {self.X_train.shape}")
        print(f"测试集大小: {self.X_test.shape}")
        
        # 数据标准化
        self.X_train = self.scaler.fit_transform(self.X_train)
        self.X_test = self.scaler.transform(self.X_test)
        
        print("数据标准化完成")
        
        return True
    
    def train_random_forest(self):
        """训练随机森林模型"""
        print("\n=== 随机森林模型训练 ===")
        
        # 创建随机森林分类器
        rf_model = RandomForestClassifier(
            n_estimators=100,
            max_depth=10,
            random_state=42,
            n_jobs=-1
        )
        
        # 十折交叉验证
        cv_scores = cross_val_score(rf_model, self.X_train, self.y_train, cv=10, scoring='accuracy')
        print(f"十折交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
        
        # 训练模型
        rf_model.fit(self.X_train, self.y_train)
        
        # 预测
        y_pred_rf = rf_model.predict(self.X_test)
        
        # 评估模型
        accuracy_rf = accuracy_score(self.y_test, y_pred_rf)
        print(f"测试集准确率: {accuracy_rf:.4f}")
        
        # 特征重要性
        feature_importance = pd.DataFrame({
            'feature': self.data.columns.drop('class'),
            'importance': rf_model.feature_importances_
        }).sort_values('importance', ascending=False)
        
        print("\n特征重要性排名:")
        print(feature_importance.head(10))
        
        return rf_model, y_pred_rf, accuracy_rf
    
    def train_gradient_boosting(self):
        """训练梯度提升树模型"""
        print("\n=== 梯度提升树模型训练 ===")
        
        # 创建梯度提升树分类器
        gb_model = GradientBoostingClassifier(
            n_estimators=100,
            learning_rate=0.1,
            max_depth=3,
            random_state=42
        )
        
        # 十折交叉验证
        cv_scores = cross_val_score(gb_model, self.X_train, self.y_train, cv=10, scoring='accuracy')
        print(f"十折交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
        
        # 训练模型
        gb_model.fit(self.X_train, self.y_train)
        
        # 预测
        y_pred_gb = gb_model.predict(self.X_test)
        
        # 评估模型
        accuracy_gb = accuracy_score(self.y_test, y_pred_gb)
        print(f"测试集准确率: {accuracy_gb:.4f}")
        
        return gb_model, y_pred_gb, accuracy_gb
    
    def train_svm(self):
        """训练支持向量机模型"""
        print("\n=== 支持向量机模型训练 ===")
        
        # 创建支持向量机分类器
        svm_model = SVC(
            kernel='rbf',
            C=1.0,
            gamma='scale',
            random_state=42
        )
        
        # 十折交叉验证
        cv_scores = cross_val_score(svm_model, self.X_train, self.y_train, cv=10, scoring='accuracy')
        print(f"十折交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
        
        # 训练模型
        svm_model.fit(self.X_train, self.y_train)
        
        # 预测
        y_pred_svm = svm_model.predict(self.X_test)
        
        # 评估模型
        accuracy_svm = accuracy_score(self.y_test, y_pred_svm)
        print(f"测试集准确率: {accuracy_svm:.4f}")
        
        return svm_model, y_pred_svm, accuracy_svm
    
    def evaluate_models(self, y_pred_rf, y_pred_gb, y_pred_svm):
        """评估模型性能 - 只保留混淆矩阵"""
        print("\n=== 模型性能评估 ===")
        
        models = {
            '随机森林': y_pred_rf,
            '梯度提升树': y_pred_gb,
            '支持向量机': y_pred_svm
        }
        
        # 图5. 分类混淆矩阵
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        for i, (model_name, y_pred) in enumerate(models.items()):
            # 混淆矩阵
            cm = confusion_matrix(self.y_test, y_pred)
            
            # 计算评估指标
            report = classification_report(self.y_test, y_pred, output_dict=True)
            accuracy = accuracy_score(self.y_test, y_pred)
            precision = report['weighted avg']['precision']
            recall = report['weighted avg']['recall']
            f1 = report['weighted avg']['f1-score']
            
            print(f"\n{model_name} 模型评估:")
            print(f"  准确率: {accuracy:.4f}")
            print(f"  查准率: {precision:.4f}")
            print(f"  查全率: {recall:.4f}")
            print(f"  F1值: {f1:.4f}")
            
            # 绘制混淆矩阵
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[i])
            axes[i].set_title(f'{model_name} - 混淆矩阵')
            axes[i].set_xlabel('预测标签')
            axes[i].set_ylabel('真实标签')
        
        plt.suptitle('图5. 分类混淆矩阵', fontsize=16)
        plt.tight_layout()
        plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return True
    
    def run_complete_analysis(self):
        """运行完整的分析流程"""
        print("开始毒蘑菇分类预测分析...")
        
        # 1. 加载和整合数据
        if not self.load_and_integrate_data():
            return
        
        # 2. 数据探索
        self.explore_data()
        
        # 3. 处理缺失值
        self.handle_missing_values()
        
        # 4. 分类特征编码
        self.encode_categorical_features()
        
        # 5. 数据可视化
        self.visualize_data()
        
        # 6. 数据准备
        self.prepare_data()
        
        # 7. 训练模型
        rf_model, y_pred_rf, accuracy_rf = self.train_random_forest()
        gb_model, y_pred_gb, accuracy_gb = self.train_gradient_boosting()
        svm_model, y_pred_svm, accuracy_svm = self.train_svm()
        
        # 8. 模型评估
        self.evaluate_models(y_pred_rf, y_pred_gb, y_pred_svm)
        
        print("\n=== 分析完成 ===")
        print("结果已保存到以下图表文件中:")
        print("- feature_correlation_matrix.png (图4. 数据可视化)")
        print("- confusion_matrix.png (图5. 分类混淆矩阵)")

# 主程序
if __name__ == "__main__":
    classifier = MushroomClassifier()
    classifier.run_complete_analysis()
posted @ 2026-01-09 09:44  vivi_vimi  阅读(3)  评论(0)    收藏  举报