基于PCA白化和K均值聚类的轴承故障诊断系统

基于PCA白化和K均值聚类的轴承故障诊断系统。方法结合了降维、去相关和聚类分析,能够有效识别不同的故障模式。

1. 轴承故障数据生成与特征提取

classdef BearingFaultData
    % 轴承故障数据生成与特征提取
    
    properties
        sampling_freq     % 采样频率
        duration         % 信号持续时间
        fault_types      % 故障类型
        bearing_params   % 轴承参数
    end
    
    methods
        function obj = BearingFaultData(fs, duration)
            % 构造函数
            obj.sampling_freq = fs;
            obj.duration = duration;
            
            % 定义故障类型
            obj.fault_types = {'Normal', 'InnerRace', 'OuterRace', 'BallFault'};
            
            % 轴承参数 (6205轴承)
            obj.bearing_params.d = 7.94;   % 滚珠直径 (mm)
            obj.bearing_params.D = 39.04;  % 节圆直径 (mm)
            obj.bearing_params.n = 9;      % 滚珠数量
            obj.bearing_params.contact_angle = 0; % 接触角
        end
        
        function [signals, labels] = generate_fault_data(obj, num_samples_per_class)
            % 生成轴承故障数据
            fprintf('生成轴承故障数据...\n');
            
            total_samples = length(obj.fault_types) * num_samples_per_class;
            signal_length = obj.sampling_freq * obj.duration;
            
            signals = zeros(total_samples, signal_length);
            labels = cell(total_samples, 1);
            
            sample_count = 1;
            
            for i = 1:length(obj.fault_types)
                fault_type = obj.fault_types{i};
                fprintf('生成 %s 故障数据...\n', fault_type);
                
                for j = 1:num_samples_per_class
                    % 生成故障信号
                    signal = obj.generate_single_signal(fault_type);
                    signals(sample_count, :) = signal;
                    labels{sample_count} = fault_type;
                    sample_count = sample_count + 1;
                end
            end
            
            fprintf('数据生成完成: %d 个样本, %d 种故障类型\n', ...
                total_samples, length(obj.fault_types));
        end
        
        function signal = generate_single_signal(obj, fault_type)
            % 生成单个故障信号
            t = 0:1/obj.sampling_freq:obj.duration-1/obj.sampling_freq;
            N = length(t);
            
            % 基础振动信号 (机器正常运行)
            base_vibration = 0.5 * sin(2*pi*30*t) + 0.3 * sin(2*pi*60*t);
            
            switch fault_type
                case 'Normal'
                    % 正常状态 - 只有基础振动和噪声
                    fault_component = zeros(1, N);
                    noise_level = 0.1;
                    
                case 'InnerRace'
                    % 内圈故障 - 冲击特征
                    fault_freq = obj.calculate_inner_race_frequency(1800); % 1800 RPM
                    fault_component = obj.generate_impact_signal(t, fault_freq, 0.8);
                    noise_level = 0.15;
                    
                case 'OuterRace'
                    % 外圈故障 - 冲击特征
                    fault_freq = obj.calculate_outer_race_frequency(1800);
                    fault_component = obj.generate_impact_signal(t, fault_freq, 0.7);
                    noise_level = 0.12;
                    
                case 'BallFault'
                    % 滚珠故障 - 冲击特征
                    fault_freq = obj.calculate_ball_fault_frequency(1800);
                    fault_component = obj.generate_impact_signal(t, fault_freq, 0.6);
                    noise_level = 0.13;
                    
                otherwise
                    fault_component = zeros(1, N);
                    noise_level = 0.1;
            end
            
            % 组合信号成分
            signal = base_vibration + fault_component + ...
                     noise_level * randn(1, N);
        end
        
        function impact_signal = generate_impact_signal(obj, t, fault_freq, amplitude)
            % 生成冲击信号 (故障特征)
            N = length(t);
            impact_signal = zeros(1, N);
            
            % 冲击间隔
            impact_interval = round(obj.sampling_freq / fault_freq);
            num_impacts = floor(N / impact_interval);
            
            for i = 1:num_impacts
                impact_pos = (i-1) * impact_interval + 1;
                if impact_pos + 100 <= N
                    % 生成衰减振荡冲击
                    tau = 0.002; % 衰减时间常数
                    osc_freq = 3000; % 共振频率
                    time_window = 0:1/obj.sampling_freq:0.01;
                    impact = amplitude * exp(-time_window/tau) .* ...
                             sin(2*pi*osc_freq*time_window);
                    
                    end_pos = min(impact_pos + length(impact) - 1, N);
                    actual_length = end_pos - impact_pos + 1;
                    impact_signal(impact_pos:end_pos) = ...
                        impact_signal(impact_pos:end_pos) + impact(1:actual_length);
                end
            end
        end
        
        function fi = calculate_inner_race_frequency(obj, rpm)
            % 计算内圈故障频率
            fr = rpm / 60; % 转频
            fi = 0.5 * obj.bearing_params.n * fr * ...
                 (1 + obj.bearing_params.d/obj.bearing_params.D * cosd(obj.bearing_params.contact_angle));
        end
        
        function fo = calculate_outer_race_frequency(obj, rpm)
            % 计算外圈故障频率
            fr = rpm / 60;
            fo = 0.5 * obj.bearing_params.n * fr * ...
                 (1 - obj.bearing_params.d/obj.bearing_params.D * cosd(obj.bearing_params.contact_angle));
        end
        
        function fb = calculate_ball_fault_frequency(obj, rpm)
            % 计算滚珠故障频率
            fr = rpm / 60;
            fb = 0.5 * fr * obj.bearing_params.D/obj.bearing_params.d * ...
                 (1 - (obj.bearing_params.d/obj.bearing_params.D * cosd(obj.bearing_params.contact_angle))^2);
        end
    end
    
    methods (Static)
        function features = extract_features(signals, fs)
            % 提取时域和频域特征
            fprintf('提取信号特征...\n');
            
            [num_samples, signal_length] = size(signals);
            
            % 特征列表
            feature_names = {
                'Mean', 'Std', 'RMS', 'Skewness', 'Kurtosis', ...
                'Peak2Peak', 'CrestFactor', 'ImpulseFactor', ...
                'MarginFactor', 'ShapeFactor', ...
                'FrequencyRMS', 'FrequencyMean', 'FrequencyStd', ...
                'SpectralCentroid', 'SpectralSpread', 'SpectralSkewness', ...
                'SpectralKurtosis', 'SpectralRolloff'
            };
            
            num_features = length(feature_names);
            features = zeros(num_samples, num_features);
            
            for i = 1:num_samples
                signal = signals(i, :);
                
                % 时域特征
                features(i, 1) = mean(signal);                    % 均值
                features(i, 2) = std(signal);                     % 标准差
                features(i, 3) = rms(signal);                     % 均方根
                features(i, 4) = skewness(signal);                % 偏度
                features(i, 5) = kurtosis(signal);                % 峭度
                features(i, 6) = peak2peak(signal);               % 峰峰值
                
                % 无量纲指标
                rms_val = rms(signal);
                peak_val = max(abs(signal));
                features(i, 7) = peak_val / rms_val;              % 峰值因子
                features(i, 8) = peak_val / mean(abs(signal));    % 脉冲因子
                features(i, 9) = peak_val / (mean(sqrt(abs(signal)))^2); % 裕度因子
                features(i, 10) = rms_val / mean(abs(signal));    % 波形因子
                
                % 频域特征
                [freq_features, ~] = BearingFaultData.extract_frequency_features(signal, fs);
                features(i, 11:18) = freq_features;
            end
            
            fprintf('特征提取完成: %d 个样本, %d 个特征\n', num_samples, num_features);
        end
        
        function [freq_features, f] = extract_frequency_features(signal, fs)
            % 提取频域特征
            N = length(signal);
            f = (0:N-1) * fs / N;
            
            % FFT
            Y = fft(signal);
            P2 = abs(Y/N);
            P1 = P2(1:floor(N/2)+1);
            P1(2:end-1) = 2*P1(2:end-1);
            f = f(1:floor(N/2)+1);
            
            % 频域统计特征
            freq_features = zeros(1, 8);
            freq_features(1) = rms(P1);                           % 频域RMS
            freq_features(2) = mean(P1);                          % 频域均值
            freq_features(3) = std(P1);                           % 频域标准差
            
            % 频谱质心
            freq_features(4) = sum(f .* P1) / sum(P1);
            
            % 频谱扩展
            centroid = freq_features(4);
            freq_features(5) = sqrt(sum(((f - centroid).^2) .* P1) / sum(P1));
            
            % 频谱偏度
            freq_features(6) = sum(((f - centroid).^3) .* P1) / (sum(P1) * freq_features(5)^3);
            
            % 频谱峭度
            freq_features(7) = sum(((f - centroid).^4) .* P1) / (sum(P1) * freq_features(5)^4);
            
            % 频谱滚降 (85%)
            total_power = sum(P1);
            cumulative_power = cumsum(P1);
            rolloff_idx = find(cumulative_power >= 0.85 * total_power, 1);
            freq_features(8) = f(rolloff_idx);
        end
        
        function plot_sample_signals(signals, labels, fs, duration)
            % 绘制样本信号
            figure('Position', [100, 100, 1400, 1000]);
            
            unique_labels = unique(labels);
            num_types = length(unique_labels);
            
            for i = 1:num_types
                % 找到该类别的第一个样本
                idx = find(strcmp(labels, unique_labels{i}), 1);
                signal = signals(idx, :);
                t = 0:1/fs:duration-1/fs;
                
                % 时域信号
                subplot(3, num_types, i);
                plot(t, signal, 'b-', 'LineWidth', 1);
                title(sprintf('%s - 时域信号', unique_labels{i}));
                xlabel('时间 (s)');
                ylabel('幅度');
                grid on;
                
                % 频域信号
                subplot(3, num_types, i + num_types);
                [freq_features, f] = BearingFaultData.extract_frequency_features(signal, fs);
                plot(f, freq_features, 'r-', 'LineWidth', 1);
                title(sprintf('%s - 频域特征', unique_labels{i}));
                xlabel('频率 (Hz)');
                ylabel('幅度');
                grid on;
                xlim([0, 1000]);
                
                % 包络谱
                subplot(3, num_types, i + 2*num_types);
                [env_spectrum, env_f] = BearingFaultData.compute_envelope_spectrum(signal, fs);
                plot(env_f, env_spectrum, 'g-', 'LineWidth', 1);
                title(sprintf('%s - 包络谱', unique_labels{i}));
                xlabel('频率 (Hz)');
                ylabel('幅度');
                grid on;
                xlim([0, 500]);
            end
            
            sgtitle('轴承故障信号样本分析');
        end
        
        function [env_spectrum, f] = compute_envelope_spectrum(signal, fs)
            % 计算包络谱
            % 希尔伯特变换求包络
            analytic_signal = hilbert(signal);
            envelope_signal = abs(analytic_signal);
            
            % 包络信号的频谱
            N = length(envelope_signal);
            f = (0:N-1) * fs / N;
            env_spectrum = abs(fft(envelope_signal)) / N;
            env_spectrum = env_spectrum(1:floor(N/2)+1);
            env_spectrum(2:end-1) = 2 * env_spectrum(2:end-1);
            f = f(1:floor(N/2)+1);
        end
    end
