• 博客园logo
  • 会员
  • 周边
  • 新闻
  • 博问
  • 闪存
  • 众包
  • 赞助商
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
MKT-porter
博客园    首页    新随笔    联系   管理    订阅  订阅
单目三角化原理

 

 

image

 

image

 

image

 

image

 为什么cv用归一化坐标 

image

 

==================================

image

 

image

 

image

 

image

image

 

image

 

image

 

image

 

image

 

 

为什么V T 的最后一列就是解?

 

image

 

image

 

image

 

image

 

 

image

 

image

 

image

 

image

 第一和第二帧产生地图点后,如果进来第三帧,如何定位第三帧,并将其建立的新地图点和现有地图点合并一个坐标系。考虑到第三帧和第二帧的相对位姿尺度未必与第一帧和第二帧一致。参考orb slam2解决

 

image

 

image

 

image

 

image

 这是一个非常核心的细节!这个过程被称为 “地图点投影匹配”或 “从粗到精的搜索”。ORB-SLAM2 通过一系列精心设计的策略来高效且准确地完成这个匹配。其核心思想是:利用已知的几何信息来极大地缩小搜索范围,而不是在整个图像上进行暴力匹配。以下是详细的步骤分解:

第一步:投影与创建搜索窗口

  1. 投影地图点:
    • 对于局部地图中的每一个地图点 Pw​(世界坐标系下的3D点),利用第三帧的初始估计位姿Tcw​(来自恒速模型或上一帧位姿)将其投影到第三帧的图像平面上。
    • 投影公式: ppred​=π(Tcw​⋅Pw​),其中 π是相机投影模型。
  2. 确定搜索窗口和尺度:
    • 搜索窗口:由于位姿估计和地图点位置本身都有不确定性,投影点 ppred​不一定是精确的匹配位置。因此,系统会以 ppred​为中心,建立一个矩形的搜索窗口。
    • 尺度不确定性:地图点是在其被创建的关键帧的尺度上被观测到的。由于第三帧可能与创建该点的关键帧距离不同,同一个地图点在不同帧中可能会出现在不同的图像金字塔层级(尺度)上。
    • 预测尺度:ORB-SLAM2 会根据地图点与第三帧光心的距离d来预测它应该在第三帧的哪个图像金字塔层级出现。具体来说,地图点会记录它被创建时的一个平均观测距离 dref​和所在的金字塔层级 lref​。
      • 预测的尺度公式大致为: lpred​=lref​+log2​(dref​d​)。
      • 这个预测的尺度 lpred​决定了搜索应该在哪个金字塔层级图像上进行,同时也决定了搜索窗口的半径(通常高层级的搜索半径会变小)。

第二步:在搜索窗口内进行匹配(描述子距离计算)

现在,任务是在第三帧上,以上述投影点 ppred​为中心、在预测尺度 lpred​附近的一个小范围内,找到一个特征点,其ORB描述子与地图点的ORB描述子最相似。

  1. 获取地图点的描述子:每个地图点都存储了一个代表性ORB描述子。这个描述子通常是在它被创建时,所有观测到它的关键帧的描述子中的中值描述子(通过计算每个bit位出现最多的值来得到),这使其对误差有一定的鲁棒性。
  2. 在搜索窗口内遍历所有候选特征点:
    • 在第三帧上,找出所有位于搜索窗口内,并且所在金字塔层级与预测尺度 lpred​相近(例如 lpred​±1)的特征点。这些点就是潜在的匹配候选点。
  3. 计算描述子距离:
    • 将地图点的描述子与搜索窗口内每一个候选特征点的描述子进行比较。
    • ORB描述子是二进制描述子(256位),比较它们使用汉明距离(Hamming Distance),即两个二进制串之间不同位的个数。汉明距离越小,相似度越高。
  4. 筛选最佳匹配:
    • 找出所有候选点中,与地图点描述子汉明距离最小的那个特征点。
    • 为了确保匹配质量,这个最小汉明距离必须低于一个预设的阈值(比如50)。同时,还需要检查次优匹配的质量,即“最近邻/次近邻”比率检验:最佳匹配的距离要明显小于次佳匹配的距离(例如,最佳距离 < 0.8 * 次佳距离),以避免模糊匹配。

第三步:几何验证与异常值剔除

