Do We Need Attention? A Mamba Primer - 2024
Details
Title : Do We Need Attention? A Mamba Primer Author(s): Sasha Rush 🤗 Link(s) : https://www.youtube.com/watch?v=dVH1dRoMPBc
Rough Notes
The goal of State-Space Models (SSMs) for language like Mamba is to overcome the scaling issues with Transformers, where inference is \(\mathcal{ O }(L)\) and training is \(\mathcal{ O }(L^2)\) where \(L\) is the context length.
SSMs have some desirable properties:
- Fixed-size memory, constant during inference.
- Compute is linear in length.
Plan of this tutorial:
- Understanding the model.
- Computing the model.
- Designing effective versions.
- Scaling the model to its best.
The SSM takes in sequences of tokens - each token is processed by combining it with previous hidden states to produce a new hidden state. Each hidden state is processed to produce an output token. We start with the vanilla RNN: \[ h_k = \sigma(\bar{A}h_{k-1} + \bar{B}x_k) \] \[ y_k = Ch_k \]
The Linear Time Invariant (LTI) model is the same as the vanilla RNN except the nonlinearity is not used - however the common wisdom that these are not good was overturned by the S4 model.
However in overall perplexity, LTI models are not as good as transformers. The Mamba paper highlights two issues that limit LTI models for language modelling:
- Lack of ability to filter, i.e. LTI models cannot ignore tokens. This is because \(\bar{B}\) is the same.
- Lack of ability to reset, i.e. LTI models cannot reset histroy. This is because \(\bar{A}\) is the same.
Historically, LSTMs and GRUs were introduced which introduce explicit gating to allow for ignoring and throwing away hidden states (note that LSTMs and GRUs have nonlinearities while LTIs do not).
We can expand LTI models to Linear Time Varying (LTV) models, where the matrices are now time-dependent, i.e. \[ h_k = \bar{A}_kh_{k-1} + \bar{B}_k x_k \] \[ y_k = C_k h_k \]
This allows for resetting by setting \(\bar{A}_k\) to 0, and filtering by allowing \(\bar{B}_k\) to 0. These matrices can be incorporated as functions of the tokens \(x_k\).
For efficient computation, we use methods for computing cumulative sums, namely associative scans. For LTIs and LTVs, we can define a new associative operator for efficient computation.
Now, how to actually get the \(\bar{A},\bar{B},\bar{C}\) parameters?
- Option 1: (Griffin) Use a reccurence and input gate.
- Option 2: (RetNet) Linear Attention.
- Option 3: (Mamba) Continuous-time SSM.
We now narrow down to Mamab. Since the LTVs are in discrete steps, we first discrete the continuous SSM. The time gaps between the tokens, \(\Delta_i\) are predicted from the tokens \(x_i\). We then have: \[ \bar{A}_k = \exp(\Delta_k A) \] \[ \bar{B}_k = (\bar{A}_k - 1)(B(t) / A) \]
Recall the issues with LTIs:
- Filtering (LTIs cannot ignore tokens). Intuitively, filtering can be thought of as compressing the token to the smallest possible value in time, i.e. \(\Delta_i\to 0\) for tokens \(x_i\) we want to filter.
- Resetting (LTIs cannot reset history by setting hidden state to 0). If we for e.g. go to the title of a new chapter, we want to reset the hidden state, we can set the token to have infinite length i.e. \(\Delta_i \to \infty\), modelling that a lot of time has passed and we are starting in a new 'state'.
Now, how do we scale a fixed-state SSM? We can make sure the hidden state is only initialized in SRAM.