end

2. PCA白化处理

classdef PCAWhitening
    % PCA白化处理类
    
    methods (Static)
        function [features_whitened, pca_model] = apply_pca_whitening(features, varargin)
            % 应用PCA白化
            % 输入: features - 原始特征矩阵 (样本数 × 特征数)
            % 输出: features_whitened - 白化后的特征
            %       pca_model - PCA模型参数
            
            fprintf('应用PCA白化...\n');
            
            p = inputParser;
            addParameter(p, 'variance_retained', 0.95, @(x) x>0 && x<=1);
            addParameter(p, 'epsilon', 1e-5, @(x) x>0);
            parse(p, varargin{:});
            
            variance_retained = p.Results.variance_retained;
            epsilon = p.Results.epsilon;
            
            [num_samples, num_features] = size(features);
            
            % 1. 数据标准化 (零均值,单位方差)
            features_standardized = zscore(features);
            
            % 2. 计算协方差矩阵
            covariance_matrix = cov(features_standardized);
            
            % 3. 特征值分解
            [eigenvectors, eigenvalues] = eig(covariance_matrix);
            eigenvalues = diag(eigenvalues);
            
            % 按特征值降序排列
            [eigenvalues, idx] = sort(eigenvalues, 'descend');
            eigenvectors = eigenvectors(:, idx);
            
            % 4. 确定保留的主成分数量
            explained_variance = cumsum(eigenvalues) / sum(eigenvalues);
            num_components = find(explained_variance >= variance_retained, 1);
            
            fprintf('原始特征数: %d, 保留主成分: %d (方差保留: %.1f%%)\n', ...
                num_features, num_components, variance_retained*100);
            
            % 5. 选择主成分
            eigenvectors_reduced = eigenvectors(:, 1:num_components);
            eigenvalues_reduced = eigenvalues(1:num_components);
            
            % 6. 白化变换
            whitening_transform = eigenvectors_reduced * diag(1 ./ sqrt(eigenvalues_reduced + epsilon));
            features_whitened = features_standardized * whitening_transform;
            
            % 保存PCA模型
            pca_model.eigenvectors = eigenvectors;
            pca_model.eigenvalues = eigenvalues;
            pca_model.explained_variance = explained_variance;
            pca_model.num_components = num_components;
            pca_model.whitening_transform = whitening_transform;
            pca_model.feature_mean = mean(features);
            pca_model.feature_std = std(features);
            
            fprintf('PCA白化完成\n');
        end
        
        function visualize_pca_results(features, features_whitened, labels, pca_model)
            % 可视化PCA白化结果
            
            figure('Position', [100, 100, 1600, 1200]);
            
            % 子图1: 原始特征相关性矩阵
            subplot(2, 3, 1);
            correlation_original = corr(features);
            imagesc(correlation_original);
            colorbar;
            title('原始特征相关性矩阵');
            xlabel('特征索引');
            ylabel('特征索引');
            axis equal tight;
            
            % 子图2: 白化特征相关性矩阵
            subplot(2, 3, 2);
            correlation_whitened = corr(features_whitened);
            imagesc(correlation_whitened);
            colorbar;
            title('白化特征相关性矩阵');
            xlabel('特征索引');
            ylabel('特征索引');
            axis equal tight;
            
            % 子图3: 方差解释率
            subplot(2, 3, 3);
            plot(pca_model.explained_variance, 'bo-', 'LineWidth', 2, 'MarkerSize', 6);
            hold on;
            plot(pca_model.num_components, pca_model.explained_variance(pca_model.num_components), ...
                'ro', 'MarkerSize', 10, 'MarkerFaceColor', 'red');
            xlabel('主成分数量');
            ylabel('累积方差解释率');
            title('PCA方差解释率');
            grid on;
            legend('累积方差', '选择点', 'Location', 'southeast');
            
            % 子图4: 原始特征前两个主成分
            subplot(2, 3, 4);
            [~, scores_original] = pca(features);
            visualize_clustering(scores_original(:, 1:2), labels, '原始特征PCA');
            
            % 子图5: 白化特征前两个主成分
            subplot(2, 3, 5);
            visualize_clustering(features_whitened(:, 1:2), labels, '白化特征');
            
            % 子图6: 特征值分布
            subplot(2, 3, 6);
            semilogy(pca_model.eigenvalues, 'bo-', 'LineWidth', 2, 'MarkerSize', 6);
            xlabel('主成分索引');
            ylabel('特征值 (对数尺度)');
            title('PCA特征值分布');
            grid on;
            
            sgtitle('PCA白化分析结果');
            
            function visualize_clustering(data, labels, title_str)
                % 可视化聚类结果
                unique_labels = unique(labels);
                colors = lines(length(unique_labels));
                markers = 'osd^v><ph';
                
                hold on;
                for i = 1:length(unique_labels)
                    idx = strcmp(labels, unique_labels{i});
                    scatter(data(idx, 1), data(idx, 2), 50, colors(i, :), ...
                        markers(mod(i-1, length(markers))+1), 'filled', ...
                        'DisplayName', unique_labels{i});
                end
                title(title_str);
                xlabel('主成分 1');
                ylabel('主成分 2');
                legend('Location', 'best');
                grid on;
            end
        end
        
        function analyze_whitening_effect(features, features_whitened)
            % 分析白化效果
            
            fprintf('\n=== PCA白化效果分析 ===\n');
            
            % 原始特征统计
            original_cov = cov(features);
            original_corr = corr(features);
            
            % 白化特征统计
            whitened_cov = cov(features_whitened);
            whitened_corr = corr(features_whitened);
            
            fprintf('原始特征协方差矩阵条件数: %.2e\n', cond(original_cov));
            fprintf('白化特征协方差矩阵条件数: %.2f\n', cond(whitened_cov));
            
            fprintf('原始特征平均相关性: %.4f\n', mean(abs(original_corr(original_corr < 1))));
            fprintf('白化特征平均相关性: %.4f\n', mean(abs(whitened_corr(whitened_corr < 1))));
            
            % 特征值分布比较
            original_eigvals = eig(original_cov);
            whitened_eigvals = eig(whitened_cov);
            
            fprintf('原始特征值范围: [%.2e, %.2e]\n', min(original_eigvals), max(original_eigvals));
            fprintf('白化特征值范围: [%.2f, %.2f]\n', min(whitened_eigvals), max(whitened_eigvals));
            
            % 可视化比较
            figure('Position', [100, 100, 1200, 500]);
            
            subplot(1, 2, 1);
            semilogy(sort(original_eigvals, 'descend'), 'bo-', 'LineWidth', 2, 'DisplayName', '原始特征');
            hold on;
            semilogy(sort(whitened_eigvals, 'descend'), 'rs-', 'LineWidth', 2, 'DisplayName', '白化特征');
            xlabel('特征值索引');
            ylabel('特征值 (对数尺度)');
            title('特征值分布比较');
            legend;
            grid on;
            
            subplot(1, 2, 2);
            boxplot([mean(abs(original_corr(original_corr < 1))), ...
                     mean(abs(whitened_corr(whitened_corr < 1)))], ...
                    {'原始特征', '白化特征'});
            ylabel('平均绝对相关性');
            title('特征相关性比较');
            grid on;
        end
    end
