DKT代码中数据处理部分
在这里本来有个代码看不懂,然后给他逐步分解:
skill_inputs = [torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1))[:-1]
for (s, l) in zip(skill_ids, labels)]
分解代码:
for (s, l) in zip(skill_ids, labels):
print(s,l)
print(torch.zeros(1, dtype=torch.long), s * 2 + l + 1)
print(torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1)))
print(torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1))[:-1])
其中这是s
# s,也就是skill_ids
tensor([143, 143, 143, 143, 143, 143, 145, 142, 144, 142, 144, 145, 121, 120, 120, 121, 120, 121, 121])
#l,也就是labels
tensor([0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0])
# torch.zeros(1, dtype=torch.long)
tensor([0])
# s * 2 + l + 1
tensor([287, 288, 287, 287, 287, 287, 291, 285, 290, 285, 290, 291, 243, 242, 242, 243, 242, 243, 243])
# torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1))
tensor([ 0, 287, 288, 287, 287, 287, 287, 291, 285, 290, 285, 290, 291, 243, 242, 242, 243, 242, 243, 243])
# torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1))[:-1]
tensor([ 0, 287, 288, 287, 287, 287, 287, 291, 285, 290, 285, 290, 291, 243, 242, 242, 243, 242, 243])