sort目标跟踪(卡尔曼滤波算法)
sort目标跟踪一个案例仅仅使用opencv库
#include <opencv2/opencv.hpp>
#include <vector>
#include <string>
#include <cmath>
#include <algorithm>
#include <limits>
#include <iostream>
#include <iomanip>
// ============================================================================
// Detection 结构体定义
// ============================================================================
struct Detection
{
int class_id{0};
std::string className{};
float confidence{0.0};
cv::Scalar color{};
cv::Rect box{};
};
// ============================================================================
// KalmanFilter 类 - 用于状态预测
// ============================================================================
class KalmanFilter
{
public:
// 状态向量: [u, v, s, r, du, dv, ds, dr]
// u, v: 中心坐标
// s: 面积 (scale)
// r: 宽高比 (aspect ratio)
// du, dv, ds, dr: 对应的速度
static const int STATE_DIM = 8;
static const int MEAS_DIM = 4; // 测量: [u, v, s, r]
KalmanFilter()
{
// 初始化状态转移矩阵 F
F = cv::Mat::zeros(STATE_DIM, STATE_DIM, CV_32F);
for (int i = 0; i < STATE_DIM; ++i)
F.at<float>(i, i) = 1.0f;
for (int i = 0; i < 4; ++i)
F.at<float>(i, i + 4) = 1.0f; // 速度项
// 初始化测量矩阵 H
H = cv::Mat::zeros(MEAS_DIM, STATE_DIM, CV_32F);
for (int i = 0; i < MEAS_DIM; ++i)
H.at<float>(i, i) = 1.0f;
// 初始化过程噪声协方差 Q
Q = cv::Mat::eye(STATE_DIM, STATE_DIM, CV_32F) * 0.01f;
// 初始化测量噪声协方差 R
R = cv::Mat::eye(MEAS_DIM, MEAS_DIM, CV_32F) * 1.0f;
// 初始化状态协方差 P
P = cv::Mat::eye(STATE_DIM, STATE_DIM, CV_32F) * 10.0f;
}
void init(const cv::Rect& bbox)
{
float u = bbox.x + bbox.width * 0.5f;
float v = bbox.y + bbox.height * 0.5f;
float s = bbox.width * bbox.height;
float r = bbox.width / bbox.height;
state = cv::Mat::zeros(STATE_DIM, 1, CV_32F);
state.at<float>(0, 0) = u;
state.at<float>(1, 0) = v;
state.at<float>(2, 0) = s;
state.at<float>(3, 0) = r;
}
cv::Mat predict()
{
// 状态预测: x' = Fx
state = F * state;
// 协方差预测: P' = FPF^T + Q
P = F * P * F.t() + Q;
return state;
}
cv::Mat update(const cv::Mat& measurement)
{
// 计算卡尔曼增益: K = PH^T(HPH^T + R)^-1
cv::Mat S = H * P * H.t() + R;
cv::Mat K = P * H.t() * S.inv();
// 状态更新: x = x + K(z - Hx)
cv::Mat y = measurement - H * state;
state = state + K * y;
// 协方差更新: P = (I - KH)P
cv::Mat I = cv::Mat::eye(STATE_DIM, STATE_DIM, CV_32F);
P = (I - K * H) * P;
return state;
}
cv::Rect get_bbox() const
{
float u = state.at<float>(0, 0);
float v = state.at<float>(1, 0);
float s = state.at<float>(2, 0);
float r = state.at<float>(3, 0);
float w = std::sqrt(s * r);
float h = s / w;
return cv::Rect(u - w * 0.5f, v - h * 0.5f, w, h);
}
cv::Mat get_state() const { return state; }
private:
cv::Mat state; // 状态向量
cv::Mat F; // 状态转移矩阵
cv::Mat H; // 测量矩阵
cv::Mat Q; // 过程噪声协方差
cv::Mat R; // 测量噪声协方差
cv::Mat P; // 状态协方差
};
// ============================================================================
// Track 类 - 单个跟踪对象
// ============================================================================
class Track
{
public:
Track(const Detection& detection, int track_id)
: id(track_id),
hits(1),
age(1),
time_since_update(0),
class_id(detection.class_id),
className(detection.className),
color(detection.color),
confidence(detection.confidence)
{
kf.init(detection.box);
state = cv::Rect2f(detection.box);
}
void update(const Detection& detection)
{
float u = detection.box.x + detection.box.width * 0.5f;
float v = detection.box.y + detection.box.height * 0.5f;
float s = detection.box.width * detection.box.height;
float r = detection.box.width / detection.box.height;
confidence = detection.confidence;
cv::Mat measurement = (cv::Mat_<float>(4, 1) << u, v, s, r);
kf.update(measurement);
state = cv::Rect2f(kf.get_bbox());
hits++;
time_since_update = 0;
}
void predict()
{
kf.predict();
state = cv::Rect2f(kf.get_bbox());
age++;
time_since_update++;
}
int get_id() const { return id; }
cv::Rect get_state() const { return state; }
int get_hits() const { return hits; }
int get_age() const { return age; }
int get_time_since_update() const { return time_since_update; }
int get_class_id() const { return class_id; }
std::string get_class_name() const { return className; }
cv::Scalar get_color() const { return color; }
float get_confidence() const { return confidence; }
private:
int id;
int hits; // 成功匹配次数
int age; // 跟踪器存在时间
int time_since_update; // 自上次更新以来的帧数
KalmanFilter kf;
cv::Rect2f state; // 当前状态 [x, y, w, h]
int class_id;
std::string className;
cv::Scalar color;
float confidence;
};
// ============================================================================
// 辅助函数 - IOU 计算
// ============================================================================
static float calculate_iou(const cv::Rect& box1, const cv::Rect& box2)
{
int x1 = std::max(box1.x, box2.x);
int y1 = std::max(box1.y, box2.y);
int x2 = std::min(box1.x + box1.width, box2.x + box2.width);
int y2 = std::min(box1.y + box1.height, box2.y + box2.height);
int inter_area = std::max(0, x2 - x1) * std::max(0, y2 - y1);
int box1_area = box1.width * box1.height;
int box2_area = box2.width * box2.height;
return static_cast<float>(inter_area) / (box1_area + box2_area - inter_area + 1e-6f);
}
// ============================================================================
// 匈牙利算法实现 - 用于最优二分图匹配
// ============================================================================
class HungarianAlgorithm
{
public:
// 返回: 每行匹配的列索引,-1 表示未匹配
static std::vector<int> solve(const cv::Mat& cost_matrix)
{
if (cost_matrix.empty())
return {};
int rows = cost_matrix.rows;
int cols = cost_matrix.cols;
// 确保矩阵是方阵(添加虚拟行或列)
cv::Mat cm = cost_matrix.clone();
if (rows < cols)
{
cv::Mat temp = cv::Mat::zeros(cols, cols, CV_32F);
cm.copyTo(temp(cv::Rect(0, 0, cols, rows)));
cm = temp;
}
else if (rows > cols)
{
cv::Mat temp = cv::Mat::zeros(rows, rows, CV_32F);
cm.copyTo(temp(cv::Rect(0, 0, cols, rows)));
cm = temp;
}
int n = cm.rows;
std::vector<int> assignment(n, -1);
std::vector<float> u(n + 1, 0), v(n + 1, 0);
std::vector<float> p(n + 1, 0), way(n + 1, 0);
for (int i = 1; i <= n; ++i)
{
p[0] = i;
int j0 = 0;
std::vector<float> minv(n + 1, std::numeric_limits<float>::max());
std::vector<bool> used(n + 1, false);
do
{
used[j0] = true;
int i0 = p[j0];
float delta = std::numeric_limits<float>::max();
int j1 = 0;
for (int j = 1; j <= n; ++j)
{
if (!used[j])
{
float cur = cm.at<float>(i0 - 1, j - 1) - u[i0] - v[j];
if (cur < minv[j])
{
minv[j] = cur;
way[j] = j0;
}
if (minv[j] < delta)
{
delta = minv[j];
j1 = j;
}
}
}
for (int j = 0; j <= n; ++j)
{
if (used[j])
{
u[p[j]] += delta;
v[j] -= delta;
}
else
{
minv[j] -= delta;
}
}
j0 = j1;
} while (p[j0] != 0);
do
{
int j1 = way[j0];
p[j0] = p[j1];
j0 = j1;
} while (j0);
}
for (int j = 1; j <= n; ++j)
{
if (p[j] != 0)
assignment[p[j] - 1] = j - 1;
}
return assignment;
}
};
// ============================================================================
// SORT 类 - 多目标跟踪主类
// ============================================================================
class SORT
{
public:
SORT(float iou_threshold = 0.3f, int max_age = 5, int min_hits = 2)
: iou_threshold_(iou_threshold),
max_age_(max_age),
min_hits_(min_hits),
next_track_id_(1)
{}
std::vector<Track> update(const std::vector<Detection>& detections)
{
// 预测所有跟踪器状态
for (auto& track : tracks_)
{
track.predict();
}
// 如果没有检测,移除过旧的跟踪器
if (detections.empty())
{
remove_old_tracks();
return get_active_tracks();
}
// 构建 IOU 矩阵
int n_tracks = tracks_.size();
int n_detections = detections.size();
cv::Mat iou_matrix = cv::Mat::zeros(n_tracks, n_detections, CV_32F);
for (int i = 0; i < n_tracks; ++i)
{
for (int j = 0; j < n_detections; ++j)
{
iou_matrix.at<float>(i, j) = calculate_iou(
tracks_[i].get_state(), detections[j].box);
}
}
// 使用匈牙利算法进行匹配
// 成本矩阵 = 1 - IOU
cv::Mat cost_matrix = cv::Mat::ones(n_tracks, n_detections, CV_32F) - iou_matrix;
std::vector<int> assignment = HungarianAlgorithm::solve(cost_matrix);
// 处理匹配结果
std::vector<bool> matched_detections(n_detections, false);
std::vector<bool> matched_tracks(n_tracks, false);
for (int i = 0; i < n_tracks; ++i)
{
if (assignment[i] >= 0 &&
iou_matrix.at<float>(i, assignment[i]) >= iou_threshold_)
{
tracks_[i].update(detections[assignment[i]]);
matched_tracks[i] = true;
matched_detections[assignment[i]] = true;
}
}
// 创建新的跟踪器
for (int j = 0; j < n_detections; ++j)
{
if (!matched_detections[j])
{
tracks_.emplace_back(detections[j], next_track_id_++);
}
}
// 移除过旧的跟踪器
remove_old_tracks();
return get_active_tracks();
}
int get_track_count() const { return tracks_.size(); }
private:
void remove_old_tracks()
{
auto it = tracks_.begin();
while (it != tracks_.end())
{
if (it->get_time_since_update() > max_age_)
{
it = tracks_.erase(it);
}
else
{
++it;
}
}
}
std::vector<Track> get_active_tracks() const
{
std::vector<Track> active_tracks;
for (const auto& track : tracks_)
{
if (track.get_hits() >= min_hits_)
{
active_tracks.push_back(track);
}
}
return active_tracks;
}
float iou_threshold_;
int max_age_;
int min_hits_;
int next_track_id_;
std::vector<Track> tracks_;
};
// ============================================================================
// 测试代码 - 模拟检测数据
// ============================================================================
// 生成模拟检测数据 - 模拟多个目标在场景中移动
std::vector<Detection> generate_simulated_detections(int frame_idx)
{
std::vector<Detection> detections;
// 目标1: 从左向右移动
if (frame_idx >= 0 && frame_idx < 60)
{
Detection det;
det.class_id = 1;
det.className = "person";
det.confidence = 0.85f;
det.color = cv::Scalar(0, 255, 0);
int x = 50 + frame_idx * 5;
int y = 100 + (frame_idx % 20) * 2;
det.box = cv::Rect(x, y, 60, 120);
detections.push_back(det);
}
// 目标2: 从右向左移动
if (frame_idx >= 10 && frame_idx < 70)
{
Detection det;
det.class_id = 1;
det.className = "person";
det.confidence = 0.90f;
det.color = cv::Scalar(255, 0, 0);
int x = 600 - (frame_idx - 10) * 4;
int y = 200 + std::sin(frame_idx * 0.1) * 20;
det.box = cv::Rect(x, y, 50, 100);
detections.push_back(det);
}
// 目标3: 从上向下移动
if (frame_idx >= 20 && frame_idx < 80)
{
Detection det;
det.class_id = 2;
det.className = "car";
det.confidence = 0.88f;
det.color = cv::Scalar(0, 0, 255);
int x = 300 + (frame_idx % 30) * 3;
int y = 50 + (frame_idx - 20) * 3;
det.box = cv::Rect(x, y, 80, 60);
detections.push_back(det);
}
// 目标4: 静态目标
if (frame_idx >= 30 && frame_idx < 90)
{
Detection det;
det.class_id = 3;
det.className = "bicycle";
det.confidence = 0.82f;
det.color = cv::Scalar(255, 255, 0);
det.box = cv::Rect(450, 350, 40, 80);
detections.push_back(det);
}
// 目标5: 对角线移动
if (frame_idx >= 40 && frame_idx < 100)
{
Detection det;
det.class_id = 1;
det.className = "person";
det.confidence = 0.87f;
det.color = cv::Scalar(255, 0, 255);
int x = 100 + (frame_idx - 40) * 4;
int y = 400 - (frame_idx - 40) * 3;
det.box = cv::Rect(x, y, 55, 110);
detections.push_back(det);
}
// ========== 极限测试用例 ==========
// 测试用例1: 突然出现一次的obj (应该在frame 15出现一次)
// 预期: 不会被激活,因为min_hits=2
if (frame_idx == 15)
{
Detection det;
det.class_id = 4;
det.className = "noise_once";
det.confidence = 0.95f;
det.color = cv::Scalar(128, 128, 128);
det.box = cv::Rect(100, 500, 30, 30);
detections.push_back(det);
}
// 测试用例2: 连续出现两次再也不出现的obj (frame 25-26)
// 预期: 会被激活(hits=2 >= min_hits=2),但很快消失
if (frame_idx >= 25 && frame_idx <= 26)
{
Detection det;
det.class_id = 5;
det.className = "noise_twice";
det.confidence = 0.92f;
det.color = cv::Scalar(255, 128, 0);
det.box = cv::Rect(200, 500, 35, 35);
detections.push_back(det);
}
// 测试用例3: 连续出现3次再也不出现的obj (frame 35-37)
// 预期: 会被激活(hits=3 >= min_hits=2),然后消失
if (frame_idx >= 35 && frame_idx <= 37)
{
Detection det;
det.class_id = 6;
det.className = "noise_thrice";
det.confidence = 0.90f;
det.color = cv::Scalar(128, 255, 128);
det.box = cv::Rect(300, 500, 40, 40);
detections.push_back(det);
}
// 测试用例4: 另一个突然出现一次的obj (frame 50)
// 预期: 不会被激活
if (frame_idx == 50)
{
Detection det;
det.class_id = 7;
det.className = "noise_once_2";
det.confidence = 0.88f;
det.color = cv::Scalar(128, 128, 255);
det.box = cv::Rect(400, 500, 25, 25);
detections.push_back(det);
}
// 测试用例5: 另一个连续出现两次再也不出现的obj (frame 60-61)
// 预期: 会被激活,但很快消失
if (frame_idx >= 60 && frame_idx <= 61)
{
Detection det;
det.class_id = 8;
det.className = "noise_twice_2";
det.confidence = 0.85f;
det.color = cv::Scalar(255, 128, 128);
det.box = cv::Rect(500, 500, 30, 30);
detections.push_back(det);
}
// 测试用例6: 另一个连续出现3次再也不出现的obj (frame 70-72)
// 预期: 会被激活,然后消失
if (frame_idx >= 70 && frame_idx <= 72)
{
Detection det;
det.class_id = 9;
det.className = "noise_thrice_2";
det.confidence = 0.93f;
det.color = cv::Scalar(255, 255, 128);
det.box = cv::Rect(600, 500, 35, 35);
detections.push_back(det);
}
return detections;
}
// 可视化检测结果和跟踪结果
void visualize_results(cv::Mat& frame,
const std::vector<Detection>& detections,
const std::vector<Track>& tracks)
{
// 绘制检测结果 (绿色框)
for (const auto& det : detections)
{
cv::rectangle(frame, det.box, cv::Scalar(0, 255, 0), 2);
std::string label = det.className + " " +
std::to_string(static_cast<int>(det.confidence * 100)) + "%";
cv::putText(frame, label,
cv::Point(det.box.x, det.box.y - 10),
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 255, 0), 2);
}
// 绘制跟踪结果 (带ID的彩色框)
for (const auto& track : tracks)
{
cv::Rect box = track.get_state();
cv::Scalar color = track.get_color();
cv::rectangle(frame, box, color, 3);
std::string label = "ID:" + std::to_string(track.get_id()) +
" " + track.get_class_name();
cv::putText(frame, label,
cv::Point(box.x, box.y - 15),
cv::FONT_HERSHEY_SIMPLEX, 0.6, color, 2);
}
}
// ============================================================================
// 主测试函数
// ============================================================================
int main()
{
std::cout << "========================================" << std::endl;
std::cout << "SORT 多目标跟踪算法测试" << std::endl;
std::cout << "========================================" << std::endl;
std::cout << std::endl;
// 创建 SORT 跟踪器
SORT tracker(0.3f, 5, 2);
// 模拟 100 帧数据
int total_frames = 100;
std::cout << "开始模拟跟踪..." << std::endl;
std::cout << std::endl;
std::cout << "SORT参数: IOU阈值=" << 0.3f << ", 最大年龄=" << 5 << ", 最小命中数=" << 2 << std::endl;
std::cout << "========================================" << std::endl;
std::cout << "极限测试用例说明:" << std::endl;
std::cout << " - Frame 16: 突然出现一次的obj (noise_once)" << std::endl;
std::cout << " - Frame 26-27: 连续出现两次的obj (noise_twice)" << std::endl;
std::cout << " - Frame 36-38: 连续出现三次的obj (noise_thrice)" << std::endl;
std::cout << " - Frame 51: 突然出现一次的obj (noise_once_2)" << std::endl;
std::cout << " - Frame 61-62: 连续出现两次的obj (noise_twice_2)" << std::endl;
std::cout << " - Frame 71-73: 连续出现三次的obj (noise_thrice_2)" << std::endl;
std::cout << "========================================" << std::endl;
std::cout << std::endl;
for (int frame_idx = 0; frame_idx < total_frames; ++frame_idx)
{
// 生成模拟检测数据
std::vector<Detection> detections = generate_simulated_detections(frame_idx);
// 更新跟踪器
std::vector<Track> tracks = tracker.update(detections);
// 打印跟踪信息
if (frame_idx % 10 == 0 || frame_idx == total_frames - 1 ||
frame_idx == 15 || frame_idx == 16 || // 突然出现一次
frame_idx == 25 || frame_idx == 26 || frame_idx == 27 || // 连续出现两次
frame_idx == 35 || frame_idx == 36 || frame_idx == 37 || frame_idx == 38 || // 连续出现三次
frame_idx == 50 || frame_idx == 51 || // 突然出现一次
frame_idx == 60 || frame_idx == 61 || frame_idx == 62 || // 连续出现两次
frame_idx == 70 || frame_idx == 71 || frame_idx == 72 || frame_idx == 73) // 连续出现三次
{
std::cout << "Frame " << std::setw(3) << frame_idx + 1 << ": "
<< "Detections=" << std::setw(2) << detections.size()
<< ", Active Tracks=" << std::setw(2) << tracks.size();
if (!tracks.empty())
{
std::cout << " [";
for (size_t i = 0; i < tracks.size(); ++i)
{
if (i > 0) std::cout << ", ";
std::cout << "ID" << tracks[i].get_id();
}
std::cout << "]";
}
// 打印检测到的特殊对象
bool has_special = false;
for (const auto& det : detections)
{
if (det.className.find("noise") != std::string::npos)
{
if (!has_special)
{
std::cout << " | 特殊检测: ";
has_special = true;
}
else
{
std::cout << ", ";
}
std::cout << det.className;
}
}
std::cout << std::endl;
}
}
std::cout << std::endl;
std::cout << "========================================" << std::endl;
std::cout << "测试完成!" << std::endl;
std::cout << "========================================" << std::endl;
return 0;
}
运行结果
cd /home/neardi/Desktop/project/01_das/build && make test_kalman -j4 && ./test_kalman
Consolidate compiler generated dependencies of target test_kalman
[ 50%] Building CXX object CMakeFiles/test_kalman.dir/test_kalman.cpp.o
[100%] Linking CXX executable test_kalman
[100%] Built target test_kalman
========================================
SORT 多目标跟踪算法测试
========================================
开始模拟跟踪...
SORT参数: IOU阈值=0.3, 最大年龄=5, 最小命中数=2
========================================
极限测试用例说明:
- Frame 16: 突然出现一次的obj (noise_once)
- Frame 26-27: 连续出现两次的obj (noise_twice)
- Frame 36-38: 连续出现三次的obj (noise_thrice)
- Frame 51: 突然出现一次的obj (noise_once_2)
- Frame 61-62: 连续出现两次的obj (noise_twice_2)
- Frame 71-73: 连续出现三次的obj (noise_thrice_2)
========================================
Frame 1: Detections= 1, Active Tracks= 0
Frame 11: Detections= 2, Active Tracks= 0
Frame 16: Detections= 3, Active Tracks= 0 | 特殊检测: noise_once
Frame 17: Detections= 2, Active Tracks= 0
Frame 21: Detections= 3, Active Tracks= 1 [ID25]
Frame 26: Detections= 4, Active Tracks= 2 [ID32, ID34] | 特殊检测: noise_twice
Frame 27: Detections= 4, Active Tracks= 3 [ID32, ID34, ID45] | 特殊检测: noise_twice
Frame 28: Detections= 3, Active Tracks= 3 [ID32, ID34, ID45]
Frame 31: Detections= 4, Active Tracks= 2 [ID34, ID45]
Frame 36: Detections= 5, Active Tracks= 1 [ID56] | 特殊检测: noise_thrice
Frame 37: Detections= 5, Active Tracks= 2 [ID56, ID73] | 特殊检测: noise_thrice
Frame 38: Detections= 5, Active Tracks= 2 [ID56, ID73] | 特殊检测: noise_thrice
Frame 39: Detections= 4, Active Tracks= 2 [ID56, ID73]
Frame 41: Detections= 5, Active Tracks= 2 [ID56, ID73]
Frame 51: Detections= 6, Active Tracks= 1 [ID56] | 特殊检测: noise_once_2
Frame 52: Detections= 5, Active Tracks= 1 [ID56]
Frame 61: Detections= 5, Active Tracks= 1 [ID56] | 特殊检测: noise_twice_2
Frame 62: Detections= 5, Active Tracks= 3 [ID56, ID168, ID171] | 特殊检测: noise_twice_2
Frame 63: Detections= 4, Active Tracks= 3 [ID56, ID168, ID171]
Frame 71: Detections= 4, Active Tracks= 2 [ID168, ID179] | 特殊检测: noise_thrice_2
Frame 72: Detections= 4, Active Tracks= 3 [ID168, ID197, ID201] | 特殊检测: noise_thrice_2
Frame 73: Detections= 4, Active Tracks= 3 [ID168, ID197, ID201] | 特殊检测: noise_thrice_2
Frame 74: Detections= 3, Active Tracks= 3 [ID168, ID197, ID201]
Frame 81: Detections= 2, Active Tracks= 2 [ID168, ID217]
Frame 91: Detections= 1, Active Tracks= 2 [ID229, ID234]
Frame 100: Detections= 1, Active Tracks= 0
========================================
测试完成!
========================================

浙公网安备 33010602011771号