end

3. K均值聚类分析

classdef KMeansClustering
    % K均值聚类分析
    
    methods (Static)
        function [cluster_labels, centroid_positions, clustering_model] = ...
                perform_clustering(features, true_labels, varargin)
            % 执行K均值聚类分析
            
            fprintf('执行K均值聚类分析...\n');
            
            p = inputParser;
            addParameter(p, 'k', 4, @(x) x>0); % 聚类数量
            addParameter(p, 'max_iterations', 100, @(x) x>0);
            addParameter(p, 'replicates', 10, @(x) x>0);
            addParameter(p, 'distance_metric', 'sqeuclidean', @ischar);
            parse(p, varargin{:});
            
            k = p.Results.k;
            max_iterations = p.Results.max_iterations;
            replicates = p.Results.replicates;
            distance_metric = p.Results.distance_metric;
            
            % 执行K均值聚类
            [cluster_labels, centroid_positions, sumd, distances] = ...
                kmeans(features, k, ...
                'MaxIter', max_iterations, ...
                'Replicates', replicates, ...
                'Distance', distance_metric, ...
                'Display', 'final');
            
            % 计算聚类质量指标
            clustering_quality = KMeansClustering.evaluate_clustering_quality(...
                features, cluster_labels, centroid_positions, true_labels);
            
            % 保存聚类模型
            clustering_model.cluster_labels = cluster_labels;
            clustering_model.centroid_positions = centroid_positions;
            clustering_model.sumd = sumd;
            clustering_model.distances = distances;
            clustering_model.quality = clustering_quality;
            clustering_model.parameters = p.Results;
            
            fprintf('K均值聚类完成: %d 个聚类\n', k);
        end
        
        function quality = evaluate_clustering_quality(features, cluster_labels, centroids, true_labels)
            % 评估聚类质量
            
            [num_samples, num_features] = size(features);
            k = size(centroids, 1);
            
            % 内部指标 (无监督)
            % 1. 轮廓系数
            silhouette_vals = silhouette(features, cluster_labels);
            quality.silhouette_mean = mean(silhouette_vals);
            quality.silhouette_std = std(silhouette_vals);
            
            % 2. Davies-Bouldin指数
            quality.db_index = KMeansClustering.calculate_db_index(features, cluster_labels, centroids);
            
            % 3. Calinski-Harabasz指数
            quality.ch_index = KMeansClustering.calculate_ch_index(features, cluster_labels, centroids);
            
            % 外部指标 (有监督,如果有真实标签)
            if ~isempty(true_labels)
                % 4. 调整兰德指数
                quality.adj_rand_index = KMeansClustering.calculate_adjusted_rand_index(...
                    true_labels, cluster_labels);
                
                % 5. 互信息
                quality.normalized_mutual_info = KMeansClustering.calculate_nmi(...
                    true_labels, cluster_labels);
                
                % 6. 聚类准确率
                quality.clustering_accuracy = KMeansClustering.calculate_clustering_accuracy(...
                    true_labels, cluster_labels);
            else
                quality.adj_rand_index = NaN;
                quality.normalized_mutual_info = NaN;
                quality.clustering_accuracy = NaN;
            end
            
            % 类内距离和类间距离
            [quality.within_cluster_distance, quality.between_cluster_distance] = ...
                KMeansClustering.calculate_distance_metrics(features, cluster_labels, centroids);
            
            fprintf('聚类质量指标:\n');
            fprintf('  轮廓系数: %.4f\n', quality.silhouette_mean);
            fprintf('  Davies-Bouldin指数: %.4f\n', quality.db_index);
            fprintf('  Calinski-Harabasz指数: %.4f\n', quality.ch_index);
            if ~isempty(true_labels)
                fprintf('  调整兰德指数: %.4f\n', quality.adj_rand_index);
                fprintf('  归一化互信息: %.4f\n', quality.normalized_mutual_info);
                fprintf('  聚类准确率: %.4f\n', quality.clustering_accuracy);
            end
        end
        
        function db_index = calculate_db_index(features, labels, centroids)
            % 计算Davies-Bouldin指数 (越小越好)
            k = size(centroids, 1);
            cluster_dispersion = zeros(1, k);
            cluster_sizes = zeros(1, k);
            
            % 计算每个聚类的离散度
            for i = 1:k
                cluster_points = features(labels == i, :);
                cluster_sizes(i) = size(cluster_points, 1);
                if cluster_sizes(i) > 0
                    cluster_dispersion(i) = mean(sqrt(sum((cluster_points - centroids(i, :)).^2, 2)));
                end
            end
            
            % 计算DB指数
            R = zeros(k);
            for i = 1:k
                for j = i+1:k
                    if cluster_sizes(i) > 0 && cluster_sizes(j) > 0
                        d_ij = norm(centroids(i, :) - centroids(j, :));
                        R(i, j) = (cluster_dispersion(i) + cluster_dispersion(j)) / d_ij;
                    end
                end
            end
            
            max_R = max(R, [], 2);
            db_index = mean(max_R(max_R > 0));
        end
        
        function ch_index = calculate_ch_index(features, labels, centroids)
            % 计算Calinski-Harabasz指数 (越大越好)
            [num_samples, num_features] = size(features);
            k = size(centroids, 1);
            
            overall_centroid = mean(features, 1);
            
            % 类间离散度
            between_ss = 0;
            cluster_sizes = zeros(1, k);
            for i = 1:k
                cluster_points = features(labels == i, :);
                cluster_sizes(i) = size(cluster_points, 1);
                between_ss = between_ss + cluster_sizes(i) * norm(centroids(i, :) - overall_centroid)^2;
            end
            
            % 类内离散度
            within_ss = 0;
            for i = 1:k
                cluster_points = features(labels == i, :);
                within_ss = within_ss + sum(sum((cluster_points - centroids(i, :)).^2, 2));
            end
            
            ch_index = (between_ss / (k-1)) / (within_ss / (num_samples - k));
        end
        
        function adj_rand_index = calculate_adjusted_rand_index(true_labels, pred_labels)
            % 计算调整兰德指数
            % 将标签转换为数值
            [~, ~, true_numeric] = unique(true_labels);
            [~, ~, pred_numeric] = unique(pred_labels);
            
            % 创建混淆矩阵
            contingency_matrix = crosstab(true_numeric, pred_numeric);
            
            n = sum(contingency_matrix(:));
            nij = contingency_matrix;
            ai = sum(contingency_matrix, 2);
            bj = sum(contingency_matrix, 1);
            
            % 计算兰德指数
            n_choose_2 = n*(n-1)/2;
            sum_nij_choose_2 = sum(nij(:).*(nij(:)-1)/2);
            sum_ai_choose_2 = sum(ai.*(ai-1)/2);
            sum_bj_choose_2 = sum(bj.*(bj-1)/2);
            
            expected_index = sum_ai_choose_2 * sum_bj_choose_2 / n_choose_2;
            max_index = 0.5 * (sum_ai_choose_2 + sum_bj_choose_2);
            
            adj_rand_index = (sum_nij_choose_2 - expected_index) / (max_index - expected_index);
        end
        
        function nmi = calculate_nmi(true_labels, pred_labels)
            % 计算归一化互信息
            % 将标签转换为数值
            [~, ~, true_numeric] = unique(true_labels);
            [~, ~, pred_numeric] = unique(pred_labels);
            
            % 计算互信息
            contingency_matrix = crosstab(true_numeric, pred_numeric) + eps;
            joint_prob = contingency_matrix / sum(contingency_matrix(:));
            
            marginal_true = sum(joint_prob, 2);
            marginal_pred = sum(joint_prob, 1);
            
            mutual_info = 0;
            for i = 1:size(joint_prob, 1)
                for j = 1:size(joint_prob, 2)
                    mutual_info = mutual_info + joint_prob(i,j) * ...
                        log(joint_prob(i,j) / (marginal_true(i) * marginal_pred(j)));
                end
            end
            
            % 计算熵
            entropy_true = -sum(marginal_true .* log(marginal_true));
            entropy_pred = -sum(marginal_pred .* log(marginal_pred));
            
            nmi = 2 * mutual_info / (entropy_true + entropy_pred);
        end
        
        function accuracy = calculate_clustering_accuracy(true_labels, pred_labels)
            % 计算聚类准确率
            % 找到最佳的标签映射
            unique_true = unique(true_labels);
            unique_pred = unique(pred_labels);
            
            % 创建混淆矩阵
            contingency_matrix = zeros(length(unique_true), length(unique_pred));
            for i = 1:length(unique_true)
                for j = 1:length(unique_pred)
                    contingency_matrix(i, j) = sum(strcmp(true_labels, unique_true{i}) & ...
                                                  strcmp(pred_labels, unique_pred{j}));
                end
            end
            
            % 使用匈牙利算法找到最佳匹配
            [assignment, ~] = munkres(-contingency_matrix);
            
            correct_count = 0;
            total_count = length(true_labels);
            
            for i = 1:length(assignment)
                if assignment(i) > 0
                    true_idx = strcmp(true_labels, unique_true{i});
                    pred_idx = strcmp(pred_labels, unique_pred{assignment(i)});
                    correct_count = correct_count + sum(true_idx & pred_idx);
                end
            end
            
            accuracy = correct_count / total_count;
        end
        
        function [within_dist, between_dist] = calculate_distance_metrics(features, labels, centroids)
            % 计算类内和类间距离
            k = size(centroids, 1);
            
            % 类内距离
            within_dist = 0;
            cluster_counts = zeros(1, k);
            for i = 1:k
                cluster_points = features(labels == i, :);
                if size(cluster_points, 1) > 0
                    distances = sqrt(sum((cluster_points - centroids(i, :)).^2, 2));
                    within_dist = within_dist + sum(distances);
                    cluster_counts(i) = size(cluster_points, 1);
                end
            end
            within_dist = within_dist / sum(cluster_counts);
            
            % 类间距离
            between_dist = 0;
            count = 0;
            for i = 1:k
                for j = i+1:k
                    if cluster_counts(i) > 0 && cluster_counts(j) > 0
                        between_dist = between_dist + norm(centroids(i, :) - centroids(j, :));
                        count = count + 1;
                    end
                end
            end
            between_dist = between_dist / max(count, 1);
        end
        
        function optimal_k = find_optimal_k(features, max_k, true_labels)
            % 寻找最优聚类数量
            fprintf('寻找最优聚类数量 (k = 2 到 %d)...\n', max_k);
            
            k_values = 2:max_k;
            num_metrics = 5;
            metrics = zeros(length(k_values), num_metrics);
            
            figure('Position', [100, 100, 1200, 800]);
            
            for i = 1:length(k_values)
                k = k_values(i);
                fprintf('测试 k = %d...\n', k);
                
                % 执行聚类
                [cluster_labels, centroids, ~] = kmeans(features, k, ...
                    'Replicates', 10, 'Display', 'off');
                
                % 计算质量指标
                quality = KMeansClustering.evaluate_clustering_quality(...
                    features, cluster_labels, centroids, true_labels);
                
                metrics(i, 1) = quality.silhouette_mean;  % 越大越好
                metrics(i, 2) = quality.db_index;         % 越小越好
                metrics(i, 3) = quality.ch_index;         % 越大越好
                if ~isempty(true_labels)
                    metrics(i, 4) = quality.adj_rand_index; % 越大越好
                    metrics(i, 5) = quality.clustering_accuracy; % 越大越好
                else
                    metrics(i, 4:5) = NaN;
                end
            end
            
            % 归一化指标 (用于综合评分)
            normalized_metrics = zeros(size(metrics));
            for j = 1:num_metrics
                if j == 2 % DB指数越小越好
                    if all(metrics(:, j) > 0)
                        normalized_metrics(:, j) = 1 - (metrics(:, j) - min(metrics(:, j))) / ...
                            (max(metrics(:, j)) - min(metrics(:, j)));
                    end
                else % 其他指标越大越好
                    if all(~isnan(metrics(:, j)))
                        normalized_metrics(:, j) = (metrics(:, j) - min(metrics(:, j))) / ...
                            (max(metrics(:, j)) - min(metrics(:, j)));
                    end
                end
            end
            
            % 计算综合评分 (忽略NaN值)
            composite_scores = mean(normalized_metrics, 2, 'omitnan');
            
            % 找到最优k
            [~, optimal_idx] = max(composite_scores);
            optimal_k = k_values(optimal_idx);
            
            fprintf('推荐最优聚类数量: k = %d\n', optimal_k);
            
            % 绘制指标曲线
            subplot(2, 3, 1);
            plot(k_values, metrics(:, 1), 'bo-', 'LineWidth', 2, 'MarkerSize', 8);
            hold on;
            plot(optimal_k, metrics(optimal_idx, 1), 'ro', 'MarkerSize', 10, 'MarkerFaceColor', 'red');
            xlabel('聚类数量 k');
            ylabel('轮廓系数');
            title('轮廓系数 vs k');
            grid on;
            legend('轮廓系数', '最优k', 'Location', 'best');
            
            subplot(2, 3, 2);
            plot(k_values, metrics(:, 2), 'rs-', 'LineWidth', 2, 'MarkerSize', 8);
            hold on;
            plot(optimal_k, metrics(optimal_idx, 2), 'ro', 'MarkerSize', 10, 'MarkerFaceColor', 'red');
            xlabel('聚类数量 k');
            ylabel('DB指数');
            title('DB指数 vs k');
            grid on;
            legend('DB指数', '最优k', 'Location', 'best');
            
            subplot(2, 3, 3);
            plot(k_values, metrics(:, 3), 'g^-', 'LineWidth', 2, 'MarkerSize', 8);
            hold on;
            plot(optimal_k, metrics(optimal_idx, 3), 'ro', 'MarkerSize', 10, 'MarkerFaceColor', 'red');
            xlabel('聚类数量 k');
            ylabel('CH指数');
            title('CH指数 vs k');
            grid on;
            legend('CH指数', '最优k', 'Location', 'best');
            
            if ~isempty(true_labels)
                subplot(2, 3, 4);
                plot(k_values, metrics(:, 4), 'md-', 'LineWidth', 2, 'MarkerSize', 8);
                hold on;
                plot(optimal_k, metrics(optimal_idx, 4), 'ro', 'MarkerSize', 10, 'MarkerFaceColor', 'red');
                xlabel('聚类数量 k');
                ylabel('调整兰德指数');
                title('调整兰德指数 vs k');
                grid on;
                legend('ARI', '最优k', 'Location', 'best');
                
                subplot(2, 3, 5);
                plot(k_values, metrics(:, 5), 'ch-', 'LineWidth', 2, 'MarkerSize', 8);
                hold on;
                plot(optimal_k, metrics(optimal_idx, 5), 'ro', 'MarkerSize', 10, 'MarkerFaceColor', 'red');
                xlabel('聚类数量 k');
                ylabel('聚类准确率');
                title('聚类准确率 vs k');
                grid on;
                legend('准确率', '最优k', 'Location', 'best');
            end
            
            subplot(2, 3, 6);
            plot(k_values, composite_scores, 'ko-', 'LineWidth', 2, 'MarkerSize', 8);
            hold on;
            plot(optimal_k, composite_scores(optimal_idx), 'ro', 'MarkerSize', 10, 'MarkerFaceColor', 'red');
            xlabel('聚类数量 k');
            ylabel('综合评分');
            title('综合评分 vs k');
            grid on;
            legend('综合评分', '最优k', 'Location', 'best');
            
            sgtitle('最优聚类数量选择分析');
        end
    end
