hrnet的get_joint
他的第18个keypoint存的是平均坐标。
def get_joints(self, anno):
num_people = len(anno)
area = np.zeros((num_people, 1))
joints = np.zeros((num_people, self.num_joints_with_center, 3))
for i, obj in enumerate(anno):
joints[i, :self.num_joints, :3] = \
np.array(obj['keypoints']).reshape([-1, 3])
area[i, 0] = self.cal_area_2_torch(
torch.tensor(joints[i:i+1,:,:]))
if obj['area'] < 32**2:
joints[i, -1, 2] = 0
continue
joints_sum = np.sum(joints[i, :-1, :2], axis=0)#把同列的坐标进行总和。
num_vis_joints = len(np.nonzero(joints[i, :-1, 2])[0])#计算不为0的关键点数量,0在coco数据集中表示未标注。
if num_vis_joints <= 0:
joints[i, -1, :2] = 0#表示不存在关键点
else:
joints[i, -1, :2] = joints_sum / num_vis_joints#关键点的均值
joints[i, -1, 2] = 1#存在关键点。
return joints, area

浙公网安备 33010602011771号