毒蘑菇分类预测
使用3种模型实现毒蘑菇分类:
"""
毒蘑菇分类预测项目
整合expanded.txt和agaricus-lepiota.data.txt数据文件
实现多个机器学习算法进行分类预测
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import warnings
warnings.filterwarnings('ignore')
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
class MushroomClassifier:
def __init__(self):
self.data = None
self.X = None
self.y = None
self.X_train = None
self.X_test = None
self.y_train = None
self.y_test = None
self.label_encoders = {}
self.scaler = StandardScaler()
def load_and_integrate_data(self):
"""加载并整合两个数据文件"""
print("正在加载和整合数据文件...")
# 定义属性名称(根据agaricus-lepiota.names.txt)
column_names = [
'class', 'cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor',
'gill-attachment', 'gill-spacing', 'gill-size', 'gill-color',
'stalk-shape', 'stalk-root', 'stalk-surface-above-ring',
'stalk-surface-below-ring', 'stalk-color-above-ring',
'stalk-color-below-ring', 'veil-type', 'veil-color', 'ring-number',
'ring-type', 'spore-print-color', 'population', 'habitat'
]
# 加载原始数据文件
try:
# 加载agaricus-lepiota.data.txt
original_data = pd.read_csv('data/agaricus-lepiota.data.txt', header=None, names=column_names)
print(f"原始数据文件加载成功,共{len(original_data)}条记录")
# 加载expanded.txt
expanded_data = pd.read_csv('data/expanded.txt', header=None, names=column_names)
print(f"扩展数据文件加载成功,共{len(expanded_data)}条记录")
# 统一类别标签(将EDIBLE/POISONOUS转换为e/p)
expanded_data['class'] = expanded_data['class'].map({'EDIBLE': 'e', 'POISONOUS': 'p'})
# 整合两个数据集
self.data = pd.concat([original_data, expanded_data], ignore_index=True)
print(f"数据整合完成,总记录数: {len(self.data)}")
return True
except Exception as e:
print(f"数据加载失败: {e}")
return False
def explore_data(self):
"""探索数据集并展示统计信息"""
print("\n=== 数据集探索 ===")
print(f"数据集形状: {self.data.shape}")
print(f"特征数量: {len(self.data.columns) - 1}")
print(f"样本数量: {len(self.data)}")
print("\n前6行数据:")
print(self.data.head(6))
print("\n数据集信息:")
print(self.data.info())
print("\n类别分布:")
class_dist = self.data['class'].value_counts()
print(class_dist)
print("\n缺失值统计:")
missing_values = self.data.isnull().sum()
print(missing_values[missing_values > 0])
return self.data
def handle_missing_values(self):
"""处理缺失值"""
print("\n=== 缺失值处理 ===")
# 保存处理前的数据用于展示
data_before = self.data.copy()
# 检查并处理"?"标记的缺失值
question_mark_count = 0
question_mark_cols = []
for col in self.data.columns:
if self.data[col].dtype == 'object':
qm_count = (self.data[col] == '?').sum()
if qm_count > 0:
question_mark_count += qm_count
question_mark_cols.append(col)
print(f" 列 {col}: {qm_count} 个'?'标记的缺失值 ({qm_count/len(self.data)*100:.2f}%)")
print(f"处理前'?'标记缺失值总数: {question_mark_count}")
# 处理"?"标记的缺失值 - 使用众数填充
for col in question_mark_cols:
# 获取非"?"的值来计算众数
non_qm_data = self.data[self.data[col] != '?'][col]
if len(non_qm_data) > 0:
mode_value = non_qm_data.mode()[0] if not non_qm_data.mode().empty else 'unknown'
self.data.loc[self.data[col] == '?', col] = mode_value
print(f" 列 {col} 使用众数 '{mode_value}' 填充'?'标记缺失值")
# 检查常规缺失值
missing_before = self.data.isnull().sum().sum()
print(f"处理前常规缺失值总数: {missing_before}")
# 处理常规缺失值
missing_cols = self.data.columns[self.data.isnull().any()].tolist()
if missing_cols:
for col in missing_cols:
mode_value = self.data[col].mode()[0] if not self.data[col].mode().empty else 'unknown'
self.data[col].fillna(mode_value, inplace=True)
print(f" 列 {col} 使用众数 '{mode_value}' 填充常规缺失值")
# 检查处理后的缺失值
question_mark_after = 0
for col in self.data.columns:
if self.data[col].dtype == 'object':
question_mark_after += (self.data[col] == '?').sum()
missing_after = self.data.isnull().sum().sum()
print(f"处理后'?'标记缺失值总数: {question_mark_after}")
print(f"处理后常规缺失值总数: {missing_after}")
# 展示处理前后对比
print("\n缺失值处理前后对比:")
if question_mark_cols:
for col in question_mark_cols[:3]: # 只展示前3列
before_count = (data_before[col] == '?').sum()
after_count = (self.data[col] == '?').sum()
print(f" 列 {col}: 处理前'?'缺失值 {before_count} -> 处理后'?'缺失值 {after_count}")
# 输出三条缺失值处理后的数据结果
print("\n三条缺失值处理后的数据结果:")
for i in range(3):
print(f"\n样本 {i+1}:")
sample = self.data.iloc[i]
for col in self.data.columns[:6]: # 只显示前6个特征
original_value = data_before.iloc[i][col]
processed_value = sample[col]
print(f" {col}: {original_value} -> {processed_value}")
return self.data
def encode_categorical_features(self):
"""对分类特征进行数值化编码"""
print("\n=== 分类特征数值化编码 ===")
# 创建编码前后的数据副本用于展示对比
data_before = self.data.copy()
# 对每个分类特征进行标签编码
categorical_columns = self.data.columns.drop('class')
for col in categorical_columns:
le = LabelEncoder()
self.data[col] = le.fit_transform(self.data[col].astype(str))
self.label_encoders[col] = le
# 显示编码信息
unique_values = le.classes_
print(f" 列 {col}: {len(unique_values)} 个类别编码完成")
# 对目标变量进行编码
le_class = LabelEncoder()
self.data['class'] = le_class.fit_transform(self.data['class'])
self.label_encoders['class'] = le_class
print("\n编码前后对比示例:")
for col in categorical_columns[:3]: # 只展示前3列
print(f"\n列 {col}:")
original_values = data_before[col].head(3).values
encoded_values = self.data[col].head(3).values
for i, (orig, enc) in enumerate(zip(original_values, encoded_values)):
print(f" 样本{i+1}: {orig} -> {enc}")
return self.data
def visualize_data(self):
"""数据可视化分析 - 只保留特征关系矩阵图"""
print("\n=== 数据可视化分析 ===")
# 特征与目标变量的相关性
correlation_matrix = self.data.corr()
target_correlation = correlation_matrix['class'].sort_values(ascending=False)
print("\n特征与目标变量的相关性:")
print(target_correlation)
# 图4. 数据可视化 - 各特征之间关系的矩阵图
plt.figure(figsize=(15, 12))
sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', center=0,
square=True, cbar_kws={"shrink": .8})
plt.title('图4. 数据可视化 - 特征相关性矩阵图')
plt.tight_layout()
plt.savefig('feature_correlation_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
return True
def prepare_data(self):
"""准备训练和测试数据"""
print("\n=== 数据准备 ===")
# 分离特征和目标变量
self.X = self.data.drop('class', axis=1)
self.y = self.data['class']
print(f"特征矩阵形状: {self.X.shape}")
print(f"目标变量形状: {self.y.shape}")
# 划分训练集和测试集
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
self.X, self.y, test_size=0.3, random_state=42, stratify=self.y
)
print(f"训练集大小: {self.X_train.shape}")
print(f"测试集大小: {self.X_test.shape}")
# 数据标准化
self.X_train = self.scaler.fit_transform(self.X_train)
self.X_test = self.scaler.transform(self.X_test)
print("数据标准化完成")
return True
def train_random_forest(self):
"""训练随机森林模型"""
print("\n=== 随机森林模型训练 ===")
# 创建随机森林分类器
rf_model = RandomForestClassifier(
n_estimators=100,
max_depth=10,
random_state=42,
n_jobs=-1
)
# 十折交叉验证
cv_scores = cross_val_score(rf_model, self.X_train, self.y_train, cv=10, scoring='accuracy')
print(f"十折交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
# 训练模型
rf_model.fit(self.X_train, self.y_train)
# 预测
y_pred_rf = rf_model.predict(self.X_test)
# 评估模型
accuracy_rf = accuracy_score(self.y_test, y_pred_rf)
print(f"测试集准确率: {accuracy_rf:.4f}")
# 特征重要性
feature_importance = pd.DataFrame({
'feature': self.data.columns.drop('class'),
'importance': rf_model.feature_importances_
}).sort_values('importance', ascending=False)
print("\n特征重要性排名:")
print(feature_importance.head(10))
return rf_model, y_pred_rf, accuracy_rf
def train_gradient_boosting(self):
"""训练梯度提升树模型"""
print("\n=== 梯度提升树模型训练 ===")
# 创建梯度提升树分类器
gb_model = GradientBoostingClassifier(
n_estimators=100,
learning_rate=0.1,
max_depth=3,
random_state=42
)
# 十折交叉验证
cv_scores = cross_val_score(gb_model, self.X_train, self.y_train, cv=10, scoring='accuracy')
print(f"十折交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
# 训练模型
gb_model.fit(self.X_train, self.y_train)
# 预测
y_pred_gb = gb_model.predict(self.X_test)
# 评估模型
accuracy_gb = accuracy_score(self.y_test, y_pred_gb)
print(f"测试集准确率: {accuracy_gb:.4f}")
return gb_model, y_pred_gb, accuracy_gb
def train_svm(self):
"""训练支持向量机模型"""
print("\n=== 支持向量机模型训练 ===")
# 创建支持向量机分类器
svm_model = SVC(
kernel='rbf',
C=1.0,
gamma='scale',
random_state=42
)
# 十折交叉验证
cv_scores = cross_val_score(svm_model, self.X_train, self.y_train, cv=10, scoring='accuracy')
print(f"十折交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
# 训练模型
svm_model.fit(self.X_train, self.y_train)
# 预测
y_pred_svm = svm_model.predict(self.X_test)
# 评估模型
accuracy_svm = accuracy_score(self.y_test, y_pred_svm)
print(f"测试集准确率: {accuracy_svm:.4f}")
return svm_model, y_pred_svm, accuracy_svm
def evaluate_models(self, y_pred_rf, y_pred_gb, y_pred_svm):
"""评估模型性能 - 只保留混淆矩阵"""
print("\n=== 模型性能评估 ===")
models = {
'随机森林': y_pred_rf,
'梯度提升树': y_pred_gb,
'支持向量机': y_pred_svm
}
# 图5. 分类混淆矩阵
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
for i, (model_name, y_pred) in enumerate(models.items()):
# 混淆矩阵
cm = confusion_matrix(self.y_test, y_pred)
# 计算评估指标
report = classification_report(self.y_test, y_pred, output_dict=True)
accuracy = accuracy_score(self.y_test, y_pred)
precision = report['weighted avg']['precision']
recall = report['weighted avg']['recall']
f1 = report['weighted avg']['f1-score']
print(f"\n{model_name} 模型评估:")
print(f" 准确率: {accuracy:.4f}")
print(f" 查准率: {precision:.4f}")
print(f" 查全率: {recall:.4f}")
print(f" F1值: {f1:.4f}")
# 绘制混淆矩阵
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[i])
axes[i].set_title(f'{model_name} - 混淆矩阵')
axes[i].set_xlabel('预测标签')
axes[i].set_ylabel('真实标签')
plt.suptitle('图5. 分类混淆矩阵', fontsize=16)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
return True
def run_complete_analysis(self):
"""运行完整的分析流程"""
print("开始毒蘑菇分类预测分析...")
# 1. 加载和整合数据
if not self.load_and_integrate_data():
return
# 2. 数据探索
self.explore_data()
# 3. 处理缺失值
self.handle_missing_values()
# 4. 分类特征编码
self.encode_categorical_features()
# 5. 数据可视化
self.visualize_data()
# 6. 数据准备
self.prepare_data()
# 7. 训练模型
rf_model, y_pred_rf, accuracy_rf = self.train_random_forest()
gb_model, y_pred_gb, accuracy_gb = self.train_gradient_boosting()
svm_model, y_pred_svm, accuracy_svm = self.train_svm()
# 8. 模型评估
self.evaluate_models(y_pred_rf, y_pred_gb, y_pred_svm)
print("\n=== 分析完成 ===")
print("结果已保存到以下图表文件中:")
print("- feature_correlation_matrix.png (图4. 数据可视化)")
print("- confusion_matrix.png (图5. 分类混淆矩阵)")
# 主程序
if __name__ == "__main__":
classifier = MushroomClassifier()
classifier.run_complete_analysis()

浙公网安备 33010602011771号