经典视觉跟踪算法的MATLAB实现

几种经典视觉跟踪算法的MATLAB实现代码,包括均值漂移(Mean Shift)、卡尔曼滤波(Kalman Filter)和粒子滤波(Particle Filter)算法。

1. 均值漂移(Mean Shift)跟踪算法

function mean_shift_tracking()
    % 均值漂移跟踪算法实现
    
    % 读取视频文件
    videoFile = 'test_video.avi';
    videoReader = VideoReader(videoFile);
    
    % 获取第一帧并选择目标区域
    frame = readFrame(videoReader);
    figure, imshow(frame);
    title('选择目标区域');
    rect = getrect;
    x = rect(1);
    y = rect(2);
    width = rect(3);
    height = rect(4);
    
    % 初始化目标模型(使用HSV颜色空间)
    target_region = frame(round(y):round(y+height), round(x):round(x+width), :);
    target_hsv = rgb2hsv(target_region);
    target_hist = compute_histogram(target_hsv);
    
    % 创建视频写入对象
    outputVideo = VideoWriter('mean_shift_result.avi');
    open(outputVideo);
    
    % 处理视频帧
    frame_count = 1;
    positions = zeros(videoReader.NumFrames, 2);
    
    while hasFrame(videoReader)
        frame = readFrame(videoReader);
        current_hsv = rgb2hsv(frame);
        
        % 使用均值漂移算法跟踪目标
        [x, y] = mean_shift_iteration(current_hsv, x, y, width, height, target_hist);
        
        % 存储位置
        positions(frame_count, :) = [x + width/2, y + height/2];
        
        % 绘制跟踪结果
        tracked_frame = insertShape(frame, 'Rectangle', [x, y, width, height], ...
                                   'LineWidth', 3, 'Color', 'red');
        tracked_frame = insertText(tracked_frame, [10, 10], sprintf('Frame: %d', frame_count));
        
        % 显示和保存结果
        imshow(tracked_frame);
        writeVideo(outputVideo, tracked_frame);
        
        frame_count = frame_count + 1;
    end
    
    close(outputVideo);
    
    % 绘制跟踪轨迹
    figure;
    imshow(read(videoReader, 1));
    hold on;
    plot(positions(:, 1), positions(:, 2), 'r-', 'LineWidth', 2);
    title('目标跟踪轨迹');
    
    fprintf('均值漂移跟踪完成!\n');
end

function hist = compute_histogram(hsv_image)
    % 计算HSV图像的色调直方图
    h = hsv_image(:, :, 1);
    hist = zeros(16, 1);
    
    for i = 1:size(h, 1)
        for j = 1:size(h, 2)
            bin = floor(h(i, j) * 15) + 1;
            if bin > 16, bin = 16; end
            hist(bin) = hist(bin) + 1;
        end
    end
    
    % 归一化直方图
    hist = hist / sum(hist);
end

function [new_x, new_y] = mean_shift_iteration(hsv_frame, x, y, width, height, target_hist)
    % 均值漂移迭代
    
    max_iterations = 20;
    epsilon = 1;
    
    for iter = 1:max_iterations
        % 提取候选区域
        x1 = max(1, round(x));
        y1 = max(1, round(y));
        x2 = min(size(hsv_frame, 2), round(x + width));
        y2 = min(size(hsv_frame, 1), round(y + height));
        
        candidate_region = hsv_frame(y1:y2, x1:x2, :);
        candidate_hist = compute_histogram(candidate_region);
        
        % 计算权重图
        weights = compute_weights(candidate_region, target_hist, candidate_hist);
        
        % 计算新的位置
        [rows, cols] = size(weights);
        [col_grid, row_grid] = meshgrid(1:cols, 1:rows);
        
        total_weight = sum(weights(:));
        if total_weight > 0
            new_x = sum(col_grid(:) .* weights(:)) / total_weight + x1 - 1;
            new_y = sum(row_grid(:) .* weights(:)) / total_weight + y1 - 1;
        else
            new_x = x;
            new_y = y;
        end
        
        % 检查收敛
        if sqrt((new_x - x)^2 + (new_y - y)^2) < epsilon
            break;
        end
        
        x = new_x;
        y = new_y;
    end
end

