1 of 16

Mamba

2 of 16

Transformers can be different:

  • During qkv, calculating the attention score of each token (relative to every other) is O(n^2) according to the length of the context window.
  • The quadratically-growing key-value cache needs to be stored alongside the model during inference.
  • having a context window at all is rather limiting

If we can do this better, we can hopefully be better at speech and video (context-heavy tasks)

3 of 16

What are State Space Models?

We can think of SSMs as blackbox mapping u(t) → y(t):

  • A, B, C, and D are learnable latent parameters
  • And x(t) is a solution to the linear ODE that represents the latent representation

4 of 16

SSMs: Continuous, Recurrent, and Convolutional

SSMs transform into different views:

Recurrent:

Convolutional:

5 of 16

Let’s drop the state space model idea

Let’s say we’re just looking at the recurrent version: RNNs

Pros:

  • No context window (unlike convolutional view)
  • Efficient constant time inference (unlike the continuous view),

Cons:

  • Not parallelizable
  • Exploding/vanishing gradients (if we truly want a large effective context)

6 of 16

Linear RNNs

7 of 16

How to Parallelize?

Normal RNNs are too complicated: let’s remove the activation function.

8 of 16

Blelloch

Blelloch scan allows us to find the prefix sum of an array very quickly

This is only because addition is associative.

9 of 16

Associative RNN Iteration

This function turns out to be associative, allowing us to iterate over all inputs in the RNN quickly (Wh and Wx folded into W)

10 of 16

Associative RNN Iteration

Note this also means we have to cache Wi Wj … Wk which is d x d for each position in the array.

That’s a lot — but luckily we can diagonalize W and simply store diagonal elements

11 of 16

We’re now parallelizable!

We are now parallelizable in O(n log(n)) time!

Cool facts

  • P and P-1 are learned by a model to not have to deal with and matrix inverting, while adding more expressivity
  • We still want nonlinearity so we can add a nonlinear layer after doing all the recurrent iterations (which will be much quicker) — just like the dense layer after attention

12 of 16

But exploding gradients?

Initialize initial weights very close to 1

And multiply all inputs by a very small number because our model is sensitive!

x := Δx

13 of 16

Mamba

14 of 16

Selective SSM: Adding a Gate

RNNs have to hold too much info in ht. We want to be selective on what to hold

15 of 16

Selectivity Implemented

Remember that this is essentially an RNN

Here — we define functions that parameterize Wh and Wx based on inputs themselves.

(Introduction of L = length → time-dependent. Input-dependency → batches different)

16 of 16

Selectivity

“Selecting functions” are chosen where

such that given with A = -1 and B = 1, the gate at each head ends up looking like

Eloquent, isn’t it?