be_weight是0.1

mask_list[0] (64, 64)
mask (18, 64, 64)
ignored (18, 64, 64)


def __call__(self, joints, sgm, ct_sgm, bg_weight=1.0):
assert self.num_joints_with_center == joints.shape[1], \
'the number of joints should be %d' % self.num_joints_with_center

hms = np.zeros((self.num_joints_with_center, self.output_res, self.output_res),
dtype=np.float32)
ignored_hms = 2*np.ones((self.num_joints_with_center, self.output_res, self.output_res),
dtype=np.float32)

hms_list = [hms, ignored_hms]

for p in joints:
for idx, pt in enumerate(p):
if idx < 17:
sigma = sgm
else:
sigma = ct_sgm
if pt[2] > 0:
x, y = pt[0], pt[1]
if x < 0 or y < 0 or \
x >= self.output_res or y >= self.output_res:
continue

ul = int(np.floor(x - 3 * sigma - 1)
), int(np.floor(y - 3 * sigma - 1))
br = int(np.ceil(x + 3 * sigma + 2)
), int(np.ceil(y + 3 * sigma + 2))

cc, dd = max(0, ul[0]), min(br[0], self.output_res)
aa, bb = max(0, ul[1]), min(br[1], self.output_res)

joint_rg = np.zeros((bb-aa, dd-cc))
for sy in range(aa, bb):
for sx in range(cc, dd):
joint_rg[sy-aa, sx -
cc] = self.get_heat_val(sigma, sx, sy, x, y)

hms_list[0][idx, aa:bb, cc:dd] = np.maximum(
hms_list[0][idx, aa:bb, cc:dd], joint_rg)
hms_list[1][idx, aa:bb, cc:dd] = 1.

hms_list[1][hms_list[1] == 2] = bg_weight

return hms_list

posted @ 2022-12-29 13:25  祥瑞哈哈哈  阅读(22)  评论(0)    收藏  举报