function weights = compute_weights(region, target_hist, candidate_hist)
    % 计算权重图
    h = region(:, :, 1);
    weights = zeros(size(h));
    
    for i = 1:size(h, 1)
        for j = 1:size(h, 2)
            bin = floor(h(i, j) * 15) + 1;
            if bin > 16, bin = 16; end
            
            if candidate_hist(bin) > 0
                weights(i, j) = sqrt(target_hist(bin) / candidate_hist(bin));
            else
                weights(i, j) = 0;
            end
        end
    end
end

2. 卡尔曼滤波跟踪算法

function kalman_filter_tracking()
    % 卡尔曼滤波跟踪算法实现
    
    % 读取视频文件
    videoFile = 'test_video.avi';
    videoReader = VideoReader(videoFile);
    
    % 获取第一帧并选择目标区域
    frame = readFrame(videoReader);
    figure, imshow(frame);
    title('选择目标区域');
    rect = getrect;
    x = rect(1);
    y = rect(2);
    width = rect(3);
    height = rect(4);
    
    % 初始化卡尔曼滤波器
    % 状态向量: [x, y, vx, vy, width, height]
    dt = 1; % 时间间隔
    A = [1, 0, dt, 0, 0, 0;  % 状态转移矩阵
         0, 1, 0, dt, 0, 0;
         0, 0, 1, 0, 0, 0;
         0, 0, 0, 1, 0, 0;
         0, 0, 0, 0, 1, 0;
         0, 0, 0, 0, 0, 1];
     
    H = [1, 0, 0, 0, 0, 0;  % 观测矩阵
         0, 1, 0, 0, 0, 0;
         0, 0, 0, 0, 1, 0;
         0, 0, 0, 0, 0, 1];
     
    Q = 0.01 * eye(6);  % 过程噪声协方差
    R = 1 * eye(4);     % 测量噪声协方差
    
    % 初始状态和协方差
    state = [x; y; 0; 0; width; height];
    covariance = 10 * eye(6);
    
    % 创建视频写入对象
    outputVideo = VideoWriter('kalman_filter_result.avi');
    open(outputVideo);
    
    % 处理视频帧
    frame_count = 1;
    positions = zeros(videoReader.NumFrames, 2);
    measurements = [];
    
    while hasFrame(videoReader)
        frame = readFrame(videoReader);
        
        % 预测步骤
        [state, covariance] = predict_kalman(state, covariance, A, Q);
        
        % 检测目标(简单实现,实际应用中应使用更复杂的检测器)
        if mod(frame_count, 5) == 1 || isempty(measurements)
            % 每5帧或丢失目标时重新检测
            measurements = detect_object(frame, state);
        end
        
        % 如果有测量值,则更新卡尔曼滤波器
        if ~isempty(measurements)
            [state, covariance] = update_kalman(state, covariance, measurements, H, R);
        end
        
        % 存储位置
        positions(frame_count, :) = [state(1) + state(5)/2, state(2) + state(6)/2];
        
        % 绘制跟踪结果
        tracked_frame = insertShape(frame, 'Rectangle', [state(1), state(2), state(5), state(6)], ...
                                   'LineWidth', 3, 'Color', 'green');
        tracked_frame = insertText(tracked_frame, [10, 10], sprintf('Frame: %d', frame_count));
        
        % 显示和保存结果
        imshow(tracked_frame);
        writeVideo(outputVideo, tracked_frame);
        
        frame_count = frame_count + 1;
    end
    
    close(outputVideo);
    
    % 绘制跟踪轨迹
    figure;
    imshow(read(videoReader, 1));
    hold on;
    plot(positions(:, 1), positions(:, 2), 'g-', 'LineWidth', 2);
    title('卡尔曼滤波跟踪轨迹');
    
    fprintf('卡尔曼滤波跟踪完成!\n');
end

function [new_state, new_covariance] = predict_kalman(state, covariance, A, Q)
    % 卡尔曼滤波预测步骤
    new_state = A * state;
    new_covariance = A * covariance * A' + Q;
end

function [updated_state, updated_covariance] = update_kalman(state, covariance, measurement, H, R)
    % 卡尔曼滤波更新步骤
    % 计算卡尔曼增益
    K = covariance * H' / (H * covariance * H' + R);
    
    % 更新状态估计
    updated_state = state + K * (measurement - H * state);
    
    % 更新协方差估计
    updated_covariance = (eye(size(covariance)) - K * H) * covariance;