end

4. 完整的主程序与结果分析

function main_bearing_fault_clustering()
    % 轴承故障聚类分析主程序
    
    fprintf('=== 基于PCA白化和K均值的轴承故障聚类分析 ===\n\n');
    
    % 1. 生成轴承故障数据
    fprintf('步骤1: 生成轴承故障数据...\n');
    fs = 10000; % 采样频率 10kHz
    duration = 1; % 信号持续时间 1秒
    num_samples_per_class = 50; % 每类样本数
    
    bearing_data = BearingFaultData(fs, duration);
    [signals, true_labels] = bearing_data.generate_fault_data(num_samples_per_class);
    
    % 可视化样本信号
    BearingFaultData.plot_sample_signals(signals, true_labels, fs, duration);
    
    % 2. 特征提取
    fprintf('\n步骤2: 特征提取...\n');
    features = BearingFaultData.extract_features(signals, fs);
    
    % 3. PCA白化处理
    fprintf('\n步骤3: PCA白化处理...\n');
    [features_whitened, pca_model] = PCAWhitening.apply_pca_whitening(...
        features, 'variance_retained', 0.95);
    
    % 可视化PCA白化结果
    PCAWhitening.visualize_pca_results(features, features_whitened, true_labels, pca_model);
    
    % 分析白化效果
    PCAWhitening.analyze_whitening_effect(features, features_whitened);
    
    % 4. 寻找最优聚类数量
    fprintf('\n步骤4: 寻找最优聚类数量...\n');
    optimal_k = KMeansClustering.find_optimal_k(features_whitened, 8, true_labels);
    
    % 5. K均值聚类分析
    fprintf('\n步骤5: K均值聚类分析...\n');
    [cluster_labels, centroids, clustering_model] = ...
        KMeansClustering.perform_clustering(features_whitened, true_labels, 'k', optimal_k);
    
    % 6. 结果可视化与分析
    fprintf('\n步骤6: 结果可视化与分析...\n');
    visualize_clustering_results(features_whitened, true_labels, cluster_labels, ...
        centroids, clustering_model, pca_model);
    
    % 7. 性能比较:原始特征 vs 白化特征
    fprintf('\n步骤7: 性能比较分析...\n');
    compare_clustering_performance(features, features_whitened, true_labels, optimal_k);
    
    % 8. 故障诊断报告
    fprintf('\n步骤8: 生成故障诊断报告...\n');
    generate_diagnosis_report(true_labels, cluster_labels, clustering_model);
    
    fprintf('\n=== 轴承故障聚类分析完成 ===\n');
