原生TensorFlow实现ESMM模型

import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import matplotlib
# 设置matplotlib使用Agg后端,避免中文字体问题
matplotlib.use('Agg')
# 设置matplotlib不使用中文
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 1. 准备数据
# 假设我们有一个电商数据集,包含用户ID、商品ID、类别ID、是否点击、是否购买等信息
# 生成模拟数据
np.random.seed(42)
n_samples = 100000

# 生成用户ID、商品ID和类别ID
user_ids = np.random.randint(1, 10001, size=n_samples)
item_ids = np.random.randint(1, 50001, size=n_samples)
category_ids = np.random.randint(1, 101, size=n_samples)
price = np.random.uniform(1, 1000, size=n_samples)
item_age_days = np.random.randint(1, 365, size=n_samples)

# 生成点击和购买标签
# 假设点击率约为20%
p_click = 1 / (1 + np.exp(-(0.1 * np.log(user_ids) - 0.05 * np.log(item_ids) + 0.02 * np.log(category_ids) - 0.1 * np.log(price) - 0.05 * np.log(item_age_days) + np.random.normal(0, 0.1, n_samples))))
clicks = np.random.binomial(1, p_click)

# 假设在点击的条件下,购买率约为10%
p_buy_given_click = 1 / (1 + np.exp(-(0.05 * np.log(user_ids) - 0.1 * np.log(item_ids) + 0.05 * np.log(category_ids) - 0.2 * np.log(price) - 0.02 * np.log(item_age_days) + np.random.normal(0, 0.1, n_samples))))
purchases = np.zeros_like(clicks)
purchases[clicks == 1] = np.random.binomial(1, p_buy_given_click[clicks == 1])

# 创建数据集
data = pd.DataFrame({
    'user_id': user_ids,
    'item_id': item_ids,
    'category_id': category_ids,
    'price': price,
    'item_age_days': item_age_days,
    'is_clicked': clicks,
    'is_purchased': purchases
})

# 计算CTCVR标签(点击并购买)
data['ctcvr'] = data['is_clicked'] * data['is_purchased']

print(f"Dataset size: {len(data)}")
print(f"Click samples: {data['is_clicked'].sum()}")
print(f"Purchase samples: {data['is_purchased'].sum()}")
print(f"Click rate: {data['is_clicked'].mean():.4f}")
print(f"Conditional conversion rate: {data['is_purchased'][data['is_clicked'] == 1].mean():.4f}")
print(f"CTCVR: {data['ctcvr'].mean():.4f}")

# 2. 特征处理
# 对类别特征进行编码
categorical_cols = ['user_id', 'item_id', 'category_id']
encoders = {}

for col in categorical_cols:
    encoders[col] = LabelEncoder()
    data[f"{col}_encoded"] = encoders[col].fit_transform(data[col])

# 将价格和商品年龄分箱处理,转为类别特征
data['price_bin'] = pd.qcut(data['price'], 10, labels=False)
data['item_age_bin'] = pd.qcut(data['item_age_days'], 5, labels=False)

encoders['price_bin'] = LabelEncoder()
encoders['item_age_bin'] = LabelEncoder()
data['price_bin_encoded'] = encoders['price_bin'].fit_transform(data['price_bin'])
data['item_age_bin_encoded'] = encoders['item_age_bin'].fit_transform(data['item_age_bin'])

# 3. 划分训练集和测试集
feature_cols = ['user_id_encoded', 'item_id_encoded', 'category_id_encoded', 'price_bin_encoded', 'item_age_bin_encoded']
X = data[feature_cols].values
y_ctr = data['is_clicked'].values.reshape(-1, 1)
y_ctcvr = data['ctcvr'].values.reshape(-1, 1)

X_train, X_test, y_ctr_train, y_ctr_test, y_ctcvr_train, y_ctcvr_test = train_test_split(
    X, y_ctr, y_ctcvr, test_size=0.2, random_state=42
)