通过描述子匹配得到的对应关系可能仍然包含错误匹配(外点)。ORB-SLAM2 在后续的位姿优化步骤中会进行严格的几何验证。

  1. 直接线性变换(DLT)或RANSAC-PnP:在初始匹配阶段,有时会使用带RANSAC的PnP来快速剔除明显的错误匹配。
  2. 运动-only的BA优化(重投影误差优化):这是最核心的步骤。在获得一组初步的匹配后,系统会构建一个优化问题:
    Tcw​min​i∑​ρ(∥pi​−π(Tcw​⋅Piw​)∥2)
    • 在这个优化过程中,重投影误差过大的匹配对会被标记为外点。优化器(如g2o)会使用鲁棒核函数(如Huber核)来降低这些外点的影响。
    • 优化完成后,那些误差始终很大的匹配会被直接剔除。

 

 参考代码

 

 

'''

# 创建环境
conda create --name yuyislam python=3.10

# 激活环境
conda activate pytorch_cuda

# 安装 PyTorch 和 CUDA 11.3(官方推荐组合)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# 安装 OpenCV 和其他依赖
conda install opencv numpy pandas matplotlib pyyaml


# 使用 pip 安装兼容 Python 3.10 的 Open3D
pip install open3d


pip install -U g2o-python



使用 environment.yml 文件创建
 
conda env create -f environment.yml
conda activate yuyislam


data_path = "/home/dongdong/2project/0data/RTK/data_1_nwpuUp/data3_1130_13pm/300_location_14pm"  # 修改为您的数据路径

'''
import os
import cv2
import numpy as np
import yaml
import pickle
import open3d as o3d
from typing import List, Dict, Tuple, Optional, Set
import matplotlib.pyplot as plt
from collections import deque, defaultdict
import time

class Camera:
    def __init__(self, config_path: str):
        self.load_config(config_path)
        self.K = np.array([
            [self.fx, 0, self.cx],
            [0, self.fy, self.cy],
            [0, 0, 1]
        ])
        # 修正畸变参数格式为OpenCV期望的形状
        self.dist_coeffs = np.zeros((1, 5), dtype=np.float64)
        self.dist_coeffs[0, 0] = self.k1
        self.dist_coeffs[0, 1] = self.k2
        self.dist_coeffs[0, 2] = self.p1
        self.dist_coeffs[0, 3] = self.p2
        self.dist_coeffs[0, 4] = self.k3
        
        print(f"相机内参加载成功:")
        print(f"  名称: {self.name}")
        print(f"  模型: {self.model}")
        print(f"  图像尺寸: {self.cols} x {self.rows}")
        print(f"  焦距: fx={self.fx:.2f}, fy={self.fy:.2f}")
        print(f"  主点: cx={self.cx:.2f}, cy={self.cy:.2f}")
        print(f"  畸变参数: k1={self.k1}, k2={self.k2}, k3={self.k3}, k4={self.k4}")
        print(f"  切向畸变: p1={self.p1}, p2={self.p2}")
        print(f"  颜色顺序: {self.color_order}")
        print(f"  帧率: {self.fps} FPS")
        print(f"  畸变参数形状: {self.dist_coeffs.shape}")
    
    def load_config(self, config_path: str):
        """从YAML文件加载相机内参"""
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"相机配置文件不存在: {config_path}")
        
        with open(config_path, 'r') as file:
            config = yaml.safe_load(file)
        
        print("原始配置文件内容:")
        print(config)
        
        # 正确获取Camera配置
        if 'Camera' in config:
            camera_config = config['Camera']
        else:
            camera_config = config
        
        # 按照您指定的格式读取相机内参
        self.cols = camera_config.get('Camera.cols', 1805)
        self.rows = camera_config.get('Camera.rows', 1203)
        self.cx = camera_config.get('Camera.cx', 910.8785687265895)
        self.cy = camera_config.get('Camera.cy', 602.174293145834)
        self.fx = camera_config.get('Camera.fx', 1193.7076128098686)
        self.fy = camera_config.get('Camera.fy', 1193.1735265967602)
        self.k1 = camera_config.get('Camera.k1', 0.0)
        self.k2 = camera_config.get('Camera.k2', 0.0)
        self.k3 = camera_config.get('Camera.k3', 0.0)
        self.k4 = camera_config.get('Camera.k4', 0.0)
        self.p1 = camera_config.get('Camera.p1', 0.0)
        self.p2 = camera_config.get('Camera.p2', 0.0)
        self.model = camera_config.get('Camera.model', 'perspective')
        self.color_order = camera_config.get('Camera.color_order', 'RGB')
        self.name = camera_config.get('Camera.name', 'NWPU monocular')
        self.fps = camera_config.get('Camera.fps', 10)

