Causality Discussion Group - Representation Learning: A Causal Perspective

Details

Title : Causality Discussion Group: Yixin Wang Author(s): Yixin Wang Link(s) :

Rough Notes

TLDR: Non-spuriousness, efficiency, entanglement etc. formalized using the notion of counterfactuals etc. so we can evaluate whether or not we have desirable properties in our ML algorithms.

Often we start with high-dimensional data (of possibly mixed types etc) - we would want to learn suitable representations of these high dimensional data to get features which are relevant for making predictions. Ideally, we would want to learn representations that are not spurious, are efficient and disentangled. Can we treat these as constraints and bake them into representation learning algorithms?

The work in this setting assumes a single dataset without multiple environments or invariance or auxillary labels (unlike e.g. Invariant Risk Minimization).

Why may naive representation learning produce spurious features? Take the e.g. of learning representations for dogs - suppose we have a supervised learning setting with images and a binary label denoting whether a dog is present in the image or not. The common approach nowadays is to use a neural network for this. But, it is possible that the learned representation of the network picks up the "Is grass present in the image?" feature, since in actual images, the presence of grass may be highly correlated with the presence of dogs. This is a spurious feature. It is not a neural network training failure - this can happen even when the predictive accuracy is high in the test set.

The problem here is the training objective - maximizing predictive accuracy does not prevent learning spurious features. If we want to optimize for non-spuriousness, we first need a mathematical definition of non-spuriousness representations.

We would also want efficiency, e.g. having only $f1$="Is there a dog face?" is more efficient than having both that \(f_1\) and $f2$="Are there 4 legs?". Similarly, we would prefer disentanglement, e.g. having \((f_1,f_2)\) rather than \((f_1+f_2,f_1-f_2)\).

What does non-spuriousness even mean? A non-spurious representation \(Z=f(X)\) captures features that causally determine the label. Adding this feature is sufficient to change the label. From this intuition, we can quantify non-spuriousness using Probability of Sufficiency (PS).

What does efficient even mean? An efficient representation \(Z=f(X)\) captures only essential features without redundancy. Removing an efficient representation will change the label. From this intuition, we can quantify efficiency using Probability of Necessity (PN).

We can quantify non-spuriousness and efficiency simultaneously using the Probability of Necessity and Sufficiency (PNS) - asking whether a feature is both necessary and sufficient for the prediction task.

This leads to the view that representation learning is about finding necessary and sufficient causes. The optimization problem should not be about finding a function that maximizes some prediction accuracy - rather it should be about finding a function that maximizes the PNS value. The speaker calls this type of learning CAUSAL-REP.

Now, if we are able to compute PS, PN, or PNS, which are counterfactual quantities, we are getting somewhere. PNS cannot be identified exactly, it can only be bounded and this bound depends on 2 interventional distributions.

One experiment is a binary sentiment analysis task using bag-of-words, where useless words like "as", "an", "the" are added. CAUSAL-REP learns features relevant to reviews compared to logistic regression. In another experiment, MNIST data is used with added colors. In the training set, the colour and label have positive correlation while in the test set they have negative correlation. CAUSAL-REP has good results in OOD predictive accuracy compared to neural networks and Variational Autoencoders (VAEs). (#DOUBT I guess second plot is of random flips of 25$ of labels in both training and test data).

Moving onto unsupervised learning, we would want learned representations here to be disentangled. One definition is to say a representation \(G\) is (causally) disentangled if \(G_1,\cdots,G_d\) represent (possibly correlated) factors of variation that do not causally affect each other.

How can we assess causal disentanglement? One can show that causal disentanglement implies independent support (#DOUBT I think they said this means values taken by one variable depend on values taken by another variable). We can use this to enforce a measure of disentanglement which they call IOSS. IOSS can be added as an extra penalty term.

Emacs 29.4 (Org mode 9.6.15)