1 of 21

Babysitting a Small Language Model through One-Step Tree-of-Thoughts Knowledge Distillation

Zichen Zhang, Anurag Renduchintala, Shangjun Meng,

Adi Mahesh, Samuel Fang, Zimo Si

{zhangzzc, ranurag, shangjun, mahesha, swfang, zimosi}@umich.edu

My dear small model,

when you think, you need to explore different thoughts and backtrack if that thought does not work out …

University of Michigan

2 of 21

Thought 1:

Who was the founder of Apple?

- Steve Jobs

Thought 2:

When was Steve Jobs born?

- Feb 24, 1955 (Final Answer 🎉)

When was the founder of Apple born?”

zhangzzc

We think and explore different thoughts!

Key to our Human Intelligence!

University of Michigan

3 of 21

The key to intelligence is reasoning.

All complicated problems require reasoning.

Machines can’t be intelligent

if they can’t reason!

zhangzzc

University of Michigan

4 of 21

LLMs can talk, but not reason very well (arithmetically)

Reasoning problems (like Game of 24) require multiple steps.

LLMs can’t inherently explore thoughts or plan ahead, due to their linear thought processes!

Example Input: Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.

4 4 6 8

Answer: (4 * 6) * (8 - 4) = 24

Steps:

4 + 8 = 12 (left: 4 6 12)

6 - 4 = 2 (left: 2 12)

2 * 12 = 24 (left: 24)

Answer: (6 - 4) * (4 + 8) = 24

Our Human’s Logic

Output from GPT-4o

Expression doesn’t equal 24!

Hmmm🤔,

how do we teach models to reason?

zhangzzc

University of Michigan

5 of 21

4 4 6 8

Sorry, I don’t know the answer …

4 4 6 8

Thought 1:

4 + 8 = 12 (left: 4 6 12)

Thought 2:

6 - 4 = 2 (left: 2 12)

Answer: (6 - 4) * (4 + 8) = 24

Regular IO Prompting

Chain of Thoughts

mahesha

University of Michigan

6 of 21

Chain of Thought continued

5 exemplar answers were given to the GPT with correct thoughts

Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.

Input: 4 4 6 8

Steps:

4 + 8 = 12 (left: 4 6 12)

6 - 4 = 2 (left: 2 12)

2 * 12 = 24 (left: 24)

Answer: (6 - 4) * (4 + 8) = 24

…(4 more examples)...

Input: 4 10 12 13

CoT Prompt Example

Input: 4 10 12 13

Steps:

13 - 10 = 3 (left: 4 12 3)

12 * 3 = 36 (left: 4 36)

36 - 12 = 24 (left: 24)

Answer: ((13 - 10) * 12) - 12 = 24

CoT Output

mahesha

Examples shown to the model

What is the alternative?

University of Michigan

7 of 21

Original Tree of Thoughts (Multi-Step)

4 4 6 8

Thought 1b:

4 + 8 = 12 (left: 4 6 12)

Thought 1a:

4 + 4 = 8 (left: 8 6 8)

Thought 2a:

6 + 4 = 10 (left: 10 12)

Thought 2b:

4 + 8 = 12 (left: 4 6 12)

Thought 1c:

4 + 6 = 10 (left: 4 10 8)

Thought 2c:

6 - 4 = 2 (left: 2 12)

Answer: (6 - 4) * (4 + 8) = 24

swfang

Generate Prompt:

Input: 4 6 8 8

Please generate the possible next steps:

4 + 4 = 8 (left: 8 6 8)

4 + 8 = 12 (left: 4 6 12)

Evaluate: Can 8, 6, and 8 reach 24?

Impossible

Evaluate: Can 4, 6, and 12 reach 24?

Likely/Promising

Each Intermediate Step:

University of Michigan

8 of 21

Extension: One-Step Tree-of-Thoughts

4 4 6 8

Thought 1b:

4 + 8 = 12 (left: 4 6 12)

Thought 1a:

4 + 4 = 8 (left: 8 6 8)

Thought 2a:

6 + 4 = 10 (left: 10 12)

Thought 2b:

4 + 8 = 12 (left: 4 6 12)

Thought 1c:

4 + 6 = 10 (left: 4 10 8)

Thought 2c:

6 - 4 = 2 (left: 2 12)

Answer: (6 - 4) * (4 + 8) = 24

Original (Multi-Step) ToT

Problems:

  • Difficult to Generalize to other Tasks
  • Infeasible to train on in knowledge distillation

One-Step ToT

Thought 1:

4 + 8 = 12 (left: 4 6 12)

Thought 2:

6 - 4 = 2 (left: 2 12)

Answer: (6 - 4) * (4 + 8) = 24

4 4 6 8

Try a path (a pair of two numbers), see if the remaining numbers can possibly reach the goal 24. If not, backtrack and attempt another

Attempts To:

  • Emulate Original ToT
  • Use system prompting to induce ToT reasoning
  • Achieve better performance than CoT
  • Create a more generalizable framework
  • Allow for Knowledge Distillation

swfang

University of Michigan

9 of 21

How well does One-Step ToT do in comparison to the original CoT and ToT frameworks?

Additionally, do the failures at each step significantly differ between One-Step ToT and CoT, or are they relatively similar?

ranurag

University of Michigan

10 of 21

One-Step ToT is less likely to fail at the first two steps

ranurag

CoT Performance

(Replication)

One-Step ToT Performance

(Extension)

University of Michigan

11 of 21

ranurag

Our One-Step ToT is better than CoT

Replication (Baselines)

Extension

University of Michigan

12 of 21

Now we have a reasoner,

but we want to get an efficient reasoner!

i.e. How can we teach step-by-step reasoning to a small language model?

zimosi

University of Michigan

13 of 21

Building efficient reasoners: Small Language Models

Pros:

  • Fast!
  • Energy and Resource Efficient!

but…

Small Language Models (SLMs) have much smaller number of parameters, which means…

GPT-4o

Estimate: ~200 billion to 1 trillion parameters

SmolLM-360M

360 Million Parameters

VS

zimosi

University of Michigan

14 of 21

Naive SLMs can’t do reasoning at all

zimosi

CoT even degrades SLM performance

x-axis: Model Size

y-axis: Performance

Cons:

  • Less capability for reasoning
  • Poor Performance even under CoT and ToT :(

University of Michigan

15 of 21

Perhaps we could use certain methods to induce CoT or ToT capability in SLMs?...

– Use One-Step ToT for

Knowledge Distillation!

teaches

shangjun

University of Michigan

16 of 21

Extension: Knowledge Distillation on Small Language Models

  • Essentially teaching SLM to emulate LLM
  • A next-token-prediction task with original prompt as the input, �LLM output as target seq.

Teacher (LLM)

Student (SLM)

1) One-Step ToT Prompting

2) Game of 24 Puzzle

LLM’s Response based on One-Step ToT Prompt

1) One-Step ToT Prompting

2) Game of 24 Puzzle

3) LLM’s Response

Synthesized Dataset for Fine-tuning

SLM’s Next Word Prediction

Loss

Filter out incorrect

responses

shangjun

University of Michigan

17 of 21

LLMs are really bad at math 😭

Our fine-tuned model achieves best accuracy with only a tiny fraction of parameters than GPT-4o!

shangjun

University of Michigan

18 of 21

Thank you for listening!

Questions?

University of Michigan

19 of 21

4 4 6 8

4 4 6 8

Sorry, I don’t know the answer …

Thought 1b:

4 + 8 = 12 (left: 4 6 12)

Thought 1a:

4 + 4 = 8 (left: 8 6 8)

Thought 2a:

6 + 4 = 10 (left: 10 12)

Thought 2b:

4 + 8 = 12 (left: 4 6 12)

Thought 1c:

4 + 6 = 10 (left: 4 10 8)

Thought 2c:

6 - 4 = 2 (left: 2 12)

Answer: (6 - 4) * (4 + 8) = 24

4 4 6 8

Thought 1:

4 + 8 = 12 (left: 4 6 12)

Thought 2:

6 - 4 = 2 (left: 2 12)

Answer: (6 - 4) * (4 + 8) = 24

Thought 1:

4 + 8 = 12 (left: 4 6 12)

Thought 2:

6 - 4 = 2 (left: 2 12)

Answer: (6 - 4) * (4 + 8) = 24

4 4 6 8

Try a path (a pair of two numbers), see if the remaining numbers can possibly reach the goal 24. If not, backtrack and attempt another

University of Michigan

20 of 21

Teacher (LLM)

Student (SLM)

1) One-Step ToT Prompting

2) Game of 24 Puzzle

LLM’s Response based on One-Step ToT Prompt

1) One-Step ToT Prompting

2) Game of 24 Puzzle

3) LLM’s Response

Synthesized Dataset for Fine-tuning

SLM’s Next Word Prediction

Loss

Filter out incorrect

responses

University of Michigan

21 of 21

University of Michigan