class MapPoint:
    def __init__(self, point_3d: np.ndarray, descriptor: np.ndarray, creating_kf_id: int):
        self.point_3d = point_3d.copy()  # 3D坐标(世界坐标系)
        self.descriptor = descriptor.copy() if descriptor is not None else None  # ORB描述子
        self.observations = []  # 观测列表 [(frame_id, kp_index)]
        self.observing_keyframes = set()  # 观测关键帧ID集合
        self.track_count = 0  # 被跟踪次数
        self.is_new = True  # 是否是新增点
        self.creating_kf_id = creating_kf_id  # 创建该点的关键帧ID
    
    def add_observation(self, frame_id: int, kp_index: int):
        """添加观测关系"""
        if frame_id not in self.observing_keyframes:
            self.observations.append((frame_id, kp_index))
            self.observing_keyframes.add(frame_id)
            self.track_count += 1
            if len(self.observing_keyframes) > 1:
                self.is_new = False

class KeyFrame:
    def __init__(self, frame_id: int, image: np.ndarray, keypoints: List[cv2.KeyPoint], 
                 descriptors: np.ndarray, pose: np.ndarray):
        self.frame_id = frame_id
        self.image = image
        self.keypoints = keypoints  # ORB特征点
        self.descriptors = descriptors  # ORB描述子
        self.pose = pose.copy()  # 4x4变换矩阵(世界坐标系)
        self.map_points = []  # 关联的地图点索引
        self.map_point_indices = {}  # 关键点索引到地图点索引的映射

class BundleAdjustment:
    """使用g2o进行光束法平差优化"""
    
    @staticmethod
    def optimize_pose_and_points(keyframes: List[KeyFrame], map_points: List[MapPoint], 
                                camera_K: np.ndarray, max_iterations: int = 10):
        """使用g2o优化位姿和地图点"""
        if len(keyframes) < 2 or len(map_points) == 0:
            return keyframes, map_points
        
        try:
            # 创建优化器
            optimizer = g2o.SparseOptimizer()
            solver = g2o.BlockSolverSE3(g2o.LinearSolverCSparseSE3())
            solver = g2o.OptimizationAlgorithmLevenberg(solver)
            optimizer.set_algorithm(solver)
            
            # 添加相机参数
            cam = g2o.CameraParameters(
                camera_K[0, 0],  # fx
                [camera_K[0, 2], camera_K[1, 2]],  # cx, cy
                0  # 基线(单目为0)
            )
            cam.set_id(0)
            optimizer.add_parameter(cam)
            
            # 添加关键帧顶点
            frame_vertices = {}
            for i, kf in enumerate(keyframes):
                v_se3 = g2o.VertexSE3Expmap()
                v_se3.set_id(i)
                v_se3.set_estimate(g2o.SE3Quat(kf.pose[:3, :3], kf.pose[:3, 3]))
                v_se3.set_fixed(i == 0)  # 固定第一帧
                optimizer.add_vertex(v_se3)
                frame_vertices[kf.frame_id] = v_se3
            
            # 添加地图点顶点
            point_vertices = {}
            for i, mp in enumerate(map_points):
                v_p = g2o.VertexPointXYZ()
                v_p.set_id(i + len(keyframes))
                v_p.set_estimate(mp.point_3d)
                v_p.set_marginalized(True)
                v_p.set_fixed(False)
                optimizer.add_vertex(v_p)
                point_vertices[i] = v_p
            
            # 添加边(观测)
            edge_id = 0
            for mp_idx, mp in enumerate(map_points):
                for frame_id, kp_idx in mp.observations:
                    if frame_id in frame_vertices:
                        # 找到对应的关键帧
                        kf = next((kf for kf in keyframes if kf.frame_id == frame_id), None)
                        if kf is not None and kp_idx < len(kf.keypoints):
                            # 创建边
                            edge = g2o.EdgeProjectXYZ2UV()
                            edge.set_vertex(0, point_vertices[mp_idx])
                            edge.set_vertex(1, frame_vertices[frame_id])
                            edge.set_measurement([kf.keypoints[kp_idx].pt[0], kf.keypoints[kp_idx].pt[1]])
                            edge.set_information(np.eye(2))
                            edge.set_parameter_id(0, 0)
                            edge.set_id(edge_id)
                            optimizer.add_edge(edge)
                            edge_id += 1
            
            # 优化
            print(f"开始g2o优化,关键帧: {len(keyframes)}, 地图点: {len(map_points)}, 边: {edge_id}")
            optimizer.initialize_optimization()
            optimizer.optimize(max_iterations)
            
            # 更新优化后的位姿和地图点
            for i, kf in enumerate(keyframes):
                if kf.frame_id in frame_vertices:
                    v_se3 = frame_vertices[kf.frame_id]
                    optimized_pose = v_se3.estimate()
                    kf.pose[:3, :3] = optimized_pose.rotation().matrix()
                    kf.pose[:3, 3] = optimized_pose.translation()
            
            for i, mp in enumerate(map_points):
                if i in point_vertices:
                    v_p = point_vertices[i]
                    mp.point_3d = v_p.estimate()
            
            print("g2o优化完成")
            
        except Exception as e:
            print(f"g2o优化失败: {e}")
        
        return keyframes, map_points

