Babysitting a Small Language Model through One-Step Tree-of-Thoughts Knowledge Distillation
Zichen Zhang, Anurag Renduchintala, Shangjun Meng,
Adi Mahesh, Samuel Fang, Zimo Si
{zhangzzc, ranurag, shangjun, mahesha, swfang, zimosi}@umich.edu
My dear small model,
when you think, you need to explore different thoughts and backtrack if that thought does not work out …
University of Michigan
Thought 1:
Who was the founder of Apple?
- Steve Jobs
Thought 2:
When was Steve Jobs born?
- Feb 24, 1955 (Final Answer 🎉)
“When was the founder of Apple born?”
zhangzzc
We think and explore different thoughts!
Key to our Human Intelligence!
University of Michigan
The key to intelligence is reasoning.
All complicated problems require reasoning.
Machines can’t be intelligent
if they can’t reason!
zhangzzc
University of Michigan
LLMs can talk, but not reason very well (arithmetically)
Reasoning problems (like Game of 24) require multiple steps.
LLMs can’t inherently explore thoughts or plan ahead, due to their linear thought processes!
Example Input: Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
4 4 6 8
Answer: (4 * 6) * (8 - 4) = 24
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Our Human’s Logic
Output from GPT-4o
Expression doesn’t equal 24!
Hmmm🤔,
how do we teach models to reason?
zhangzzc
University of Michigan
4 4 6 8
Sorry, I don’t know the answer …
4 4 6 8
Thought 1:
4 + 8 = 12 (left: 4 6 12)
Thought 2:
6 - 4 = 2 (left: 2 12)
…
Answer: (6 - 4) * (4 + 8) = 24
Regular IO Prompting
Chain of Thoughts
mahesha
University of Michigan
Chain of Thought continued
5 exemplar answers were given to the GPT with correct thoughts
Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
…(4 more examples)...
Input: 4 10 12 13
CoT Prompt Example
Input: 4 10 12 13
Steps:
13 - 10 = 3 (left: 4 12 3)
12 * 3 = 36 (left: 4 36)
36 - 12 = 24 (left: 24)
Answer: ((13 - 10) * 12) - 12 = 24
CoT Output
mahesha
Examples shown to the model
What is the alternative?
University of Michigan
Original Tree of Thoughts (Multi-Step)
4 4 6 8
Thought 1b:
4 + 8 = 12 (left: 4 6 12)
Thought 1a:
4 + 4 = 8 (left: 8 6 8)
Thought 2a:
6 + 4 = 10 (left: 10 12)
Thought 2b:
4 + 8 = 12 (left: 4 6 12)
Thought 1c:
4 + 6 = 10 (left: 4 10 8)
Thought 2c:
6 - 4 = 2 (left: 2 12)
…
Answer: (6 - 4) * (4 + 8) = 24
swfang
Generate Prompt:
Input: 4 6 8 8
Please generate the possible next steps:
4 + 4 = 8 (left: 8 6 8)
4 + 8 = 12 (left: 4 6 12)
Evaluate: Can 8, 6, and 8 reach 24?
Impossible
Evaluate: Can 4, 6, and 12 reach 24?
Likely/Promising
Each Intermediate Step:
University of Michigan
Extension: One-Step Tree-of-Thoughts
4 4 6 8
Thought 1b:
4 + 8 = 12 (left: 4 6 12)
Thought 1a:
4 + 4 = 8 (left: 8 6 8)
Thought 2a:
6 + 4 = 10 (left: 10 12)
Thought 2b:
4 + 8 = 12 (left: 4 6 12)
Thought 1c:
4 + 6 = 10 (left: 4 10 8)
Thought 2c:
6 - 4 = 2 (left: 2 12)
…
Answer: (6 - 4) * (4 + 8) = 24
Original (Multi-Step) ToT
Problems:
One-Step ToT
Thought 1:
4 + 8 = 12 (left: 4 6 12)
Thought 2:
6 - 4 = 2 (left: 2 12)
…
Answer: (6 - 4) * (4 + 8) = 24
4 4 6 8
Try a path (a pair of two numbers), see if the remaining numbers can possibly reach the goal 24. If not, backtrack and attempt another
Attempts To:
swfang
University of Michigan
How well does One-Step ToT do in comparison to the original CoT and ToT frameworks?
Additionally, do the failures at each step significantly differ between One-Step ToT and CoT, or are they relatively similar?
ranurag
University of Michigan
One-Step ToT is less likely to fail at the first two steps
ranurag
CoT Performance
(Replication)
One-Step ToT Performance
(Extension)
University of Michigan
ranurag
Our One-Step ToT is better than CoT
Replication (Baselines)
Extension
University of Michigan
Now we have a reasoner,
but we want to get an efficient reasoner!
i.e. How can we teach step-by-step reasoning to a small language model?
zimosi
University of Michigan
Building efficient reasoners: Small Language Models
Pros:
but…
Small Language Models (SLMs) have much smaller number of parameters, which means…
GPT-4o
Estimate: ~200 billion to 1 trillion parameters
SmolLM-360M
360 Million Parameters
VS
zimosi
University of Michigan
Naive SLMs can’t do reasoning at all
zimosi
CoT even degrades SLM performance
x-axis: Model Size
y-axis: Performance
Cons:
University of Michigan
Perhaps we could use certain methods to induce CoT or ToT capability in SLMs?...
– Use One-Step ToT for
Knowledge Distillation!
teaches
shangjun
University of Michigan
Extension: Knowledge Distillation on Small Language Models
Teacher (LLM)
Student (SLM)
1) One-Step ToT Prompting
2) Game of 24 Puzzle
LLM’s Response based on One-Step ToT Prompt
1) One-Step ToT Prompting
2) Game of 24 Puzzle
3) LLM’s Response
Synthesized Dataset for Fine-tuning
SLM’s Next Word Prediction
Loss
Filter out incorrect
responses
shangjun
University of Michigan
LLMs are really bad at math 😭
Our fine-tuned model achieves best accuracy with only a tiny fraction of parameters than GPT-4o!
shangjun
University of Michigan
Thank you for listening!
Questions?
University of Michigan
4 4 6 8
4 4 6 8
Sorry, I don’t know the answer …
Thought 1b:
4 + 8 = 12 (left: 4 6 12)
Thought 1a:
4 + 4 = 8 (left: 8 6 8)
Thought 2a:
6 + 4 = 10 (left: 10 12)
Thought 2b:
4 + 8 = 12 (left: 4 6 12)
Thought 1c:
4 + 6 = 10 (left: 4 10 8)
Thought 2c:
6 - 4 = 2 (left: 2 12)
…
Answer: (6 - 4) * (4 + 8) = 24
4 4 6 8
Thought 1:
4 + 8 = 12 (left: 4 6 12)
Thought 2:
6 - 4 = 2 (left: 2 12)
…
Answer: (6 - 4) * (4 + 8) = 24
Thought 1:
4 + 8 = 12 (left: 4 6 12)
Thought 2:
6 - 4 = 2 (left: 2 12)
…
Answer: (6 - 4) * (4 + 8) = 24
4 4 6 8
Try a path (a pair of two numbers), see if the remaining numbers can possibly reach the goal 24. If not, backtrack and attempt another
University of Michigan
Teacher (LLM)
Student (SLM)
1) One-Step ToT Prompting
2) Game of 24 Puzzle
LLM’s Response based on One-Step ToT Prompt
1) One-Step ToT Prompting
2) Game of 24 Puzzle
3) LLM’s Response
Synthesized Dataset for Fine-tuning
SLM’s Next Word Prediction
Loss
Filter out incorrect
responses
University of Michigan
University of Michigan