end

function measurement = detect_object(frame, state)
    % 简单的目标检测函数
    % 在实际应用中,应使用更复杂的检测器(如相关滤波、深度学习等)
    
    % 提取搜索区域
    search_margin = 30;
    x1 = max(1, round(state(1)) - search_margin);
    y1 = max(1, round(state(2)) - search_margin);
    x2 = min(size(frame, 2), round(state(1) + state(5)) + search_margin);
    y2 = min(size(frame, 1), round(state(2) + state(6)) + search_margin);
    
    search_region = frame(y1:y2, x1:x2, :);
    
    % 简单模板匹配(实际应用中应使用更复杂的方法)
    % 这里只是示例,实际效果可能不佳
    template = imcrop(frame, [state(1), state(2), state(5), state(6)]);
    
    if size(template, 1) > 0 && size(template, 2) > 0
        % 调整模板大小以适应搜索区域
        template = imresize(template, [size(search_region, 1), size(search_region, 2)]);
        
        % 计算相关性
        correlation = normxcorr2(rgb2gray(template), rgb2gray(search_region));
        
        % 找到最大相关位置
        [ypeak, xpeak] = find(correlation == max(correlation(:)));
        
        if ~isempty(ypeak) && ~isempty(xpeak)
            % 计算测量值
            meas_x = x1 + xpeak(1) - size(template, 2);
            meas_y = y1 + ypeak(1) - size(template, 1);
            meas_width = state(5);
            meas_height = state(6);
            
            measurement = [meas_x; meas_y; meas_width; meas_height];
            return;
        end
    end
    
    % 如果检测失败,返回空值
    measurement = [];
end

3. 粒子滤波跟踪算法

function particle_filter_tracking()
    % 粒子滤波跟踪算法实现
    
    % 读取视频文件
    videoFile = 'test_video.avi';
    videoReader = VideoReader(videoFile);
    
    % 获取第一帧并选择目标区域
    frame = readFrame(videoReader);
    figure, imshow(frame);
    title('选择目标区域');
    rect = getrect;
    x = rect(1);
    y = rect(2);
    width = rect(3);
    height = rect(4);
    
    % 初始化粒子滤波器
    n_particles = 100;  % 粒子数量
    particles = initialize_particles(x, y, width, height, n_particles);
    
    % 提取目标模型(颜色直方图)
    target_region = frame(round(y):round(y+height), round(x):round(x+width), :);
    target_hist = compute_color_histogram(target_region);
    
    % 创建视频写入对象
    outputVideo = VideoWriter('particle_filter_result.avi');
    open(outputVideo);
    
    % 处理视频帧
    frame_count = 1;
    positions = zeros(videoReader.NumFrames, 2);
    
    while hasFrame(videoReader)
        frame = readFrame(videoReader);
        
        % 粒子滤波跟踪
        [particles, estimate] = particle_filter_step(frame, particles, target_hist);
        
        % 存储位置
        positions(frame_count, :) = [estimate(1) + estimate(3)/2, estimate(2) + estimate(4)/2];
        
        % 绘制跟踪结果
        tracked_frame = insertShape(frame, 'Rectangle', estimate, 'LineWidth', 3, 'Color', 'blue');
        
        % 可选:绘制粒子
        % for i = 1:size(particles, 1)
        %     tracked_frame = insertShape(tracked_frame, 'Rectangle', ...
        %         [particles(i,1), particles(i,2), particles(i,3), particles(i,4)], ...
        %         'LineWidth', 1, 'Color', 'yellow', 'Opacity', 0.2);
        % end
        
        tracked_frame = insertText(tracked_frame, [10, 10], sprintf('Frame: %d', frame_count));
        
        % 显示和保存结果
        imshow(tracked_frame);
        writeVideo(outputVideo, tracked_frame);
        
        frame_count = frame_count + 1;
    end
    
    close(outputVideo);
    
    % 绘制跟踪轨迹
    figure;
    imshow(read(videoReader, 1));
    hold on;
    plot(positions(:, 1), positions(:, 2), 'b-', 'LineWidth', 2);
    title('粒子滤波跟踪轨迹');
    
    fprintf('粒子滤波跟踪完成!\n');
end

