Deep Symbolic Regression: Recovering Mathematical Expressions from Data via Risk-Seeking Policy Gradients - 2021
Details
Title : Deep Symbolic Regression: Recovering Mathematical Expressions from Data via Risk-Seeking Policy Gradients Author(s): Petersen, Brenden K. and Landajuela, Mikel and Mundhenk, T. Nathan and Santiago, Claudio P. and Kim, Soo K. and Kim, Joanne T. Link(s) : http://arxiv.org/abs/1912.04871
Rough Notes
Notes from online talk
The problem of discrete sequence optimization is to find the following: \(\text{argmax}_{n\leqN, \tau_1,\cdots,\tau_n} R(\tau_1,\cdots,\tau_n)\) where \(\tau_i\) is from a finite set of symbols \(\mathcal{ L }\) and \(R\) is some blackbox reward function. Some examples are neural architecture search, antibody design, and related to this work, symbolic regression.
In Symbolic Regression, given a dataset, we aim to find a mathematical expression \(f\) alongside its parameters such that \(f(x)\approx y\). There are community-vetter benchmarks, strong existing baselines (typically Genetic Algorithms), and a computationally cheap reward function, for e.g. compared to neural architecture search which requires training the network for each evaluation. This is highly applicable in AI for Science.
Expressions can be described as trees, and the trees used in symbolic regression have a one-to-one correspondence with sequences since for each operation node, we know how many children it can have, based on whether the operation is binary or unary.
This work uses an RNN. There is no explicit stop state - the sequence generation ends once there is no space for non-operation states (input variables or constants).
This approach makes it easy to incorporate domain knowledge, e.g. constraints like no trigonometric function can be descents of a trigonometric functions could be added by changing the softmax to always give 0 to this constraint during sequence generation. Other common constraints include:
- Each input variable must appear at least once.
- Each operator/output must follow its physical units (requires attaching units to input variables).
- Its length.
The RNN is a generative model which defines a distribution \(\pi(\theta)\) over mathematical expressions (#DOUBT Is it?). However, the negative MSE \(R(\tau)\) is not differentiable with respect to the RNN parameters. But, note that \(J(\theta)=\mathbf{ E }_{\tau \sim \pi(\theta)}[R(\tau)]\) is, allowing for using policy gradients.
Policy gradients maximize the expected reward, but the search algorithms are evaluated on the single or few best performing samples - so how can we optimize for best-case performance instead of average case performance? To do this, this work introduces Risk-Seeking policy gradients, where the objective function is now \(J'(\theta)=\mathbf{ E}_{\tau\sim \pi(\theta)}[R(\tau)|R(\tau)>R_\epsilon(\theta)]\), i.e. maximize the goal conditioned on the goal being greater than some quantile. In short:
- Use the RNN to generate \(N\) sample symbolic expression sequences, compute the rewards for each of them, filter them to get the highest performing ones, then train on this filtered batch to update the RNN parameters, and move on.
Experimental setup:
- Nyugen benchmarks: only 20 data points, success is measured in symbolic equivalence.
Baselines:
- PQT: Priority quite training
- VPG: Vanilla policy gradient
- Eureqa (commercial software)
- Wolfram (commercial software)
Symbolic expressions also give better generalization (Wignerian Prior). E.g. for predicting harmonic numbers, MLPs extrapolate poorly compared to the symbolic regression example (even though there is no closed form expression for Harmonic numbers). They also proved that \(\forall n \in \mathbb{ N }\) their symbolic regression expression has an error of less than 0.000001%, and even approximated the constant (Euler-Mascharoni constant) in the expression correct to 5-6 decimals.
For noisy data, the recovery rate drop, giving incorrect answers that fit the data better, overfitting in a different form (#DOUBT What does he mean by that), in the limit of infinite data noise is not a problem.
To prevent sampling equivalent expressions (e.g. \(a+b\) and \(b+a\)) is a needles-in-the-haystack problem, we could define a semantic equivalence class but not sure to search over a canonical example of this class.