torch-rechub学习打卡笔记(一)

【Torch-RecHub 学习笔记】Task 1:环境搭建与 DSSM 召回实战

1. 任务背景

在推荐系统领域,高效的特征处理和模型框架是开展研究的基础。本次 Task 1 的核心目标是完成 Torch-RecHub 环境配置,并跑通基础的召回模型流程。


2. 环境准备

根据官方文档要求,确保系统满足以下基础配置:

  • Python: 3.9+
  • PyTorch: 1.7+(推荐 CUDA 版本)
  • 核心依赖: NumPy, Pandas, SciPy, Scikit-learn

3. 框架设计与 DSSM 模型

Torch-RecHub 采用模块化设计,将推荐任务抽象为 特征(Features)模型(Models)训练器(Trainers)

本次实验使用 DSSM(Deep Structured Semantic Model) 作为召回模型。DSSM 是一种经典的双塔结构,通过分别构建用户塔和物品塔,将二者映射到同一低维向量空间中,再利用向量相似度完成匹配。


4. 实验流程说明

整体实验流程如下:

  1. 数据加载与预处理(MovieLens-1M 采样数据)
  2. 类别特征编码(Label Encoding)
  3. 构建用户特征与物品特征
  4. 构造序列特征与训练样本
  5. 定义 DSSM 双塔模型
  6. 模型训练
  7. 用户与物品向量导出

5. 关键代码实现

5.1 数据加载与预处理

import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import LabelEncoder

from torch_rechub.basic.features import SparseFeature, SequenceFeature
from torch_rechub.models.matching import DSSM
from torch_rechub.trainers import MatchTrainer
from torch_rechub.utils.data import df_to_dict, MatchDataGenerator
from torch_rechub.utils.match import generate_seq_feature_match, gen_model_input

torch.manual_seed(2022)

# Load data
data_url = "https://raw.githubusercontent.com/datawhalechina/torch-rechub/main/examples/matching/data/ml-1m/ml-1m_sample.csv"
data = pd.read_csv(data_url)
print(f"Dataset size: {len(data)} records")

# Category feature
data["cate_id"] = data["genres"].apply(lambda x: x.split("|")[0])

5.2 特征编码

user_col, item_col = "user_id", "movie_id"
sparse_features = ["user_id", "movie_id", "gender", "age", "occupation", "zip", "cate_id"]

feature_max_idx = {}
for feat in sparse_features:
    encoder = LabelEncoder()
    data[feat] = encoder.fit_transform(data[feat]) + 1
    feature_max_idx[feat] = data[feat].max() + 1

5.3 用户与物品画像构建

user_cols = ["user_id", "gender", "age", "occupation", "zip"]
item_cols = ["movie_id", "cate_id"]

user_profile = data[user_cols].drop_duplicates("user_id")
item_profile = data[item_cols].drop_duplicates("movie_id")

5.4 序列特征与训练数据生成

df_train, df_test = generate_seq_feature_match(
    data,
    user_col,
    item_col,
    time_col="timestamp",
    item_attribute_cols=[],
    sample_method=1,
    mode=0,
    neg_ratio=3,
    min_item=0
)

x_train = gen_model_input(df_train, user_profile, user_col, item_profile, item_col, seq_max_len=50)
y_train = x_train["label"]

x_test = gen_model_input(df_test, user_profile, user_col, item_profile, item_col, seq_max_len=50)

5.5 特征类型定义(重点修正)

user_features = [
    SparseFeature(name, vocab_size=feature_max_idx[name], embed_dim=16)
    for name in user_cols
]

user_features += [
    SequenceFeature(
        "hist_movie_id",
        vocab_size=feature_max_idx["movie_id"],
        embed_dim=16,
        pooling="mean",
        shared_with="movie_id"
    )
]

item_features = [
    SparseFeature(name, vocab_size=feature_max_idx[name], embed_dim=16)
    for name in item_cols
]

5.6 DataLoader 与模型定义

all_item = df_to_dict(item_profile)
test_user = x_test

dg = MatchDataGenerator(x=x_train, y=y_train)
train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)

model = DSSM(
    user_features,
    item_features,
    temperature=0.02,
    user_params={"dims": [128, 64, 32], "activation": "prelu"},
    item_params={"dims": [128, 64, 32], "activation": "prelu"},
)

5.7 模型训练与向量导出

trainer = MatchTrainer(
    model,
    mode=0,
    optimizer_params={"lr": 1e-4, "weight_decay": 1e-6},
    n_epoch=3,
    device="cpu",
)

trainer.fit(train_dl)

user_embedding = trainer.inference_embedding(model, mode="user", data_loader=test_dl, model_path="./")
item_embedding = trainer.inference_embedding(model, mode="item", data_loader=item_dl, model_path="./")

print(user_embedding.shape)
print(item_embedding.shape)

6. 运行结果与分析

  • 样本规模:共处理 100 条采样记录,生成训练集 384 条,测试集 2 条。
  • 向量维度:成功导出用户与物品向量,Embedding 维度均为 32。
User embedding shape: torch.Size([2, 32])
Item embedding shape: torch.Size([93, 32])

7. 学习总结

通过 Task 1 的学习,我完整跑通了从环境搭建、特征工程到 DSSM 模型推理的全过程。Torch-RecHub 在特征抽象与训练流程上的高度封装,使得实验实现更加清晰高效。后续计划将社区搜索与召回模型相结合,探索更复杂的交互式推荐场景。


posted @ 2026-02-10 18:17  oyller  阅读(4)  评论(0)    收藏  举报