function particles = initialize_particles(x, y, width, height, n_particles)
    % 初始化粒子
    particles = zeros(n_particles, 4);
    
    for i = 1:n_particles
        % 在目标位置附近随机分布粒子
        particles(i, 1) = x + randn * 10;  % x坐标
        particles(i, 2) = y + randn * 10;  % y坐标
        particles(i, 3) = width * (0.9 + 0.2 * rand);   % 宽度
        particles(i, 4) = height * (0.9 + 0.2 * rand);  % 高度
    end
end

function hist = compute_color_histogram(region)
    % 计算RGB颜色直方图
    bins = 8;  % 每个颜色通道的直方图bin数量
    
    r = region(:, :, 1);
    g = region(:, :, 2);
    b = region(:, :, 3);
    
    r_hist = histcounts(r(:), bins, 'Normalization', 'probability');
    g_hist = histcounts(g(:), bins, 'Normalization', 'probability');
    b_hist = histcounts(b(:), bins, 'Normalization', 'probability');
    
    hist = [r_hist, g_hist, b_hist];
end

function [updated_particles, estimate] = particle_filter_step(frame, particles, target_hist)
    % 粒子滤波单步处理
    
    n_particles = size(particles, 1);
    weights = zeros(n_particles, 1);
    
    % 为每个粒子计算权重
    for i = 1:n_particles
        x = particles(i, 1);
        y = particles(i, 2);
        width = particles(i, 3);
        height = particles(i, 4);
        
        % 提取粒子区域
        x1 = max(1, round(x));
        y1 = max(1, round(y));
        x2 = min(size(frame, 2), round(x + width));
        y2 = min(size(frame, 1), round(y + height));
        
        if x2 > x1 && y2 > y1
            particle_region = frame(y1:y2, x1:x2, :);
            particle_hist = compute_color_histogram(particle_region);
            
            % 计算Bhattacharyya系数作为相似度度量
            similarity = sum(sqrt(target_hist .* particle_hist));
            
            % 权重与相似度成正比
            weights(i) = similarity;
        else
            weights(i) = 0;
        end
    end
    
    % 归一化权重
    if sum(weights) > 0
        weights = weights / sum(weights);
    else
        weights = ones(n_particles, 1) / n_particles;
    end
    
    % 重采样
    indices = resample_particles(weights);
    updated_particles = particles(indices, :);
    
    % 添加随机噪声(防止粒子退化)
    process_noise = [5, 5, 2, 2];  % 位置和尺寸的噪声水平
    for i = 1:n_particles
        updated_particles(i, :) = updated_particles(i, :) + randn(1, 4) .* process_noise;
    end
    
    % 计算状态估计(加权平均)
    estimate = sum(updated_particles .* weights, 1);
end

function indices = resample_particles(weights)
    % 系统重采样
    n_particles = length(weights);
    indices = zeros(n_particles, 1);
    
    % 计算累积分布函数
    cdf = cumsum(weights);
    
    % 生成均匀分布的随机数
    u = rand / n_particles;
    
    i = 1;
    for j = 1:n_particles
        while u > cdf(i)
            i = i + 1;
            if i > n_particles
                i = n_particles;
                break;
            end
        end
        indices(j) = i;
        u = u + 1 / n_particles;
    end
end

使用

  1. 准备视频文件:将您的视频文件命名为test_video.avi并放在MATLAB工作目录中,或者修改代码中的视频文件路径。

  2. 运行跟踪算法:选择并运行上述任一算法函数(mean_shift_tracking, kalman_filter_tracking, 或particle_filter_tracking)。

  3. 选择目标区域:程序会显示第一帧图像,请用鼠标选择您要跟踪的目标区域。

  4. 查看结果:算法会处理视频的每一帧,显示跟踪结果,并保存为新的视频文件。

推荐代码 经典的视觉跟踪算法的MATLAB代码 www.3dddown.com/cna/50672.html

算法比较

算法 优点 缺点 适用场景
均值漂移 计算简单,实时性好 对快速运动和遮挡敏感 颜色特征明显、运动缓慢的目标
卡尔曼滤波 对线性系统有最优估计,能预测目标位置 假设系统为线性高斯模型 运动模式可预测的目标
粒子滤波 能处理非线性非高斯系统,鲁棒性强 计算复杂度高,需要较多粒子 复杂运动模式、需要高鲁棒性的场景
posted @ 2025-12-19 16:22  荒川之主  阅读(0)  评论(0)    收藏  举报