end

function visualize_clustering_results(features, true_labels, cluster_labels, ...
                                      centroids, clustering_model, pca_model)
    % 可视化聚类结果
    
    figure('Position', [100, 100, 1600, 1200]);
    
    % 获取唯一标签
    unique_true_labels = unique(true_labels);
    unique_cluster_labels = unique(cluster_labels);
    
    % 子图1: 真实标签分布 (前两个主成分)
    subplot(2, 3, 1);
    gscatter(features(:, 1), features(:, 2), true_labels, [], 'os^d', 10);
    title('真实标签分布 (白化特征)');
    xlabel('主成分 1');
    ylabel('主成分 2');
    grid on;
    
    % 子图2: 聚类结果分布
    subplot(2, 3, 2);
    gscatter(features(:, 1), features(:, 2), cluster_labels, [], 'os^d', 10);
    hold on;
    plot(centroids(:, 1), centroids(:, 2), 'kx', 'MarkerSize', 15, ...
        'LineWidth', 3, 'DisplayName', '聚类中心');
    title('K均值聚类结果');
    xlabel('主成分 1');
    ylabel('主成分 2');
    grid on;
    legend('Location', 'best');
    
    % 子图3: 混淆矩阵
    subplot(2, 3, 3);
    plot_confusion_matrix(true_labels, cluster_labels, unique_true_labels);
    title('聚类混淆矩阵');
    
    % 子图4: 轮廓系数图
    subplot(2, 3, 4);
    silhouette(features, cluster_labels);
    title('轮廓系数分析');
    xlabel('轮廓系数值');
    ylabel('聚类标签');
    
    % 子图5: 聚类质量指标雷达图
    subplot(2, 3, 5);
    plot_clustering_quality_radar(clustering_model.quality);
    title('聚类质量指标雷达图');
    
    % 子图6: 特征重要性分析
    subplot(2, 3, 6);
    plot_feature_importance(pca_model);
    title('PCA特征重要性');
    
    sgtitle('轴承故障聚类分析结果');
