• 博客园logo
  • 会员
  • 周边
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
MKT-porter
博客园    首页    新随笔    联系   管理    订阅  订阅

roma2

 

安装

环境

#============================== 安装  
检查 CUDA 环境
export CUDA_HOME=/usr/local/cuda-11.8
export PATH=/usr/local/cuda-11.8/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH

nvcc -V 查看
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

  

安装

conda remove -n romav2 --all -y

conda create -n romav2 python=3.11 -y
conda activate romav2

pip install \
    torch==2.5.1 \
    torchvision==0.20.1 \
    torchaudio==2.5.1 \
    --index-url https://download.pytorch.org/whl/cu118



pip install --force-reinstall \
torch==2.5.1+cu118 \
torchvision==0.20.1+cu118 \
torchaudio==2.5.1+cu118 \
--index-url https://download.pytorch.org/whl/cu118

  

去掉默认的pytorch

image

 

[project]
name = "romav2"
version = "2.0.1"
description = "RoMa v2: Harder Better Faster Denser Feature Matching"
readme = "README.md"
authors = [
    { name = "Johan Edstedt", email = "johan.edstedt@liu.se" }
]
requires-python = ">=3.10"
dependencies = [
    "einops>=0.8.1",
    "pillow>=12.0.0",
    "rich>=14.2.0",
    "tqdm>=4.67.1",
    "fused-local-corr ; sys_platform == 'linux'",
]

[build-system]
requires = ["uv_build>=0.8.15,<0.9.0"]
build-backend = "uv_build"

[project.optional-dependencies]
eval = [
    "kornia>=0.8.2",
    "matplotlib>=3.10.7",
    "opencv-python>=4.12.0.88",
    "wandb>=0.23.0",
    "wxbs-benchmark>=0.0.4",
]
dev = [
    "slurm-util>=0.2.7",
    "ruff>=0.14.5",
]

  

 

 重新安装

pip install -e .

  

报错cudnn  cuda 之类

确保只有一个版本

pip install --force-reinstall \
torch==2.5.1+cu118 \
torchvision==0.20.1+cu118 \
torchaudio==2.5.1+cu118 \
--index-url https://download.pytorch.org/whl/cu118

  

 

 

# 第一次下载模型
Downloading: "https://github.com/Parskatt/RoMaV2/releases/download/v2.0.1/romav2.0.1.pt" to /home/dongdong/.cache/torch/hub/checkpoints/romav2.0.1.pt


测试1 划线匹配

image

 

image

 

 

 

from pathlib import Path
import cv2
import numpy as np
import torch
import matplotlib.cm as cm
from PIL import Image

from romav2 import RoMaV2
from romav2.device import device

# -----------------------------
# 1. 图像路径
# -----------------------------
img_A_path = '/media/dongdong/新加卷/0ubuntu20/1slam/数据/2RTK/City1-buildings/location21_fog_0325_8pm_133m/images/DJI_00001.jpg'

img_B_path = (
    "/media/dongdong/新加卷/0ubuntu20/1slam/数据/"
    "2RTK/City1-buildings/location11_night_0224_21pm_125m/"
    "pic_0224_night_yintian_2131pm_125/images/DJI_00006.jpg"
)

# -----------------------------
# 2. 加载模型
# -----------------------------
model = RoMaV2()
model.apply_setting("precise")
model.eval()

# -----------------------------
# 3. 读取图像 & 尺寸
# -----------------------------
im1 = Image.open(img_A_path).convert("RGB")
im2 = Image.open(img_B_path).convert("RGB")

H_A, W_A = im1.height, im1.width
H_B, W_B = im2.height, im2.width

# -----------------------------
# 4. 密集匹配 + 采样
# -----------------------------
preds = model.match(img_A_path, img_B_path)
matches, overlaps, _, _ = model.sample(preds, 5000)

# -----------------------------
# 5. 转像素坐标
# -----------------------------
kptsA, kptsB = model.to_pixel_coordinates(
    matches, H_A, W_A, H_B, W_B
)

# -----------------------------
# 6. RANSAC 估计 F
# -----------------------------
kptsA_np = kptsA.cpu().numpy()
kptsB_np = kptsB.cpu().numpy()

F, mask = cv2.findFundamentalMat(
    kptsA_np,
    kptsB_np,
    ransacReprojThreshold=0.2,
    method=cv2.USAC_MAGSAC,
    confidence=0.999999,
    maxIters=10000,
)

print("Fundamental Matrix:\n", F)
print("Inliers:", int(mask.sum()) if mask is not None else 0)

