小样本实验-基于数据增强的CNN声音分类模型简单模型搭建
样本太少了 但是又头铁想做声音分类模型 脑子一灵光 不如自己写一个数据增强(实际项目可以把GAN加进来 有空我再美美把玩把玩GAN)
仅供娱乐
这还是我第一次弄cnn 很兴奋 虽然已经是大模型时代了 哎
import os import numpy as np import librosa import matplotlib.pyplot as plt import seaborn as sns from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve from sklearn.preprocessing import label_binarize import tensorflow as tf from tensorflow.keras import layers, models import pandas as pd # ==================== 全局配置 ==================== TARGET_SR = 22050 # 采样率 DURATION = 3 # 音频时长(秒),修改为3秒 AUGMENT_TIMES = 500 # 每个原始音频生成增强样本数 EXP_DIR = "实验记录" # 实验结果保存目录 SPEC_VIS_DIR = os.path.join(EXP_DIR, "频谱可视化") # 频谱可视化保存路径 TRAIN_VIS_DIR = os.path.join(EXP_DIR, "训练过程可视化") # 训练过程可视化路径 COMPARE_VIS_DIR = os.path.join(EXP_DIR, "特征对比可视化") # 新增:特征对比图保存路径 AUGMENT_VIS_DIR = os.path.join(EXP_DIR, "数据增强可视化") # 新增:数据增强可视化保存路径 METRICS_VIS_DIR = os.path.join(EXP_DIR, "分类指标可视化") # 新增:分类指标可视化保存路径 # ==================== 必须存在的音频文件(同目录) ==================== REQUIRED_AUDIOS = [ "normal.wav", # 正常声音(标签0) "sharp.wav", # 尖音(标签1) "wave.wav", # 波浪音(标签2) "clunk.wav", # 撞击音1(标签3) "clunk2.wav" # 撞击音2(标签3) ] # ==================== 测试集样本量配置 ==================== TESTSET_SAMPLES = { 0: 300, # 正常样本300个 1: 50, # 尖音样本50个 2: 50, # 波浪音样本50个 3: 50 # 撞击音样本50个(clunk+clunk2合并) } # ==================== 初始化 ==================== os.makedirs(EXP_DIR, exist_ok=True) os.makedirs(SPEC_VIS_DIR, exist_ok=True) os.makedirs(TRAIN_VIS_DIR, exist_ok=True) os.makedirs(COMPARE_VIS_DIR, exist_ok=True) # 新增:创建对比图目录 os.makedirs(AUGMENT_VIS_DIR, exist_ok=True) # 新增:创建数据增强可视化目录 os.makedirs(METRICS_VIS_DIR, exist_ok=True) # 新增:创建分类指标可视化目录 plt.rcParams['font.sans-serif'] = ['SimHei'] # 解决中文乱码 plt.rcParams['figure.dpi'] = 300 # 高分辨率 plt.rcParams['savefig.bbox'] = 'tight' # 自动调整边界 # ==================== 核心函数 ==================== def check_audio_files(): current_dir = os.path.dirname(os.path.abspath(__file__)) missing = [] for audio in REQUIRED_AUDIOS: if not os.path.exists(os.path.join(current_dir, audio)): missing.append(audio) if missing: raise SystemExit(f"错误:以下文件未在代码目录找到(路径:{current_dir}):\n{missing}") return current_dir def load_raw_audio(filename, base_dir): audio_path = os.path.join(base_dir, filename) audio, _ = librosa.load(audio_path, sr=TARGET_SR) # 特殊处理clunk.wav和clunk2.wav,从第1秒开始截取3秒 if filename in ["clunk.wav", "clunk2.wav"]: start_sample = TARGET_SR # 第1秒开始的样本索引 end_sample = start_sample + TARGET_SR * DURATION # 3秒后的样本索引 # 确保音频长度足够 if len(audio) < end_sample: # 如果音频不足3秒,从第1秒开始截取到结尾,并填充到3秒 audio = audio[start_sample:] audio = librosa.util.fix_length(audio, size=TARGET_SR * DURATION) else: # 正常截取3秒 audio = audio[start_sample:end_sample] else: # 其他音频文件截取前3秒 audio = librosa.util.fix_length(audio, size=TARGET_SR * DURATION) return audio def augment_audio(audio): if np.random.random() > 0.5: audio += np.random.randn(len(audio)) * 0.005 if np.random.random() > 0.5: shift = np.random.randint(-len(audio)//10, len(audio)//10) audio = np.roll(audio, shift) if np.random.random() > 0.5: audio = librosa.effects.pitch_shift(audio, sr=TARGET_SR, n_steps=np.random.randint(-2, 3)) if np.random.random() > 0.5: rate = np.random.uniform(0.9, 1.1) audio = librosa.effects.time_stretch(audio, rate=rate) audio = librosa.util.fix_length(audio, size=TARGET_SR * DURATION) return audio # ==================== 数据增强可视化函数 ==================== def visualize_augmentations(base_dir): """为每种增强方法生成可视化频谱图""" label_names = { "normal.wav": "正常", "sharp.wav": "尖音", "wave.wav": "波浪音", "clunk.wav": "撞击音", "clunk2.wav": "撞击音" } # 为每种增强方法创建子目录 aug_types = ["噪声添加", "时间平移", "音高变换", "时间拉伸"] for aug_type in aug_types: os.makedirs(os.path.join(AUGMENT_VIS_DIR, aug_type), exist_ok=True) # 为每种音频文件可视化增强效果 for audio_file in REQUIRED_AUDIOS: audio = load_raw_audio(audio_file, base_dir) label_name = label_names[audio_file] sample_name = os.path.splitext(audio_file)[0] # 1. 原始音频频谱图 mel_spec = librosa.feature.melspectrogram(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128) mel_db = librosa.power_to_db(mel_spec, ref=np.max) plt.figure(figsize=(10, 4)) librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma') plt.colorbar(format='%+2.0f dB') plt.title(f"{label_name}原始音频梅尔频谱图", fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("梅尔频率", fontsize=10) plt.savefig(os.path.join(AUGMENT_VIS_DIR, "原始音频", f"{sample_name}_原始频谱.png")) plt.close() # 2. 噪声添加增强 np.random.seed(42) # 固定随机种子,确保结果可重现 aug_audio = audio + np.random.randn(len(audio)) * 0.005 mel_spec = librosa.feature.melspectrogram(y=aug_audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128) mel_db = librosa.power_to_db(mel_spec, ref=np.max) plt.figure(figsize=(10, 4)) librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma') plt.colorbar(format='%+2.0f dB') plt.title(f"{label_name}噪声添加增强梅尔频谱图", fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("梅尔频率", fontsize=10) plt.savefig(os.path.join(AUGMENT_VIS_DIR, "噪声添加", f"{sample_name}_噪声频谱.png")) plt.close() # 3. 时间平移增强 shift = np.random.randint(-len(audio)//10, len(audio)//10) aug_audio = np.roll(audio, shift) mel_spec = librosa.feature.melspectrogram(y=aug_audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128) mel_db = librosa.power_to_db(mel_spec, ref=np.max) plt.figure(figsize=(10, 4)) librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma') plt.colorbar(format='%+2.0f dB') plt.title(f"{label_name}时间平移增强梅尔频谱图", fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("梅尔频率", fontsize=10) plt.savefig(os.path.join(AUGMENT_VIS_DIR, "时间平移", f"{sample_name}_平移频谱.png")) plt.close() # 4. 音高变换增强 n_steps = np.random.randint(-2, 3) aug_audio = librosa.effects.pitch_shift(audio, sr=TARGET_SR, n_steps=n_steps) mel_spec = librosa.feature.melspectrogram(y=aug_audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128) mel_db = librosa.power_to_db(mel_spec, ref=np.max) plt.figure(figsize=(10, 4)) librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma') plt.colorbar(format='%+2.0f dB') plt.title(f"{label_name}音高变换增强梅尔频谱图 (n_steps={n_steps})", fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("梅尔频率", fontsize=10) plt.savefig(os.path.join(AUGMENT_VIS_DIR, "音高变换", f"{sample_name}_音高频谱.png")) plt.close() # 5. 时间拉伸增强 rate = np.random.uniform(0.9, 1.1) aug_audio = librosa.effects.time_stretch(audio, rate=rate) aug_audio = librosa.util.fix_length(aug_audio, size=TARGET_SR * DURATION) mel_spec = librosa.feature.melspectrogram(y=aug_audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128) mel_db = librosa.power_to_db(mel_spec, ref=np.max) plt.figure(figsize=(10, 4)) librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma') plt.colorbar(format='%+2.0f dB') plt.title(f"{label_name}时间拉伸增强梅尔频谱图 (rate={rate:.2f})", fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("梅尔频率", fontsize=10) plt.savefig(os.path.join(AUGMENT_VIS_DIR, "时间拉伸", f"{sample_name}_拉伸频谱.png")) plt.close() def generate_dataset(base_dir): label_map = { "normal.wav": 0, "sharp.wav": 1, "wave.wav": 2, "clunk.wav": 3, "clunk2.wav": 3 } data_dict = {0: [], 1: [], 2: [], 3: []} metadata_dict = {0: [], 1: [], 2: [], 3: []} for audio_file in REQUIRED_AUDIOS: raw_audio = load_raw_audio(audio_file, base_dir) label = label_map[audio_file] for i in range(AUGMENT_TIMES): augmented = augment_audio(raw_audio) mel = librosa.feature.melspectrogram(y=augmented, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128) mel_db = librosa.power_to_db(mel, ref=np.max) data_dict[label].append(mel_db) metadata_dict[label].append({ "原始文件": audio_file, "增强编号": i, "标签": label }) return {k: np.array(v) for k, v in data_dict.items()}, metadata_dict # ==================== 多维度音频特征可视化 ==================== def plot_waveform(audio, save_path, title): plt.figure(figsize=(10, 4)) time = np.linspace(0, DURATION, len(audio)) plt.plot(time, audio, color='steelblue', alpha=0.8) plt.title(title, fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("振幅", fontsize=10) plt.grid(alpha=0.2) plt.savefig(save_path) plt.close() def plot_spectral_centroid(audio, save_path, title): spectral_centroids = librosa.feature.spectral_centroid(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0] time = librosa.times_like(spectral_centroids, sr=TARGET_SR, hop_length=512) plt.figure(figsize=(10, 4)) plt.plot(time, spectral_centroids, color='firebrick', linewidth=1.5, label='频谱质心') plt.title(title, fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("频率(Hz)", fontsize=10) plt.ylim(0, TARGET_SR//2) plt.grid(alpha=0.2) plt.legend() plt.savefig(save_path) plt.close() def plot_zero_crossing_rate(audio, save_path, title): zcr = librosa.feature.zero_crossing_rate(audio, hop_length=512)[0] time = librosa.times_like(zcr, sr=TARGET_SR, hop_length=512) plt.figure(figsize=(10, 4)) plt.plot(time, zcr, color='forestgreen', linewidth=1.2, label='过零率') plt.title(title, fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("过零率(次/帧)", fontsize=10) plt.ylim(0, 0.5) plt.grid(alpha=0.2) plt.legend() plt.savefig(save_path) plt.close() def plot_spectral_bandwidth(audio, save_path, title): spec_bw = librosa.feature.spectral_bandwidth(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0] time = librosa.times_like(spec_bw, sr=TARGET_SR, hop_length=512) plt.figure(figsize=(10, 4)) plt.plot(time, spec_bw, color='darkorange', linewidth=1.2, label='频谱带宽') plt.title(title, fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("频率(Hz)", fontsize=10) plt.ylim(0, TARGET_SR//2) plt.grid(alpha=0.2) plt.legend() plt.savefig(save_path) plt.close() def plot_spectral_contrast(audio, save_path, title): contrast = librosa.feature.spectral_contrast(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0] time = librosa.times_like(contrast, sr=TARGET_SR, hop_length=512) plt.figure(figsize=(10, 4)) plt.plot(time, contrast, color='purple', linewidth=1.2, label='频谱对比度') plt.title(title, fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("对比度(dB)", fontsize=10) plt.grid(alpha=0.2) plt.legend() plt.savefig(save_path) plt.close() def visualize_audio_features(audio_path, label_name, sample_name): audio = load_raw_audio(os.path.basename(audio_path), os.path.dirname(audio_path)) label_dir = os.path.join(SPEC_VIS_DIR, label_name) os.makedirs(label_dir, exist_ok=True) plot_waveform(audio, os.path.join(label_dir, f"{sample_name}_波形图.png"), f"{label_name}音频波形图({sample_name})") mel_spec = librosa.feature.melspectrogram(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128) mel_db = librosa.power_to_db(mel_spec, ref=np.max) plt.figure(figsize=(10, 4)) librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma') plt.colorbar(format='%+2.0f dB') plt.title(f"{label_name}音频梅尔频谱图({sample_name})", fontsize=12) plt.xlabel("时间(秒)", fontsize=10) plt.ylabel("梅尔频率", fontsize=10) plt.savefig(os.path.join(label_dir, f"{sample_name}_梅尔频谱图.png")) plt.close() plot_spectral_centroid(audio, os.path.join(label_dir, f"{sample_name}_频谱质心图.png"), f"{label_name}音频频谱质心图({sample_name})") plot_zero_crossing_rate(audio, os.path.join(label_dir, f"{sample_name}_过零率图.png"), f"{label_name}音频过零率图({sample_name})") plot_spectral_bandwidth(audio, os.path.join(label_dir, f"{sample_name}_频谱带宽图.png"), f"{label_name}音频频谱带宽图({sample_name})") plot_spectral_contrast(audio, os.path.join(label_dir, f"{sample_name}_频谱对比度图.png"), f"{label_name}音频频谱对比度图({sample_name})") # ==================== 多标签特征对比图 ==================== def plot_feature_comparison(base_dir): """绘制多标签特征对比图(6特征×4标签综合对比)""" # 定义标签与颜色映射(4种标签对应4种颜色) label_info = [ {"name": "正常", "file": "normal.wav", "color": "steelblue"}, {"name": "尖音", "file": "sharp.wav", "color": "firebrick"}, {"name": "波浪音", "file": "wave.wav", "color": "forestgreen"}, {"name": "撞击音", "file": "clunk.wav", "color": "darkorange"} ] # 加载各标签的原始音频 audios = [] for info in label_info: audio = load_raw_audio(info["file"], base_dir) audios.append(audio) # 创建2行3列的子图(共6个特征) fig, axes = plt.subplots(2, 3, figsize=(24, 12)) fig.suptitle("不同类别音频特征对比图", fontsize=16, y=1.02) # 1. 波形对比(第1行第1列) ax = axes[0, 0] for i, (audio, info) in enumerate(zip(audios, label_info)): time = np.linspace(0, DURATION, len(audio)) ax.plot(time, audio, color=info["color"], alpha=0.7, label=info["name"]) ax.set_title("波形对比(时间-振幅)", fontsize=14) ax.set_xlabel("时间(秒)", fontsize=12) ax.set_ylabel("振幅", fontsize=12) ax.grid(alpha=0.2) ax.legend() # 2. 频谱质心对比(第1行第2列) ax = axes[0, 1] for i, (audio, info) in enumerate(zip(audios, label_info)): centroids = librosa.feature.spectral_centroid(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0] time = librosa.times_like(centroids, sr=TARGET_SR, hop_length=512) ax.plot(time, centroids, color=info["color"], linewidth=1.5, label=info["name"]) ax.set_title("频谱质心对比(时间-频率)", fontsize=14) ax.set_xlabel("时间(秒)", fontsize=12) ax.set_ylabel("频率(Hz)", fontsize=12) ax.set_ylim(0, TARGET_SR//2) ax.grid(alpha=0.2) ax.legend() # 3. 过零率对比(第1行第3列) ax = axes[0, 2] for i, (audio, info) in enumerate(zip(audios, label_info)): zcr = librosa.feature.zero_crossing_rate(audio, hop_length=512)[0] time = librosa.times_like(zcr, sr=TARGET_SR, hop_length=512) ax.plot(time, zcr, color=info["color"], linewidth=1.2, label=info["name"]) ax.set_title("过零率对比(时间-过零次数)", fontsize=14) ax.set_xlabel("时间(秒)", fontsize=12) ax.set_ylabel("过零率(次/帧)", fontsize=12) ax.set_ylim(0, 0.5) ax.grid(alpha=0.2) ax.legend() # 4. 频谱带宽对比(第2行第1列) ax = axes[1, 0] for i, (audio, info) in enumerate(zip(audios, label_info)): spec_bw = librosa.feature.spectral_bandwidth(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0] time = librosa.times_like(spec_bw, sr=TARGET_SR, hop_length=512) ax.plot(time, spec_bw, color=info["color"], linewidth=1.2, label=info["name"]) ax.set_title("频谱带宽对比(时间-频率宽度)", fontsize=14) ax.set_xlabel("时间(秒)", fontsize=12) ax.set_ylabel("频率(Hz)", fontsize=12) ax.set_ylim(0, TARGET_SR//2) ax.grid(alpha=0.2) ax.legend() # 5. 频谱对比度对比(第2行第2列) ax = axes[1, 1] for i, (audio, info) in enumerate(zip(audios, label_info)): contrast = librosa.feature.spectral_contrast(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0] time = librosa.times_like(contrast, sr=TARGET_SR, hop_length=512) ax.plot(time, contrast, color=info["color"], linewidth=1.2, label=info["name"]) ax.set_title("频谱对比度对比(时间-频率差异)", fontsize=14) ax.set_xlabel("时间(秒)", fontsize=12) ax.set_ylabel("对比度(dB)", fontsize=12) ax.grid(alpha=0.2) ax.legend() # 6. 梅尔频谱对比(第2行第3列) ax = axes[1, 2] for i, (audio, info) in enumerate(zip(audios, label_info)): mel_spec = librosa.feature.melspectrogram(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128) mel_db = librosa.power_to_db(mel_spec, ref=np.max) # 取均值作为该类的典型频谱(避免子图显示复杂波形) avg_mel = np.mean(mel_db, axis=1) ax.plot(avg_mel, color=info["color"], linewidth=1.5, label=info["name"]) ax.set_title("梅尔频谱均值对比(梅尔频率-能量)", fontsize=14) ax.set_xlabel("梅尔频率索引(共128维)", fontsize=12) ax.set_ylabel("能量(dB)", fontsize=12) ax.grid(alpha=0.2) ax.legend() # 调整子图间距 plt.tight_layout() # 保存对比图 save_path = os.path.join(COMPARE_VIS_DIR, "多标签特征对比图.png") plt.savefig(save_path, bbox_inches='tight') plt.close() # ==================== 训练过程可视化函数 ==================== def plot_training_history(history, save_dir): epochs = range(1, len(history.history['loss']) + 1) train_loss = history.history['loss'] val_loss = history.history['val_loss'] train_acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] plt.figure(figsize=(10, 4)) plt.subplot(1, 2, 1) plt.plot(epochs, train_loss, 'bo-', label='训练损失') plt.plot(epochs, val_loss, 'ro-', label='验证损失') plt.title('训练与验证损失', fontsize=12) plt.xlabel('轮次', fontsize=10) plt.ylabel('损失值', fontsize=10) plt.legend() plt.grid(alpha=0.3) plt.subplot(1, 2, 2) plt.plot(epochs, train_acc, 'bo-', label='训练准确率') plt.plot(epochs, val_acc, 'ro-', label='验证准确率') plt.title('训练与验证准确率', fontsize=12) plt.xlabel('轮次', fontsize=10) plt.ylabel('准确率', fontsize=10) plt.legend() plt.grid(alpha=0.3) plt.tight_layout() save_path = os.path.join(save_dir, "训练过程曲线.png") plt.savefig(save_path) plt.close() # ==================== 分类指标可视化函数 ==================== def plot_classification_metrics(y_true, y_pred, y_prob, save_dir): """ 绘制多种分类指标的可视化图表 参数: - y_true: 真实标签 - y_pred: 预测标签 - y_prob: 预测概率 - save_dir: 保存目录 """ # 1. ROC曲线 plt.figure(figsize=(10, 8)) # 二值化标签以计算ROC n_classes = len(np.unique(y_true)) y_true_bin = label_binarize(y_true, classes=list(range(n_classes))) # 计算每个类别的ROC曲线和AUC fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes): fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_prob[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) # 绘制每个类别的ROC曲线 label_names = ["正常", "尖音", "波浪音", "撞击音"] colors = ['blue', 'red', 'green', 'orange'] for i, color in zip(range(n_classes), colors): plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{label_names[i]} (AUC = {roc_auc[i]:.3f})') # 绘制随机猜测的ROC曲线 plt.plot([0, 1], [0, 1], 'k--', lw=2) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('假阳性率 (FPR)', fontsize=12) plt.ylabel('真阳性率 (TPR)', fontsize=12) plt.title('多类别ROC曲线', fontsize=14) plt.legend(loc="lower right") plt.grid(alpha=0.3) plt.savefig(os.path.join(save_dir, "ROC曲线.png")) plt.close() # 2. 精确率-召回率曲线 plt.figure(figsize=(10, 8)) # 计算每个类别的精确率-召回率曲线 precision = dict() recall = dict() for i in range(n_classes): precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], y_prob[:, i]) plt.plot(recall[i], precision[i], lw=2, label=f'{label_names[i]}') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('召回率', fontsize=12) plt.ylabel('精确率', fontsize=12) plt.title('精确率-召回率曲线', fontsize=14) plt.legend(loc="lower left") plt.grid(alpha=0.3) plt.savefig(os.path.join(save_dir, "精确率-召回率曲线.png")) plt.close() # 3. 分类报告热力图 plt.figure(figsize=(10, 8)) # 生成分类报告 report = classification_report(y_true, y_pred, target_names=label_names, output_dict=True) # 提取精确率、召回率和F1分数 metrics = ['precision', 'recall', 'f1-score'] results = [] for label in label_names: row = [report[label][metric] for metric in metrics] results.append(row) # 转换为DataFrame并绘制热力图 df = pd.DataFrame(results, index=label_names, columns=metrics) sns.heatmap(df, annot=True, fmt='.3f', cmap='YlGnBu', annot_kws={'size': 12}, linewidths=0.5) plt.title('分类指标热力图', fontsize=14) plt.yticks(rotation=0, fontsize=12) plt.xticks(rotation=0, fontsize=12) plt.savefig(os.path.join(save_dir, "分类指标热力图.png")) plt.close() # 4. 各类别样本数量分布 plt.figure(figsize=(10, 6)) # 统计各类别样本数量 class_counts = pd.Series(y_true).value_counts().sort_index() # 绘制柱状图 plt.bar(label_names, class_counts, color=colors, alpha=0.7) # 添加数值标签 for i, v in enumerate(class_counts): plt.text(i, v + 2, str(v), ha='center', fontsize=12) plt.title('测试集各类别样本数量分布', fontsize=14) plt.xlabel('类别', fontsize=12) plt.ylabel('样本数量', fontsize=12) plt.ylim(0, max(class_counts) + 20) plt.grid(axis='y', alpha=0.3) plt.savefig(os.path.join(save_dir, "类别分布.png")) plt.close() # 5. 保存详细的分类报告到CSV report_df = pd.DataFrame(report).transpose() report_df.to_csv(os.path.join(save_dir, "分类报告.csv")) print(f"分类指标可视化完成,保存至:{save_dir}") # ==================== 主流程 ==================== if __name__ == "__main__": base_dir = check_audio_files() print(f"检测到所有音频文件,路径:{base_dir}") # 1. 原始音频多特征可视化 label_names = { "normal.wav": "正常", "sharp.wav": "尖音", "wave.wav": "波浪音", "clunk.wav": "撞击音", "clunk2.wav": "撞击音" } for audio_file in REQUIRED_AUDIOS: audio_path = os.path.join(base_dir, audio_file) label_name = label_names[audio_file] sample_name = os.path.splitext(audio_file)[0] visualize_audio_features(audio_path, label_name, f"{sample_name}_原始") print(f"多维度音频特征可视化完成,保存至:{SPEC_VIS_DIR}") # 新增:绘制多标签特征对比图 plot_feature_comparison(base_dir) print(f"多标签特征对比图生成完成,保存至:{COMPARE_VIS_DIR}") # 新增:数据增强可视化 os.makedirs(os.path.join(AUGMENT_VIS_DIR, "原始音频"), exist_ok=True) visualize_augmentations(base_dir) print(f"数据增强可视化完成,保存至:{AUGMENT_VIS_DIR}") # 2. 生成带分组的数据集 data_dict, metadata_dict = generate_dataset(base_dir) total_samples = sum([len(v) for v in data_dict.values()]) print(f"成功生成数据集(总样本数:{total_samples})") # 3. 定制测试集 X_test, y_test, meta_test = [], [], [] X_train, y_train, meta_train = [], [], [] for label in TESTSET_SAMPLES: available = len(data_dict[label]) required = TESTSET_SAMPLES[label] if available < required: raise ValueError(f"标签{label}可用样本数{available}不足测试集需求{required}") indices = np.random.choice(available, size=required, replace=False) X_test.extend(data_dict[label][indices]) y_test.extend([label]*required) meta_test.extend([metadata_dict[label][i] for i in indices]) train_mask = ~np.isin(np.arange(available), indices) X_train.extend(data_dict[label][train_mask]) y_train.extend([label]*(available - required)) meta_train.extend([metadata_dict[label][i] for i in np.where(train_mask)[0]]) X_train = np.array(X_train) X_test = np.array(X_test) y_train = np.array(y_train) y_test = np.array(y_test) np.save(os.path.join(EXP_DIR, "测试集梅尔频谱.npy"), X_test) np.save(os.path.join(EXP_DIR, "测试集标签.npy"), y_test) pd.DataFrame(meta_test).to_csv(os.path.join(EXP_DIR, "测试集元数据.csv"), index=False) print(f"测试集已保存(正常:{TESTSET_SAMPLES[0]},尖音:{TESTSET_SAMPLES[1]},波浪音:{TESTSET_SAMPLES[2]},撞击音:{TESTSET_SAMPLES[3]})") # 4. 构建并训练模型 model = models.Sequential([ layers.Input(shape=(128, 130)), # 调整为3秒的特征维度 layers.Reshape((128, 130, 1)), layers.Conv2D(32, (3,3), activation='relu'), layers.MaxPooling2D((2,2)), layers.Conv2D(64, (3,3), activation='relu'), layers.MaxPooling2D((2,2)), layers.Flatten(), layers.Dense(128, activation='relu'), layers.Dropout(0.5), layers.Dense(4, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) history = model.fit(X_train, y_train, validation_split=0.2, epochs=20, batch_size=32) # 5. 训练过程可视化 plot_training_history(history, TRAIN_VIS_DIR) print(f"训练过程可视化完成,保存至:{TRAIN_VIS_DIR}") # 6. 评估并保存结果 test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0) print(f"\n测试准确率:{test_acc:.4f}") # 获取预测概率和预测标签 y_prob = model.predict(X_test) y_pred = np.argmax(y_prob, axis=1) # 绘制分类指标可视化 plot_classification_metrics(y_test, y_pred, y_prob, METRICS_VIS_DIR) # 绘制混淆矩阵 plt.figure(figsize=(8,6)) cm = confusion_matrix(y_test, y_pred) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=["正常", "尖音", "波浪音", "撞击音"], yticklabels=["正常", "尖音", "波浪音", "撞击音"]) plt.xlabel("预测标签") plt.ylabel("真实标签") plt.title("混淆矩阵") plt.savefig(os.path.join(EXP_DIR, "混淆矩阵.png")) print(f"实验结果已保存至:{EXP_DIR}")

浙公网安备 33010602011771号