1 of 23

Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling

Liliang Ren, Yang Liu, Yadong Lu, Yelong Shen, Chen Liang, Weizhu Chen

Microsoft GenAI

UIUC

2 of 23

Context length of LLMs

2

From Gemini 1.5 blog

3 of 23

How to support infinite context length?

  • No Full Attention
    • We don’t have infinite physical memory and I/O is costly
  • We need Constant Memory size + Dynamic Memory Content
    • SeqBoat: Sliding Window Attention (SWA) with Hard Input Selection
  • Samba: SSMs with soft input selection + SWA

3

Mamba Layer

4 of 23

A closer look into Mamba layer

4

  • Selective SSMs (S6):
  • Parallelized with Prefix-Sum in O(log(n)) with n workers
  • O(1) inference complexity per step

[1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces

  • Mamba = GLU + ShortConv+S6

5 of 23

A closer look into Mamba layer

  • Mamba’s scan operator has a similar speed as FA2 at 2K length

5

[1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces

6 of 23

Why adding SWA?

  • RNN:
    • (selectively) memorize/compress input into hidden state
    • Can recall from compressed memory
    • Cannot extrapolate infinitely
    • Constant memory size

  • SWA:
    • Can extrapolate
    • Can retrieval exact memory within context window
    • Constant memory size

6

7 of 23

Samba Architecture

7

  • We use 2K window size for SWA and train all the model with 4K sequence length.

8 of 23

Samba 1.7B on 230B tokens from Phi2

8

9 of 23

Samba 3.8B on 3.2T tokens from Phi3

9

  • Outperforms Transformer++ by a large margin on most tasks

10 of 23

Efficient Length Extrapolation

  • We include the SoTA length extrapolation method Self-Extend for Llama-3

10

11 of 23

Samba can memorize long-term information

  • Samba 1.7B (left) and Mistral 1.6B (right) instruction tuned on Passkey Retrieval with 4k length

11

12 of 23

Training Curves for Instruction Tuning

  • Samba can have near perfect retrieval accuracy within 150 steps SFT

12

13 of 23

Samba is good at long-context summarization

  • We instruction tuned Samba following the recipes from Phi3-mini.

13

14 of 23

What about training on open-source data?

14

Models trained on SlimPajama

15 of 23

Ablation: How to train models with SWA?

  • We fix the window size to be 2K for Llama-2 with SWA 438M

15

16 of 23

Why not hybridize with full attention?

  • We replace Mamba with full attention at different positions

16

17 of 23

How to allocate parameters for attention?

  • We use GQA/MQA and increase the MLP size.

17

18 of 23

Why hybrid works?

18

  • Samba has higher selection entropy and smaller attention entropy in the middle

19 of 23

Effect of Short Convolution

  • ShortConv can even boost SWA’s performance

19

20 of 23

Conclusion

  • Samba outperforms SoTA Transformers on short-context tasks
  • Samba can memorize long-term information through instruction tuning
  • Samba has linear complexity with unlimited context length
  • Still need improvement for long-term retrieval
    • We may need Dynamically Sparse SWA from SeqBoat

20

21 of 23

Future Directions

  • How to better post-train Samba/SSMs?
  • How to improve the recall ability of Samba while still keeping linear complexity?
  • How to have task-adaptive hybrid architecture?

21

22 of 23

Thanks for your time!

23 of 23