# -----------------------------
# 7. 可视化匹配(彩色连线)
# -----------------------------
# -----------------------------
# 7. 实时可视化(前 100 个 inlier)
# -----------------------------
if mask is not None and int(mask.sum()) > 0:
    imgA = np.array(im1)
    imgB = np.array(im2)
    vis = np.concatenate([imgA, imgB], axis=1)

    inliers = mask.ravel().astype(bool)
    ptsA = kptsA_np[inliers][:100]
    ptsB = kptsB_np[inliers][:100]

    N = len(ptsA)
    cmap = cm.get_cmap("hsv", N)

    for i, ((x1, y1), (x2, y2)) in enumerate(zip(ptsA, ptsB)):
        color = tuple(int(c * 255) for c in cmap(i)[:3])

        cv2.circle(vis, (int(x1), int(y1)), 3, color, -1)
        cv2.circle(vis, (int(x2 + W_A), int(y2)), 3, color, -1)
        cv2.line(
            vis,
            (int(x1), int(y1)),
            (int(x2 + W_A), int(y2)),
            color,
            1,
            cv2.LINE_AA,
        )

    # 实时显示
    cv2.namedWindow("RoMaV2 Matches", cv2.WINDOW_NORMAL)
    cv2.imshow("RoMaV2 Matches", vis[..., ::-1])

    print(f"[INFO] Showing {N} matches (press ESC to exit)")

    while True:
        key = cv2.waitKey(1)
        if key == 27:  # ESC
            break

    cv2.destroyAllWindows()
else:
    print("⚠️ No inliers found, skip visualization")

  

2 加载正射图,抠图对

cd /home/dongdong/2project/1salm/OpenDroneMap/ODM

python3 examples/localize_query_sift.py \
--query /media/dongdong/新加卷/0ubuntu20/1slam/数据/2RTK/City5_wilderness/location3_cloudy_0306_16pm_135m/images/DJI_00298.jpg \
--ortho "/media/dongdong/新加卷/0ubuntu20/1slam/数据/2RTK/City5_wilderness/map_cloudy_0228_17pm_132/odm_result/odm_orthophoto/odm_orthophoto.tif" \
--dsm "/media/dongdong/新加卷/0ubuntu20/1slam/数据/2RTK/City5_wilderness/map_cloudy_0228_17pm_132/odm_result/odm_dem/dsm.tif" \
--out-dir "/media/dongdong/新加卷/0ubuntu20/1slam/数据/2RTK/City5_wilderness/map_cloudy_0228_17pm_132/localization_test/location_result" \
--gnss-config examples/GNSS_config.yaml \
--min-matches 5

image

image

 

image

 

 

 

 

#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import math
import sys
from pathlib import Path
from typing import Iterable

import cv2
import numpy as np
from PIL import Image
from PIL.ExifTags import GPSTAGS, TAGS
import yaml

try:
    import rasterio
    from rasterio.windows import Window
    from pyproj import Geod, Transformer
except ImportError:
    rasterio = None
    Window = None
    Geod = None
    Transformer = None


def require_geotiff_deps() -> None:
    if rasterio is None or Transformer is None or Geod is None or Window is None:
        raise RuntimeError("Missing dependency. Install with: python3 -m pip install rasterio pyproj")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Localize one UAV query image against an ODM orthophoto + DSM using RoMaV2 matching + PnP."
    )
    parser.add_argument("--query", required=True, help="Query UAV image path.")
    parser.add_argument("--ortho", required=True, help="ODM orthophoto GeoTIFF path.")
    parser.add_argument("--dsm", required=True, help="ODM DSM GeoTIFF path.")
    parser.add_argument("--out-dir", required=True, help="Output directory.")
    parser.add_argument("--gnss-config", default=None, help="GNSS_config.yaml path for camera intrinsics.")
    parser.add_argument("--fx", type=float, default=None, help="Override camera focal length fx in pixels.")
    parser.add_argument("--fy", type=float, default=None, help="Override camera focal length fy in pixels.")
    parser.add_argument("--cx", type=float, default=None, help="Override camera principal point cx in pixels.")
    parser.add_argument("--cy", type=float, default=None, help="Override camera principal point cy in pixels.")
    parser.add_argument("--dist", default=None, help="Override distortion coeffs, e.g. k1,k2,p1,p2,k3.")
    parser.add_argument("--query-max-size", type=int, default=1800, help="Max query side before matching.")
    parser.add_argument("--romav2-device", default="cuda", help="RoMaV2 device, e.g. cuda, cuda:0, cpu.")
    parser.add_argument("--romav2-sample-num", type=int, default=5000, help="Number of matches to sample from RoMaV2 dense output.")
    parser.add_argument("--ransac-threshold", type=float, default=0.2, help="RANSAC reprojection threshold for Fundamental Matrix.")
    parser.add_argument("--min-matches", type=int, default=20, help="Minimum homography inlier matches.")
    parser.add_argument("--pnp-grid-cols", type=int, default=0, help="Grid columns for PnP sampling. 0 disables grid limiting.")
    parser.add_argument("--pnp-grid-rows", type=int, default=0, help="Grid rows for PnP sampling. 0 disables grid limiting.")
    parser.add_argument("--pnp-max-per-cell", type=int, default=0, help="Max PnP matches kept per grid cell. 0 disables grid limiting.")
    parser.add_argument("--pnp-max-points", type=int, default=0, help="Max PnP matches after filtering. 0 keeps all.")
    parser.add_argument("--no-lk-refine", action="store_true", help="Disable Lucas-Kanade refinement of match endpoints.")
    parser.add_argument("--lk-max-error", type=float, default=20.0, help="Max LK tracking error kept after match refinement.")
    return parser.parse_args()


