Mamba
Transformers can be different:
If we can do this better, we can hopefully be better at speech and video (context-heavy tasks)
What are State Space Models?
We can think of SSMs as blackbox mapping u(t) → y(t):
SSMs: Continuous, Recurrent, and Convolutional
SSMs transform into different views:
Recurrent:
Convolutional:
Let’s drop the state space model idea
Let’s say we’re just looking at the recurrent version: RNNs
Pros:
Cons:
Linear RNNs
How to Parallelize?
Normal RNNs are too complicated: let’s remove the activation function.
Blelloch
Blelloch scan allows us to find the prefix sum of an array very quickly
This is only because addition is associative.
Associative RNN Iteration
This function turns out to be associative, allowing us to iterate over all inputs in the RNN quickly (Wh and Wx folded into W)
Associative RNN Iteration
Note this also means we have to cache Wi Wj … Wk which is d x d for each position in the array.
That’s a lot — but luckily we can diagonalize W and simply store diagonal elements
We’re now parallelizable!
We are now parallelizable in O(n log(n)) time!
Cool facts
But exploding gradients?
Initialize initial weights very close to 1
And multiply all inputs by a very small number because our model is sensitive!
x := Δx
Mamba
Selective SSM: Adding a Gate
RNNs have to hold too much info in ht. We want to be selective on what to hold
Selectivity Implemented
Remember that this is essentially an RNN
Here — we define functions that parameterize Wh and Wx based on inputs themselves.
(Introduction of L = length → time-dependent. Input-dependency → batches different)
Selectivity
“Selecting functions” are chosen where
such that given with A = -1 and B = 1, the gate at each head ends up looking like
Eloquent, isn’t it?