Advances in Neural Processes - 2021
Details
Title : Advances in Neural Processes Author(s): Richard Turner Link(s) : https://www.youtube.com/watch?v=5I_E0nWpXcM
Rough Notes
Consider the standard regression setup with data \(D_c = \{x_n,y_n\}_{n=1}^N\) - in standard deep learning we would have the model \(p(y_T|x_T,\theta)=\mathcal{N}(y_T;\mu_\theta(x_T),\sigma^2_\theta(x_T))\) where \(\mu_\theta,\sigma^2_\theta\) are neural networks parametrized by \(\theta\) - which could be fit via e.g. maximum likelihood (MLE). MLE would however overfit really badly. To overcome this, we can take a different perspective on \(\mu_\theta, \sigma^2_\theta\), namely, after training the parameters depend on the data \(\theta(D_c)\) thus can be written as \(\mu(x_T,D_c), \sigma^2(x_T,D_c)\). (#DOUBT Technically isn't this the case even before training?)
Neural Processes (NPs) model \(\mu_\theta(x_T,D_c), \sigma^2_\theta(x_T,D_c)\) directly using a neural network, i.e. these functions take in a whole dataset and give a mean and variance for some datapoint \(x_T\). This can be done using the theorem that a function over a set \(S\) is a valid set function iff it could be decomposed to the form \(\rho(\sum_{s\in S}(\phi(s))\) where \(\rho, \phi\) are possibly nonlinearities - this result is also used in the Deep Sets paper.
For each dataset, NPs have a set encoder \(E(D_c)\) which pass all input output pairs to some MLP \(\phi\) to get representation \(r_i\) for data \(x_i,y_i\) - and the output is \(E(D_c)=\sum_i r_i := r\).
The decoder then takes \(r\), and for each prediction point \(x_{t_i}\), passes the prediction point and \(r\) into another neural network \(\phi\) which then outputs the mean and variance at that point.
Using the encoder-decoder framework above results in what is called the Conditional Neural Process (CNP). The MLE estimate for a CNP is \(\theta^* = \text{argmax}_\theta \mathbb{E}_{(D_c,D_t)\sim P} \log p(D_t|D_c,\theta)\) where the \(D_t\) is the dataset to predict on. This is well-suited to scenarios where we may have multiple datasets where each dataset itself has few samples. This is in effect a Meta-Learning framework.
NPs are often used in datasets with all sorts of missing data, small dataset sizes, irreguarly sampled spatio-temporal data etc. They are also a natural model for Continual Learning, and also relevant for Sim2Real transfer where we could train on a simulator and deploy on real data.
Some problems CNPs face right now include underfitting, failure to extrapolate. One solution that helps is translation equivariance, i.e. shifting data by a constant should shift the predictions by the same constant. This is done by encoding \(D_c\) to a function, which can be done by replacing the MLP \(\phi\) with a CNN, and the encoder outputs the function \(E(D_c)(x')=\sum_{(x,y)\in D_c}\phi_y(y)\psi(x-x')\) - this makes the set function permutation invariant (like before) and also translation equivariant.