Tractable Uncertainty for Causal Structure Learning @ APML Seminar Series

Details

Title : Tractable Uncertainty for Causal Structure Learning @ APML Seminar Author(s): Benjie Wang Link(s) :

Rough Notes

Task: How to learn causal structure with uncertainty in a scalable manner. Requires:

  • Compact representation of the uncertainty.
  • Efficient and parallel learning of uncertainty.
  • Efficient and flexible reasoning over uncertainty.

Motivation

Some motivation related to bioinformatics Shen et al. 2020.

We have statistical uncertainty due to limited data, and causal uncertainty due to non-identifiability.

Some other motivational questions we are interested in "what is the probability that X causes Y?", "what is the expected causal effect of X on Y", etc.

Background

Bayesian structure learning: Inputs:

  • Prior \(P(G)\) over DAGs.
  • Data
  • Likelihood, assumed given in closed form.

(Q) Why not prior over MEC rather than DAGs? Technically speaking shouldn't we want a prior over the space of Equivalence classes instead of DAGs, since with observational data thats the only thing we can know for sure?

Goal: Compute \(P(G|D) \propto P(D|G)P(G)\), to answer questions of the form \(\mathbb{E}_{P(G|D)}[f(G)|g(F)]\).

MCMC and variational approximations are often used.

Speaker's work is more on the variational approximation side.

Mean-field approximation for the graph could be done with iid Bernoulli's but this does not guarantee acyclicity. Neural autoregressive models for this are difficult to train to encode acyclicity.

Representation

The introduced model is called OrderSPN, which can represent a family of distribution over causal structures.

First of all, introduce orderings, and work on the joint space of topological orders \(\sigma\) and directed graphs \(G\).

  • Every DAG is consistent with atleast 1 order.
  • Every directed graph consistent with an order is acyclic.

Using sum-product networks we can parametrize a flexible model class over this space.

Sum-Product Networks (SPNs) are a type of tractable probabilistic model for expressing a distribution over a set of variables \(\mathbf{X}\). There are rooted in DAGs consisting of 3 types of nodes:

  • L: Simple base distributions \(L(\mathbf{X})\) e.g. Gaussians.
  • \(\times\): Factor distributions \(P(X)=C_1(\mathbf{X)_1)\times C_2\)
  • \(+\): Mixture of component distributions \(S(\mathbf{X})=\sum_j c_j P(X)\)

They introduce a special type of SPNs called OrderSPNs. The sum nodes represent different partitions of the orderings, and the factor nodes imply multiplication over decomposed orderings on the same leaf (#DOUBT). At the bottom, the leaf nodes only have 1 variable.

(#TODO Fully understand the model).

Learning

Learning OrderSPNs involve 2 tasks:

  • When we have the splits of the orderings, use the data to decide which splits to keep.
  • Then learn the corresponding parameters.

Ordering selection: At each sum node, use existing Bayesian structure learning oracle e.g. MCMC to sample orderings.

Parameter learning: Use variational inference to optimize the parameters. ELBO can be computed exactly for OrderSPNs, no need for reparametrization trick.

Reasoning

Tractable queries follow from the results for SPNs.

With a model like this, we can compute Bayesian causal effects e.g. CAE via Bayesian Model Averaging.

Experimental results

Emacs 29.4 (Org mode 9.6.15)