offline meta-RL | 经典论文速读记录



也请参考:offline meta-RL | 近期工作速读记录


(MAML) Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks [ICML 2017]

(MACAW) Offline Meta-Reinforcement Learning with Advantage Weighting [ICML 2021]

主要内容:

  • 提出了 offline meta-RL 的 setting:offline multi-task 数据集 + 新任务的少量 offline 数据(<5 条轨迹)用于适应新任务。
  • method:增强版 AWR(一种 offline 方法)+ MAML。
    • 内核替换:将 MAML 的策略梯度换成 AWR 回归(离线友好);
    • 增强表达能力:简单 MAML + AWR 会失败,因为 AWR 梯度信息量不足。MACAW 增加优势回归头,让梯度能同时编码“动作该是什么”和“优势有多大”;
    • 架构升级:引入权重变换层,突破普通 MLP"秩1更新"的限制,让内循环更强大。
  • 实验环境:MuJoCo 的 cheetah-direction、cheetah-velocity、walker-params、ant-direction。
  • 很好奇它的 baseline 是怎么做的,meta-BC 和 multi-task offline RL with fine-tuning 是怎么做的。

(PEARL) Efficient Off-Policy Meta-Reinforcement Learning via Probabilistic Context Variables [ICML 2019]

VariBAD: A Very Good Method for Bayes-Adaptive Deep RL via Meta-Learning

setting:

  • 可以与 N 个 task online 交互,它们的 state 和 action space 是相同的,但 transition 和 reward 可能不同。
  • evaluate 时,给出一个新 task,希望用尽可能少的步数,得到一个在这个 task 上表现好的策略。
  • varibad 生成,使用贝叶斯思想,policy 可以十分智能:在对 task 的推测不够确定的情况下,policy 会主动探索,降低 task 推测的不确定性;当对 task 的推测足够确定,policy 则会 exploitation。

关于 BAMDP(Bayes-adaptive Markov decision process):

  • BAMDP 这个 formulation 好像适用于 meta-RL。
  • BAMDP 把 state 扩展为了 (s, b),其中 b 是对于 task 猜测的 belief。
  • BAMDP 的 transition 公式是(公式 1),

\[\begin{aligned} T^+(s_{t+1}^+ | s_t^+, a_t) = & \underbrace{\mathbb{E}_{b_t}[T(s_{t+1}|s_t,a_t)]}_{\text{期望转移}} \\ & \cdot \underbrace{\delta\big(b_{t+1} = \text{BayesUpdate}(b_t, s_t,a_t,r_t,s_{t+1})\big)}_{\text{确定性信念更新}} \end{aligned} \]

基于贝叶斯思想的 meta-RL vs. 迭代 \(z = q(z | c)\)\(a = \pi(a | s,z)\) 的 meta-RL(如 focal):

  • 后者的流程是,encoder 输出一个确定性或单点采样的 task representation z,encoder 的目标是最小化 task 识别误差。这样执行的策略 \(\pi(a∣s,z)\) 仅依赖于猜测的单个 task,而不包含猜测 task 的不确定性信息;若 z 猜错,再使用后续数据修正 z。最优策略与 task 推断是解耦的。
  • 前者的思想是,encoder 输出猜测 task 的完整后验分布(后验的意思是 已经有环境相关的 context \(\tau\) 信息了,基于这个 \(\tau\) 再猜 task),varibad 把这个分布建模成正态分布,包含均值 \(\mu\) 和方差 \(\Sigma\)。然后,策略的形式是 \(\pi(a|s, \mu, \Sigma)\),策略能知道当前对 task 的猜测是否不确定性比较大,并通过训练,学会主动降低 task 猜测的不确定性。