end

function plot_confusion_matrix(true_labels, pred_labels, class_names)
    % 绘制混淆矩阵
    
    % 创建数值标签
    [~, ~, true_numeric] = unique(true_labels);
    [~, ~, pred_numeric] = unique(pred_labels);
    
    % 计算混淆矩阵
    cm = confusionmat(true_numeric, pred_numeric);
    
    % 归一化
    cm_normalized = cm ./ sum(cm, 2);
    
    imagesc(cm_normalized);
    colorbar;
    
    % 添加标签
    set(gca, 'XTick', 1:length(class_names), 'XTickLabel', class_names);
    set(gca, 'YTick', 1:length(class_names), 'YTickLabel', class_names);
    xlabel('预测标签');
    ylabel('真实标签');
    
    % 添加数值
    [num_classes, ~] = size(cm);
    for i = 1:num_classes
        for j = 1:num_classes
            text(j, i, sprintf('%.2f\n(%d)', cm_normalized(i,j), cm(i,j)), ...
                'HorizontalAlignment', 'center', 'VerticalAlignment', 'middle', ...
                'Color', 'white', 'FontWeight', 'bold');
        end
    end
end

function plot_clustering_quality_radar(quality)
    % 绘制聚类质量指标雷达图
    
    metrics = {'轮廓系数', 'DB指数', 'CH指数', '调整兰德', '聚类准确率'};
    values = [quality.silhouette_mean, 1/quality.db_index, ... % DB指数取倒数
              quality.ch_index/1000, quality.adj_rand_index, quality.clustering_accuracy];
    
    % 归一化到[0,1]范围
    normalized_values = values / max(values);
    
    % 雷达图
    angles = linspace(0, 2*pi, length(metrics)+1);
    angles = angles(1:end-1);
    
    polarplot([angles, angles(1)], [normalized_values, normalized_values(1)], 'ro-', 'LineWidth', 2);
    thetaticks(rad2deg(angles));
    thetaticklabels(metrics);
    rlim([0, 1]);
    title('聚类质量指标');
