3D Gaussian splatting 07: 代码阅读-训练载入数据和保存结果

目录

训练载入数据

在 train.py 中载入数据对应的方法调用栈如下, 因为convert.py预处理使用的是colmap, 读取数据最终调用的是 readColmapSceneInfo 方法

Scene(dataset, gaussians)
└─sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp)
  └─readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8)
    ├─readColmapCameras(cam_extrinsics, cam_intrinsics, depths_params, images_folder, depths_folder, test_cam_names_list)
    ├─read_points3D_binary(path_to_model_file)
    ├─storePly(path, xyz, rgb)
    └─fetchPly(path)

读取流程是

  1. 从 cameras.bin, images.bin 读取相机内参和图像外参(位姿)
  2. 区分训练集和测试集
  3. 从 points3D.bin 读取3D点云

在 readColmapSceneInfo() 方法中, 如果设置了--eval参数, 会将cam_names 排序后, 按序号与 llffhold 求余是否为0分为训练集和测试集. llffhold 值为8, 所以训练集与测试集的比例为 7:1. 如果没有指定, 则全部数据作为训练集. 如果要手工指定测试集, 可以在 sparse/0 下创建一个 test.txt, 将参数 llffhold 的默认值改为0.

if eval:
    if "360" in path:
        llffhold = 8
    if llffhold:
        print("------------LLFF HOLD-------------")
        cam_names = [cam_extrinsics[cam_id].name for cam_id in cam_extrinsics]
        cam_names = sorted(cam_names)
        test_cam_names_list = [name for idx, name in enumerate(cam_names) if idx % llffhold == 0]
    else:
        with open(os.path.join(path, "sparse/0", "test.txt"), 'r') as file:
            test_cam_names_list = [line.strip() for line in file]

得到 test_cam_names_list, 会在 readColmapCameras 读取 CameraInfo 时, 设置为镜头的 is_test 属性, 在后续训练,渲染和评估时, 用于区分是训练集还是测试集.

cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, depth_params=depth_params,
                        image_path=image_path, image_name=image_name, depth_path=depth_path,
                        width=width, height=height, is_test=image_name in test_cam_names_list)

读取镜头焦距时有一个转换, 从焦距转成视场角

focal_length_x = intr.params[0]
FovY = focal2fov(focal_length_x, height)

转换方法的定义

def fov2focal(fov, pixels):
    return pixels / (2 * math.tan(fov / 2))

def focal2fov(focal, pixels):
    return 2*math.atan(pixels/(2*focal))

概念:

  • 视场角 FOV (Field of View): 相机视场的角度范围, 通常按垂直/水平方向分别计算
  • 焦距 Focal Length: 相机从传感器到视场面的距离, 与成像大小成正比

fov2focal(fov, pixels)

  • 将视场角(FOV)转换为焦距(focal length)
  • 参数
    • fov: 视场角, 表示相机能观察到的角度范围
    • pixels: 图像传感器在某个维度(宽度或高度)的像素数量

基于针孔相机模型,焦距 \(f\) 的计算公式为

\[ f = \frac{\text{pixels}}{2 \cdot \tan(\text{FOV}/2)} \]

若图像高度为 1080 像素, 垂直 FOV 为 60°, 对应弧度为math.radians(60)

focal = 1080 / (2 * math.tan(math.radians(60)/2)) ≈ 935.3

focal2fov(focal, pixels):

  • 将焦距(focal length)转换为视场角(FOV)
  • 参数
    • focal: 焦距(单位: 像素)
    • pixels: 图像传感器在某个维度的像素数量

视场角 \(\theta\) 的计算公式为

\[ \theta = 2 \cdot \arctan\left(\frac{\text{pixels}}{2 \cdot f}\right) \]

若焦距为 935.3 像素, 图像高度为 1080 像素:

fov = 2 * math.atan(1080 / (2 * 935.3)) ≈ 1.047 弧度, 对应角度约 60°

再下面会判断是否有 points3D.ply, 存在就读取, 不存在就创建一个再读取

    ply_path = os.path.join(path, "sparse/0/points3D.ply")
    bin_path = os.path.join(path, "sparse/0/points3D.bin")
    txt_path = os.path.join(path, "sparse/0/points3D.txt")
    if not os.path.exists(ply_path):
        print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
        try:
            xyz, rgb, _ = read_points3D_binary(bin_path)
        except:
            xyz, rgb, _ = read_points3D_text(txt_path)
        storePly(ply_path, xyz, rgb)
    try:
        pcd = fetchPly(ply_path)

从 points3D.bin 读取3D点云

