monodepth2学习5-数据读取
数据读取
monodepth2的数据读取比较简单主要是首先声明了一个基础的MonoDataset类,这个类继承了Dataset,内部主要实现了python中Dataset需要实现数据读取的三个方法—init,getitem,len,用于读取图片数据。这图片读取完成后,还实现了图片处理,进行了尺寸和亮度变换。具体代码如下:
class MonoDataset(data.Dataset):
"""Superclass for monocular dataloaders
Args:
data_path
filenames
height
width
frame_idxs
num_scales
is_train
img_ext
"""
def __init__(self,
data_path,
filenames,
height,
width,
frame_idxs,
num_scales,
is_train=False,
img_ext='.jpg'):
super(MonoDataset, self).__init__()
self.data_path = data_path
self.filenames = filenames
self.height = height
self.width = width
self.num_scales = num_scales#图像缩放比例
self.interp = Image.ANTIALIAS#抗锯齿设置
self.frame_idxs = frame_idxs
self.is_train = is_train
self.img_ext = img_ext#图片格式
self.loader = pil_loader#加载图片
self.to_tensor = transforms.ToTensor()#张量转化
# We need to specify augmentations differently in newer versions of torchvision.
# We first try the newer tuple version; if this fails we fall back to scalars
# 创建transforms,进行亮度和尺寸变换
try:
self.brightness = (0.8, 1.2)
self.contrast = (0.8, 1.2)
self.saturation = (0.8, 1.2)
self.hue = (-0.1, 0.1)
transforms.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
except TypeError:
self.brightness = 0.2
self.contrast = 0.2
self.saturation = 0.2
self.hue = 0.1
self.resize = {}
for i in range(self.num_scales):
s = 2 ** i
self.resize[i] = transforms.Resize((self.height // s, self.width // s),
interpolation=self.interp)
self.load_depth = self.check_depth()
def preprocess(self, inputs, color_aug):
"""Resize colour images to the required scales and augment if required
We create the color_aug object in advance and apply the same augmentation to all
images in this item. This ensures that all images input to the pose network receive the
same augmentation.
我们提前创建color_aug对象,并将相同的增强应用于所有对象
此项目中的图像。这可确保输入到姿势网络的所有图像都接收相同的增强。
预处理部分
"""
for k in list(inputs):
frame = inputs[k]
if "color" in k:
n, im, i = k
for i in range(self.num_scales):
inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
for k in list(inputs):
f = inputs[k]
if "color" in k:
n, im, i = k
inputs[(n, im, i)] = self.to_tensor(f)
inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
def __len__(self):
return len(self.filenames)
def __getitem__(self, index):
"""Returns a single training item from the dataset as a dictionary.
Values correspond to torch tensors.
Keys in the dictionary are either strings or tuples:
("color", <frame_id>, <scale>) for raw colour images,
("color_aug", <frame_id>, <scale>) for augmented colour images,
("K", scale) or ("inv_K", scale) for camera intrinsics,
"stereo_T" for camera extrinsics, and
"depth_gt" for ground truth depth maps.
<frame_id> is either:
an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
or
"s" for the opposite image in the stereo pair.
<scale> is an integer representing the scale of the image relative to the fullsize image:
-1 images at native resolution as loaded from disk
0 images resized to (self.width, self.height )
1 images resized to (self.width // 2, self.height // 2)
2 images resized to (self.width // 4, self.height // 4)
3 images resized to (self.width // 8, self.height // 8)
"""
inputs = {}
do_color_aug = self.is_train and random.random() > 0.5#随机进行颜色变化
do_flip = self.is_train and random.random() > 0.5#随机进行翻转
line = self.filenames[index].split()
folder = line[0]
if len(line) == 3:
frame_index = int(line[1])
else:
frame_index = 0
if len(line) == 3:
side = line[2]
else:
side = None
for i in self.frame_idxs:
if i == "s":
# 这部分主要是为了处理双目情况
other_side = {"r": "l", "l": "r"}[side]
inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
else:
inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
# adjusting intrinsics to match each scale in the pyramid
for scale in range(self.num_scales):
K = self.K.copy()
# 图片的尺寸变化也需要修改相机内参K,主要是修改中心点
K[0, :] *= self.width // (2 ** scale)
K[1, :] *= self.height // (2 ** scale)
inv_K = np.linalg.pinv(K)
inputs[("K", scale)] = torch.from_numpy(K)
inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
if do_color_aug:
color_aug = transforms.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
else:
color_aug = (lambda x: x)#这里表示直接使用原始图片
self.preprocess(inputs, color_aug)
for i in self.frame_idxs:
del inputs[("color", i, -1)]#在设立需要删除-1,因为处理后都表示,0,1,2,3的尺寸不同
del inputs[("color_aug", i, -1)]
if self.load_depth:
depth_gt = self.get_depth(folder, frame_index, side, do_flip)
inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
if "s" in self.frame_idxs:
stereo_T = np.eye(4, dtype=np.float32)
baseline_sign = -1 if do_flip else 1
side_sign = -1 if side == "l" else 1
stereo_T[0, 3] = side_sign * baseline_sign * 0.1
inputs["stereo_T"] = torch.from_numpy(stereo_T)
return inputs
def get_color(self, folder, frame_index, side, do_flip):#这部分会在其子类中实现,基本上都是辅助函数
raise NotImplementedError
def check_depth(self):
raise NotImplementedError
def get_depth(self, folder, frame_index, side, do_flip):
raise NotImplementedError
这个类是基类,具体使用类是KITTIRAWDataset,它继承了class KITTIDataset(MonoDataset),KITTTIDataset内部实现了check_depth,get_color,定义了K等参数。
具体代码如下:
class KITTIDataset(MonoDataset):
"""Superclass for different types of KITTI dataset loaders
"""
def __init__(self, *args, **kwargs):
super(KITTIDataset, self).__init__(*args, **kwargs)
# NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
# To normalize you need to scale the first row by 1 / image_width and the second row
# by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
# If your principal point is far from the center you might need to disable the horizontal
# flip augmentation.
self.K = np.array([[0.58, 0, 0.5, 0],
[0, 1.92, 0.5, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]], dtype=np.float32)
self.full_res_shape = (1242, 375)
self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
def check_depth(self):
line = self.filenames[0].split()
scene_name = line[0]
frame_index = int(line[1])
velo_filename = os.path.join(
self.data_path,
scene_name,
"velodyne_points/data/{:010d}.bin".format(int(frame_index)))
return os.path.isfile(velo_filename)
def get_color(self, folder, frame_index, side, do_flip):
color = self.loader(self.get_image_path(folder, frame_index, side))
if do_flip:
color = color.transpose(pil.FLIP_LEFT_RIGHT)
return color
class KITTIRAWDataset(KITTIDataset):
"""KITTI dataset which loads the original velodyne depth maps for ground truth
"""
def __init__(self, *args, **kwargs):
super(KITTIRAWDataset, self).__init__(*args, **kwargs)
def get_image_path(self, folder, frame_index, side):
f_str = "{:010d}{}".format(frame_index, self.img_ext)
image_path = os.path.join(
self.data_path, folder, "image_0{}/data".format(self.side_map[side]), f_str)
return image_path
def get_depth(self, folder, frame_index, side, do_flip):
calib_path = os.path.join(self.data_path, folder.split("/")[0])
velo_filename = os.path.join(
self.data_path,
folder,
"velodyne_points/data/{:010d}.bin".format(int(frame_index)))
depth_gt = generate_depth_map(calib_path, velo_filename, self.side_map[side])
depth_gt = skimage.transform.resize(
depth_gt, self.full_res_shape[::-1], order=0, preserve_range=True, mode='constant')
if do_flip:
depth_gt = np.fliplr(depth_gt)
return depth_gt
这部分不难理解。

浙公网安备 33010602011771号