end

function plot_feature_importance(pca_model)
    % 绘制特征重要性 (基于PCA)
    
    explained_variance = pca_model.explained_variance;
    cumulative_variance = cumsum(explained_variance);
    
    bar(explained_variance(1:10), 'b', 'FaceAlpha', 0.7);
    hold on;
    plot(cumulative_variance(1:10), 'ro-', 'LineWidth', 2, 'MarkerSize', 6);
    
    xlabel('主成分序号');
    ylabel('方差解释率');
    title('前10个主成分的方差解释率');
    legend('单个主成分', '累积解释率', 'Location', 'southeast');
    grid on;
end

function compare_clustering_performance(features_original, features_whitened, true_labels, k)
    % 比较原始特征和白化特征的聚类性能
    
    fprintf('比较原始特征与白化特征的聚类性能...\n');
    
    % 原始特征聚类
    [labels_original, centroids_original, model_original] = ...
        KMeansClustering.perform_clustering(features_original, true_labels, 'k', k);
    
    % 白化特征聚类
    [labels_whitened, centroids_whitened, model_whitened] = ...
        KMeansClustering.perform_clustering(features_whitened, true_labels, 'k', k);
    
    % 性能比较可视化
    figure('Position', [100, 100, 1200, 800]);
    
    metrics = {'轮廓系数', 'DB指数', 'CH指数', '调整兰德指数', '聚类准确率'};
    values_original = [
        model_original.quality.silhouette_mean,
        1/model_original.quality.db_index, % DB指数取倒数便于比较
        model_original.quality.ch_index/1000,
        model_original.quality.adj_rand_index,
        model_original.quality.clustering_accuracy
    ];
    
    values_whitened = [
        model_whitened.quality.silhouette_mean,
        1/model_whitened.quality.db_index,
        model_whitened.quality.ch_index/1000,
        model_whitened.quality.adj_rand_index,
        model_whitened.quality.clustering_accuracy
    ];
    
    % 条形图比较
    subplot(2, 2, 1);
    bar_data = [values_original; values_whitened]';
    bar(bar_data);
    set(gca, 'XTickLabel', metrics);
    ylabel('指标值');
    title('聚类性能指标比较');
    legend('原始特征', 'PCA白化特征', 'Location', 'best');
    grid on;
    rotateXLabels(gca, 45);
    
    % 性能提升百分比
    subplot(2, 2, 2);
    improvement = (values_whitened - values_original) ./ abs(values_original) * 100;
    bar(improvement, 'FaceColor', [0.2, 0.6, 0.2], 'FaceAlpha', 0.7);
    set(gca, 'XTickLabel', metrics);
    ylabel('性能提升 (%)');
    title('PCA白化带来的性能提升');
    grid on;
    rotateXLabels(gca, 45);
    
    % 聚类结果可视化比较
    subplot(2, 2, 3);
    [~, scores_original] = pca(features_original);
    gscatter(scores_original(:, 1), scores_original(:, 2), labels_original);
    title('原始特征聚类结果');
    xlabel('主成分 1');
    ylabel('主成分 2');
    grid on;
    
    subplot(2, 2, 4);
    gscatter(features_whitened(:, 1), features_whitened(:, 2), labels_whitened);
    title('白化特征聚类结果');
    xlabel('主成分 1');
    ylabel('主成分 2');
    grid on;
    
    sgtitle('原始特征 vs PCA白化特征聚类性能比较');
    
    % 输出比较结果
    fprintf('\n=== 性能比较结果 ===\n');
    fprintf('指标\t\t\t原始特征\t白化特征\t提升\n');
    fprintf('轮廓系数\t\t%.4f\t\t%.4f\t\t+%.1f%%\n', ...
        model_original.quality.silhouette_mean, model_whitened.quality.silhouette_mean, ...
        improvement(1));
    fprintf('DB指数\t\t\t%.4f\t\t%.4f\t\t+%.1f%%\n', ...
        model_original.quality.db_index, model_whitened.quality.db_index, ...
        improvement(2));
    fprintf('CH指数\t\t\t%.1f\t\t%.1f\t\t+%.1f%%\n', ...
        model_original.quality.ch_index, model_whitened.quality.ch_index, ...
        improvement(3));
    fprintf('调整兰德指数\t\t%.4f\t\t%.4f\t\t+%.1f%%\n', ...
        model_original.quality.adj_rand_index, model_whitened.quality.adj_rand_index, ...
        improvement(4));
    fprintf('聚类准确率\t\t%.4f\t\t%.4f\t\t+%.1f%%\n', ...
        model_original.quality.clustering_accuracy, model_whitened.quality.clustering_accuracy, ...
        improvement(5));
