DKT代码中数据处理部分

在这里本来有个代码看不懂,然后给他逐步分解:

skill_inputs = [torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1))[:-1]
                    for (s, l) in zip(skill_ids, labels)]

分解代码:
image

    for (s, l) in zip(skill_ids, labels):
        print(s,l)
        print(torch.zeros(1, dtype=torch.long), s * 2 + l + 1)
        print(torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1)))
        print(torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1))[:-1])

其中这是s

# s,也就是skill_ids
tensor([143, 143, 143, 143, 143, 143, 145, 142, 144, 142, 144, 145, 121, 120, 120, 121, 120, 121, 121])

#l,也就是labels
tensor([0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0])

# torch.zeros(1, dtype=torch.long)
tensor([0])

# s * 2 + l + 1
tensor([287, 288, 287, 287, 287, 287, 291, 285, 290, 285, 290, 291, 243, 242, 242, 243, 242, 243, 243])

# torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1))
tensor([  0, 287, 288, 287, 287, 287, 287, 291, 285, 290, 285, 290, 291, 243, 242, 242, 243, 242, 243, 243])

# torch.cat((torch.zeros(1, dtype=torch.long), s * 2 + l + 1))[:-1]
tensor([  0, 287, 288, 287, 287, 287, 287, 291, 285, 290, 285, 290, 291, 243, 242, 242, 243, 242, 243])
posted @ 2024-09-18 11:28  lipu123  阅读(26)  评论(0)    收藏  举报