GlenTt

导航

PLE模型简洁解读

PLE模型简洁解读

image

基础设定

  • 有 2 个任务:CTR、CVR
  • 使用 1 层 PLE(num_levels = 1)
  • 每个任务 2 个任务特定专家(specific_expert_num = 2)
  • 有 1 个共享专家(shared_expert_num = 1)
  • 输入 embedding 是:[batch_size, 64] 的拼接向量

我们来看看“这一层”里的每一个步骤数据是如何流动的。


第 1 步:准备输入

ple_inputs = [x_ctr, x_cvr, x_shared]
  • x_ctr = CTR 的输入 = 原始 embedding 向量 [B, 64]
  • x_cvr = CVR 的输入 = 同上
  • x_shared = Shared 的输入 = 同上

注意:这三个向量在第 1 层是一样的,但在后续层会变得不同。


第 2 步:任务专家和共享专家网络

每个任务的 experts:

每个任务有 2 个 specific expert,输入是自己:

  • CTR 的两个专家 → 输入 x_ctr → 输出 [B, 64]
  • CVR 的两个专家 → 输入 x_cvr → 输出 [B, 64]

共享 experts:

只有 1 个共享专家,输入是 x_shared,输出 [B, 64]


第 3 步:Gate 网络

我们看 CTR 任务的 gate 是怎么处理的:

CTR 的 gate 做了什么?

  1. 输入: x_ctr → shape [B, 64]

  2. 过一个小 DNN: 输出变为 [B, H](比如 H=32)

  3. 线性变换 + softmax: 输出为 [B, 3],表示对 3 个专家的权重:

    • expert_1_ctr
    • expert_2_ctr
    • expert_shared
gate_input = DNN(...)(x_ctr)   # [B, 32]
gate_weights = Dense(3, activation='softmax')(gate_input)  # [B, 3]

第 4 步:Gate × Experts

将所有专家输出堆叠:

expert_outputs = tf.stack([expert_1_ctr, expert_2_ctr, expert_shared], axis=1)  # [B, 3, 64]

将 gate 权重 reshape:

gate_weights = tf.expand_dims(gate_weights, -1)  # [B, 3, 1]

点乘加权求和:

fused_output = tf.reduce_sum(expert_outputs * gate_weights, axis=1)  # [B, 64]

✅ 这就是 CTR 任务在这一层提取到的特征,来自自己和共享专家的动态组合。

CVR 任务也完全一样,只是换成用 x_cvr 输入,构建自己的 gate 和专家融合。


下一层(若存在):

然后这些输出(fused_output_ctr, fused_output_cvr, fused_output_shared)会作为下一层的输入,继续重复这一机制。

每一层都会重新生成:

  • 专家网络(不同任务分开)
  • gate(使用该层输入为条件)

从而实现「逐层提纯」。


Gate 的本质:

Gate 是一个小的 DNN 网络,输入是当前任务的 embedding,输出是对所有专家的 softmax 权重
决定了“这个任务现在要听谁的话”


PLE的pytorch实现

class Expert(nn.Module):
    def __init__(self, input_dim, expert_dim):
        super(Expert, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dim, expert_dim),
            nn.ReLU(),
            nn.BatchNorm1d(expert_dim),
            nn.Dropout(0.2)
        )

    def forward(self, x):
        return self.layer(x)

class Gate(nn.Module):
    def __init__(self, input_dim, n_experts):
        super(Gate, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(input_dim, n_experts),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        weights = self.gate(x)  # [B, n_experts]
        return weights.unsqueeze(-1)  # [B, n_experts, 1]

class PLELayer(nn.Module):
    def __init__(self, input_dim, expert_dim, n_tasks, n_task_experts, n_shared_experts):
        super(PLELayer, self).__init__()
        self.n_tasks = n_tasks
        self.task_experts = nn.ModuleList([
            nn.ModuleList([Expert(input_dim, expert_dim) for _ in range(n_task_experts)])
            for _ in range(n_tasks)
        ])
        self.shared_experts = nn.ModuleList([
            Expert(input_dim, expert_dim) for _ in range(n_shared_experts)
        ])
        self.task_gates = nn.ModuleList([
            Gate(input_dim, n_task_experts + n_shared_experts)
            for _ in range(n_tasks)
        ])
        self.shared_gate = Gate(input_dim, n_tasks * n_task_experts + n_shared_experts)

    def forward(self, task_inputs, shared_input):
        # Compute expert outputs
        task_outputs = []
        for i in range(self.n_tasks):
            task_outputs.append([expert(task_inputs[i]) for expert in self.task_experts[i]])
        shared_outputs = [expert(shared_input) for expert in self.shared_experts]

        # Task-specific gate outputs
        next_task_inputs = []
        for i in range(self.n_tasks):
            all_expert_outputs = task_outputs[i] + shared_outputs
            stacked = torch.stack(all_expert_outputs, dim=1)  # [B, n_experts, D]
            weights = self.task_gates[i](task_inputs[i])      # [B, n_experts, 1]
            fused = torch.sum(stacked * weights, dim=1)       # [B, D]
            next_task_inputs.append(fused)

        # Shared gate output (for next layer's shared input)
        flat_all_experts = sum(task_outputs, []) + shared_outputs
        stacked_shared = torch.stack(flat_all_experts, dim=1)
        shared_weights = self.shared_gate(shared_input)
        next_shared_input = torch.sum(stacked_shared * shared_weights, dim=1)  # [B, D]

        return next_task_inputs, next_shared_input

class PLE(nn.Module):
    # 正确处理多层维度
    def __init__(self, input_dim, expert_dim, n_tasks=3, n_layers=2,
                 n_task_experts=2, n_shared_experts=1):
        super(PLE, self).__init__()
        self.n_tasks = n_tasks
        self.ple_layers = nn.ModuleList()
        
        # 为每一层设置正确的输入维度
        for layer_idx in range(n_layers):
            if layer_idx == 0:
                # 第一层:使用原始输入维度
                current_input_dim = input_dim
            else:
                # 后续层:使用expert输出维度作为输入
                current_input_dim = expert_dim
                
            self.ple_layers.append(
                PLELayer(
                    input_dim=current_input_dim,  # 动态设置输入维度
                    expert_dim=expert_dim,
                    n_tasks=n_tasks,
                    n_task_experts=n_task_experts,
                    n_shared_experts=n_shared_experts
                )
            )

    def forward(self, x):
        # Initial input: shared across all tasks and shared experts
        task_inputs = [x for _ in range(self.n_tasks)]
        shared_input = x

        for layer in self.ple_layers:
            task_inputs, shared_input = layer(task_inputs, shared_input)

        return task_inputs  # final task-specific vectors [task1_repr, task2_repr, task3_repr] 

posted on 2025-08-07 20:10  GRITJW  阅读(128)  评论(0)    收藏  举报