function [convergence_rates, fig_handle] = plotConvergence(dt_values, error_data, method_names, varargin)
% PLOTCONVERGENCE - Plot convergence analysis for numerical methods
% Compatible with MATLAB 2020a and earlier versions
%
% Syntax:
%   [rates, fig] = plotConvergence(dt_values, error_data, method_names)
%   [rates, fig] = plotConvergence(dt_values, error_data, method_names, 'Name', Value)
%
% Inputs:
%   dt_values    - Vector of time step sizes
%   error_data   - Matrix where each column contains errors for one method
%                  OR cell array of error vectors
%   method_names - Cell array of method names for legend
%
% Optional Name-Value pairs:
%   'SavePath'       - Path to save figure (default: 'figs/convergence_plot')
%   'FigureSize'     - Figure size in cm [width, height] (default: [13, 11.5])
%   'ShowGrid'       - Show grid (default: true)
%   'LegendLocation' - Legend location (default: 'southeast')
%   'SaveFig'        - Save figure flag (default: false)
%   'ForceScientific' - Force scientific notation (default: true)
%
% Outputs:
%   convergence_rates - Vector of calculated convergence rates
%   fig_handle       - Figure handle

    % Parse input arguments
    p = inputParser;
    addRequired(p, 'dt_values');
    addRequired(p, 'error_data');
    addRequired(p, 'method_names');
    addParameter(p, 'SavePath', 'figs/convergence_plot');
    addParameter(p, 'FigureSize', [13, 11.5]);
    addParameter(p, 'ShowGrid', true);
    addParameter(p, 'GridAlpha', 0.15);
    addParameter(p, 'MinorGridAlpha', 0.05);
    addParameter(p, 'LegendLocation', 'southeast');
    addParameter(p, 'SaveFig', false);
    addParameter(p, 'ForceScientific', true);  % 新增参数
    
    parse(p, dt_values, error_data, method_names, varargin{:});
    
    % Extract parsed values
    save_path = p.Results.SavePath;
    fig_size = p.Results.FigureSize;
    show_grid = p.Results.ShowGrid;
    grid_alpha = p.Results.GridAlpha;
    minor_grid_alpha = p.Results.MinorGridAlpha;
    legend_loc = p.Results.LegendLocation;
    save_fig = p.Results.SaveFig;
    force_scientific = p.Results.ForceScientific;
    
    % Convert dt_values to row vector
    dt_values = dt_values(:).';
    
    % Process error data
    if iscell(error_data)
        num_methods = length(error_data);
        max_length = max(cellfun(@length, error_data));
        error_matrix = NaN(max_length, num_methods);
        for i = 1:num_methods
            len = length(error_data{i});
            error_matrix(1:len, i) = error_data{i}(:);
        end
        error_data = error_matrix;
    else
        if isrow(error_data)
            error_data = error_data.';
        end
        num_methods = size(error_data, 2);
    end
    
    % Convert method_names to cell array of char (2020a compatibility)
    if isstring(method_names)
        method_names = cellstr(method_names);
    end
    
    % Validate inputs
    if length(method_names) ~= num_methods
        error('Number of method names must match number of error datasets');
    end
    
    % Calculate convergence rates using log-linear fitting
    convergence_rates = zeros(1, num_methods);
    
    for i = 1:num_methods
        errors = error_data(:, i);
        valid_idx = ~isnan(errors) & errors > 0;
        if sum(valid_idx) < 2
            warning('Method %d has insufficient valid data points', i);
            convergence_rates(i) = NaN;
            continue;
        end
        
        valid_dt = dt_values(valid_idx);
        valid_errors = errors(valid_idx);
        
        % Log-linear fitting
        X = [ones(length(valid_dt), 1), log(valid_dt(:))];
        coeffs = X \ log(valid_errors(:));
        convergence_rates(i) = coeffs(2);
    end
    
    % Create figure
    fig_handle = figure('Units', 'centimeters', 'Position', [0, 0, fig_size], 'Color', 'white');
    
    % Define color palette and markers
    colors = {[0.8500 0.3250 0.0980], [0.9290 0.6940 0.1250], [0.4660 0.6740 0.1880], ...
              [0 0.4470 0.7410], [0.4940 0.1840 0.5560], [0.3010 0.7450 0.9330], ...
              [0.6350 0.0780 0.1840], [0.8500 0.8500 0.8500]};
    markers = {'o', 's', '^', 'd', 'v', '>', '<', 'p'};
    
    % Plot data points first
    hold on;
    for i = 1:num_methods
        errors = error_data(:, i);
        valid_idx = ~isnan(errors) & errors > 0;
        
        if sum(valid_idx) > 1
            color_idx = mod(i-1, length(colors)) + 1;
            marker_idx = mod(i-1, length(markers)) + 1;
            
            valid_dt = dt_values(valid_idx);
            valid_errors = errors(valid_idx);
            
            % Create display name with rate information
            data_label = sprintf('%s', method_names{i});
            
            loglog(valid_dt, valid_errors, ...
                   [markers{marker_idx} '-'], ...
                   'Color', colors{color_idx}, ...
                   'MarkerFaceColor', 'white', ...
                   'MarkerEdgeColor', colors{color_idx}, ...
                   'LineWidth', 2.5, ...
                   'MarkerSize', 9, ...
                   'DisplayName', data_label);
        end
    end
    
    % Plot fitted lines on top
    for i = 1:num_methods
        errors = error_data(:, i);
        valid_idx = ~isnan(errors) & errors > 0;
        
        if sum(valid_idx) > 1 && ~isnan(convergence_rates(i))
            color_idx = mod(i-1, length(colors)) + 1;
            
            valid_dt = dt_values(valid_idx);
            valid_errors = errors(valid_idx);
            
            % Calculate fitted line
            X = [ones(length(valid_dt), 1), log(valid_dt(:))];
            coeffs = X \ log(valid_errors(:));
            
            % Generate smooth fitted line
            fit_dt_min = min(valid_dt) * 0.8;
            fit_dt_max = max(valid_dt) * 1.2;
            fit_dt = logspace(log10(fit_dt_min), log10(fit_dt_max), 100);
            fit_errors = exp(coeffs(1)) * fit_dt.^coeffs(2);
            
            % Create safe label for fitted line
            fit_label = sprintf('(fitted rate=%.2f)', convergence_rates(i));
            
            darker_color = colors{color_idx} * 0.7;
            
            loglog(fit_dt, fit_errors, '--', ...
                   'Color', darker_color, ...
                   'LineWidth', 3.5, ...
                   'DisplayName', fit_label);
        end
    end
    
    % Ensure log scale is properly set
    set(gca, 'XScale', 'log', 'YScale', 'log');
    
    % Customize plot
    xlabel('Time Step Size (\tau)', 'FontSize', 14, 'FontName', 'Times New Roman');
    ylabel('Error', 'FontSize', 14, 'FontName', 'Times New Roman');
    
    if show_grid
        grid on;
        set(gca, 'GridAlpha', grid_alpha);
        set(gca, 'MinorGridAlpha', minor_grid_alpha);
        set(gca, 'XMinorGrid', 'on', 'YMinorGrid', 'on');
    end
    
    set(gca, 'FontSize', 12, 'FontName', 'Times New Roman');
    set(gca, 'LineWidth', 1.5);
    set(gca, 'TickLength', [0.02 0.02]);
    
    % Set custom limits
    xlim([min(dt_values)*0.7, max(dt_values)*1.4]);
    
    min_err = min(error_data(:), [], 'omitnan');
    max_err = max(error_data(:), [], 'omitnan');
    if ~isnan(min_err) && ~isnan(max_err) && min_err > 0 && max_err > 0
        ylim([min_err*0.3, max_err*3]);
    end
    
    % ===== 改进的刻度设置部分 =====
    if force_scientific
        % 方法1: 强制使用科学记数法
        ax = gca;
        ax.XAxis.Exponent = 0;  % 强制显示指数
        ax.YAxis.Exponent = 0;  % 强制显示指数
        
        % 设置刻度格式
        ax.XAxis.TickLabelFormat = '%.0e';
        ax.YAxis.TickLabelFormat = '%.0e';
        
    else
        % 原来的刻度设置方法(可能不显示科学记数法)
        dt_min_exp = floor(log10(min(dt_values)));
        dt_max_exp = ceil(log10(max(dt_values)));
        
        if (dt_max_exp - dt_min_exp) <= 3
            x_ticks = [];
            for exp_val = (dt_min_exp-1):(dt_max_exp+1)
                base = 10^exp_val;
                candidates = [base, 2*base, 5*base];
                valid = candidates(candidates >= min(dt_values)*0.5 & candidates <= max(dt_values)*2);
                x_ticks = [x_ticks, valid];
            end
        else
            x_ticks = 10.^(dt_min_exp:dt_max_exp);
        end
        xticks(sort(unique(x_ticks)));
        
        if ~isnan(min_err) && ~isnan(max_err) && min_err > 0 && max_err > 0
            err_min_exp = floor(log10(min_err));
            err_max_exp = ceil(log10(max_err));
            
            exp_range = err_max_exp - err_min_exp;
            
            if exp_range <= 4
                y_ticks = [];
                for exp_val = (err_min_exp-1):(err_max_exp+1)
                    base = 10^exp_val;
                    candidates = [base, 2*base, 5*base];
                    valid = candidates(candidates >= min_err*0.2 & candidates <= max_err*5);
                    y_ticks = [y_ticks, valid];
                end
            else
                y_ticks = 10.^(err_min_exp:err_max_exp);
            end
            yticks(sort(unique(y_ticks)));
        end
    end
    
    box on;
    
    % Create legend with tex interpreter
    leg = legend('Location', legend_loc, 'FontSize', 12, 'Interpreter', 'tex');
    set(leg, 'Box', 'off');
    
    % Set tight layout
    set(gca, 'LooseInset', get(gca, 'TightInset'));
    
    hold off;
    
    % Print convergence rates
    fprintf('\n=== Convergence Analysis Results ===\n');
    for i = 1:num_methods
        fprintf('%-15s: %.4f\n', method_names{i}, convergence_rates(i));
    end
    fprintf('====================================\n\n');
    
    % Save figure if requested
    if save_fig
        [save_dir, ~, ~] = fileparts(save_path);
        if ~isempty(save_dir) && ~exist(save_dir, 'dir')
            mkdir(save_dir);
        end
        
        print([save_path], '-dpng', '-r300');