def read_points3D_binary(path_to_model_file):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadPoints3DBinary(const std::string& path)
        void Reconstruction::WritePoints3DBinary(const std::string& path)
    """

    with open(path_to_model_file, "rb") as fid:
        num_points = read_next_bytes(fid, 8, "Q")[0]

        # 创建未初始化的 n * 3 数组, 随机值
        xyzs = np.empty((num_points, 3))
        rgbs = np.empty((num_points, 3))
        errors = np.empty((num_points, 1))

        for p_id in range(num_points):
            binary_point_line_properties = read_next_bytes(
                fid, num_bytes=43, format_char_sequence="QdddBBBd")
            xyz = np.array(binary_point_line_properties[1:4])
            rgb = np.array(binary_point_line_properties[4:7])
            error = np.array(binary_point_line_properties[7])
            track_length = read_next_bytes(
                fid, num_bytes=8, format_char_sequence="Q")[0]
            track_elems = read_next_bytes(
                fid, num_bytes=8*track_length,
                format_char_sequence="ii"*track_length)
            xyzs[p_id] = xyz
            rgbs[p_id] = rgb
            errors[p_id] = error
    return xyzs, rgbs, errors

里面用到的read_next_bytes方法, 读取一段二进制字节, 使用struct.unpack按指定的格式, 转为对应的变量

def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
    """Read and unpack the next bytes from a binary file.
    :param fid:
    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
    :param endian_character: Any of {@, =, <, >, !}
    :return: Tuple of read and unpacked values.
    """
    data = fid.read(num_bytes)
    return struct.unpack(endian_character + format_char_sequence, data)

BasicPointCloud

BasicPointCloud 用于表示三维点云的基础数据结构, 包含坐标、颜色和法线信息

class BasicPointCloud(NamedTuple):
    points : np.array
    colors : np.array
    normals : np.array

def geom_transform_points(points, transf_matrix):
    # 将点转换为齐次坐标后应用变换矩阵, 返回经过投影变换后的三维坐标, PyTorch实现的齐次坐标变换,支持批量变换操作
    P, _ = points.shape
    ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
    points_hom = torch.cat([points, ones], dim=1)
    points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))

    denom = points_out[..., 3:] + 0.0000001
    return (points_out[..., :3] / denom).squeeze(dim=0)

def getWorld2View(R, t):
    # 创建世界坐标系到相机坐标系的4x4变换矩阵 R: 3x3旋转矩阵,t: 3D平移向量
    Rt = np.zeros((4, 4))
    Rt[:3, :3] = R.transpose()
    Rt[:3, 3] = t
    Rt[3, 3] = 1.0
    return np.float32(Rt)

def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
    # 增强版视图变换,支持场景平移和缩放, 通过相机到世界坐标系的逆变换实现
    Rt = np.zeros((4, 4))
    Rt[:3, :3] = R.transpose()
    Rt[:3, 3] = t
    Rt[3, 3] = 1.0

    C2W = np.linalg.inv(Rt)
    cam_center = C2W[:3, 3]
    cam_center = (cam_center + translate) * scale
    C2W[:3, 3] = cam_center
    Rt = np.linalg.inv(C2W)
    return np.float32(Rt)

def getProjectionMatrix(znear, zfar, fovX, fovY):
    # 生成透视投影矩阵 参数包含近/远裁剪面,水平和垂直视场角 返回4x4投影矩阵
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P

def fov2focal(fov, pixels):
    # 视场角转焦距(单位:像素)
    return pixels / (2 * math.tan(fov / 2))

def focal2fov(focal, pixels):
    # 焦距转视场角
    return 2*math.atan(pixels/(2*focal))

训练结果数据结构

安装 pyntcloud

pip install pyntcloud

查看 ply 文件

>>> from pyntcloud import PyntCloud
>>> cloud = PyntCloud.from_file("output/1ed8e6a1-9/point_cloud/iteration_7000/point_cloud.ply")
>>> print(cloud)
PyntCloud
743269 points with 59 scalar fields
0 faces in mesh
0 kdtrees
0 voxelgrids
Centroid: 1.6537141799926758, -2.9306182861328125, -4.471662521362305

>>> type(cloud.points)
<class 'pandas.core.frame.DataFrame'>

点的数据类型是 DataFrame, 查看第一个点的属性列, 每一项都是float32/4个字节, 但是属性太多被省略了

>>> print(cloud.points.loc[0])
x          1.947371
y         -0.500535
z          1.388533
nx         0.000000
ny         0.000000
             ...   
scale_2   -4.380099
rot_0      0.840099
rot_1     -0.143527
rot_2      0.065419
rot_3      0.179504
Name: 0, Length: 62, dtype: float32

此去掉rows限制, 就可以打印全貌了

>>> pd.set_option('display.max_rows', None)
>>> print(cloud.points.loc[0])
x            1.947371
y           -0.500535
z            1.388533
nx           0.000000
ny           0.000000
nz           0.000000
f_dc_0      -0.264158
f_dc_1       0.352959
f_dc_2       0.361867
f_rest_0     0.012889
f_rest_1    -0.001385
f_rest_2     0.044487
f_rest_3     0.013909
# 省略 f_rest_ 开头的字段
f_rest_41   -0.038870
f_rest_42   -0.015730
f_rest_43    0.042109
f_rest_44    0.021378
opacity     -1.817663
scale_0     -5.108221
scale_1     -4.811676
scale_2     -4.380099
rot_0        0.840099
rot_1       -0.143527
rot_2        0.065419
rot_3        0.179504
Name: 0, dtype: float32

结果数据输出的时候是通过拼接参数产生的

attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)

对应的属性含义

  • x, y, z: 3D点云位置坐标
  • nx, ny, nz: 未使用
  • f_dc_0 - f_dc_2, f_rest_0 - f_rest_44: 颜色特征的DC分量和剩余分量, 3阶一共16个RGB球谐系数
  • opacity: 不透明度参数
  • scale_0 - scale_2: 缩放参数
  • rot_0 - rot_3: 旋转参数

posted on 2025-06-03 18:24  Milton  阅读(319)  评论(0)    收藏  举报

导航