SAITS
SAITS
MIT(掩码插补任务)
- 目的 :通过在训练时隐藏部分已知观测值,教会模型如何预测缺失值。
- 掩码方式 :手动 / 人工掩码——随机掩藏一定比例(例如 20%)的观测条目。
- 损失计算 :仅针对人工掩码条目计算损失(即真实值已知但暂时隐藏的位置)。
ORT(观测重建任务)
- 目的 :确保模型不会降低已观测值的质量,即模型应保留已知的内容。
- 掩码方式 :不进行人工掩码,使用原始输入,其中包含真实的缺失值和观测值。
- 损失计算 :仅针对原始观测条目计算损失——如果模型输出与这些已知值有偏差,就会受到惩罚。
总结对比
| 方面 | MIT | ORT |
|---|---|---|
| 是否使用人工掩码 | ✅ 是 | ❌ 否 |
| 是否针对观测值 | ✅(暂时隐藏) | ✅(原始观测值) |
| 是否针对真实缺失值 | ✅(通过上下文间接影响) | ❌(忽略缺失条目) |
| 目的 | 教会模型如何插补 | 鼓励重建保真度 |
总结来说:
- MIT = “你能猜出这个隐藏的值吗?”
- ORT = “不要弄乱你已经知道的内容。”
需要展示训练流程图或伪训练循环,以展示 MIT 和 ORT 如何结合使用吗?
以下是关于 SAITS 中哪些关键组件对提高插补准确率贡献最大的内容翻译:
一、SAITS 中对准确率贡献最大的关键组件
| 排名 | 组件名称 | 对准确率的贡献程度 | 描述 |
|---|---|---|---|
| ⭐ 1 | 对角掩码自注意力(DMSA) | 至关重要 | 阻止每个时间步对自身进行注意力计算,迫使模型依据上下文(其他时间步)来推断缺失值。这是该模型的核心创新点。 |
| ⭐ 2 | 双重 DMSA 模块 | 高 | 堆叠的两个 DMSA 模块有助于学习更深的时序依赖关系和特征交互。每个模块旨在为插补提供不同的视角。 |
| ⭐ 3 | 加权组合层 | 高 | 使用学习到的权重(注意力分数)融合两个 DMSA 模块的输出,有效整合多视角信息。能够自适应地选择更优的插补结果。 |
| ⭐ 4 | 掩码插补任务(MIT) | 中 - 高 | 通过在训练期间隐藏已知值并学习恢复它们,促使模型提升缺失值预测能力。提供了强有力的监督。 |
| ⭐ 5 | 观测重建任务(ORT) | 中 | 确保模型保留观测数据,防止过度拟合到仅缺失的条目。有助于稳定训练过程。 |
| ✅ 6 | 位置嵌入 | 中 | 添加时间顺序信息,帮助模型区分不同的时间步。 |
二、削减实验的证据(来自论文)
- 移除 DMSA → 性能大幅下降。
- 仅使用 1 个 DMSA 模块 → 表现不如使用 2 个。
- 没有加权组合层 → 性能变差。
- 仅使用 MIT 或 ORT 中的一种任务 → 表现不如二者结合使用。
三、总结:关注重点
为了提高准确率(或对 SAITS 进行改进或适应性调整):
- DMSA 模块 是最为关键的。
- 使用多个 DMSA 模块并采用学习融合方式(加权组合)。
- 同时保留MIT 和 ORT 损失函数,以获得更好的泛化能力和鲁棒性。
需要根据这些见解提出关于如何进一步改进或修改 SAITS 的建议吗?
好的,以下是关于 MIT 和 ORT 的内容翻译:
一、MIT 和 ORT 是什么?
(一)MIT:掩码插补任务
- 目标 :通过在训练期间隐藏一些已观测数据,迫使模型学习如何填补缺失值。
- 工作原理 :
- 以一个本身存在部分缺失值的输入序列为起点。
- 随机掩码(隐藏)额外的观测值,例如,掩藏 20% 的观测值。
- 让模型对所有缺失值(原始缺失的和人为掩码的)进行预测。
- 仅在人为掩码的值上计算损失,引导模型训练。
- 损失函数 :在人为掩码的条目上计算绝对平均误差(MAE)
\(L_{MIT} = \text{MAE}(\hat{X}_c, X, I)\)
其中,I 是人为掩码值的二元指示符。
(二)ORT:观测重建任务
- 目标 :帮助模型保留对原始观测值的保真度。
- 工作原理 :
- 以包含观测值和缺失值的输入序列为起点。
- 在模型填补缺失值后,还对观测部分进行重建。
- 仅在原始观测值上计算损失。
- 损失函数 :在原始观测条目上计算绝对平均误差
\(L_{ORT} = \frac{1}{3} \left[ \text{MAE}(\tilde{X}_1, X, \hat{M}) + \text{MAE}(\tilde{X}_2, X, \hat{M}) + \text{MAE}(\tilde{X}_3, X, \hat{M}) \right]\)
其中,M^ 是缺失掩码,X1、X2、X~3 是不同层的输出。
二、MIT 和 ORT 的关键组件
| 组件名称 | 作用 |
|---|---|
X(原始输入) |
带有部分真实缺失值的原始多变量时间序列数据。 |
M(缺失掩码) |
二元矩阵:观测值为 1,原始缺失值为 0。 |
X̂(掩码输入) |
应用人工掩码后的输入(用于 MIT)。 |
I(指示掩码) |
二元矩阵:人为掩码值为 1,其他位置为 0。 |
X̂_c(完整输出) |
经过模型处理后的最终插补版本的输入。 |
| MAE 损失(用于 MIT 和 ORT) | 通过最小化已知真实值上的误差来训练模型。 |
三、为什么要同时使用 MIT 和 ORT?
- MIT 教会模型如何通过创建带有已知答案的受控 “测试用例” 来预测缺失值。
- ORT 确保模型不会忘记或损坏数据的 观测部分。
- 二者结合 ,使得模型能够:
- 通过 MIT 学习现实的插补方法。
- 通过 ORT 保持与真实观测数据的一致性。
需要展示训练中这两种损失是如何结合的伪代码或示意图吗?
以下是该论文的简要介绍:
任务:时间序列插补
目标:填补多变量时间序列数据中的缺失值,应对现实世界数据因传感器故障或不规则采样导致的缺失问题。
输入与输出数据:
- 输入 :多变量时间序列矩阵 X ∈ ℝT×D(T 为时间步数,D 为特征数,X 中存在缺失值,对应的掩码矩阵 M ∈ {0,1}T×D 用于指示观测值(1)或缺失值(0))。
- 输出 :对 X 中缺失值进行填补后的矩阵 X̂c ∈ ℝT×D 。
框架概览:SAITS
具有新颖的对角掩码自注意力(DMSA)模块的Transformer风格架构:
核心组件:
- 两个DMSA模块 :利用对角掩码多头自注意力,阻止时间步对自身进行注意力计算,促使模型基于上下文推断缺失值。
- 加权组合层 :依据缺失掩码及内部注意力分数学习注意力权重,融合两个DMSA模块的输出。
- 联合训练目标 :
- MIT(掩码插补任务) :遮蔽部分已知观测值,训练模型预测这些被遮蔽的值。
- ORT(观测重建任务) :重建原始观测值。
- 总损失函数:L = LORT + λL MIT 。
使用的数据集:
- PhysioNet-2012 :ICU临床数据(37个特征、48个时间步、80%数据缺失)。
- Air-Quality :跨越24个时间步的132个空气污染物特征。
- Electricity :370个电力读数,100个时间步。
- ETT :变压器温度数据(7个特征、24个时间步)。
需要展示SAITS架构或DMSA机制的示意图吗?
GP-Graph: Learning Pedestrian Group Representations for Multi-modal Trajectory Prediction
以下是该论文的简要介绍:
任务定义:行人轨迹预测
目标:通过学习个体行为和群体行为,在拥挤环境中预测行人符合社交规范的未来轨迹。
输入数据:
- 观测轨迹 :Xn = {(x t n, y t n) | t ∈ [1, ..., T obs ]} → 每个行人在过去时间步的二维坐标(例如,8帧=3.2秒)。
- 图结构 :
- 节点 :每个行人。
- 边 :社交互动(基于邻近性或相似性)。
输出数据:
- 预测轨迹 :Y^n = {(x t n, y t n) | t ∈ [T obs +1, ..., T pred ]} → 未来时间步的位置(例如,12帧=4.8秒)。
- 多模态预测 :对多种可能的未来情况的概率分布。
框架:GP-Graph架构
1. 群体分配模块
- 通过可学习的距离阈值,学习哪些行人属于同一个行为群体。
- 使用直通估计器实现离散群体分配的反向传播。
2. 图层次结构构建
- 组内图 :仅连接同一组成员之间的边。
- 组间图 :节点代表整个群体(通过池化),边捕获群体间的互动。
- 个体图 :个体之间的常规图。
3. 群体池化与解池化
- 池化 :将多个行人聚合到一个节点中(群体级表示)。
- 解池化 :利用共享的群体特征恢复个体节点级预测。
4. 轨迹预测
- 对所有 3 个图(个体图、组内图、组间图)应用共享预测模型。
- 通过群体融合模块融合输出,生成最终预测。
5. 群体级潜在向量采样
- 潜在向量引导多模态预测。
- 群体内的潜在向量共享,保持群体行为的一致性,同时允许群体间的多样性。
使用的数据集:
- ETH / UCY (留一法交叉验证)
- 斯坦福无人机数据集(SDD)
- Grand Central Station(GCS)
评估指标:
- ADE (平均位移误差)
- FDE (最终位移误差)
- COL (碰撞率)
- TCC (时间相关系数)
需要对框架组件进行总结性的示意图展示吗?
好的,以下是关于 “Unpooling” 以及 GP-Graph 管道的完整介绍:
一、“Unpooling” 是什么?
-
定义 :在 GP-Graph 中,“Unpooling” 是指从群体级别的池化特征中恢复出行人的个体特征。
-
动机 :
- 经过群体池化后,每个群体由一个单一的特征向量表示(例如,其成员的平均值)。
- 但轨迹预测最终必须针对每个个体进行。
- 因此,需要将群体特征广播回所有个体成员。
-
工作原理 :
- 设 Zk 为群体 Gk 的群体特征,Xn 为行人 n 的恢复特征。
- 则对于属于群体 Gk 的行人 n,Xn = Zk。
- 这意味着:
- 群体中的每个成员都会复制相同的群体级特征向量 Zk。
- 这个共享特征在预测过程中反映了群体的上下文。
二、预测管道 —— 三个图如何协同工作?
-
管道概览 :
[输入:过去轨迹 X] ↓ [群体分配模块] → 获取群体成员资格 ↓ [构建 3 个图]: 1. 个体图(G_ped) 2. 组内图(G_member) 3. 群间图(G_group,通过池化得到) ↓ [对每个图应用共享的 GNN / Transformer 模型] ↓ [对群间图输出进行 Unpool 操作(恢复到个体级别)] ↓ [融合所有三个输出 → 群体融合模块] ↓ [最终预测:Ŷ → 未来轨迹的概率图] -
三个图的使用详情 :
| 图类型 | 节点 | 边 | 用途 |
|----------|------|------|----------------------------------------------------|
| 个体图 | 单个行人 | 完全成对连接 | 捕捉人与人之间的直接交互 |
| 组内图 | 同上 | 仅组内边 | 建模组内动态(跟随、避开等) |
| 群间图 | 群体节点 | 群体之间 | 建模不同群体之间的高层行为 | -
共享预测模型 :
- 对每个图应用相同的 GNN/Transformer(共享权重)。
- 这确保了:
- 更少的参数量
- 一致的嵌入空间
- 多尺度推理(个体和群体)
-
群体融合模块 :
- 接收以下输出嵌入:
- 个体图预测头
- 组内图预测头
- 群间图预测头(Unpool 后)
- 融合它们(例如,通过连接或注意力机制)
- 输出最终轨迹分布 Ŷ
- 接收以下输出嵌入:
-
示例张量流程 :
- 输入 X:
X.shape = [N, T_obs, 2](N 个行人,T_obs 个观测步,2D 坐标) - 群体分配:输出群体索引
- 池化:
Z.shape = [K, d](K 个群体,d 特征维) - 图传递:
- 个体图:
GNN(X, G_ped) → H_ped - 组内图:
GNN(X, G_member) → H_member - 群间图:
GNN(Z, G_group) → H_group→ Unpool →H_group_upscaled
- 融合:
\[H_{\text{fused}} = \text{Fuse}(H_{\text{ped}}, H_{\text{member}}, H_{\text{group\_upscaled}}) \]- 预测:
\[\hat{Y} = \text{Decoder}(H_{\text{fused}}) \] - 输入 X:
需要该管道的示意图来更清晰地展示流程吗?
group的思路
You're asking the right questions. Let’s go step by step to fully understand what GP-Graph is doing and how graph, group, and trajectory prediction all tie together. I’ll explain it from the ground up.
✅ Big Picture: What is GP-Graph trying to do?
GP-Graph wants to predict the future trajectories of pedestrians in a crowded scene by understanding:
- how individuals behave
- how groups behave
- how groups interact with other groups
🧭 Step-by-Step Explanation of the Pipeline
🔹 Step 1: Input – Observed Trajectories
-
For each person \(i\), we have:
\[\text{Trajectory}_i = \{(x_1^i, y_1^i), ..., (x_{T_{obs}}^i, y_{T_{obs}}^i)\} \] -
These are the past 2D positions over time (e.g., 8 time steps).
✅ So input = N trajectories from a frame at a time (e.g., 8-step observed).
🔹 Step 2: Group Assignment
-
The model uses a learnable rule (based on proximity and similarity) to cluster pedestrians into groups.
-
Example:
- Group 1: Person 1, 2, 3
- Group 2: Person 4, 5
- Group 3: Person 6
✅ Now the crowd is divided into groups, but this assignment is learned and dynamic.
🔹 Step 3: Build Graphs
Now we use graph neural networks (GNNs) to model interactions.
But why use graphs?
Because we want to model interactions: who influences who, either individually or in a group.
Three graphs are built:
| Graph | Nodes | Edges | Purpose |
|---|---|---|---|
| Agent-wise | Individual people | All-to-all or nearby pedestrians | Local social interaction |
| Intra-group | Individuals in group | Only within the same group | Capture tight cooperation in the group |
| Inter-group | Groups (after pooling) | Between groups | Understand how whole groups affect each other |
🔹 Step 4: Encoding with GNN
Each of the three graphs is passed through the same GNN module (shared weights):
GNN(Graph) → Embeddings
You now have:
- H_ped: from agent-wise graph
- H_intra: from intra-group graph
- H_group: from inter-group graph
But H_group is still group-level, so we need:
🔹 Step 5: Unpooling
Unpooling means:
For each individual in a group, copy the group embedding to them.
Example:
- Group 1 embedding =
Z_1 - All members of Group 1 (e.g., 3 people) get
Z_1as their feature
🔹 Step 6: Fusion
Now each person has:
- Agent-wise feature
- Intra-group feature
- Unpooled inter-group feature
These are fused (e.g., concatenated or added) into one final embedding per person:
🔹 Step 7: Predict Future Trajectory
Finally, for each person \(i\), the model outputs:
This prediction can be:
- Multi-modal: multiple possible futures
- Sampled from a latent vector (shared within a group)
🧠 Summary Pipeline (In Your Words)
1. Input all observed trajectories
2. Split them into groups (learned dynamically)
3. Build 3 graphs:
a. agent-wise (individuals)
b. intra-group (within groups)
c. inter-group (between group embeddings)
4. Apply GNNs to extract interaction-aware embeddings
5. Unpool group-level embeddings back to individuals
6. Fuse all embeddings
7. Predict each trajectory’s future
Would you like me to draw a diagram of this pipeline to make it even clearer?

浙公网安备 33010602011771号