# 4. 定义ESMM模型
class ESMM(tf.keras.Model):
    def __init__(self, vocab_sizes, embed_dims, dnn_hidden_units=(64, 32), dnn_activation='relu', dnn_dropout=0.0):
        super(ESMM, self).__init__()
        
        # 嵌入层
        self.embedding_layers = []
        for i, (vocab_size, embed_dim) in enumerate(zip(vocab_sizes, embed_dims)):
            self.embedding_layers.append(
                tf.keras.layers.Embedding(
                    input_dim=vocab_size,
                    output_dim=embed_dim,
                    name=f'embedding_{i}'
                )
            )
        
        # 共享底层网络
        self.shared_bottom = tf.keras.Sequential(name='shared_bottom')
        for units in dnn_hidden_units:
            self.shared_bottom.add(tf.keras.layers.Dense(units, activation=dnn_activation))
            if dnn_dropout > 0:
                self.shared_bottom.add(tf.keras.layers.Dropout(dnn_dropout))
        
        # CTR任务塔
        self.ctr_tower = tf.keras.Sequential(name='ctr_tower')
        self.ctr_tower.add(tf.keras.layers.Dense(32, activation=dnn_activation))
        self.ctr_tower.add(tf.keras.layers.Dense(1, activation='sigmoid', name='ctr_output'))
        
        # CTCVR任务塔
        self.ctcvr_tower = tf.keras.Sequential(name='ctcvr_tower')
        self.ctcvr_tower.add(tf.keras.layers.Dense(32, activation=dnn_activation))
        self.ctcvr_tower.add(tf.keras.layers.Dense(1, activation='sigmoid', name='ctcvr_output'))
    
    def call(self, inputs):
        # 嵌入层
        embeddings = []
        for i, embedding_layer in enumerate(self.embedding_layers):
            embeddings.append(embedding_layer(inputs[:, i]))
        
        # 拼接嵌入
        concat_embeddings = tf.concat(embeddings, axis=1)
        
        # 共享底层网络
        shared_output = self.shared_bottom(concat_embeddings)
        
        # CTR任务
        ctr_output = self.ctr_tower(shared_output)
        
        # CTCVR任务
        ctcvr_output = self.ctcvr_tower(shared_output)
        
        return ctr_output, ctcvr_output

# 5. 创建模型
vocab_sizes = [
    len(encoders['user_id'].classes_),
    len(encoders['item_id'].classes_),
    len(encoders['category_id'].classes_),
    len(encoders['price_bin'].classes_),
    len(encoders['item_age_bin'].classes_)
]

embed_dims = [32, 32, 16, 8, 8]  # 嵌入维度

model = ESMM(
    vocab_sizes=vocab_sizes,
    embed_dims=embed_dims,
    dnn_hidden_units=(128, 64, 32),
    dnn_activation='relu',
    dnn_dropout=0.2
)

# 6. 编译模型
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=['binary_crossentropy', 'binary_crossentropy'],
    loss_weights=[1.0, 1.0],
    metrics=[
        [tf.keras.metrics.AUC(name='ctr_auc'), tf.keras.metrics.BinaryAccuracy(name='ctr_acc')],
        [tf.keras.metrics.AUC(name='ctcvr_auc'), tf.keras.metrics.BinaryAccuracy(name='ctcvr_acc')]
    ]
)

# 7. 定义回调函数
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=3, restore_best_weights=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss', factor=0.5, patience=2, min_lr=1e-5
    ),
    tf.keras.callbacks.TensorBoard(log_dir='./logs/esmm')
]

# 8. 训练模型
batch_size = 1024
history = model.fit(
    X_train,
    [y_ctr_train, y_ctcvr_train],
    batch_size=batch_size,
    epochs=20,
    validation_data=(X_test, [y_ctr_test, y_ctcvr_test]),
    callbacks=callbacks
)

# 9. 评估模型
eval_results = model.evaluate(X_test, [y_ctr_test, y_ctcvr_test])
print("Test results:", eval_results)

# 10. 模型预测
ctr_preds, ctcvr_preds = model.predict(X_test)

# 计算CVR预测值 (CTCVR/CTR)
cvr_preds = ctcvr_preds / (ctr_preds + 1e-8)  # 添加小值避免除零错误