def find_gnss_config(query_path: Path, explicit_path: str | None) -> Path | None:
    if explicit_path:
        return Path(explicit_path).expanduser().resolve()
    candidates = [
        query_path.parent / "GNSS_config.yaml",
        query_path.parent.parent / "GNSS_config.yaml",
        Path.cwd() / "GNSS_config.yaml",
        Path.cwd() / "examples" / "GNSS_config.yaml",
    ]
    for candidate in candidates:
        if candidate.is_file():
            return candidate.resolve()
    return None


def load_camera_from_gnss_config(path: Path | None) -> dict:
    if path is None:
        return {}
    with path.open("r", encoding="utf-8") as f:
        data = yaml.safe_load(f) or {}
    camera = {
        "fx": data.get("Camera.fx"),
        "fy": data.get("Camera.fy"),
        "cx": data.get("Camera.cx"),
        "cy": data.get("Camera.cy"),
        "cols": data.get("Camera.cols"),
        "rows": data.get("Camera.rows"),
        "dist": [
            data.get("Camera.k1", 0.0),
            data.get("Camera.k2", 0.0),
            data.get("Camera.p1", 0.0),
            data.get("Camera.p2", 0.0),
            data.get("Camera.k3", 0.0),
        ],
    }
    return {key: value for key, value in camera.items() if value is not None}


def read_image_for_matching(path: Path, max_size: int) -> tuple[np.ndarray, float]:
    image = cv2.imread(str(path), cv2.IMREAD_COLOR)
    if image is None:
        raise FileNotFoundError(path)
    h, w = image.shape[:2]
    scale = min(1.0, max_size / max(h, w))
    if scale < 1.0:
        image = cv2.resize(image, (round(w * scale), round(h * scale)), interpolation=cv2.INTER_AREA)
    return image, scale


def read_ortho_crop_near_gps(
    path: Path,
    gps_lat: float,
    gps_lon: float,
    crop_width: int,
    crop_height: int,
) -> tuple[np.ndarray, tuple[int, int], dict]:
    require_geotiff_deps()
    with rasterio.open(path) as src:
        if src.crs is None:
            raise RuntimeError(f"Orthophoto has no CRS: {path}")
        to_ortho = Transformer.from_crs("EPSG:4326", src.crs, always_xy=True)
        x, y = to_ortho.transform(gps_lon, gps_lat)
        center_row, center_col = src.index(x, y)
        col_off = int(round(center_col - crop_width / 2.0))
        row_off = int(round(center_row - crop_height / 2.0))
        window = Window(col_off, row_off, crop_width, crop_height)

        count = min(3, src.count)
        data = src.read(
            indexes=list(range(1, count + 1)),
            window=window,
            boundless=True,
            fill_value=0,
        )
        if count == 1:
            gray = normalize_to_uint8(data[0])
            bgr = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
        else:
            rgb = np.dstack([normalize_to_uint8(data[i]) for i in range(count)])
            bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)

        crop_meta = {
            "gps_lon_lat": [float(gps_lon), float(gps_lat)],
            "gps_map_xy": [float(x), float(y)],
            "center_pixel_col_row": [int(center_col), int(center_row)],
            "crop_origin_col_row": [int(col_off), int(row_off)],
            "crop_size_wh": [int(crop_width), int(crop_height)],
            "ortho_size_wh": [int(src.width), int(src.height)],
        }
    return bgr, (col_off, row_off), crop_meta


def normalize_to_uint8(array: np.ndarray) -> np.ndarray:
    if array.dtype == np.uint8:
        return array
    finite = array[np.isfinite(array)]
    if finite.size == 0:
        return np.zeros(array.shape, dtype=np.uint8)
    low, high = np.percentile(finite, [1, 99])
    if high <= low:
        high = low + 1
    return np.clip((array - low) * 255.0 / (high - low), 0, 255).astype(np.uint8)