class MonoSLAM:
    def __init__(self, config_path: str, image_folder: str):
        self.camera = Camera(config_path)
        self.image_folder = image_folder
        
        if not os.path.exists(image_folder):
            raise FileNotFoundError(f"图像文件夹不存在: {image_folder}")
        
        # ORB特征提取器
        self.orb = cv2.ORB_create(nfeatures=2000, scaleFactor=1.2, nlevels=8)
        self.bf_matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
        
        # SLAM状态
        self.keyframes: List[KeyFrame] = []
        self.map_points: List[MapPoint] = []
        self.current_frame_id = 0
        self.world_pose = np.eye(4)  # 世界坐标系位姿(第一帧为原点)
        
        # 可视化
        self.vis = None
        self.point_cloud = None
        self.camera_trajectory = None
        
        # 参数
        self.min_matches = 50
        self.keyframe_threshold = 0.1
        self.pnp_reprojection_threshold = 3.0
        self.descriptor_match_threshold = 50  # 描述子匹配阈值
        
        print("单目SLAM系统初始化完成")
    
    def init_visualization(self):
        """初始化可视化窗口"""
        if self.vis is None:
            self.vis = o3d.visualization.Visualizer()
            self.vis.create_window(window_name='MonoSLAM', width=1200, height=800)
            
            # 设置视角
            ctr = self.vis.get_view_control()
            ctr.set_front([0, 0, -1])
            ctr.set_lookat([0, 0, 0])
            ctr.set_up([0, -1, 0])
            ctr.set_zoom(0.8)
    
    def load_images(self) -> List[str]:
        """加载并排序图像文件"""
        image_files = [f for f in os.listdir(self.image_folder) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'))]
        
        def extract_number(filename):
            numbers = ''.join(filter(str.isdigit, os.path.splitext(filename)[0]))
            return int(numbers) if numbers else 0
        
        image_files.sort(key=extract_number)
        full_paths = [os.path.join(self.image_folder, f) for f in image_files]
        
        print(f"找到 {len(image_files)} 张图像")
        return full_paths
    
    def extract_features(self, image: np.ndarray) -> Tuple[List[cv2.KeyPoint], np.ndarray]:
        """提取ORB特征点"""
        if len(image.shape) == 3:
            if self.camera.color_order == 'RGB':
                image_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            else:
                image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            image_gray = image
        
        keypoints, descriptors = self.orb.detectAndCompute(image_gray, None)
        
        if descriptors is None:
            descriptors = np.array([])
        
        print(f"提取到 {len(keypoints)} 个特征点")
        return keypoints, descriptors
    
    def match_features(self, desc1: np.ndarray, desc2: np.ndarray, ratio_test=0.8) -> List[cv2.DMatch]:
        """匹配特征点"""
        if desc1 is None or desc2 is None or len(desc1) == 0 or len(desc2) == 0:
            return []
        
        # 使用knn匹配
        matches = self.bf_matcher.knnMatch(desc1, desc2, k=2)
        
        # 应用比率测试
        good_matches = []
        for match_pair in matches:
            if len(match_pair) == 2:
                m, n = match_pair
                if m.distance < ratio_test * n.distance:
                    good_matches.append(m)
        
        return good_matches
    
    def match_with_map_points(self, descriptors: np.ndarray) -> List[Tuple[int, int]]:
        """使用描述子匹配当前帧特征点与地图点"""
        if len(self.map_points) == 0 or descriptors is None or len(descriptors) == 0:
            return []
        
        # 收集所有地图点的描述子
        map_descriptors = []
        valid_map_indices = []
        
        for i, mp in enumerate(self.map_points):
            if mp.descriptor is not None and len(mp.descriptor) > 0:
                map_descriptors.append(mp.descriptor)
                valid_map_indices.append(i)
        
        if len(map_descriptors) == 0:
            return []
        
        map_descriptors = np.array(map_descriptors)
        
        # 匹配当前帧描述子与地图点描述子
        matches = self.bf_matcher.match(descriptors, map_descriptors)
        
        # 应用距离阈值
        good_matches = []
        for match in matches:
            if match.distance < self.descriptor_match_threshold:
                map_point_idx = valid_map_indices[match.trainIdx]
                good_matches.append((match.queryIdx, map_point_idx))
        
        print(f"地图点匹配: {len(good_matches)}/{len(descriptors)}")
        return good_matches
    
    def initialize_map(self, image1: np.ndarray, image2: np.ndarray) -> bool:
        """初始化地图:第一帧和第二帧三角化生成初始地图"""
        print("=== 初始化地图 ===")
        
        # 提取第一帧特征
        kps1, descs1 = self.extract_features(image1)
        if len(kps1) < self.min_matches:
            print("第一帧特征点不足,无法初始化")
            return False
        
        # 提取第二帧特征
        kps2, descs2 = self.extract_features(image2)
        if len(kps2) < self.min_matches:
            print("第二帧特征点不足,无法初始化")
            return False
        
        # 匹配特征
        matches = self.match_features(descs1, descs2)
        print(f"初始化匹配点: {len(matches)}")
        
        if len(matches) < self.min_matches:
            print("匹配点不足,无法初始化")
            return False
        
        # 估计相对位姿
        pts1 = np.float32([kps1[m.queryIdx].pt for m in matches])
        pts2 = np.float32([kps2[m.trainIdx].pt for m in matches])
        
        E, mask = cv2.findEssentialMat(pts1, pts2, self.camera.K, method=cv2.RANSAC, prob=0.999, threshold=1.0)
        if E is None or mask.sum() < self.min_matches:
            print("本质矩阵估计失败")
            return False
        
        _, R, t, mask = cv2.recoverPose(E, pts1, pts2, self.camera.K, mask=mask)
        inlier_indices = [i for i in range(len(matches)) if mask[i] > 0]
        
        print(f"位姿估计内点: {len(inlier_indices)}/{len(matches)}")
        
        if len(inlier_indices) < self.min_matches:
            print("内点不足,无法初始化")
            return False
        
        # 创建第一关键帧(世界坐标系原点)
        pose1 = np.eye(4)
        kf1 = KeyFrame(0, image1, kps1, descs1, pose1)
        self.keyframes.append(kf1)
        
        # 创建第二关键帧
        pose2 = np.eye(4)
        pose2[:3, :3] = R
        pose2[:3, 3] = t.flatten()
        kf2 = KeyFrame(1, image2, kps2, descs2, pose2)
        self.keyframes.append(kf2)
        
        # 三角化生成初始地图点
        inlier_matches = [matches[i] for i in inlier_indices]
        initial_map_points = self.triangulate_points(pose1, pose2, kps1, kps2, descs1, inlier_matches, 0, 1)
        
        if len(initial_map_points) < 10:
            print("三角化点不足,初始化失败")
            self.keyframes = []
            return False
        
        # 建立关键帧与地图点的关联
        for mp_idx, mp in enumerate(initial_map_points):
            global_mp_idx = len(self.map_points) + mp_idx
            
            # 关联到第一关键帧
            for obs_kf_id, kp_idx in mp.observations:
                if obs_kf_id == 0:
                    kf1.map_points.append(global_mp_idx)
                    kf1.map_point_indices[kp_idx] = global_mp_idx
                    break
            
            # 关联到第二关键帧
            for obs_kf_id, kp_idx in mp.observations:
                if obs_kf_id == 1:
                    kf2.map_points.append(global_mp_idx)
                    kf2.map_point_indices[kp_idx] = global_mp_idx
                    break
        
        self.map_points.extend(initial_map_points)
        self.world_pose = pose2
        
        print(f"地图初始化成功: {len(initial_map_points)} 个地图点")
        return True
    
    def triangulate_points(self, pose1: np.ndarray, pose2: np.ndarray, 
                          kp1: List[cv2.KeyPoint], kp2: List[cv2.KeyPoint],
                          descriptors1: np.ndarray, matches: List[cv2.DMatch], 
                          kf1_id: int, kf2_id: int) -> List[MapPoint]:
        """三角化生成3D点"""
        if len(matches) < 8:
            return []
        
        # 准备投影矩阵
        P1 = self.camera.K @ np.hstack((np.eye(3), np.zeros((3, 1))))
        P2 = self.camera.K @ pose2[:3, :]
        
        # 准备匹配点
        pts1 = np.float32([kp1[m.queryIdx].pt for m in matches])
        pts2 = np.float32([kp2[m.trainIdx].pt for m in matches])
        
        # 三角化
        points_4d = cv2.triangulatePoints(P1, P2, pts1.T, pts2.T)
        points_3d = points_4d[:3] / points_4d[3]
        
        map_points = []
        for i, point in enumerate(points_3d.T):
            if point[2] > 0 and not np.any(np.isinf(point)) and not np.any(np.isnan(point)):
                # 使用关键点的描述子(从描述子数组中获取)
                if len(descriptors1) > 0 and i < len(descriptors1):
                    descriptor = descriptors1[matches[i].queryIdx]
                else:
                    descriptor = None
                
                map_point = MapPoint(point.copy(), descriptor, kf2_id)
                
                # 添加观测关系
                map_point.add_observation(kf1_id, matches[i].queryIdx)
                map_point.add_observation(kf2_id, matches[i].trainIdx)
                
                map_points.append(map_point)
        
        return map_points
    
    def track_with_map_points(self, image: np.ndarray, kps: List[cv2.KeyPoint], desc: np.ndarray) -> bool:
        """使用地图点进行跟踪(第三帧及以后)"""
        if len(self.keyframes) < 2 or len(self.map_points) == 0:
            return False
        
        print("=== 使用地图点进行跟踪 ===")
        
        # 1. 使用描述子匹配当前帧特征点与地图点
        map_matches = self.match_with_map_points(desc)
        print(f"地图点匹配结果: {len(map_matches)} 个匹配")
        
        if len(map_matches) < self.min_matches:
            print("地图点匹配不足,无法进行PnP跟踪")
            return False
        
        # 2. 准备3D-2D对应点进行PnP求解
        object_points = []
        image_points = []
        matched_map_indices = []
        
        for kp_idx, map_idx in map_matches:
            if map_idx < len(self.map_points):
                object_points.append(self.map_points[map_idx].point_3d)
                image_points.append(kps[kp_idx].pt)
                matched_map_indices.append(map_idx)
        
        if len(object_points) < 6:
            print(f"PnP有效点不足: {len(object_points)}")
            return False
        
        object_points = np.float32(object_points)
        image_points = np.float32(image_points)
        
        # 3. 使用PnP求解位姿
        try:
            success, rvec, tvec, inliers = cv2.solvePnPRansac(
                object_points, image_points, self.camera.K, self.camera.dist_coeffs,
                reprojectionError=self.pnp_reprojection_threshold, iterationsCount=100,
                flags=cv2.SOLVEPNP_ITERATIVE
            )
        except Exception as e:
            print(f"PnP求解失败: {e}")
            return False
        
        if not success or inliers is None or len(inliers) < 6:
            print("PnP求解失败")
            return False
        
        print(f"PnP成功: 内点 {len(inliers)}/{len(object_points)}")
        
        # 4. 转换旋转向量为旋转矩阵
        R, _ = cv2.Rodrigues(rvec)
        
        # 5. 更新世界位姿
        new_pose = np.eye(4)
        new_pose[:3, :3] = R
        new_pose[:3, 3] = tvec.flatten()
        self.world_pose = new_pose
        
        # 6. 判断是否为关键帧
        last_kf = self.keyframes[-1]
        if self.is_keyframe(last_kf.pose, new_pose):
            # 创建新关键帧
            kf = KeyFrame(self.current_frame_id, image, kps, desc, new_pose.copy())
            
            # 7. 与前一关键帧三角化新点
            kf_matches = self.match_features(last_kf.descriptors, desc)
            if len(kf_matches) >= 8:
                new_map_points = self.triangulate_points(
                    last_kf.pose, kf.pose, last_kf.keypoints, kps, last_kf.descriptors, kf_matches,
                    last_kf.frame_id, kf.frame_id
                )
                
                # 建立关联
                for mp_idx, mp in enumerate(new_map_points):
                    global_mp_idx = len(self.map_points) + mp_idx
                    
                    # 关联到前一关键帧
                    for obs_kf_id, kp_idx in mp.observations:
                        if obs_kf_id == last_kf.frame_id:
                            last_kf.map_points.append(global_mp_idx)
                            last_kf.map_point_indices[kp_idx] = global_mp_idx
                            break
                    
                    # 关联到当前关键帧
                    for obs_kf_id, kp_idx in mp.observations:
                        if obs_kf_id == kf.frame_id:
                            kf.map_points.append(global_mp_idx)
                            kf.map_point_indices[kp_idx] = global_mp_idx
                            break
                
                self.map_points.extend(new_map_points)
                print(f"新增 {len(new_map_points)} 个地图点")
            
            self.keyframes.append(kf)
            print(f"创建关键帧 {len(self.keyframes)},总地图点: {len(self.map_points)}")
            
            return True
        
        return True
    
    def is_keyframe(self, prev_pose: np.ndarray, curr_pose: np.ndarray) -> bool:
        """判断是否为关键帧"""
        rel_pose = np.linalg.inv(prev_pose) @ curr_pose
        R = rel_pose[:3, :3]
        t = rel_pose[:3, 3]
        
        translation_norm = np.linalg.norm(t)
        rotation_angle = np.arccos(np.clip((np.trace(R) - 1) / 2, -1, 1))
        
        is_kf = translation_norm > self.keyframe_threshold or rotation_angle > 0.1
        
        if is_kf:
            print(f"选择关键帧: 平移={translation_norm:.3f}, 旋转={rotation_angle:.3f}")
        
        return is_kf
    
    def update_visualization(self):
        """更新可视化"""
        if self.vis is None:
            self.init_visualization()
        
        if self.vis is not None:
            self.vis.clear_geometries()
        
        # 添加坐标系
        coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.0)
        if self.vis is not None:
            self.vis.add_geometry(coordinate_frame)
        
        if len(self.keyframes) == 0 or len(self.map_points) == 0:
            if self.vis is not None:
                self.vis.poll_events()
                self.vis.update_renderer()
            return
        
        try:
            # 准备点云数据
            points = [mp.point_3d for mp in self.map_points]
            
            if len(points) > 0:
                points_arr = np.array(points, dtype=np.float64)
                if len(points_arr.shape) == 2 and points_arr.shape[1] == 3:
                    point_cloud = o3d.geometry.PointCloud()
                    point_cloud.points = o3d.utility.Vector3dVector(points_arr)
                    
                    # 根据观测次数设置颜色
                    colors = []
                    for mp in self.map_points:
                        if mp.track_count == 1:
                            colors.append([0.0, 0.0, 1.0])  # 蓝色
                        elif mp.track_count == 2:
                            colors.append([0.0, 1.0, 1.0])  # 青色
                        else:
                            colors.append([1.0, 0.0, 0.0])  # 红色
                    
                    point_cloud.colors = o3d.utility.Vector3dVector(np.array(colors))
                    if self.vis is not None:
                        self.vis.add_geometry(point_cloud)
            
            # 更新相机轨迹
            if len(self.keyframes) > 1:
                trajectory_points = [kf.pose[:3, 3] for kf in self.keyframes]
                lines = [[i, i+1] for i in range(len(trajectory_points)-1)]
                
                camera_trajectory = o3d.geometry.LineSet()
                camera_trajectory.points = o3d.utility.Vector3dVector(np.array(trajectory_points, dtype=np.float64))
                camera_trajectory.lines = o3d.utility.Vector2iVector(lines)
                camera_trajectory.paint_uniform_color([0.0, 1.0, 0.0])
                if self.vis is not None:
                    self.vis.add_geometry(camera_trajectory)
            
            # 添加关键帧坐标系
            for i, kf in enumerate(self.keyframes):
                if i % 5 == 0 or i == len(self.keyframes) - 1:
                    kf_coordinate = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3)
                    kf_coordinate.transform(kf.pose)
                    if self.vis is not None:
                        self.vis.add_geometry(kf_coordinate)
            
            if self.vis is not None:
                self.vis.poll_events()
                self.vis.update_renderer()
                
                if len(points) > 0:
                    points_array = np.array(points)
                    print(f"点云范围: X[{points_array[:,0].min():.2f}, {points_array[:,0].max():.2f}] "
                          f"Y[{points_array[:,1].min():.2f}, {points_array[:,1].max():.2f}] "
                          f"Z[{points_array[:,2].min():.2f}, {points_array[:,2].max():.2f}]")
            
        except Exception as e:
            print(f"可视化更新失败: {e}")
            import traceback
            traceback.print_exc()
    
    def run_slam(self, max_frames: int = 50):
        """运行SLAM系统"""
        image_paths = self.load_images()
        if len(image_paths) == 0:
            return
        
        image_paths = image_paths[:max_frames]
        
        if len(image_paths) < 2:
            print("至少需要2张图像进行初始化")
            return
        
        print("\n=== 开始SLAM处理 ===")
        
        # 初始化阶段:处理前两帧
        image1 = cv2.imread(image_paths[0])
        image2 = cv2.imread(image_paths[1])
        
        if image1 is None or image2 is None:
            print("无法读取前两帧图像")
            return
        
        # 地图初始化
        if not self.initialize_map(image1, image2):
            print("地图初始化失败")
            return
        
        self.current_frame_id = 2
        print("地图初始化成功,开始跟踪后续帧...")
        
        # 跟踪阶段:处理后续帧
        for i in range(2, len(image_paths)):
            image_path = image_paths[i]
            print(f"\n--- 处理帧 {i+1}/{len(image_paths)}: {os.path.basename(image_path)} ---")
            
            image = cv2.imread(image_path)
            if image is None:
                continue
            
            kps, desc = self.extract_features(image)
            
            # 从第三帧开始使用地图点进行跟踪
            success = self.track_with_map_points(image, kps, desc)
            
            if success:
                # 更新可视化
                if len(self.keyframes) % 2 == 0:
                    self.update_visualization()
            
            self.current_frame_id += 1
            
            # 显示当前帧
            display_image = cv2.drawKeypoints(image, kps, None, color=(0, 255, 0))
            display_image = cv2.resize(display_image, (800, 600))
            cv2.imshow('Current Frame', display_image)
            if cv2.waitKey(0) & 0xFF == ord('q'):
                break
        
        # 最终可视化
        self.update_visualization()
        
        print(f"\n=== SLAM处理完成 ===")
        print(f"生成关键帧数量: {len(self.keyframes)}")
        print(f"生成地图点数量: {len(self.map_points)}")
        
        # 保持窗口打开
        if self.vis is not None:
            print("按 'q' 关闭窗口...")
            while True:
                self.vis.poll_events()
                self.vis.update_renderer()
                if cv2.waitKey(100) & 0xFF == ord('q'):
                    break
        
        cv2.destroyAllWindows()
        if self.vis is not None:
            try:
                self.vis.destroy_window()
            except:
                pass

# 使用示例
if __name__ == "__main__":
    # 设置数据路径
    data_path = "/home/dongdong/2project/0data/RTK/data_1_nwpuUp/data3_1130_13pm/300_location_14pm"
    config_path = os.path.join(data_path, "slam_config", "GNSS_config.yaml")
    image_folder = os.path.join(data_path, "images")
    
    try:
        slam = MonoSLAM(config_path, image_folder)
        slam.run_slam(max_frames=50)
        
    except Exception as e:
        print(f"错误: {e}")
        import traceback
        traceback.print_exc()

  

posted on 2025-11-03 18:48  MKT-porter  阅读(14)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2026
浙公网安备 33010602011771号 浙ICP备2021040463号-3