%         print([save_path], '-dpdf', '-r300');
        fprintf('Figure saved to: %s\n', save_path);
    end
end
% 时空同时加密的收敛阶绘制函数

function [convergence_rates, fig_handle] = plotConvergenceST(error_data, method_names, varargin)
% PLOTCONVERGENCECUSTOM - Plot convergence analysis for numerical methods
% Compatible with MATLAB 2020a and earlier versions
%
% Syntax:
%   [rates, fig] = plotConvergenceCustom(error_data, method_names)
%   [rates, fig] = plotConvergenceCustom(error_data, method_names, 'Name', Value)
%
% Inputs:
%   error_data   - Matrix where each column contains errors for one method
%                  OR cell array of error vectors
%   method_names - Cell array of method names for legend
%
% Optional Name-Value pairs:
%   'SavePath'       - Path to save figure (default: 'figs/convergence_plot')
%   'FigureSize'     - Figure size in cm [width, height] (default: [13, 11.5])
%   'ShowGrid'       - Show grid (default: true)
%   'LegendLocation' - Legend location (default: 'southeast')
%   'SaveFig'        - Save figure flag (default: false)
%   'YLabel'         - Y-axis label (default: 'Error')
%   'HSpatialSymbol' - Spatial step symbol (default: 'h')
%   'HTemporalSymbol' - Temporal step symbol (default: '\tau')
%
% Note: 
%   - Step values are automatically generated as 1, 1/2, 1/4, 1/8, ... based on data length
%   - X-axis labels are automatically set as (h, τ), (h/2, τ/2), (h/4, τ/4), ...
%   - No X-axis tick marks are shown, only the custom labels
%
% Example:
%   errors = [1.0283e-02, 8.2165e-03, 4.0960e-03, 2.0417e-03];
%   methods = {"Method A"};
%   [rates, fig] = plotConvergenceCustom(errors, methods, 'SaveFig', true);
%
% Outputs:
%   convergence_rates - Vector of calculated convergence rates
%   fig_handle       - Figure handle

    % Parse input arguments
    p = inputParser;
    addRequired(p, 'error_data');
    addRequired(p, 'method_names');
    addParameter(p, 'SavePath', 'figs/convergence_plot', @ischar);
    addParameter(p, 'FigureSize', [13, 11.5], @isnumeric);
    addParameter(p, 'ShowGrid', true, @islogical);
    addParameter(p, 'GridAlpha', 0.15, @isnumeric);
    addParameter(p, 'MinorGridAlpha', 0.05, @isnumeric);
    addParameter(p, 'LegendLocation', 'southeast', @ischar);
    addParameter(p, 'SaveFig', false, @islogical);
    addParameter(p, 'YLabel', 'Error', @ischar);
    addParameter(p, 'HSpatialSymbol', 'h', @ischar);
    addParameter(p, 'HTemporalSymbol', '\tau', @ischar);
    
    parse(p, error_data, method_names, varargin{:});
    
    % Extract parsed values
    save_path = p.Results.SavePath;
    fig_size = p.Results.FigureSize;
    show_grid = p.Results.ShowGrid;
    grid_alpha = p.Results.GridAlpha;
    minor_grid_alpha = p.Results.MinorGridAlpha;
    legend_loc = p.Results.LegendLocation;
    save_fig = p.Results.SaveFig;
    y_label = p.Results.YLabel;
    h_spatial = p.Results.HSpatialSymbol;
    h_temporal = p.Results.HTemporalSymbol;
    
    % Process error data and determine data length
    if iscell(error_data)
        num_methods = length(error_data);
        max_length = max(cellfun(@length, error_data));
        error_matrix = NaN(max_length, num_methods);
        for i = 1:num_methods
            len = length(error_data{i});
            error_matrix(1:len, i) = error_data{i}(:);
        end
        error_data = error_matrix;
        data_length = max_length;
    else
        if isrow(error_data)
            error_data = error_data.';
        end
        num_methods = size(error_data, 2);
        data_length = size(error_data, 1);
    end
    
    % Auto-generate step values: 1, 1/2, 1/4, 1/8, ... (原始递减序列)
    step_values_original = 1 ./ (2.^(0:(data_length-1)));
    
    % 为了满足xticks的递增要求,我们反转step_values和对应的数据
    step_values = fliplr(step_values_original);  % 变成递增: [1/8, 1/4, 1/2, 1]
    
    % 同时反转error_data的行顺序,使其与step_values对应
    error_data = flipud(error_data);
    
    % Auto-generate custom tick labels (对应反转后的顺序)
    custom_tick_labels = cell(1, data_length);
    for i = 1:data_length
        original_index = data_length - i + 1;  % 对应到原始序列的索引
        if original_index == 1
            custom_tick_labels{i} = sprintf('(%s, %s)', h_spatial, h_temporal);
        else
            divisor = 2^(original_index-1);
            custom_tick_labels{i} = sprintf('(%s/%d, %s/%d)', h_spatial, divisor, h_temporal, divisor);
        end
    end
    
    % Convert method_names to cell array of char (2020a compatibility)
    if isstring(method_names)
        method_names = cellstr(method_names);
    end
    
    % Validate inputs
    if length(method_names) ~= num_methods
        error('Number of method names must match number of error datasets');
    end
    
    % Calculate convergence rates using log-linear fitting
    convergence_rates = zeros(1, num_methods);
    
    for i = 1:num_methods
        errors = error_data(:, i);
        valid_idx = ~isnan(errors) & errors > 0;
        if sum(valid_idx) < 2
            warning('Method %d has insufficient valid data points', i);
            convergence_rates(i) = NaN;
            continue;
        end
        
        valid_step = step_values(valid_idx);
        valid_errors = errors(valid_idx);
        
        % Log-linear fitting
        X = [ones(length(valid_step), 1), log(valid_step(:))];
        coeffs = X \ log(valid_errors(:));
        convergence_rates(i) = coeffs(2);
    end
    
    % Create figure
    fig_handle = figure('Units', 'centimeters', 'Position', [0, 0, fig_size], 'Color', 'white');
    
    % Define color palette and markers
    colors = {[0.8500 0.3250 0.0980], [0.9290 0.6940 0.1250], [0.4660 0.6740 0.1880], ...
              [0 0.4470 0.7410], [0.4940 0.1840 0.5560], [0.3010 0.7450 0.9330], ...
              [0.6350 0.0780 0.1840], [0.8500 0.8500 0.8500]};
    markers = {'o', 's', '^', 'd', 'v', '>', '<', 'p'};
    
    % Plot data points first
    hold on;
    for i = 1:num_methods
        errors = error_data(:, i);
        valid_idx = ~isnan(errors) & errors > 0;
        
        if sum(valid_idx) > 1
            color_idx = mod(i-1, length(colors)) + 1;
            marker_idx = mod(i-1, length(markers)) + 1;
            
            valid_step = step_values(valid_idx);
            valid_errors = errors(valid_idx);
            
            % Create display name with rate information
            data_label = sprintf('%s', method_names{i});
            
            loglog(valid_step, valid_errors, ...
                   [markers{marker_idx} '-'], ...
                   'Color', colors{color_idx}, ...
                   'MarkerFaceColor', 'white', ...
                   'MarkerEdgeColor', colors{color_idx}, ...
                   'LineWidth', 2.5, ...
                   'MarkerSize', 9, ...
                   'DisplayName', data_label);
        end
    end
    
    % Plot fitted lines on top
    for i = 1:num_methods
        errors = error_data(:, i);
        valid_idx = ~isnan(errors) & errors > 0;
        
        if sum(valid_idx) > 1 && ~isnan(convergence_rates(i))
            color_idx = mod(i-1, length(colors)) + 1;
            
            valid_step = step_values(valid_idx);
            valid_errors = errors(valid_idx);
            
            % Calculate fitted line
            X = [ones(length(valid_step), 1), log(valid_step(:))];
            coeffs = X \ log(valid_errors(:));
            
            % Generate smooth fitted line
            fit_step_min = min(valid_step) * 0.8;
            fit_step_max = max(valid_step) * 1.2;
            fit_step = logspace(log10(fit_step_min), log10(fit_step_max), 100);
            fit_errors = exp(coeffs(1)) * fit_step.^coeffs(2);
            
            % Create safe label for fitted line
            fit_label = sprintf('(fitted rate=%.2f)', convergence_rates(i));
            
            darker_color = colors{color_idx} * 0.7;
            
            loglog(fit_step, fit_errors, '--', ...
                   'Color', darker_color, ...
                   'LineWidth', 3.5, ...
                   'DisplayName', fit_label);
        end
    end
    
    % Ensure log scale is properly set
    set(gca, 'XScale', 'log', 'YScale', 'log');
    
    % Customize plot
    xlabel('Mesh Refinement Level', 'FontSize', 14, 'FontName', 'Times New Roman');
    ylabel(y_label, 'FontSize', 14, 'FontName', 'Times New Roman');
    
    if show_grid
        grid on;
        set(gca, 'GridAlpha', grid_alpha);
        set(gca, 'MinorGridAlpha', minor_grid_alpha);
        set(gca, 'XMinorGrid', 'on', 'YMinorGrid', 'on');
    end
    
    set(gca, 'FontSize', 12, 'FontName', 'Times New Roman');
    set(gca, 'LineWidth', 1.5);
    set(gca, 'TickLength', [0.02 0.02]);
    
    % Set custom limits
    xlim([min(step_values)*0.7, max(step_values)*1.4]);
    
    min_err = min(error_data(:), [], 'omitnan');
    max_err = max(error_data(:), [], 'omitnan');
    if ~isnan(min_err) && ~isnan(max_err) && min_err > 0 && max_err > 0
        ylim([min_err*0.3, max_err*3]);
    end
    
    % ===== 设置自定义刻度标签 =====
    % 现在step_values是递增的,可以使用xticks
    xticks(step_values);
    xticklabels(custom_tick_labels);
    
    % 隐藏X轴刻度线
    ax = gca;
    ax.XAxis.TickLength = [0 0];  % 隐藏主刻度线
    ax.XAxis.MinorTickValues = [];  % 移除次要刻度
    
    % Y轴使用科学记数法
    ax.YAxis.Exponent = 0;
    ax.YAxis.TickLabelFormat = '%.0e';
    
    box on;
    
    % Create legend with tex interpreter
    leg = legend('Location', legend_loc, 'FontSize', 12, 'Interpreter', 'tex');
    set(leg, 'Box', 'off');
    
    % Set tight layout
    set(gca, 'LooseInset', get(gca, 'TightInset'));
    
    hold off;
    
    % Print convergence rates
    fprintf('\n=== Convergence Analysis Results ===\n');
    fprintf('Analysis Type: Spatiotemporal Refinement\n');
    for i = 1:num_methods
        fprintf('%-15s: %.4f\n', method_names{i}, convergence_rates(i));
    end
    fprintf('====================================\n\n');
    
    % Save figure if requested
    if save_fig
        [save_dir, ~, ~] = fileparts(save_path);
        if ~isempty(save_dir) && ~exist(save_dir, 'dir')
            mkdir(save_dir);
        end
        
        print([save_path], '-dpng', '-r300');
        % print([save_path], '-dpdf', '-r300');  % 如果需要PDF可以取消注释
        fprintf('Figure saved to: %s\n', save_path);
    end
end
% 现在只需要传入误差数据和方法名
error_v_cn = [2.3060e-03, 5.7582e-04, 1.4391e-04, 3.5976e-05, 8.9944e-06];
error_Q_cn = [3.1611e-03, 7.8095e-04, 1.9467e-04, 4.8631e-05, 1.2155e-05];

error_v_bdf2 = [2.2656e-03, 5.6576e-04, 1.4140e-04, 3.5347e-05, 8.8374e-06];
error_Q_bdf2 = [3.0983e-03, 7.6529e-04, 1.9075e-04, 4.7653e-05, 1.1911-05];


method_cn = {"Error_v of SGE-PDG", "Error_Q of SGE-PDG"} ;

method_bdf2 = {"Error_v of SGE-BDF2", "Error_Q of SGE-BDF2"};

plotConvergenceST([error_v_cn; error_Q_cn]', method_cn, 'SavePath', 'convergence_sge_cn.png', 'SaveFig', true);
plotConvergenceST([error_v_bdf2; error_Q_bdf2]', method_bdf2, 'SavePath', 'convergence_sge_bdf2.png', 'SaveFig', true);