# 11. 结果分析
# 计算真实的CVR (仅对点击样本)
clicked_indices = y_ctr_test.flatten() == 1
true_cvr = y_ctcvr_test[clicked_indices] / y_ctr_test[clicked_indices]
pred_cvr_on_clicked = cvr_preds[clicked_indices]

# 计算点击样本上的CVR预测性能
from sklearn.metrics import roc_auc_score, log_loss, mean_squared_error

if np.sum(clicked_indices) > 0:
    cvr_auc = roc_auc_score(y_ctcvr_test[clicked_indices], pred_cvr_on_clicked)
    print(f"CVR AUC on clicked samples: {cvr_auc:.4f}")

# 12. 可视化训练过程
plt.figure(figsize=(15, 5))

# 损失函数
plt.subplot(1, 3, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# CTR AUC
plt.subplot(1, 3, 2)
plt.plot(history.history['ctr_auc'], label='Train CTR AUC')
plt.plot(history.history['val_ctr_auc'], label='Validation CTR AUC')
plt.title('CTR AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.legend()

# CTCVR AUC
plt.subplot(1, 3, 3)
plt.plot(history.history['ctcvr_auc'], label='Train CTCVR AUC')
plt.plot(history.history['val_ctcvr_auc'], label='Validation CTCVR AUC')
plt.title('CTCVR AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.legend()

plt.tight_layout()
plt.savefig('esmm_training_history.png')
plt.show()

# 13. 保存模型 - 修复文件名
model.save_weights('esmm_model_weights.weights.h5')  # 文件名必须以.weights.h5结尾

# 14. 模型应用示例 - 批量预测
def predict_batch(user_ids, item_ids, category_ids, price_bins, item_age_bins):
    # 编码输入特征
    encoded_user_ids = np.array([encoders['user_id'].transform([uid])[0] if uid in encoders['user_id'].classes_ else 0 for uid in user_ids])
    encoded_item_ids = np.array([encoders['item_id'].transform([iid])[0] if iid in encoders['item_id'].classes_ else 0 for iid in item_ids])
    encoded_category_ids = np.array([encoders['category_id'].transform([cid])[0] if cid in encoders['category_id'].classes_ else 0 for cid in category_ids])
    encoded_price_bins = np.array([encoders['price_bin'].transform([pb])[0] if pb in encoders['price_bin'].classes_ else 0 for pb in price_bins])
    encoded_item_age_bins = np.array([encoders['item_age_bin'].transform([ab])[0] if ab in encoders['item_age_bin'].classes_ else 0 for ab in item_age_bins])
    
    # 准备模型输入
    inputs = np.column_stack([
        encoded_user_ids,
        encoded_item_ids,
        encoded_category_ids,
        encoded_price_bins,
        encoded_item_age_bins
    ])
    
    # 模型预测
    ctr_preds, ctcvr_preds = model.predict(inputs)
    cvr_preds = ctcvr_preds / (ctr_preds + 1e-8)
    
    return ctr_preds.flatten(), cvr_preds.flatten(), ctcvr_preds.flatten()

# 示例:预测一批新样本
sample_users = data['user_id'].iloc[:5].values
sample_items = data['item_id'].iloc[:5].values
sample_categories = data['category_id'].iloc[:5].values
sample_price_bins = data['price_bin'].iloc[:5].values
sample_item_age_bins = data['item_age_bin'].iloc[:5].values

ctr_preds, cvr_preds, ctcvr_preds = predict_batch(
    sample_users, sample_items, sample_categories, sample_price_bins, sample_item_age_bins
)

# 打印预测结果
results_df = pd.DataFrame({
    'user_id': sample_users,
    'item_id': sample_items,
    'category_id': sample_categories,
    'predicted_ctr': ctr_preds,
    'predicted_cvr': cvr_preds,
    'predicted_ctcvr': ctcvr_preds,
    'actual_click': data['is_clicked'].iloc[:5].values,
    'actual_purchase': data['is_purchased'].iloc[:5].values
})

print("\nPrediction results example:")
print(results_df)
posted @ 2025-03-06 20:34  zedliu  阅读(57)  评论(0)    收藏  举报