varibad 包含三个模块:

  • encoder(推断网络)\(q_\phi(m|\tau_{:t})\),输入是历史轨迹 \(\tau_{:t} = (s_0, a_0, r_1, ..., s_t)\),输出是 task 猜测 m 的后验分布 \(q_\phi(m|\tau_{:t}) = \mathcal{N}(\mu_t, \sigma_t^2)\)
    • 实现:历史序列喂入 GRU(应该是一种循环网络),最后一步隐状态映射到分布参数。
  • decoder(生成模型)\(p_\theta^T(s_{i+1} | s_i, a_i; m) , \; p_\theta^R(r_{i} | s_i, a_i,s_{i+1}; m)\),输入是采样的 task embedding \(m \sim q_\phi(m|\tau_{:t})\) + 动作序列 \(a_{:H^+-1}\),输出是整条轨迹(包括未来轨迹)的预测 \(p_\theta(\tau_{:H^+}|m) = \prod_i p(s_{i+1}|s_i,a_i,m) \cdot p(r_{i+1}|s_i,a_i,s_{i+1},m)\)
    • decoder 的角色类似于 world model,输入 (s,a),输出 (s',r)。
    • 这里有一个不能传梯度的 在正态分布中采样,用 reparameterization trick 让它变得能传梯度。
    • encoder 和 decoder 的联合训练,能让 m 通过预测未来的轨迹会怎么走,提取到本质的 task 信息,而非机械背诵历史信息(kimi 说的)。
    • 实现:输入是 m(5-dim)拼接 (s_t, a_t) → MLP(64-dim ReLU → 32-dim ReLU) → 分裂两输出头预测 s', r。
  • 策略 \(\pi_\psi(a_t | s_t, q_\phi(m|\tau_{:t}))\),输入是当前状态 \(s_t\) + 后验分布参数 \((\mu_t, \sigma_t)\),输出是动作分布 \(\pi_\psi(a_t|s_t, \mu_t, \sigma_t)\),跟正常的 policy 一样。
    • 实现:使用 on-policy 方法,虽然不知道为什么。

varibad 的优化目标:

总目标(公式 10),由 RL 目标和 encoder-decoder 目标组成:

\[\mathcal{L}(\phi,\theta,\psi) = \mathbb{E}_{p(M)}\left[\mathcal{J}(\psi,\phi) + \lambda \sum_{t=0}^{H^+} \text{ELBO}_t(\phi,\theta)\right] \]

RL 目标(A2C / PPO):(大概就是正常最大化真 reward 的 return 吧,而非 decoder 预测的 reward 的 return)

\[\mathcal{J}(\psi,\phi) = \mathbb{E}_{\tau \sim \pi_\psi, q_\phi}\left[\sum_{t=0}^{H^+-1} \gamma^t r_{t+1} - \alpha \cdot \text{policy entropy} + \beta \cdot \text{value loss}\right] \]

encoder-decoder 的 VAE reconstruction 目标(ELBO,公式 8,最大化下式)(第二项应该是让对 task 的猜测 m 不能变得太大,不太确定这一项是 ELBO 推出来的,还是训练稳定性的 trick):

\[\begin{aligned} \text{ELBO}_t(\phi,\theta) = & \underbrace{\mathbb{E}_{m \sim q_\phi(m|\tau_{:t})}\left[\sum_{i=0}^{H^+-1} \log p_\theta(s_{i+1}|s_i,a_i,m) + \log p_\theta(r_{i+1}|s_i,a_i,s_{i+1},m)\right]}_{\text{重建损失}} \\ & ~~ - \underbrace{\text{KL}\left(q_\phi(m|\tau_{:t}) \,\|\, q_\phi(m|\tau_{:t-1})\right)}_{\text{动态先验(初始为 } \mathcal{N}(0,I)\text{)}} \end{aligned} \]

实现细节:

  • 重建损失用 MSE 预测奖励,MSE 预测状态。
  • KL 散度对高斯分布有闭式解:\(\text{KL} = \frac{1}{2}\sum_d \left(\log\frac{\sigma_{t-1,d}^2}{\sigma_{t,d}^2} + \frac{\sigma_{t,d}^2 + (\mu_{t,d}-\mu_{t-1,d})^2}{\sigma_{t-1,d}^2} - 1\right)\)
  • 梯度计算:RL 损失不回传至编码器 \(\phi\),VAE 损失不回传至策略 \(\psi\)

(BOReL) Offline Meta Reinforcement Learning -- Identifiability Challenges and Effective Data Collection Strategies [NeurIPS 2021]

(第三节)或许可以将 BOReL 理解为 offline 版的 varibad:

  • 核心主张:在 offline 场景下,只要数据是完整轨迹,而非零散的 transition 样本,就可以通过状态重标注 trick,为每一个 (s,a,r,s') transition 标注 belief b,让标准的离线 Q-learning 算法(DQN / SAC)直接在信念增强的状态 (s,b) 上训练。
  • 为什么在概念上,直接扩展 VariBAD 会遇到问题?因为在 BAMDP 中,增强状态是 \(s^+\) = (s, b),其中信念 b 是不可直接观测的;同时,要计算 Q(\(s^+\), a) 需要知道 \(R^+\)\(P^+\)(BAMDP 的奖励/转移),但这些依赖未知的信念 b。(不确定这里真的理解对了)
  • relabel 的 trick:首先用所有 offline 轨迹训练 varibad 的 VAE encoder-decoder,然后对于每条轨迹,把轨迹 \(\tau_{:t}\) 塞进 encoder 里,就得到了 belief,从而可以把 s 扩展成 \(s^+\)。然后,直接拿这些扩展的 transition,做 DQN / SAC 就可以了。proposition 1 似乎就是证明,这样直接 relabel 的期望,跟如果知道真的 belief b 是一样的。

(第四节)MDP ambiguity 的问题:

  • kimi 的一句话总结:在 offline meta-RL 中,如果不同任务的数据,没有重叠的区分性状态-动作,VAE 将无法分辨数据来自哪个任务,导致信念估计失效。解决方案是策略回放或奖励重标注,人为制造重叠数据。
  • MDP ambiguity 的直观描述:想象有两个 task,一个 task 的 reward 在蓝色圆圈里,另一个 task 的 reward 在黄色圆圈里。然而在 offline 数据中,蓝 task 的 agent 只去蓝色区域,黄 task 的 agent 只去黄色区域。此时我们无法判断,这是两个不同 task,还是一个 task 有两个奖励点(蓝+黄)。borel 声称,这将导致 belief 无法收敛,基于 belief 不确定性的探索策略也失效。
    • 还没想清楚。这个具体失效的过程是什么?是 VAE 先分不出来这是两个不同 task,把它们的 task embedding m 学到了一起,然后 evaluate 的时候,发现 history \(\tau_{:t}\) 不落在学到的任何一种模式上,导致猜不出来正确的 belief 吗?
  • borel 形式化定义了 MDP ambiguity。Definition 1,MDP 模糊性 = 存在一个 MDP,能用不同策略解释多组(本来在不同 MDP 里采到的)离线数据。
  • Proposition 2:可识别性的充分条件。只要每对任务的 offline 数据集里,都有至少一个共同的区分性状态-动作,数据就是可识别的(identifiable)。区分性指的是,对于一个 (s,a),它的 reward r 或 next state s' 的分布是不同的。
    • (这意味着,在 BOReL 和绝大多数现有 offline meta-RL 工作中,数据集是显式标注了 task 归属的。也就是说,我们知道每条轨迹 / transition 来自哪一个 task。
    • (还没想清楚)因为,如果没有标注轨迹属于哪个 task,可能需要先恢复 task 结构,用无监督方法把数据聚类成 N 个 task,这本身就容易受 MDP ambiguity 影响。在这种情况下,即使满足 proposition 2,也无法解决 task ambiguity 的问题,或许可以认为那些可识别的 (s,a,r,s') 属于不同的 task,而剩下的 transition 属于同一个 task。
  • Proposition 3:策略回放保证可识别性。大意是,让收集 task i 的 policy \(\pi_i\) 去 task j 生产一些数据,然后把这些数据也并入 task j 的 offline dataset,就能满足可识别性。
    • 这是一种对收集 offline 数据的指导,并不是 offline meta-RL 方法。
  • 如果各个 task 的 transition 一样,只有 reward 不同,可以直接使用 task j 的奖励函数来标记 task i 的 offline 轨迹,然后把这些数据也并入 task j 的 offline dataset,就能满足可识别性。
    • 这种方法不需要重新生成轨迹,只需要重新标记 offline 轨迹,能在 offline 上做,但前提是能拿到各个 task 的 ground truth reward function。

实验(kimi 速读):

1 实验 Setting:7 个任务:覆盖离散/连续、稀疏/密集奖励、奖励/转移函数变化

  • Gridworld:5×5网格,稀疏奖励,目标位置未知(21个训练任务)
  • Semi-circle:2D点机器人,目标在未知角度(80任务)
  • Ant-Semi-circle:Ant机器人版,高度稀疏奖励(80任务)
  • Half-Cheetah-Vel:经典密集奖励任务,目标速度未知(100任务)
  • Reacher-Image:图像输入,密集奖励,目标位置未知(50任务)
  • Wind:转移函数变化(风力扰动),密集奖励(40任务)
  • Escape-Room:稀疏奖励,转移函数变化(出口位置未知,60任务)

输入输出:

  • 训练时:输入完整离线轨迹 τ^{i,j}(标注了 task ID)
  • 输入元策略:增强状态 s⁺ = (s, b),其中信念 b 是 VAE 输出的高斯分布 (μ, Σ)
  • 输出:动作 a(离散用 DQN,连续用 SAC)

2 Baseline 对比:主要对比两类(均为 Thompson 采样 方法,非贝叶斯最优):

  • 精确 Thompson 采样:仅在离散 Gridworld 上可计算,作为理论上限参考
  • 在线 PEARL:连续任务上的强大 baseline
    • 优势:允许在线训练,不受离线数据限制
    • 机制:先推断任务后验,然后直接用任务条件化策略执行(不主动规划探索)
    • 对比意义:证明 BOReL 即使数据受限,仍因贝叶斯最优探索而胜出
  • 消融版本:BOReL w/o RR/PR:去掉奖励重标注或策略回放,验证 MDP 模糊性的影响

3 核心评价指标

  • 新测试任务上,前 2 个 episode 的平均回报(探索能力关键期)
  • Gridworld:前 4 个 episode(任务更难)
  • Wind:仅第 1 个 episode(无多 episode 聚合)
  • 补充:也报告了更多 adaptation episode 的性能曲线

4 具体实验结果

  1. 稀疏奖励任务碾压性优势(Semi-circle, Ant-Semi-circle, Escape-Room)。BOReL 学会系统性搜索(第一 episode 沿半圆/墙壁探索,第二 episode 直奔目标),性能显著优于在线 PEARL,差距可达 2-3 倍。
  2. RR/PR 在稀疏任务上至关重要:去掉 RR / PR 后性能暴跌,信念无法更新(如图 7,信念错误锁定首次碰到的点)。在密集奖励任务(Cheetah-Vel, Wind)上,RR/PR 几乎无影响,因识别状态天然重叠)
  3. 数据多样性很重要,但非绝对。均匀初始状态分布 > 排除关键区域 > 固定初始点。即使在固定初始点(几乎无探索数据),BOReL 仍能学到搜索行为(表1,图10)
  4. 与离线 RL 方法结合的结果:用 CQL 训练 critic 有轻微提升,但数据多样性影响更显著(表 1)。这证明 MDP 模糊性是 meta-RL 特有挑战,并非标准的 offline RL 存在的问题。
  5. 在线迁移:BOReL 用于在线训练时,比原 VariBAD 样本效率提升显著(图11)

5 细节实验设置

数据收集:

  • 用标准 RL(DQN/SAC)训练每个任务,保存完整训练日志(含所有探索阶段)
  • 每迭代收集 2-5 个 episode,共训练 50-1000 次迭代(依任务难度)
  • 关键 trick:聚合 k 个连续 episode 为长轨迹(k=2 或 4),不重置 VAE 的 RNN 隐藏状态,使跨 episode 信念持续

VAE 架构 :

  • 编码器:GRU(64/128 维)处理 \((a_t, r_{t+1}, s_{t+1})\),输出潜变量 m ∼ N(μ, Σ),维度为 5
  • 解码器:2 层 FC(32 维),仅预测奖励(稀疏任务)或同时预测转移
  • KL 权重 β = 0.05 (非标准 VAE 的 1.0)
  • 学习率:3×10⁻⁴,batch size 256

meta-RL 训练 :

  • 在状态重标注后的 BAMDP 数据上训练 DQN/SAC。
  • 网络:2-3 层 FC(128-256 维),soft target update τ=0.005
  • 折扣因子 γ:0.9(稀疏任务)或 0.99(密集任务)
  • 熵系数:0.01-0.2(SAC)

(MBML) Multi-task Batch Reinforcement Learning with Metric Learning [NeurIPS 2020]

主要内容:

Meta-Q-Learning [ICLR 2020]

FOCAL: Efficient Fully-Offline Meta-Reinforcement Learning via Distance Metric Learning and Behavior Regularization [ICLR 2021]

主要内容:

  • 这篇文章提出了 fully-offline context-based actor-critic meta-RL algorithm(FOCAL),声称是首个端到端的 model-free 的 offline meta-RL 方法。
  • preliminaries:TA-MDP,一个用 task embedding z 来表示 multi-task 的 MDP 定义。其中 state space 是 \(S \times Z\)(原始状态 + task embedding),transition 是 \(P_z(s'|s,a)\) 的形式,reward 是 \(R(s,a,z) = R_z(s,a)\) 的形式。
  • setting:在训练时,给定 N 个 task 的带 reward 和 task 标签的 offline 数据集;在测试时,给出一些新任务的 offline 数据集。
  • method:
    • focal 训练一个 inference network \(q_\phi(z|c)\),用来从 context c 中猜出现在在做什么 task z,其中 context c 是一批 (s,a,s',r) 数据。
    • 训练的方式是,希望相同 task 的 z 聚集,而不同 task 的 z 相互原理,focal 说这本质上是距离度量学习(这听起来像对比学习)。具体的,focal 提出了以下损失函数(Eq 13),第一项把相同 task 拉近,第二项把不同 task 推远:

      \[\mathcal{L}_{dml}(x_i, x_j; q) = \mathbb{1}\{y_i = y_j\} \|q_i - q_j\|^2 + \mathbb{1}\{y_i \neq y_j\} \cdot \frac{\beta}{\|q_i - q_j\|^n + \epsilon} \]

    • Eq 13 是 Eq 12 的改进版,因为 Eq 12 的梯度在高维 z 空间中好像会不 work,而 Eq 13 则没有这种问题。
    • focal 的算法流程:通过 Eq 13 训练 \(q_\phi(z|c)\),同时训练带 z 的 actor 和 critic。拿到新 task 之后,使用新 task 给定的 offline 数据集推断 z,然后使用 \(\pi_\theta(a | s, z)\) 作为策略。
  • 实验环境:环境具体描述见附录 D.2,其中 Point-Robot-Wind、Walker-2D-Params 是改变 transition dynamics 的任务,而其他是改变 reward function 的任务。
  • offline 数据集:数据集貌似是自己生成的,使用训练 SAC 过程中保存下来的 checkpoint;有一些使用 expert 数据集,另一些则使用 mixed 数据集,详见 Table 2。
  • 所比较的 baseline:Batch PEARL、Contextual BCQ、MBML。focal 直接采用了 MBML 代码中的 Contextual BCQ 和 MBML 的实现。
  • 实验结果:outperform baseline。展示的是 curve,其中横轴是用于 offline 训练的数据集大小,即 transition 的个数,而纵轴是 average return。focal 是 fully-offline 的,没有 fine-tune。

其他信息:

  • intro 的前两段,在讲 offline 和 meta-RL 的动机和故事,第三段就直接提出 focal 了。
  • related work 中提到了大量 meta-RL 方法,不过看时间可能都是 19 年之前的,可能会比较老。focal 声称自己与 PEARL 最为相关。
  • focal 只考虑确定性的 MDP,而不考虑状态转移具有随机性的 MDP。
    • 在确定性 MDP 下,给定一个 (s,a) 和对应的 task,存在唯一的 (s', r)。Assumption 1 假设,如果两个 task 的 transition 和 reward 都是一样的,那么它们是同一个 task。由此可以推出,给定一个 \((s,a,s',r)\),可以唯一确定 task。
    • Figure 4(b) 将 focal 的 \(q_\phi(z|c)\) 换成了 probabilistic context encoder(虽然不是很明白怎么做到的),发现性能下降。
  • Finn et al. (2017) and Rakelly et al. (2019) 这两篇文章提出了 meta-RL 的 benchmark,疑似重要,需要 check 一下。
  • focal 在优化 actor 和 critic 时,会把 z 冻结住,不会更新 \(q_\phi(z|c)\),论文在 4.3 节使用 disentangle 这个词来说明这件事。
    • 5.2.3 节通过实验说明了这种 disentangle 的必要性。
    • (同时,论文在附录 C 说明,若任务嵌入 \(z_i, z_j\) 过于接近,连续神经网络无法区分其价值函数。如果 Q 函数的 bootstrapping 误差反向传播到 \(q_\phi(z|c)\),会迫使编码器生成相近的的 z 来最小化 TD 误差,破坏任务可分性。(论文好像没有真这样说,应该是 kimi 的幻觉)
  • 关于使用的 offline 数据集:
    • focal 说可能 expert 数据集里,一个 task 里一个 s 只会对应一个 a,不同 task 的 state-action 分布几乎没有重叠,会导致 agent 学到一些只跟 state 对应,而不跟 (s,a,s',r) transition 对应的 pattern;而 mixed 数据集则可以缓解这个问题。“这正是 Li 等人(2019b)旨在解决的问题,称为 MDP 模糊性。”

感谢师弟和参考博客的讲解🍵


posted @ 2025-12-07 10:35  MoonOut  阅读(114)  评论(0)    收藏  举报