def romav2_match_points(
    query_bgr: np.ndarray,
    ortho_bgr: np.ndarray,
    device: str,
    sample_num: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
    """
    Dense matching using RoMaV2.
    Returns:
        q_pts: (N,2) query pixel coords (in the resized matching image space)
        o_pts: (N,2) ortho pixel coords (in the resized matching image space)
        conf : (N,) confidence scores
        stats: dict with matching statistics
    """
    import torch
    from romav2 import RoMaV2
    from romav2.device import device as roma_device

    if device.startswith("cuda") and not torch.cuda.is_available():
        device = "cpu"
    else:
        device = roma_device

    model = RoMaV2()
    model.apply_setting("precise")
    model.eval()

    H_q, W_q = query_bgr.shape[:2]
    H_o, W_o = ortho_bgr.shape[:2]

    with torch.inference_mode():
        preds = model.match(query_bgr, ortho_bgr)

    matches, overlaps, _, _ = model.sample(preds, sample_num)

    q_pts, o_pts = model.to_pixel_coordinates(
        matches, H_q, W_q, H_o, W_o
    )

    q_pts = q_pts.cpu().numpy()
    o_pts = o_pts.cpu().numpy()
    conf = overlaps.detach().cpu().numpy().ravel()

    stats = {
        "matcher": "romav2",
        "device": device,
        "sampled_matches": int(len(q_pts)),
        "conf_min": float(conf.min()) if len(conf) > 0 else 0.0,
        "conf_median": float(np.median(conf)) if len(conf) > 0 else 0.0,
        "conf_max": float(conf.max()) if len(conf) > 0 else 0.0,
    }


    romav2_stats = {
    "matcher": "romav2",
    "device": str(device),   # ✅ 转成字符串
    "sampled_matches": int(len(q_pts)),
    "conf_min": float(conf.min()),
    "conf_median": float(np.median(conf)),
    "conf_max": float(conf.max()),
}

    return q_pts, o_pts, conf, romav2_stats


def homography_inliers(q_pts: np.ndarray, o_pts: np.ndarray, threshold: float = 5.0):
    h_mat, mask = cv2.findHomography(q_pts, o_pts, cv2.RANSAC, threshold)
    if h_mat is None or mask is None:
        raise RuntimeError("Homography RANSAC failed.")
    inlier_mask = mask.ravel().astype(bool)
    return h_mat, inlier_mask


def refine_ortho_points_lk(
    query_bgr: np.ndarray,
    ortho_bgr: np.ndarray,
    q_pts: np.ndarray,
    o_pts: np.ndarray,
    max_error: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
    if len(q_pts) == 0:
        return q_pts, o_pts, np.zeros(0, dtype=bool), {"enabled": True, "input": 0, "kept": 0}

    q_gray = cv2.cvtColor(query_bgr, cv2.COLOR_BGR2GRAY)
    o_gray = cv2.cvtColor(ortho_bgr, cv2.COLOR_BGR2GRAY)
    q_init = q_pts.astype(np.float32).reshape(-1, 1, 2)
    o_init = o_pts.astype(np.float32).reshape(-1, 1, 2)
    refined, status, err = cv2.calcOpticalFlowPyrLK(
        q_gray,
        o_gray,
        q_init,
        o_init.copy(),
        winSize=(31, 31),
        maxLevel=3,
        criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 30, 0.01),
        flags=cv2.OPTFLOW_USE_INITIAL_FLOW,
    )
    if refined is None or status is None:
        return q_pts, o_pts, np.zeros(len(q_pts), dtype=bool), {"enabled": True, "input": int(len(q_pts)), "kept": 0}

    status_mask = status.ravel().astype(bool)
    err_values = np.full(len(q_pts), np.inf, dtype=np.float32)
    if err is not None:
        err_values = err.ravel().astype(np.float32)
        status_mask &= err_values <= max_error

    refined_pts = refined.reshape(-1, 2)
    h, w = ortho_bgr.shape[:2]
    status_mask &= (
        (refined_pts[:, 0] >= 0)
        & (refined_pts[:, 0] < w)
        & (refined_pts[:, 1] >= 0)
        & (refined_pts[:, 1] < h)
    )
    stats = {
        "enabled": True,
        "input": int(len(q_pts)),
        "kept": int(status_mask.sum()),
        "max_error": float(max_error),
        "kept_error_median": None if not np.any(status_mask) else float(np.median(err_values[status_mask])),
        "kept_error_max": None if not np.any(status_mask) else float(np.max(err_values[status_mask])),
    }
    return q_pts[status_mask], refined_pts[status_mask], status_mask, stats


def draw_matches_points(
    query_bgr: np.ndarray,
    ortho_bgr: np.ndarray,
    q_pts: np.ndarray,
    o_pts: np.ndarray,
    out_path: Path,
    max_matches: int = 120,
) -> None:
    q_kp = [cv2.KeyPoint(float(pt[0]), float(pt[1]), 1) for pt in q_pts]
    o_kp = [cv2.KeyPoint(float(pt[0]), float(pt[1]), 1) for pt in o_pts]
    matches = [cv2.DMatch(i, i, 0.0) for i in range(min(len(q_kp), max_matches))]
    vis = cv2.drawMatches(
        query_bgr,
        q_kp,
        ortho_bgr,
        o_kp,
        matches,
        None,
        flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
    )
    cv2.imwrite(str(out_path), vis)


def draw_region(query_bgr: np.ndarray, ortho_bgr: np.ndarray, h_mat: np.ndarray, out_path: Path) -> None:
    h, w = query_bgr.shape[:2]
    corners = np.float32([[0, 0], [w, 0], [w, h], [0, h]]).reshape(-1, 1, 2)
    projected = cv2.perspectiveTransform(corners, h_mat)
    canvas = ortho_bgr.copy()
    cv2.polylines(canvas, [np.int32(projected)], isClosed=True, color=(0, 0, 255), thickness=4)
    cv2.imwrite(str(out_path), canvas)


def select_pnp_matches(
    q_pts: np.ndarray,
    o_pts: np.ndarray,
    conf: np.ndarray,
    image_width: int,
    image_height: int,
    min_conf: float,
    grid_cols: int,
    grid_rows: int,
    max_per_cell: int,
    max_points: int,
) -> tuple[np.ndarray, np.ndarray, dict]:
    if len(q_pts) == 0:
        return q_pts, o_pts, {"input": 0, "kept": 0}

    valid = np.isfinite(conf)
    if min_conf > 0:
        valid &= conf >= min_conf
    if not np.any(valid):
        raise RuntimeError("No PnP matches remain after confidence filtering.")

    q_valid = q_pts[valid]
    o_valid = o_pts[valid]
    conf_valid = conf[valid]
    order = np.argsort(-conf_valid)

    use_grid = grid_cols > 0 and grid_rows > 0 and max_per_cell > 0
    if use_grid:
        grid_cols = max(1, grid_cols)
        grid_rows = max(1, grid_rows)
        max_per_cell = max(1, max_per_cell)
    cell_counts: dict[tuple[int, int], int] = {}
    keep = []
    for idx in order:
        if use_grid:
            col = min(grid_cols - 1, max(0, int(q_valid[idx, 0] * grid_cols / image_width)))
            row = min(grid_rows - 1, max(0, int(q_valid[idx, 1] * grid_rows / image_height)))
            key = (col, row)
            count = cell_counts.get(key, 0)
            if count >= max_per_cell:
                continue
            cell_counts[key] = count + 1
        keep.append(idx)
        if max_points > 0 and len(keep) >= max_points:
            break

    if len(keep) < 6:
        raise RuntimeError(f"Too few PnP matches after filtering: {len(keep)}")

    keep_arr = np.asarray(keep, dtype=np.int64)
    stats = {
        "input": int(len(q_pts)),
        "after_conf": int(len(q_valid)),
        "kept": int(len(keep_arr)),
        "use_grid": bool(use_grid),
        "occupied_cells": int(len(cell_counts)),
        "grid_cols": int(grid_cols),
        "grid_rows": int(grid_rows),
        "max_per_cell": int(max_per_cell),
        "max_points": int(max_points),
        "kept_conf_min": float(np.min(conf_valid[keep_arr])),
        "kept_conf_median": float(np.median(conf_valid[keep_arr])),
        "kept_conf_max": float(np.max(conf_valid[keep_arr])),
    }
    return q_valid[keep_arr], o_valid[keep_arr], stats


def parse_dist_coeffs(text: str | None, config_dist: list | None = None) -> np.ndarray:
    if text is None:
        values = config_dist if config_dist is not None else []
    elif not text.strip():
        values = []
    else:
        values = [float(v) for v in text.split(",")]
    if not values:
        return np.zeros((5, 1), dtype=np.float64)
    return np.asarray(values, dtype=np.float64).reshape(-1, 1)


def gps_from_exif(path: Path) -> tuple[float, float, float | None] | None:
    image = Image.open(path)
    exif = image.getexif()
    if not exif:
        return None
    gps_raw = None
    if hasattr(exif, "get_ifd"):
        gps_raw = exif.get_ifd(34853)
    if not gps_raw:
        for tag_id, value in exif.items():
            if TAGS.get(tag_id) == "GPSInfo":
                gps_raw = value
                break
    if not gps_raw:
        return None
    gps = {GPSTAGS.get(k, k): v for k, v in gps_raw.items()}
    if "GPSLatitude" not in gps or "GPSLongitude" not in gps:
        return None

    lat = dms_to_decimal(gps["GPSLatitude"], gps.get("GPSLatitudeRef", "N"))
    lon = dms_to_decimal(gps["GPSLongitude"], gps.get("GPSLongitudeRef", "E"))
    alt = None
    if "GPSAltitude" in gps:
        alt = float(gps["GPSAltitude"])
        alt_ref = gps.get("GPSAltitudeRef", 0)
        if isinstance(alt_ref, bytes):
            alt_ref = alt_ref[0] if alt_ref else 0
        if int(alt_ref) == 1:
            alt = -alt
    return lat, lon, alt


def dms_to_decimal(dms: Iterable, ref: str) -> float:
    deg, minute, sec = [float(v) for v in dms]
    value = deg + minute / 60.0 + sec / 3600.0
    if ref in ("S", "W"):
        value = -value
    return value


def build_pnp_points(
    ortho_path: Path,
    dsm_path: Path,
    query_points: np.ndarray,
    ortho_points_full: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, dict]:
    require_geotiff_deps()
    object_points = []
    image_points = []
    with rasterio.open(ortho_path) as ortho, rasterio.open(dsm_path) as dsm:
        to_dsm = Transformer.from_crs(ortho.crs, dsm.crs, always_xy=True)
        for q_pt, o_pt in zip(query_points, ortho_points_full):
            col, row = float(o_pt[0]), float(o_pt[1])
            x, y = ortho.xy(row, col)
            dx, dy = to_dsm.transform(x, y)
            values = list(dsm.sample([(dx, dy)]))
            if not values:
                continue
            z = float(values[0][0])
            if not np.isfinite(z):
                continue
            if dsm.nodata is not None and math.isclose(z, float(dsm.nodata)):
                continue
            object_points.append([x, y, z])
            image_points.append(q_pt)

        meta = {
            "ortho_crs": str(ortho.crs),
            "dsm_crs": str(dsm.crs),
            "ortho_transform": tuple(ortho.transform),
            "dsm_transform": tuple(dsm.transform),
        }

    if len(object_points) < 6:
        raise RuntimeError(f"Not enough valid DSM-backed points for PnP: {len(object_points)}")
    return np.asarray(object_points, dtype=np.float64), np.asarray(image_points, dtype=np.float64), meta


def solve_pose(object_points_map: np.ndarray, image_points: np.ndarray, camera_matrix: np.ndarray, dist_coeffs: np.ndarray):
    origin = object_points_map.mean(axis=0)
    object_points_local = object_points_map - origin
    ok, rvec, tvec, inliers = cv2.solvePnPRansac(
        object_points_local,
        image_points,
        camera_matrix,
        dist_coeffs,
        flags=cv2.SOLVEPNP_ITERATIVE,
        reprojectionError=3,
        iterationsCount=300,
        confidence=0.99,
    )
    if not ok or inliers is None:
        raise RuntimeError("solvePnPRansac failed.")
    inlier_idx = inliers.ravel()
    if len(inlier_idx) >= 6:
        try:
            rvec, tvec = cv2.solvePnPRefineLM(
                object_points_local[inlier_idx],
                image_points[inlier_idx],
                camera_matrix,
                dist_coeffs,
                rvec,
                tvec,
            )
        except AttributeError:
            ok_refine, rvec_refined, tvec_refined = cv2.solvePnP(
                object_points_local[inlier_idx],
                image_points[inlier_idx],
                camera_matrix,
                dist_coeffs,
                rvec,
                tvec,
                useExtrinsicGuess=True,
                flags=cv2.SOLVEPNP_ITERATIVE,
            )
            if ok_refine:
                rvec, tvec = rvec_refined, tvec_refined
    rot, _ = cv2.Rodrigues(rvec)
    camera_center_local = -rot.T @ tvec
    camera_center_map = camera_center_local.reshape(3) + origin
    return rvec, tvec, inlier_idx, camera_center_map


def draw_matches_colored(
    query_bgr: np.ndarray,
    ortho_bgr: np.ndarray,
    q_pts: np.ndarray,
    o_pts: np.ndarray,
    out_path: Path,
    max_matches: int = 100,
) -> None:
    """Draw colored match lines between two images (top-100 inliers, real-time display)."""
    import matplotlib.cm as cm

    h_q, w_q = query_bgr.shape[:2]
    vis = np.concatenate([query_bgr, ortho_bgr], axis=1)

    num = min(len(q_pts), max_matches)
    cmap = cm.get_cmap("hsv", num)

    for i in range(num):
        color = tuple(int(c * 255) for c in cmap(i)[:3])
        x1, y1 = int(q_pts[i, 0]), int(q_pts[i, 1])
        x2, y2 = int(o_pts[i, 0] + w_q), int(o_pts[i, 1])

        cv2.circle(vis, (x1, y1), 3, color, -1)
        cv2.circle(vis, (x2, y2), 3, color, -1)
        cv2.line(vis, (x1, y1), (x2, y2), color, 1, cv2.LINE_AA)

    cv2.imwrite(str(out_path), vis)


def main() -> None:
    args = parse_args()
    query_path = Path(args.query)
    ortho_path = Path(args.ortho)
    dsm_path = Path(args.dsm)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    gnss_config_path = find_gnss_config(query_path, args.gnss_config)
    camera_config = load_camera_from_gnss_config(gnss_config_path)

    with Image.open(query_path) as query_image:
        width = query_image.width
        height = query_image.height

    query_bgr, query_scale = read_image_for_matching(query_path, args.query_max_size)

    gps = gps_from_exif(query_path)
    if gps is None:
        raise RuntimeError("Query image has no EXIF GPS. GPS is required to crop the orthophoto search area.")
    gps_lat, gps_lon, gps_alt = gps

    match_h, match_w = query_bgr.shape[:2]
    ortho_bgr_full_crop, crop_origin, crop_meta = read_ortho_crop_near_gps(
        ortho_path,
        gps_lat,
        gps_lon,
        width,
        height,
    )
    if (match_w, match_h) == (width, height):
        ortho_bgr = ortho_bgr_full_crop
    else:
        ortho_bgr = cv2.resize(ortho_bgr_full_crop, (match_w, match_h), interpolation=cv2.INTER_AREA)

    crop_match_scale_x = match_w / float(width)
    crop_match_scale_y = match_h / float(height)
    crop_meta["match_size_wh"] = [int(match_w), int(match_h)]
    crop_meta["match_scale_xy"] = [float(crop_match_scale_x), float(crop_match_scale_y)]

    # ==========================================
    # RoMaV2 dense matching (replaces MASt3R)
    # ==========================================
    q_pts_scaled, o_pts_scaled, match_conf, romav2_stats = romav2_match_points(
        query_bgr,
        ortho_bgr,
        device=args.romav2_device,
        sample_num=args.romav2_sample_num,
    )

    q_pts_raw = q_pts_scaled.copy()
    o_pts_raw = o_pts_scaled.copy()
    match_conf_raw = match_conf.copy()

    # LK refinement
    if args.no_lk_refine:
        lk_stats = {"enabled": False, "input": int(len(q_pts_scaled)), "kept": int(len(q_pts_scaled))}
    else:
        q_pts_scaled, o_pts_scaled, lk_mask, lk_stats = refine_ortho_points_lk(
            query_bgr,
            ortho_bgr,
            q_pts_scaled,
            o_pts_scaled,
            args.lk_max_error,
        )
        match_conf = match_conf[lk_mask]

    # Homography RANSAC
    h_mat, inlier_mask = homography_inliers(q_pts_scaled, o_pts_scaled, args.ransac_threshold)
    inlier_count = int(inlier_mask.sum())

    # Fallback: if too few inliers, revert to unrefined matches
    used_lk_fallback = False
    if inlier_count < args.min_matches and not args.no_lk_refine:
        q_pts_scaled = q_pts_raw
        o_pts_scaled = o_pts_raw
        match_conf = match_conf_raw
        h_mat, inlier_mask = homography_inliers(q_pts_scaled, o_pts_scaled, args.ransac_threshold)
        inlier_count = int(inlier_mask.sum())
        used_lk_fallback = True
        lk_stats["fallback_to_unrefined"] = True
        lk_stats["fallback_reason"] = "too_few_homography_inliers_after_lk"
    else:
        lk_stats["fallback_to_unrefined"] = False

    # Draw match visualizations
    draw_matches_points(
        query_bgr,
        ortho_bgr,
        q_pts_scaled[inlier_mask],
        o_pts_scaled[inlier_mask],
        out_dir / "romav2_matches.jpg",
    )
    draw_matches_colored(
        query_bgr,
        ortho_bgr,
        q_pts_scaled[inlier_mask],
        o_pts_scaled[inlier_mask],
        out_dir / "romav2_matches_colored.jpg",
        max_matches=100,
    )
    cv2.imwrite(str(out_dir / "ortho_gps_crop.jpg"), ortho_bgr)
    draw_region(query_bgr, ortho_bgr, h_mat, out_dir / "matched_region_on_ortho.jpg")

    pre_pnp = {
        "query": str(query_path),
        "ortho": str(ortho_path),
        "dsm": str(dsm_path),
        "gnss_config": None if gnss_config_path is None else str(gnss_config_path),
        "matching_mode": "gps_center_crop",
        "query_scale": float(query_scale),
        "ortho_crop": crop_meta,
        "matcher": "romav2",
        "romav2": romav2_stats,
        "lk_refine": lk_stats,
        "romav2_matches": int(len(q_pts_scaled)),
        "homography_inliers": inlier_count,
        "outputs": {
            "matches": str(out_dir / "romav2_matches.jpg"),
            "matches_colored": str(out_dir / "romav2_matches_colored.jpg"),
            "ortho_crop": str(out_dir / "ortho_gps_crop.jpg"),
            "region": str(out_dir / "matched_region_on_ortho.jpg"),
        },
    }

        # ===== 在这之前插入 ↓ =====
    if "device" in pre_pnp and isinstance(pre_pnp["device"], torch.device):
        pre_pnp["device"] = str(pre_pnp["device"])
    # ===== 插入结束 ↑ =====

    (out_dir / "match_result.json").write_text(json.dumps(pre_pnp, indent=2), encoding="utf-8")

    if inlier_count < args.min_matches:
        raise RuntimeError(f"Too few homography inliers: {inlier_count} < {args.min_matches}")

    # Scale points back to full-resolution image coordinates
    q_full_scale = np.asarray([width / float(match_w), height / float(match_h)], dtype=np.float32)
    q_pts_full = q_pts_scaled[inlier_mask] * q_full_scale

    crop_col_off, crop_row_off = crop_origin
    crop_full_scale = np.asarray([width / float(match_w), height / float(match_h)], dtype=np.float32)
    o_pts_full = o_pts_scaled[inlier_mask] * crop_full_scale + np.asarray(
        [crop_col_off, crop_row_off], dtype=np.float32,
    )

    # PnP match selection (grid/confidence filtering)
    q_pts_full, o_pts_full, pnp_match_stats = select_pnp_matches(
        q_pts_full,
        o_pts_full,
        match_conf[inlier_mask],
        width,
        height,
        0.0,
        args.pnp_grid_cols,
        args.pnp_grid_rows,
        args.pnp_max_per_cell,
        args.pnp_max_points,
    )

    # Build 3D object points from DSM
    object_points, image_points, geo_meta = build_pnp_points(ortho_path, dsm_path, q_pts_full, o_pts_full)

    # Camera intrinsics
    fx = args.fx if args.fx is not None else camera_config.get("fx")
    fy = args.fy if args.fy is not None else camera_config.get("fy")
    if fx is None or fy is None:
        raise RuntimeError("Camera fx/fy are required. Pass --gnss-config or --fx/--fy.")
    cx = args.cx if args.cx is not None else camera_config.get("cx", width / 2.0)
    cy = args.cy if args.cy is not None else camera_config.get("cy", height / 2.0)
    camera_matrix = np.asarray([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
    dist_coeffs = parse_dist_coeffs(args.dist, camera_config.get("dist"))

    # Solve PnP
    rvec, tvec, pnp_inliers, camera_center_map = solve_pose(
        object_points, image_points, camera_matrix, dist_coeffs
    )

    # Compute GPS error
    require_geotiff_deps()
    with rasterio.open(ortho_path) as ortho:
        to_wgs84 = Transformer.from_crs(ortho.crs, "EPSG:4326", always_xy=True)
        pred_lon, pred_lat = to_wgs84.transform(camera_center_map[0], camera_center_map[1])

    geod = Geod(ellps="WGS84")
    _, _, horizontal_error_m = geod.inv(pred_lon, pred_lat, gps_lon, gps_lat)
    vertical_error_m = None
    error_3d_m = None
    if gps_alt is not None:
        vertical_error_m = float(camera_center_map[2]) - float(gps_alt)
        error_3d_m = math.sqrt(horizontal_error_m**2 + vertical_error_m**2)


    # 修复 device 不可序列化问题
    if "device" in pre_pnp and isinstance(pre_pnp["device"], torch.device):
        pre_pnp["device"] = str(pre_pnp["device"])


    # Final result
    result = {
        "query": str(query_path),
        "ortho": str(ortho_path),
        "dsm": str(dsm_path),
        "gnss_config": None if gnss_config_path is None else str(gnss_config_path),
        "matching_mode": "gps_center_crop",
        "query_scale": float(query_scale),
        "ortho_crop": crop_meta,
        "matcher": "romav2",
        "romav2": romav2_stats,
        "lk_refine": lk_stats,
        "romav2_matches": int(len(q_pts_scaled)),
        "homography_inliers": inlier_count,
        "pnp_match_filter": pnp_match_stats,
        "pnp_points_with_valid_dsm": int(len(object_points)),
        "pnp_inliers": int(len(pnp_inliers)),
        "camera_center_map_xyz": camera_center_map.tolist(),
        "camera_lon_lat_height": [float(pred_lon), float(pred_lat), float(camera_center_map[2])],
        "query_exif_gps_lon_lat_alt": [float(gps_lon), float(gps_lat), gps_alt],
        "gps_error_m": {
            "horizontal": float(horizontal_error_m),
            "vertical_signed": None if vertical_error_m is None else float(vertical_error_m),
            "vertical_abs": None if vertical_error_m is None else abs(float(vertical_error_m)),
            "3d": None if error_3d_m is None else float(error_3d_m),
        },
        "horizontal_gps_error_m": float(horizontal_error_m),
        "vertical_gps_error_m": None if vertical_error_m is None else abs(float(vertical_error_m)),
        "gps_error_3d_m": None if error_3d_m is None else float(error_3d_m),
        "rvec": rvec.reshape(-1).tolist(),
        "tvec_local_origin": tvec.reshape(-1).tolist(),
        "camera_matrix": camera_matrix.tolist(),
        "dist_coeffs": dist_coeffs.reshape(-1).tolist(),
        "geo": geo_meta,
        "outputs": {
            "matches": str(out_dir / "romav2_matches.jpg"),
            "matches_colored": str(out_dir / "romav2_matches_colored.jpg"),
            "ortho_crop": str(out_dir / "ortho_gps_crop.jpg"),
            "region": str(out_dir / "matched_region_on_ortho.jpg"),
        },
    }
    (out_dir / "pose_result.json").write_text(json.dumps(result, indent=2), encoding="utf-8")
    print(json.dumps(result, indent=2))


if __name__ == "__main__":
    main()

  

posted on 2026-06-11 00:04  MKT-porter  阅读(2)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2026
浙公网安备 33010602011771号 浙ICP备2021040463号-3