end

function generate_diagnosis_report(true_labels, cluster_labels, clustering_model)
    % 生成故障诊断报告
    
    fprintf('\n=== 轴承故障诊断报告 ===\n\n');
    
    % 基本统计
    unique_true = unique(true_labels);
    unique_cluster = unique(cluster_labels);
    
    fprintf('诊断统计:\n');
    fprintf('  总样本数: %d\n', length(true_labels));
    fprintf('  真实故障类型: %d 种\n', length(unique_true));
    fprintf('  识别出的聚类: %d 个\n', length(unique_cluster));
    fprintf('  聚类轮廓系数: %.4f\n', clustering_model.quality.silhouette_mean);
    fprintf('  聚类准确率: %.2f%%\n', clustering_model.quality.clustering_accuracy * 100);
    
    % 聚类与真实标签的对应关系
    fprintf('\n聚类-故障类型对应关系:\n');
    contingency = crosstab(true_labels, cluster_labels);
    
    for i = 1:size(contingency, 1)
        [~, dominant_cluster] = max(contingency(i, :));
        fprintf('  %s -> 聚类 %d (%.1f%%)\n', ...
            unique_true{i}, dominant_cluster, ...
            contingency(i, dominant_cluster) / sum(contingency(i, :)) * 100);
    end
    
    % 诊断建议
    fprintf('\n诊断建议:\n');
    if clustering_model.quality.silhouette_mean > 0.7
        fprintf('  ✅ 聚类质量优秀,故障分离明显\n');
    elseif clustering_model.quality.silhouette_mean > 0.5
        fprintf('  ⚠️  聚类质量良好,部分故障有重叠\n');
    else
        fprintf('  ❗ 聚类质量较差,建议检查特征提取或调整参数\n');
    end
    
    if clustering_model.quality.clustering_accuracy > 0.9
        fprintf('  ✅ 故障识别准确率高,诊断可靠\n');
    elseif clustering_model.quality.clustering_accuracy > 0.7
        fprintf('  ⚠️  故障识别准确率中等,建议增加样本\n');
    else
        fprintf('  ❗ 故障识别准确率低,需要优化方法\n');
    end
    
    % 保存报告
    timestamp = datestr(now, 'yyyymmdd_HHMMSS');
    report_filename = sprintf('bearing_diagnosis_report_%s.txt', timestamp);
    
    fid = fopen(report_filename, 'w');
    fprintf(fid, '轴承故障诊断报告\n');
    fprintf(fid, '生成时间: %s\n\n', datestr(now));
    fprintf(fid, '诊断统计:\n');
    fprintf(fid, '  总样本数: %d\n', length(true_labels));
    fprintf(fid, '  聚类轮廓系数: %.4f\n', clustering_model.quality.silhouette_mean);
    fprintf(fid, '  聚类准确率: %.2f%%\n', clustering_model.quality.clustering_accuracy * 100);
    fclose(fid);
    
    fprintf('\n诊断报告已保存: %s\n', report_filename);
end

% 运行主程序
main_bearing_fault_clustering();

参考代码 采用PCA白化和K均值对轴承故障进行聚类分析 www.youwenfan.com/contentcnl/80087.html

总结

特点:

  1. 真实数据模拟:基于轴承故障机理生成仿真振动信号
  2. 全面特征提取:18个时域和频域统计特征
  3. PCA白化处理:降维、去相关、特征增强
  4. 智能聚类分析:自动确定最优聚类数量
  5. 多维度评估:内部和外部聚类质量指标

技术优势:

  • 物理真实性:基于真实的轴承故障频率和冲击特征
  • 数学严谨性:完整的PCA白化理论和K均值算法
  • 工程实用性:可直接应用于实际轴承故障诊断
  • 可视化分析:丰富的图形化结果展示

应用价值:

  • 为旋转机械故障诊断提供智能解决方案
  • 支持无监督的故障模式识别
  • 可用于状态监测和预测性维护
  • 为工业4.0和智能制造提供技术支持
posted @ 2025-11-11 09:19  u95900090  阅读(6)  评论(0)    收藏  举报