小样本实验-基于数据增强的CNN声音分类模型简单模型搭建

样本太少了 但是又头铁想做声音分类模型 脑子一灵光 不如自己写一个数据增强(实际项目可以把GAN加进来 有空我再美美把玩把玩GAN)

仅供娱乐

这还是我第一次弄cnn 很兴奋 虽然已经是大模型时代了 哎

import os
import numpy as np
import librosa
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve
from sklearn.preprocessing import label_binarize
import tensorflow as tf
from tensorflow.keras import layers, models
import pandas as pd

# ==================== 全局配置 ====================
TARGET_SR = 22050  # 采样率
DURATION = 3       # 音频时长(秒),修改为3秒
AUGMENT_TIMES = 500  # 每个原始音频生成增强样本数
EXP_DIR = "实验记录"  # 实验结果保存目录
SPEC_VIS_DIR = os.path.join(EXP_DIR, "频谱可视化")  # 频谱可视化保存路径
TRAIN_VIS_DIR = os.path.join(EXP_DIR, "训练过程可视化")  # 训练过程可视化路径
COMPARE_VIS_DIR = os.path.join(EXP_DIR, "特征对比可视化")  # 新增:特征对比图保存路径
AUGMENT_VIS_DIR = os.path.join(EXP_DIR, "数据增强可视化")  # 新增:数据增强可视化保存路径
METRICS_VIS_DIR = os.path.join(EXP_DIR, "分类指标可视化")  # 新增:分类指标可视化保存路径

# ==================== 必须存在的音频文件(同目录) ====================
REQUIRED_AUDIOS = [
    "normal.wav",   # 正常声音(标签0)
    "sharp.wav",    # 尖音(标签1)
    "wave.wav",     # 波浪音(标签2)
    "clunk.wav",   # 撞击音1(标签3)
    "clunk2.wav"    # 撞击音2(标签3)
]

# ==================== 测试集样本量配置 ====================
TESTSET_SAMPLES = {
    0: 300,  # 正常样本300个
    1: 50,   # 尖音样本50个
    2: 50,   # 波浪音样本50个
    3: 50    # 撞击音样本50个(clunk+clunk2合并)
}

# ==================== 初始化 ====================
os.makedirs(EXP_DIR, exist_ok=True)
os.makedirs(SPEC_VIS_DIR, exist_ok=True)
os.makedirs(TRAIN_VIS_DIR, exist_ok=True)
os.makedirs(COMPARE_VIS_DIR, exist_ok=True)  # 新增:创建对比图目录
os.makedirs(AUGMENT_VIS_DIR, exist_ok=True)  # 新增:创建数据增强可视化目录
os.makedirs(METRICS_VIS_DIR, exist_ok=True)  # 新增:创建分类指标可视化目录
plt.rcParams['font.sans-serif'] = ['SimHei']  # 解决中文乱码
plt.rcParams['figure.dpi'] = 300  # 高分辨率
plt.rcParams['savefig.bbox'] = 'tight'  # 自动调整边界

# ==================== 核心函数 ====================
def check_audio_files():
    current_dir = os.path.dirname(os.path.abspath(__file__))
    missing = []
    for audio in REQUIRED_AUDIOS:
        if not os.path.exists(os.path.join(current_dir, audio)):
            missing.append(audio)
    if missing:
        raise SystemExit(f"错误:以下文件未在代码目录找到(路径:{current_dir}):\n{missing}")
    return current_dir

