CORE: Towards Scalable and Efficient Causal Discovery with Reinforcement Learning | 代码 | 因果发现、强化学习、部分可观测马尔可夫决策过程、泛化性、干预规划
论文信息
论文标题:CORE: Towards Scalable and Efficient Causal Discovery with Reinforcement Learning
论文作者:Andreas W.M. Sauter, Nicolò Botteghi, Erman Acar, Aske Plaat
论文来源:AAMAS 2024
论文地址:link
论文代码:link
Abstract
1. 研究背景:因果发现的核心挑战与解决思路
任务定义:明确 “因果发现(Causal Discovery)” 的本质是从数据中推断因果结构,这一任务本身具有较高挑战性。
核心痛点:引用 Pearl 的因果层级理论(Pearl’s Causal Hierarchy, PCH)指出,仅依靠被动观测数据无法区分变量间的 “相关性” 与 “因果性”—— 这是传统因果发现方法的关键局限。
解决方向:为突破上述局限,近年来研究趋势转向将 “干预(Interventions)” 融入机器学习领域;而强化学习(Reinforcement Learning, RL)因具备处理序贯决策与交互学习的能力,成为实现 “主动干预式因果发现” 的理想框架。
2. 方法设计:CORE 的核心定位与功能
3. 核心成果:性能与泛化能力验证
1 Introduction and Related work
1.1 核心概念:因果发现(Causal Discovery, CD)
- 传统方法分类:
- constraint-based 方法:基于统计独立性约束(如条件独立测试)推断因果图结构,代表研究 [14, 31]。
- score-based 方法:定义评分函数(如 BIC、AIC)评估候选因果图与数据的拟合度,通过搜索最优评分图确定结构,代表研究 [6]。
- continuous optimization-based 方法:将因果图结构学习转化为连续空间的优化问题(如通过矩阵分解、梯度下降求解),近年新兴方向,代表研究 [38, 39]。
- 传统方法局限:仅依赖纯观测数据,无法满足佩尔因果层次结构(PCH)的要求,难以有效区分因果关系与虚假相关。
1.2 关键理论基础:佩尔因果层次结构(Pearl’s Causal Hierarchy, PCH)
- 核心断言:要从 “单纯相关性” 中识别 “真实因果关系”,通常需要整合干预(Interventions) ,而非仅依赖被动观测数据 [3]。
- 研究趋势影响:推动因果发现研究向 “结合干预” 方向发展,包括在机器学习领域的应用 [5, 18, 28],例如通过主动操纵变量获取干预数据,辅助因果结构推断。
1.3 强化学习(Reinforcement Learning, RL)在因果发现中的角色与相关工作
1.3.1 RL 的核心优势与适配性
-
-
直接用于恢复环境的因果结构 [40];
-
用于学习通用的因果发现算法 [28],成为灵活的因果发现工具。
-
1.3.2 相关工作细分
|
研究方向
|
核心思路
|
代表研究
|
|
基于 RL 的因果结构空间搜索
|
基于固定数据集,用 RL 更高效地搜索因果结构空间,可融入先验知识
|
[11, 36, 40]
|
|
RL 相关的 GFlowNets 应用
|
利用生成流网络(GFlowNets)生成对真实因果结构的优质估计
|
[4, 7, 17]
|
|
RL 与因果的监督学习整合
|
将因果发现过程限制在监督学习框架内,结合 RL 优化任务目标
|
[9, 16, 21, 24, 34]
|
|
基于 RL 的干预选择
|
学习 “选择最优干预” 的策略,以提升因果发现效率
|
[1, 29, 33]
|
1.3.3 现有方法的共性挑战
- 可扩展性(Scalability):难以处理变量数较多的大规模因果图;
- 泛化性(Generalization):在未见过的因果结构上表现不佳;
- 干预规划(Intervention Planning):无法高效平衡 “干预信息增益” 与 “干预成本”。
1.3.4 本文贡献:CORE 方法的提出
1.3.4.1 CORE 核心定位
1.3.4.2 四大核心贡献
- 问题形式化:将 “学习因果发现算法” 的任务形式化为部分可观测马尔可夫决策过程(POMDP) ,为序贯因果发现提供严谨的数学框架;
- 双 Q 学习架构:提出双 Q 学习(dual Q-learning)设置,实现 “干预设计” 与 “结构估计” 的同步学习;
- 泛化性验证:验证 CORE 可成功应用于未见过的、变量数最多达 10 个的因果图;
- 联合学习价值与局限性分析:论证 “联合学习干预选择与图生成” 的重要性,并分析该方法在真实世界应用中的局限性。
1.3.4.3 CORE 与现有 SOTA 方法的差异
|
方法
|
核心差异
|
局限性对比
|
|
CORE
|
不预设固定因果识别算法,而是通过 RL 学习算法;双网络分离处理干预与结构更新
|
可处理 10 变量图,无离线数据依赖
|
|
MCD [28]
|
同样学习因果发现算法,但干预与结构更新需序贯执行
|
难以扩展到 4 变量以上的图
|
|
AVICI [19]
|
学习图生成器,但依赖离线数据集,无法主动干预
|
无法处理在线序贯干预场景
|
2 Preliminaries and Notation
2.1 章节核心目标
2.2、因果模型(Causal Models)
2.2.1 结构因果模型(SCM)定义
|
组件
|
符号定义
|
核心解释
|
|
内生变量
|
$ X = \{X_1, ..., X_n\}$
|
问题相关的随机变量(即待分析的核心变量),其取值由其他变量(父变量)和噪声共同决定
|
|
外生变量
|
$ U = \{U_1, ..., U_n\}$
|
不可观测的随机变量(又称噪声变量),独立于内生变量,用于表示模型外的不确定性
|
|
结构函数
|
$ F = \{f_1, ..., f_n\}$
|
每个函数 $ f_i$ 满足 $ X_i \leftarrow f_i(Pa(X_i), U_i)$ ,其中 $ Pa(X_i) \subseteq X \setminus \{X_i\}$ 是 $ X_i$ 的内生父变量集合,描述父变量对 $ X_i$ 的直接因果影响
|
|
概率分布
|
$ P = \{P_1, ..., P_n\}$
|
外生变量的独立概率分布集合,满足 $ U_i \sim P_i$ (即每个噪声变量服从各自的独立分布)
|
2.2.2 SCM 诱导的因果图与分布特性
(1)因果图的诱导规则
- 节点与边:将内生变量 $ X_i$ 视为节点,若 $ X_j \in Pa(X_i)$ (即 $ X_j$ 是 $ X_i$ 的父变量),则添加有向边 $ X_j \to X_i$ ,形成的有向图称为因果图(G)。
- 无环性假设:默认无变量是自身的原因(无循环函数依赖),因此因果图为有向无环图(DAG);边的 “存在 / 缺失” 均有意义 —— 存在表示直接因果关系,缺失表示无直接因果关系。
(2)观测分布与马尔可夫条件
- 观测分布定义:SCM 诱导内生变量的联合分布 $ P_M(X)$ ,称为观测分布(即无干预时的自然分布)。
- 马尔可夫条件:因果图的结构特性使 $ P_M(X)$ 满足马尔可夫条件 —— 给定父变量 $ Pa(X_i)$ ,每个变量 $ X_i$ 与其非后代变量独立。
- 分布分解:结合噪声变量的独立性,观测分布可分解为:
(3)SCM 的生成性
- 从外生变量的分布 $ P$ 中采样 $ U_i$ ;
- 根据结构函数 $ F$ 中的 $ f_i$ ,结合 $ Pa(X_i)$ 的值与 $ U_i$ ,计算得到 $ X_i$ 的值;
- 最终得到内生变量的联合采样结果,等价于从观测分布 $ P_M(X)$ 中采样 [26]。
2.3 干预(Interventions)
2.3.1 干预的核心作用
2.3.2 干预的形式化定义
(1)do - 操作(Hard Intervention)
- 定义:对变量 $ X$ 的干预(记为 $ do(X=x)$ )指将 $ X$ 的值强制设为固定值 $ x$ ,使其独立于原有的父变量(即切断所有指向 $ X$ 的因果边),此时 $ Pa(X) = \emptyset$ 。
- 图形化变化:干预后因果图 $ G$ 变为 $ G'$ ,核心变化是 “删除所有指向干预目标 $ X$ 的边”(示例见图 1,干预 $ X$ 后,原指向 $ X$ 的边被移除,$ X$ 取值固定为 $ x$ )。
(2)干预对 SCM 的影响
干预 $ do(X=x)$ 会修改 SCM:将原结构函数 $ f_X \in F$ 替换为 $ X = x$ ,形成新的 SCM $ M'$ 。此时,干预目标 $ X$ 的条件分布变为:
$ P_{M'}(X \mid Pa(X)) = P_{M'}(X \mid \emptyset) = \delta_x $
(3)干预后分布(Post-interventional Distribution)
干预后的内生变量联合分布(干预后分布)可分解为:
$ P_M(\mathcal{X} \mid do(X=x)) = \prod_{X_i \in \mathcal{X} \setminus \{X\}} P_M(X_i \mid Pa(X_i)) \cdot P_{M'}(X=x) $
- 解释:除干预目标 $ X$ 外,其他变量的条件分布保持与原 SCM 一致;$ X$ 的分布被替换为 $ \delta_x$ (即 $ P_{M'}(X=x) = 1$ )。
- 简化符号:当干预目标或取值明确时,可简写为 $ P_{M_{do(X=x)}}(X)$ 或 $ P_{M_{do(X)}}(X)$ 。
2.4 强化学习(Reinforcement Learning, RL)
2.4.1 RL 的核心框架
2.4.2 部分可观测马尔可夫决策过程(POMDP)
因果发现中,环境状态(如真实因果结构)无法完全观测,因此采用 POMDP 建模,其定义为元组 $ (S, A, T, R, \Omega, O, \gamma)$ ,各组件含义如下:
|
组件
|
符号定义
|
核心解释
|
|
状态空间
|
$ S$
|
环境所有可能状态的集合(在 CORE 中对应不同 SCM)
|
|
动作空间
|
$ A$
|
智能体可执行的所有动作的集合(如干预动作、结构更新动作)
|
|
转移概率
|
$ T: S \times A \times S \to [0,1]$
|
给定当前状态 $ s$ 和动作 $ a$ ,转移到下一状态 $ s'$ 的概率
|
|
奖励函数
|
$ R: S \times A \to \mathbb{R}$
|
在状态 $ s$ 执行动作 $ a$ 后获得的即时奖励
|
|
观测空间
|
$ \Omega$
|
智能体可观测到的所有信息的集合(如采样的内生变量值)
|
|
观测概率
|
$ O: S \times A \times \Omega \to [0,1]$
|
给定状态 $ s$ 、动作 $ a$ ,观测到 $ o \in \Omega$ 的概率
|
|
折扣因子
|
$ \gamma \in [0,1)$
|
用于权衡即时奖励与未来奖励的权重,$ \gamma$ 越小越重视即时奖励
|
2.4.3 策略与价值函数
(1)策略(Policy)
- 确定性策略:$ \pi: S \to A$ ,将状态直接映射到唯一动作;
- 随机性策略:$ \pi: S \times A \to [0,1]$ ,给出在状态 $ s$ 选择动作 $ a$ 的概率。
(2)价值函数
- 状态价值函数 $ V(s)$ :从状态 $ s$ 出发,遵循策略 $ \pi$ 获得的期望累积奖励;
- 动作价值函数(Q 函数) $ Q(s,a)$ :从状态 $ s$ 出发,执行动作 $ a$ 后遵循策略 $ \pi$ 获得的期望累积奖励。
(3)最优策略与最优 Q 函数
-
最优 Q 函数 $ Q^*$ :满足贝尔曼最优方程,即对任意状态 $ s$ 和动作 $ a$ :
$ Q^*(s,a) = \mathbb{E}_{s' \sim T(s,a,\cdot)} \left[ R(s,a) + \gamma \max_{a'} Q^*(s',a') \right] $
-
最优策略 $ \pi^*$ :由 $ Q^*$ 推导,即 $ \pi^*(s) = \arg\max_a Q^*(s,a)$ (贪婪选择使 Q 值最大的动作)[32]。
2.4.4 深度 Q 学习(Deep Q-Learning, DQN)
(1)时序差分(TD)学习
DQN 基于 TD 学习更新 Q 函数,核心思想是将 “当前 Q 值估计” 与 “即时奖励 + 未来最优 Q 值估计” 对齐,更新规则为:
$ Q(s,a) \leftarrow Q(s,a) + \alpha \left[ r(s,a) + \gamma \max_{a'} Q(s',a') - Q(s,a) \right] $
其中 $ \alpha$ 是学习率,$ r(s,a)$ 是即时奖励。
(2)目标网络(Target Network)
为稳定训练,DQN 引入目标网络(参数为 $ \theta^-$ ),与主网络(参数为 $ \theta$ )结构相同,但参数更新滞后于主网络。损失函数定义为:
$ \mathcal{L}(\theta) = \mathbb{E}_{s,a,r,s' \sim \mathcal{D}} \left[ \left( r + \gamma \max_{a'} Q(s',a'; \theta^-) - Q(s,a; \theta) \right)^2 \right] $
- 经验回放池 $ \mathcal{D}$ :存储历史($ s,a,r,s'$ )样本,随机采样训练以打破样本相关性,提升训练稳定性;
- 目标网络更新:定期将主网络参数复制到目标网络(如每 1000 步更新一次),避免目标 Q 值波动过大。
3 Learning a causal discovery policy with informative interventions
3.1 章节核心目标
本节提出 CORE 算法的核心框架,旨在通过强化学习(RL)学习一种 “因果发现策略(CD policy)”。该策略能结合观测数据与干预数据,序贯推断真实因果结构,同时主动选择 “信息性干预”(即对提升因果结构估计准确性有帮助的干预),最终实现对未知因果结构的高效、准确识别。
3.2 核心设计理念
- 交互范式:采用 RL 中经典的 “智能体 - 环境” 交互模式,智能体通过对环境(结构因果模型 SCM)执行干预动作,收集数据并更新因果图估计,环境则反馈奖励以指导智能体优化策略。
- 干预的信息性:强调 “信息性干预” 的重要性 —— 当干预预算有限时,智能体需学习选择能最大化 “因果结构推断增益” 的干预,避免无效操作。
- 训练与推理分离:因果发现模块仅在训练阶段需要真实因果结构(用于计算奖励),推理阶段可直接应用于真实世界中 “真实因果结构未知” 的场景:
- 训练阶段:基于合成数据(可轻松生成真实因果结构)训练 CD 策略;
- 推理阶段:将训练好的策略应用于未知环境,估计其因果结构。
3.3 因果发现的 POMDP 形式化建模(核心贡献 1)
由于因果发现过程中,智能体无法完全观测环境的真实状态(如真实因果结构、噪声分布),因此将其建模为部分可观测马尔可夫决策过程(POMDP),这是本节的核心贡献之一。以下是 POMDP 各组件的具体定义:
3.3.1 状态空间(State Space)
- 状态定义:环境状态由 SCM 决定,每个状态对应一个 SCM。若 SCM 包含 $n$ 个内生变量,则状态 $s = \{f_0, ..., f_{n-1}\}$ ,其中 $f_i$ 是定义第 $i$ 个内生变量的结构函数(即 $X_i \leftarrow f_i(Pa(X_i), U_i)$ )。
- 附加信息:每个状态还包含由 “无干预 SCM( $M_{do(\emptyset)}$ )” 诱导的真实观测图 $G_s^*$ ,作为后续计算奖励的基准。
3.3.2 动作空间(Action Space)
动作空间为多离散空间,定义为 $A = [n+1] \times [2n(n-1)+1]$,分为 “干预动作” 和 “结构动作” 两个维度,具体如下:
|
动作维度
|
核心作用
|
具体定义
|
动作数量
|
优化设计
|
|
干预动作(第一维度)
|
选择干预目标或收集观测数据
|
- 对每个内生变量 $X_i$ ,执行干预 $do(X_i = c)$ ( $c$ 为预设常数);- 不干预,仅收集观测数据( $do(\emptyset)$ )。
|
$n+1$ ( $n$ 个干预动作 + 1 个观测动作)
|
无
|
|
结构动作(第二维度)
|
更新当前因果图估计(添加 / 删除边)
|
- 对当前估计图,添加某条不存在的边;- 对当前估计图,删除某条已存在的边;- 无结构更新(空操作)。
|
理论上 $2n(n-1)+1$ ( $n(n-1)$ 条可能边,每条边对应 “添加”“删除” 2 个动作,加 1 个空操作)
|
动作掩码:禁止 “添加已存在的边” 或 “删除不存在的边”,实际动作数量减半,提升效率。
|
3.3.3 转移动态(Transition Dynamics)
- 初始状态:每个 episode(序贯决策周期)始于 “无干预 SCM( $M_{do(\emptyset)}$ )”,此时未执行任何干预,仅观测自然分布数据。
- 干预转移:若在状态 $M_{do(X_i)}$ (即已干预 $X_i$ 的 SCM)上执行干预 $do(X_j = c)$ ,则环境确定性转移到新状态 $M_{do(X_j)}$ —— 此时 SCM 中 $f_j$ 被替换为常数 $c$ ,即 $T(M_{do(X_i)}, do(X_j), M_{do(X_j)}) = 1$ 。
- 示例:若 SCM 含 2 个内生变量 $X_0, X_1$ ,从 $M_{do(\emptyset)}$ 执行 $do(X_0 = c)$ ,则转移到 $M_{do(X_0)}$ ,此时 $X_0$ 的结构函数变为 $X_0 = c$ (见图 2 左部)。
3.3.4 观测(Observations)
- 观测定义:在每个决策步骤 $t$ ,智能体从当前 SCM( $M_{do(X)}$ )诱导的联合分布 $P_{M_{do(X)}}(X)$ 中采样内生变量值 $\{x_0, ..., x_{n-1}\}$ ,该采样结果即为观测 $o_t$ ,即 $o_t \sim P_{M_{do(X)}}(X)$ 。
- 状态表示:单一观测无法反映环境全貌(POMDP 特性),因此智能体需构建观测 - 动作历史 $h_t$ 作为状态表示,即 $h_t = [x_0, a_0, ..., x_t, a_t]$ ,其中 $x_k$ 是第 $k$ 步观测, $a_k$ 是第 $k$ 步动作。
3.3.5 奖励函数(Reward Function)
1)基础奖励公式
SHD 衡量两个 DAG 的差异(计数不同边的数量),奖励定义为 “当前估计图与真实图的 SHD” 与 “下一步估计图与真实图的 SHD” 的差值:
$r(s,a) = SHD(G_s^*, \hat{G}_t) - SHD(G_s^*, \hat{G}_{t'})$
2)简化奖励公式
-
- 若动作 $a$ 是添加边:
-
- 若动作 $a$ 是删除边:
-
- 若动作 $a$ 是空操作(无结构更新): $r(s,a) = 0$ 。
3)简化公式的优势
-
-
计算高效:仅需对比 “操纵边 $E(a)$ ” 与真实图 $G_s^*$ ,无需遍历完整图;
-
奖励更密集:每步结构动作均产生即时奖励,避免 “仅在 episode 结束时反馈奖励” 的稀疏性问题;
-
状态依赖简化:奖励仅依赖当前状态 $s$ 和动作 $a$ ,不依赖历史结构动作。
-
3.4 数据生成(Data-Generation)
3.4.1 因果图生成
- 小节点数(3-4 个变量):生成所有可能的 DAG(3 变量共 25 个,4 变量共 543 个),确保覆盖完整结构空间。
- 大节点数(>4 个变量):由于 DAG 数量随节点数超指数增长,采用Erdös-Rényi 随机图模型生成子集,边生成概率设为 0.2(与文献常见设置一致)。
3.4.2 数据集划分
- 生成所有 / 部分 DAG 后,打乱顺序以保证各稀疏度的图均匀分布;
- 划分为训练集与评估集,确保评估集中无训练集见过的图(验证泛化性);
- 假设:训练集图数量越多,智能体泛化到未知图的能力越强(符合机器学习通用规律)。
3.4.3 SCM 生成
3.5 学习方法(Learning Approach)
3.5.1 双 Q 网络设计
智能体维护两个独立的多层感知器(MLP)Q 网络,分别对应两类动作,具体如下:
|
Q 网络
|
输入
|
输出
|
目标网络
|
作用
|
|
$Q_{st}(h, a_{st}; \Theta_{st})$
|
观测 - 动作历史 $h$ 、结构动作 $a_{st}$
|
结构动作 $a_{st}$ 的 Q 值(即该动作的长期收益)
|
对应目标网络 $Q_{st}^-(h, a_{st}; \Theta_{st}^-)$
|
优化结构更新策略,选择 “能提升图估计准确性” 的边操作
|
|
$Q_{in}(h, a_{in}; \Theta_{in})$
|
观测 - 动作历史 $h$ 、干预动作 $a_{in}$
|
干预动作 $a_{in}$ 的 Q 值
|
对应目标网络 $Q_{in}^-(h, a_{in}; \Theta_{in}^-)$
|
优化干预策略,选择 “信息性最强” 的干预目标
|
-
网络共性:除输出层(因动作空间维度不同)外,两网络结构完全一致;目标网络参数定期从主网络复制,用于稳定训练(避免目标 Q 值波动过大)。
3.5.2 策略与训练流程
1)策略选择
-
- 最优干预动作: $a_{in}^* = \arg\max_{a_{in}} Q_{in}(h, a_{in}; \Theta_{in})$
-
- 最优结构动作: $a_{st}^* = \arg\max_{a_{st}} Q_{st}(h, a_{st}; \Theta_{st})$
-
- 实际训练中加入 $\epsilon$ - 贪婪探索(以小概率随机选择动作),平衡探索与利用。
2)训练流程(见图 2 右部)
-
-
初始化:每个 episode 开始时,从训练集中随机采样一个 SCM 作为初始环境,初始化 “空因果图” 作为初始估计图,清空观测 - 动作历史 $h$ ;
-
序贯交互:
-
智能体根据当前历史 $h$ ,通过双 Q 网络选择干预动作 $a_{in}$ 和结构动作 $a_{st}$ ;
-
环境执行干预动作 $a_{in}$ ,转移到新 SCM,采样观测 $o$ ;
-
智能体根据结构动作 $a_{st}$ 更新估计图,结合真实图 $G_s^*$ 计算奖励 $r$ ;
-
将 $(h, a_{in}, a_{st}, r, h')$ 存入经验回放池(用于 DQN 训练),更新历史 $h$ 为 $h'$ ;
-
-
网络更新:从经验回放池随机采样批次样本,分别计算两 Q 网络的损失(基于目标网络),通过梯度下降更新主网络参数;
-
终止条件:达到预设的 episode 步数(推理时间由步数固定,保证高效性),结束当前 episode,进入下一轮训练。
-

浙公网安备 33010602011771号