bprmf
def evaluate(self, u_ids):
u_e = self.user_embeddings(u_ids)
return torch.matmul(u_e, self.item_embeddings.weight.t())
def forward(self, u_ids, i_ids):
u_e = self.user_embeddings(u_ids) # [batch_size, embedding_size]
i_e = self.item_embeddings(i_ids) # [batch_size, embedding_size]
return torch.bmm(u_e.unsqueeze(1), i_e.unsqueeze(2)).squeeze() # [batch_size]
posted on 2020-08-19 10:45 镸大的代价纯真都融化 阅读(80) 评论(0) 收藏 举报
浙公网安备 33010602011771号