def load_raw_audio(filename, base_dir):
    audio_path = os.path.join(base_dir, filename)
    audio, _ = librosa.load(audio_path, sr=TARGET_SR)
    
    # 特殊处理clunk.wav和clunk2.wav,从第1秒开始截取3秒
    if filename in ["clunk.wav", "clunk2.wav"]:
        start_sample = TARGET_SR  # 第1秒开始的样本索引
        end_sample = start_sample + TARGET_SR * DURATION  # 3秒后的样本索引
        
        # 确保音频长度足够
        if len(audio) < end_sample:
            # 如果音频不足3秒,从第1秒开始截取到结尾,并填充到3秒
            audio = audio[start_sample:]
            audio = librosa.util.fix_length(audio, size=TARGET_SR * DURATION)
        else:
            # 正常截取3秒
            audio = audio[start_sample:end_sample]
    else:
        # 其他音频文件截取前3秒
        audio = librosa.util.fix_length(audio, size=TARGET_SR * DURATION)
    
    return audio

def augment_audio(audio):
    if np.random.random() > 0.5:
        audio += np.random.randn(len(audio)) * 0.005
    if np.random.random() > 0.5:
        shift = np.random.randint(-len(audio)//10, len(audio)//10)
        audio = np.roll(audio, shift)
    if np.random.random() > 0.5:
        audio = librosa.effects.pitch_shift(audio, sr=TARGET_SR, n_steps=np.random.randint(-2, 3))
    if np.random.random() > 0.5:
        rate = np.random.uniform(0.9, 1.1)
        audio = librosa.effects.time_stretch(audio, rate=rate)
        audio = librosa.util.fix_length(audio, size=TARGET_SR * DURATION)
    return audio

# ==================== 数据增强可视化函数 ====================
def visualize_augmentations(base_dir):
    """为每种增强方法生成可视化频谱图"""
    label_names = {
        "normal.wav": "正常",
        "sharp.wav": "尖音",
        "wave.wav": "波浪音",
        "clunk.wav": "撞击音",
        "clunk2.wav": "撞击音"
    }
    
    # 为每种增强方法创建子目录
    aug_types = ["噪声添加", "时间平移", "音高变换", "时间拉伸"]
    for aug_type in aug_types:
        os.makedirs(os.path.join(AUGMENT_VIS_DIR, aug_type), exist_ok=True)
    
    # 为每种音频文件可视化增强效果
    for audio_file in REQUIRED_AUDIOS:
        audio = load_raw_audio(audio_file, base_dir)
        label_name = label_names[audio_file]
        sample_name = os.path.splitext(audio_file)[0]
        
        # 1. 原始音频频谱图
        mel_spec = librosa.feature.melspectrogram(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128)
        mel_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        plt.figure(figsize=(10, 4))
        librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma')
        plt.colorbar(format='%+2.0f dB')
        plt.title(f"{label_name}原始音频梅尔频谱图", fontsize=12)
        plt.xlabel("时间(秒)", fontsize=10)
        plt.ylabel("梅尔频率", fontsize=10)
        plt.savefig(os.path.join(AUGMENT_VIS_DIR, "原始音频", f"{sample_name}_原始频谱.png"))
        plt.close()
        
        # 2. 噪声添加增强
        np.random.seed(42)  # 固定随机种子,确保结果可重现
        aug_audio = audio + np.random.randn(len(audio)) * 0.005
        mel_spec = librosa.feature.melspectrogram(y=aug_audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128)
        mel_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        plt.figure(figsize=(10, 4))
        librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma')
        plt.colorbar(format='%+2.0f dB')
        plt.title(f"{label_name}噪声添加增强梅尔频谱图", fontsize=12)
        plt.xlabel("时间(秒)", fontsize=10)
        plt.ylabel("梅尔频率", fontsize=10)
        plt.savefig(os.path.join(AUGMENT_VIS_DIR, "噪声添加", f"{sample_name}_噪声频谱.png"))
        plt.close()
        
        # 3. 时间平移增强
        shift = np.random.randint(-len(audio)//10, len(audio)//10)
        aug_audio = np.roll(audio, shift)
        mel_spec = librosa.feature.melspectrogram(y=aug_audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128)
        mel_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        plt.figure(figsize=(10, 4))
        librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma')
        plt.colorbar(format='%+2.0f dB')
        plt.title(f"{label_name}时间平移增强梅尔频谱图", fontsize=12)
        plt.xlabel("时间(秒)", fontsize=10)
        plt.ylabel("梅尔频率", fontsize=10)
        plt.savefig(os.path.join(AUGMENT_VIS_DIR, "时间平移", f"{sample_name}_平移频谱.png"))
        plt.close()
        
        # 4. 音高变换增强
        n_steps = np.random.randint(-2, 3)
        aug_audio = librosa.effects.pitch_shift(audio, sr=TARGET_SR, n_steps=n_steps)
        mel_spec = librosa.feature.melspectrogram(y=aug_audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128)
        mel_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        plt.figure(figsize=(10, 4))
        librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma')
        plt.colorbar(format='%+2.0f dB')
        plt.title(f"{label_name}音高变换增强梅尔频谱图 (n_steps={n_steps})", fontsize=12)
        plt.xlabel("时间(秒)", fontsize=10)
        plt.ylabel("梅尔频率", fontsize=10)
        plt.savefig(os.path.join(AUGMENT_VIS_DIR, "音高变换", f"{sample_name}_音高频谱.png"))
        plt.close()
        
        # 5. 时间拉伸增强
        rate = np.random.uniform(0.9, 1.1)
        aug_audio = librosa.effects.time_stretch(audio, rate=rate)
        aug_audio = librosa.util.fix_length(aug_audio, size=TARGET_SR * DURATION)
        mel_spec = librosa.feature.melspectrogram(y=aug_audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128)
        mel_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        plt.figure(figsize=(10, 4))
        librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma')
        plt.colorbar(format='%+2.0f dB')
        plt.title(f"{label_name}时间拉伸增强梅尔频谱图 (rate={rate:.2f})", fontsize=12)
        plt.xlabel("时间(秒)", fontsize=10)
        plt.ylabel("梅尔频率", fontsize=10)
        plt.savefig(os.path.join(AUGMENT_VIS_DIR, "时间拉伸", f"{sample_name}_拉伸频谱.png"))
        plt.close()

def generate_dataset(base_dir):
    label_map = {
        "normal.wav": 0,
        "sharp.wav": 1,
        "wave.wav": 2,
        "clunk.wav": 3,
        "clunk2.wav": 3
    }
    data_dict = {0: [], 1: [], 2: [], 3: []}
    metadata_dict = {0: [], 1: [], 2: [], 3: []}

    for audio_file in REQUIRED_AUDIOS:
        raw_audio = load_raw_audio(audio_file, base_dir)
        label = label_map[audio_file]
        
        for i in range(AUGMENT_TIMES):
            augmented = augment_audio(raw_audio)
            mel = librosa.feature.melspectrogram(y=augmented, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128)
            mel_db = librosa.power_to_db(mel, ref=np.max)
            
            data_dict[label].append(mel_db)
            metadata_dict[label].append({
                "原始文件": audio_file,
                "增强编号": i,
                "标签": label
            })
    
    return {k: np.array(v) for k, v in data_dict.items()}, metadata_dict

# ==================== 多维度音频特征可视化 ====================
def plot_waveform(audio, save_path, title):
    plt.figure(figsize=(10, 4))
    time = np.linspace(0, DURATION, len(audio))
    plt.plot(time, audio, color='steelblue', alpha=0.8)
    plt.title(title, fontsize=12)
    plt.xlabel("时间(秒)", fontsize=10)
    plt.ylabel("振幅", fontsize=10)
    plt.grid(alpha=0.2)
    plt.savefig(save_path)
    plt.close()

def plot_spectral_centroid(audio, save_path, title):
    spectral_centroids = librosa.feature.spectral_centroid(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0]
    time = librosa.times_like(spectral_centroids, sr=TARGET_SR, hop_length=512)
    
    plt.figure(figsize=(10, 4))
    plt.plot(time, spectral_centroids, color='firebrick', linewidth=1.5, label='频谱质心')
    plt.title(title, fontsize=12)
    plt.xlabel("时间(秒)", fontsize=10)
    plt.ylabel("频率(Hz)", fontsize=10)
    plt.ylim(0, TARGET_SR//2)
    plt.grid(alpha=0.2)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

def plot_zero_crossing_rate(audio, save_path, title):
    zcr = librosa.feature.zero_crossing_rate(audio, hop_length=512)[0]
    time = librosa.times_like(zcr, sr=TARGET_SR, hop_length=512)
    
    plt.figure(figsize=(10, 4))
    plt.plot(time, zcr, color='forestgreen', linewidth=1.2, label='过零率')
    plt.title(title, fontsize=12)
    plt.xlabel("时间(秒)", fontsize=10)
    plt.ylabel("过零率(次/帧)", fontsize=10)
    plt.ylim(0, 0.5)
    plt.grid(alpha=0.2)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

def plot_spectral_bandwidth(audio, save_path, title):
    spec_bw = librosa.feature.spectral_bandwidth(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0]
    time = librosa.times_like(spec_bw, sr=TARGET_SR, hop_length=512)
    
    plt.figure(figsize=(10, 4))
    plt.plot(time, spec_bw, color='darkorange', linewidth=1.2, label='频谱带宽')
    plt.title(title, fontsize=12)
    plt.xlabel("时间(秒)", fontsize=10)
    plt.ylabel("频率(Hz)", fontsize=10)
    plt.ylim(0, TARGET_SR//2)
    plt.grid(alpha=0.2)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

def plot_spectral_contrast(audio, save_path, title):
    contrast = librosa.feature.spectral_contrast(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0]
    time = librosa.times_like(contrast, sr=TARGET_SR, hop_length=512)
    
    plt.figure(figsize=(10, 4))
    plt.plot(time, contrast, color='purple', linewidth=1.2, label='频谱对比度')
    plt.title(title, fontsize=12)
    plt.xlabel("时间(秒)", fontsize=10)
    plt.ylabel("对比度(dB)", fontsize=10)
    plt.grid(alpha=0.2)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

def visualize_audio_features(audio_path, label_name, sample_name):
    audio = load_raw_audio(os.path.basename(audio_path), os.path.dirname(audio_path))
    label_dir = os.path.join(SPEC_VIS_DIR, label_name)
    os.makedirs(label_dir, exist_ok=True)

    plot_waveform(audio, os.path.join(label_dir, f"{sample_name}_波形图.png"), 
                 f"{label_name}音频波形图({sample_name})")
    
    mel_spec = librosa.feature.melspectrogram(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128)
    mel_db = librosa.power_to_db(mel_spec, ref=np.max)
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(mel_db, sr=TARGET_SR, hop_length=512, x_axis='time', y_axis='mel', cmap='magma')
    plt.colorbar(format='%+2.0f dB')
    plt.title(f"{label_name}音频梅尔频谱图({sample_name})", fontsize=12)
    plt.xlabel("时间(秒)", fontsize=10)
    plt.ylabel("梅尔频率", fontsize=10)
    plt.savefig(os.path.join(label_dir, f"{sample_name}_梅尔频谱图.png"))
    plt.close()
    
    plot_spectral_centroid(audio, os.path.join(label_dir, f"{sample_name}_频谱质心图.png"), 
                         f"{label_name}音频频谱质心图({sample_name})")
    
    plot_zero_crossing_rate(audio, os.path.join(label_dir, f"{sample_name}_过零率图.png"), 
                          f"{label_name}音频过零率图({sample_name})")
    
    plot_spectral_bandwidth(audio, os.path.join(label_dir, f"{sample_name}_频谱带宽图.png"), 
                          f"{label_name}音频频谱带宽图({sample_name})")
    
    plot_spectral_contrast(audio, os.path.join(label_dir, f"{sample_name}_频谱对比度图.png"), 
                         f"{label_name}音频频谱对比度图({sample_name})")

# ==================== 多标签特征对比图 ====================
def plot_feature_comparison(base_dir):
    """绘制多标签特征对比图(6特征×4标签综合对比)"""
    # 定义标签与颜色映射(4种标签对应4种颜色)
    label_info = [
        {"name": "正常", "file": "normal.wav", "color": "steelblue"},
        {"name": "尖音", "file": "sharp.wav", "color": "firebrick"},
        {"name": "波浪音", "file": "wave.wav", "color": "forestgreen"},
        {"name": "撞击音", "file": "clunk.wav", "color": "darkorange"}
    ]
    
    # 加载各标签的原始音频
    audios = []
    for info in label_info:
        audio = load_raw_audio(info["file"], base_dir)
        audios.append(audio)
    
    # 创建2行3列的子图(共6个特征)
    fig, axes = plt.subplots(2, 3, figsize=(24, 12))
    fig.suptitle("不同类别音频特征对比图", fontsize=16, y=1.02)

    # 1. 波形对比(第1行第1列)
    ax = axes[0, 0]
    for i, (audio, info) in enumerate(zip(audios, label_info)):
        time = np.linspace(0, DURATION, len(audio))
        ax.plot(time, audio, color=info["color"], alpha=0.7, label=info["name"])
    ax.set_title("波形对比(时间-振幅)", fontsize=14)
    ax.set_xlabel("时间(秒)", fontsize=12)
    ax.set_ylabel("振幅", fontsize=12)
    ax.grid(alpha=0.2)
    ax.legend()

    # 2. 频谱质心对比(第1行第2列)
    ax = axes[0, 1]
    for i, (audio, info) in enumerate(zip(audios, label_info)):
        centroids = librosa.feature.spectral_centroid(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0]
        time = librosa.times_like(centroids, sr=TARGET_SR, hop_length=512)
        ax.plot(time, centroids, color=info["color"], linewidth=1.5, label=info["name"])
    ax.set_title("频谱质心对比(时间-频率)", fontsize=14)
    ax.set_xlabel("时间(秒)", fontsize=12)
    ax.set_ylabel("频率(Hz)", fontsize=12)
    ax.set_ylim(0, TARGET_SR//2)
    ax.grid(alpha=0.2)
    ax.legend()

    # 3. 过零率对比(第1行第3列)
    ax = axes[0, 2]
    for i, (audio, info) in enumerate(zip(audios, label_info)):
        zcr = librosa.feature.zero_crossing_rate(audio, hop_length=512)[0]
        time = librosa.times_like(zcr, sr=TARGET_SR, hop_length=512)
        ax.plot(time, zcr, color=info["color"], linewidth=1.2, label=info["name"])
    ax.set_title("过零率对比(时间-过零次数)", fontsize=14)
    ax.set_xlabel("时间(秒)", fontsize=12)
    ax.set_ylabel("过零率(次/帧)", fontsize=12)
    ax.set_ylim(0, 0.5)
    ax.grid(alpha=0.2)
    ax.legend()

    # 4. 频谱带宽对比(第2行第1列)
    ax = axes[1, 0]
    for i, (audio, info) in enumerate(zip(audios, label_info)):
        spec_bw = librosa.feature.spectral_bandwidth(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0]
        time = librosa.times_like(spec_bw, sr=TARGET_SR, hop_length=512)
        ax.plot(time, spec_bw, color=info["color"], linewidth=1.2, label=info["name"])
    ax.set_title("频谱带宽对比(时间-频率宽度)", fontsize=14)
    ax.set_xlabel("时间(秒)", fontsize=12)
    ax.set_ylabel("频率(Hz)", fontsize=12)
    ax.set_ylim(0, TARGET_SR//2)
    ax.grid(alpha=0.2)
    ax.legend()

    # 5. 频谱对比度对比(第2行第2列)
    ax = axes[1, 1]
    for i, (audio, info) in enumerate(zip(audios, label_info)):
        contrast = librosa.feature.spectral_contrast(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512)[0]
        time = librosa.times_like(contrast, sr=TARGET_SR, hop_length=512)
        ax.plot(time, contrast, color=info["color"], linewidth=1.2, label=info["name"])
    ax.set_title("频谱对比度对比(时间-频率差异)", fontsize=14)
    ax.set_xlabel("时间(秒)", fontsize=12)
    ax.set_ylabel("对比度(dB)", fontsize=12)
    ax.grid(alpha=0.2)
    ax.legend()

    # 6. 梅尔频谱对比(第2行第3列)
    ax = axes[1, 2]
    for i, (audio, info) in enumerate(zip(audios, label_info)):
        mel_spec = librosa.feature.melspectrogram(y=audio, sr=TARGET_SR, n_fft=2048, hop_length=512, n_mels=128)
        mel_db = librosa.power_to_db(mel_spec, ref=np.max)
        # 取均值作为该类的典型频谱(避免子图显示复杂波形)
        avg_mel = np.mean(mel_db, axis=1)
        ax.plot(avg_mel, color=info["color"], linewidth=1.5, label=info["name"])
    ax.set_title("梅尔频谱均值对比(梅尔频率-能量)", fontsize=14)
    ax.set_xlabel("梅尔频率索引(共128维)", fontsize=12)
    ax.set_ylabel("能量(dB)", fontsize=12)
    ax.grid(alpha=0.2)
    ax.legend()

    # 调整子图间距
    plt.tight_layout()
    # 保存对比图
    save_path = os.path.join(COMPARE_VIS_DIR, "多标签特征对比图.png")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()

# ==================== 训练过程可视化函数 ====================
def plot_training_history(history, save_dir):
    epochs = range(1, len(history.history['loss']) + 1)
    train_loss = history.history['loss']
    val_loss = history.history['val_loss']
    train_acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, 'bo-', label='训练损失')
    plt.plot(epochs, val_loss, 'ro-', label='验证损失')
    plt.title('训练与验证损失', fontsize=12)
    plt.xlabel('轮次', fontsize=10)
    plt.ylabel('损失值', fontsize=10)
    plt.legend()
    plt.grid(alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_acc, 'bo-', label='训练准确率')
    plt.plot(epochs, val_acc, 'ro-', label='验证准确率')
    plt.title('训练与验证准确率', fontsize=12)
    plt.xlabel('轮次', fontsize=10)
    plt.ylabel('准确率', fontsize=10)
    plt.legend()
    plt.grid(alpha=0.3)

    plt.tight_layout()
    save_path = os.path.join(save_dir, "训练过程曲线.png")
    plt.savefig(save_path)
    plt.close()

# ==================== 分类指标可视化函数 ====================
def plot_classification_metrics(y_true, y_pred, y_prob, save_dir):
    """
    绘制多种分类指标的可视化图表
    参数:
    - y_true: 真实标签
    - y_pred: 预测标签
    - y_prob: 预测概率
    - save_dir: 保存目录
    """
    # 1. ROC曲线
    plt.figure(figsize=(10, 8))
    
    # 二值化标签以计算ROC
    n_classes = len(np.unique(y_true))
    y_true_bin = label_binarize(y_true, classes=list(range(n_classes)))
    
    # 计算每个类别的ROC曲线和AUC
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_prob[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    # 绘制每个类别的ROC曲线
    label_names = ["正常", "尖音", "波浪音", "撞击音"]
    colors = ['blue', 'red', 'green', 'orange']
    
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                 label=f'{label_names[i]} (AUC = {roc_auc[i]:.3f})')
    
    # 绘制随机猜测的ROC曲线
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('假阳性率 (FPR)', fontsize=12)
    plt.ylabel('真阳性率 (TPR)', fontsize=12)
    plt.title('多类别ROC曲线', fontsize=14)
    plt.legend(loc="lower right")
    plt.grid(alpha=0.3)
    plt.savefig(os.path.join(save_dir, "ROC曲线.png"))
    plt.close()
    
    # 2. 精确率-召回率曲线
    plt.figure(figsize=(10, 8))
    
    # 计算每个类别的精确率-召回率曲线
    precision = dict()
    recall = dict()
    
    for i in range(n_classes):
        precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], y_prob[:, i])
        plt.plot(recall[i], precision[i], lw=2,
                 label=f'{label_names[i]}')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('召回率', fontsize=12)
    plt.ylabel('精确率', fontsize=12)
    plt.title('精确率-召回率曲线', fontsize=14)
    plt.legend(loc="lower left")
    plt.grid(alpha=0.3)
    plt.savefig(os.path.join(save_dir, "精确率-召回率曲线.png"))
    plt.close()
    
    # 3. 分类报告热力图
    plt.figure(figsize=(10, 8))
    
    # 生成分类报告
    report = classification_report(y_true, y_pred, 
                                   target_names=label_names, 
                                   output_dict=True)
    
    # 提取精确率、召回率和F1分数
    metrics = ['precision', 'recall', 'f1-score']
    results = []
    
    for label in label_names:
        row = [report[label][metric] for metric in metrics]
        results.append(row)
    
    # 转换为DataFrame并绘制热力图
    df = pd.DataFrame(results, index=label_names, columns=metrics)
    
    sns.heatmap(df, annot=True, fmt='.3f', cmap='YlGnBu', 
                annot_kws={'size': 12}, linewidths=0.5)
    
    plt.title('分类指标热力图', fontsize=14)
    plt.yticks(rotation=0, fontsize=12)
    plt.xticks(rotation=0, fontsize=12)
    plt.savefig(os.path.join(save_dir, "分类指标热力图.png"))
    plt.close()
    
    # 4. 各类别样本数量分布
    plt.figure(figsize=(10, 6))
    
    # 统计各类别样本数量
    class_counts = pd.Series(y_true).value_counts().sort_index()
    
    # 绘制柱状图
    plt.bar(label_names, class_counts, color=colors, alpha=0.7)
    
    # 添加数值标签
    for i, v in enumerate(class_counts):
        plt.text(i, v + 2, str(v), ha='center', fontsize=12)
    
    plt.title('测试集各类别样本数量分布', fontsize=14)
    plt.xlabel('类别', fontsize=12)
    plt.ylabel('样本数量', fontsize=12)
    plt.ylim(0, max(class_counts) + 20)
    plt.grid(axis='y', alpha=0.3)
    plt.savefig(os.path.join(save_dir, "类别分布.png"))
    plt.close()
    
    # 5. 保存详细的分类报告到CSV
    report_df = pd.DataFrame(report).transpose()
    report_df.to_csv(os.path.join(save_dir, "分类报告.csv"))
    
    print(f"分类指标可视化完成,保存至:{save_dir}")

# ==================== 主流程 ====================
if __name__ == "__main__":
    base_dir = check_audio_files()
    print(f"检测到所有音频文件,路径:{base_dir}")

    # 1. 原始音频多特征可视化
    label_names = {
        "normal.wav": "正常",
        "sharp.wav": "尖音",
        "wave.wav": "波浪音",
        "clunk.wav": "撞击音",
        "clunk2.wav": "撞击音"
    }
    for audio_file in REQUIRED_AUDIOS:
        audio_path = os.path.join(base_dir, audio_file)
        label_name = label_names[audio_file]
        sample_name = os.path.splitext(audio_file)[0]
        visualize_audio_features(audio_path, label_name, f"{sample_name}_原始")
    print(f"多维度音频特征可视化完成,保存至:{SPEC_VIS_DIR}")

    # 新增:绘制多标签特征对比图
    plot_feature_comparison(base_dir)
    print(f"多标签特征对比图生成完成,保存至:{COMPARE_VIS_DIR}")
    
    # 新增:数据增强可视化
    os.makedirs(os.path.join(AUGMENT_VIS_DIR, "原始音频"), exist_ok=True)
    visualize_augmentations(base_dir)
    print(f"数据增强可视化完成,保存至:{AUGMENT_VIS_DIR}")

    # 2. 生成带分组的数据集
    data_dict, metadata_dict = generate_dataset(base_dir)
    total_samples = sum([len(v) for v in data_dict.values()])
    print(f"成功生成数据集(总样本数:{total_samples})")

    # 3. 定制测试集
    X_test, y_test, meta_test = [], [], []
    X_train, y_train, meta_train = [], [], []
    
    for label in TESTSET_SAMPLES:
        available = len(data_dict[label])
        required = TESTSET_SAMPLES[label]
        if available < required:
            raise ValueError(f"标签{label}可用样本数{available}不足测试集需求{required}")
        
        indices = np.random.choice(available, size=required, replace=False)
        X_test.extend(data_dict[label][indices])
        y_test.extend([label]*required)
        meta_test.extend([metadata_dict[label][i] for i in indices])
        
        train_mask = ~np.isin(np.arange(available), indices)
        X_train.extend(data_dict[label][train_mask])
        y_train.extend([label]*(available - required))
        meta_train.extend([metadata_dict[label][i] for i in np.where(train_mask)[0]])
    
    X_train = np.array(X_train)
    X_test = np.array(X_test)
    y_train = np.array(y_train)
    y_test = np.array(y_test)

    np.save(os.path.join(EXP_DIR, "测试集梅尔频谱.npy"), X_test)
    np.save(os.path.join(EXP_DIR, "测试集标签.npy"), y_test)
    pd.DataFrame(meta_test).to_csv(os.path.join(EXP_DIR, "测试集元数据.csv"), index=False)
    print(f"测试集已保存(正常:{TESTSET_SAMPLES[0]},尖音:{TESTSET_SAMPLES[1]},波浪音:{TESTSET_SAMPLES[2]},撞击音:{TESTSET_SAMPLES[3]})")

    # 4. 构建并训练模型
    model = models.Sequential([
        layers.Input(shape=(128, 130)),  # 调整为3秒的特征维度
        layers.Reshape((128, 130, 1)),
        layers.Conv2D(32, (3,3), activation='relu'),
        layers.MaxPooling2D((2,2)),
        layers.Conv2D(64, (3,3), activation='relu'),
        layers.MaxPooling2D((2,2)),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(4, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    history = model.fit(X_train, y_train, validation_split=0.2, epochs=20, batch_size=32)

    # 5. 训练过程可视化
    plot_training_history(history, TRAIN_VIS_DIR)
    print(f"训练过程可视化完成,保存至:{TRAIN_VIS_DIR}")

    # 6. 评估并保存结果
    test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
    print(f"\n测试准确率:{test_acc:.4f}")
    
    # 获取预测概率和预测标签
    y_prob = model.predict(X_test)
    y_pred = np.argmax(y_prob, axis=1)
    
    # 绘制分类指标可视化
    plot_classification_metrics(y_test, y_pred, y_prob, METRICS_VIS_DIR)
    
    # 绘制混淆矩阵
    plt.figure(figsize=(8,6))
    cm = confusion_matrix(y_test, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=["正常", "尖音", "波浪音", "撞击音"],
                yticklabels=["正常", "尖音", "波浪音", "撞击音"])
    plt.xlabel("预测标签")
    plt.ylabel("真实标签")
    plt.title("混淆矩阵")
    plt.savefig(os.path.join(EXP_DIR, "混淆矩阵.png"))
    print(f"实验结果已保存至:{EXP_DIR}")

 

posted @ 2025-06-12 17:50  土星狗蛋  阅读(37)  评论(0)    收藏  举报