Lecture 29:
Advanced: LLMs
CS 136: Spring 2024
Katie Keith
Record on Zoom
📣 Announcements
This week: Advanced Topics
The last day of content that is included on the final exam is last Friday (May 2)
Two AI Paradigms
Symbolic AI
Statistical AI
We’ll explore:
A* Search
We’ll explore:
LLMs
🎯 Today’s Learning Objectives
Mainstream optimism for AI and LLMs
in policy and healthcare
Recent LLMs
Mainstream optimism for AI and LLMs
in policy and healthcare
Mainstream optimism for AI and LLMs
in policy and healthcare
Decisions
Policy
Treatment
Outcomes
Causal Inference!
In proximal causal inference:
Katie’s Research
LLMs (and other ML models) are helpful but not sufficient subroutines of data science and causal inference.
LLM-1
LLM-2
ChatGPT in 3 slides
.
Recite
Asimov's
first
law
.
A
robot
must
not
Recite
Asimov's
first
law
.
A
robot
must
User’s Prompt
Generated Text
Next | Prediction |
not | 0.81 |
fulfill | 0.13 |
adhere | 0.02 |
... | ... |
sample
P(next word | previous words)
Where does this come from?
Previous Words
2. Optimize
P(next word | previous words)�������
Gradient descent on non-linear function w/ billions of parameters!
1. Data from the internet
3. Data: Humans rank multiple output
#1
...
#2
#2 > #1
Loss function: Increases if incorrect prediction for a masked word
We use python to analyze our data and make visualizations.
The python slithered silently through the jungle underbrush.
Sum loss for billions of masked words
We use python to analyze our data and make visualizations.
The python slithered silently through the jungle underbrush.
We use python to analyze our data and make visualizations.
The camels slithered silently through the jungle underbrush.
We use python to analyze our data and make visualizations.
The python slithered silently through the jungle underbrush.
We help train ChatGPT too...
🎯 Today’s Learning Objectives
A line in 2D space
y = mx + b
m: slope
b: intercept
x
y
Linear regression, i.e. fitting lines
y = mx + b
m: slope
b: intercept
Linear regression:
Find m that minimizes sum of residuals squared
Linear regression beyond 2D
Source: Jacob Watters
Classification: Logistic regression
We want to learn a decision boundary between discrete classes.
Example: Classifying movie genres from their reviews
Warm-up. Draw me a decision boundary (straight line) that perfectly separates the two classes (colors).
💡Think-pair-share
History of deep learning: XOR problem
Minsky and Papert (1969) proved that a perceptron (a simple linear classifier with no non-linear activation) cannot solve the logical operation XOR.
Exclusive-or (XOR) outputs true (1) only when the inputs are different from each other
OR
x1=1
x2=1
x2=0
x1=0
x1=1
x2=1
x2=0
x1=0
XOR is not linearly separable
1
1
1
0
0
0
1
1
OR is linearly separable
XOR
Models
Input
Deep Learning Network
Output
Many deep learning “architecture” options
Metaphor: Stacking lego pieces
Logistic regression as a “shallow” network
Input
Output
Non-linear activation
Linear layer
Feature vector
*also called “fully connected” or “affine” layer
Feedforward deep learning network
Non-linear activation
Linear layer
Linear layer
Non-linear activation
“Architecture” design decision: we can stack as many layers as we want!
Input
Output
In a feedforward network, the computation proceeds iteratively from one layer of units to the next (and there are no cycles).
Deep networks can learn non-linear relationships
✅
🎯 Today’s Learning Objectives
Linear regression, aka fitting lines
y = mx + b
m: slope
b: intercept
Linear regression:
Find m that minimizes sum of residuals (d_i) squared
Loss function:
Mean squared error (MSE)
Training linear regression
For linear regression, we can derive an explicit formula for the optimal weights using the least-squares approach,
However, for logistic regression, there is no closed-form (analytical) solution for finding the optimal weights. We will need to use find the approximately optimal weights via computation.
Gradient Descent (intuition)
Thought experiment: How would you get to the bottom of a crater if you were blindfolded?
Coconino County, AZ
Meteor Impact Site
Gradient descent in two dimensions: surface plot
Take small steps in the “steepest downhill direction”
Loss function
Weight/Slope (Dim 2)
Weight/Slope (Dim 1)
Classification: Logistic regression
We want to learn a decision boundary between discrete classes.
Expensive to get classification labels
Definition by example
Guess. What is the meaning of the word tezgüino?
Word in context:
Example credit: Lin 1998, Eisenstein ANLP, 14.1
Distributional hypothesis (linguistics)
A word’s meaning can be derived from its context.
Example word in context:
J.R. Firth
Linguist, 1890-1960
“A word is characterized by the company it keeps.”
Self-supervision: Predict next words
Self-supervision is the process by which a model learns to predict part of the data using other parts of the same data as implicit labels.
Advantage: Cheap and abundant training data! No need for manually labeled training examples.
Model
Next word prediction ➡️ many tasks!
P(“The cat sat on the mat.”) > P(“The cat sats on the mat.”)
P(“The cat sat on the mat.”) > P(“The whale sat on the mat.”)
P(“4” | “2+2=”) > P(“5” | “2+2=”)
P(“1 star” | “That movie was terrible. I’d give it ”) > P(“5 starts” | “That movie was terrible. I’d give it ”)
Grammaticality; Subject-verb agreement
World Knowledge
Addition
Sentiment analysis
Examples adapted from Alec Radford
Next-word Prediction for LLMs
LLMs are pre-trained using a next word prediction self-supervised task.
Pseudocode:
Input: Corpus of text
For each masked token t:
Model predicts
Model weights trained via (variant of) gradient descent using true
(unmasked) word
Lots of variation in masking selection.
e.g., BERT: Randomly select 15%
LLMs require large amounts of compute
Apple M1 Pro 16-Core-GPU
5.3 x 10^12 FLOPS
GPT-3
(total train compute)
3.14 x 10^23 FLOPS
Factor of
~60 billion
Floating point operations per second
Autoregressive generation
Figure credit: Jay Alamar
An autoregressive model generates a token and then adds that generated token to its’ input sequence and repeats.
LLM
AI2 training OLMo “scare” 😱
When training OLMo, the AI2 team was monitoring the loss function and saw “fast spikes” in the figure below.
Loss
Gradient descent iteration
Slide credit: Hannaneh Hajishirzi, COLM, 2024.
Q: Guesses why? 🍬
“Bug” found in the training data
Slide credit: Hannaneh Hajishirzi, COLM, 2024.
Takeaway: “Garbage in, garbage out”
✅
✅
🎯 Today’s Learning Objectives
Unstructured text can be more abundant, expressive, and flexible than structured data
… but difficult to directly incorporate text data into existing causal methods
And only increasing…
For many applications, text data may be a summary (or record) of structured confounding variables
43
Blood thinner
Clot buster
e.g., clinical notes in an
electronic health record
C
Y
A
U
Text
atrial fibrillation
age, sex, severity,
family history…
Zero-shot classifiers perform an unseen task with no supervised examples
44
Figures credit: Wei et al. “Finetuned language models are zero-shot learners.” ICLR, 2022.
Zero-shot predictions from LLMs
45
Context: {Text} \n
Is it likely the patient has {U}?\n
Constraint: Even if you are uncertain, you must pick either “Yes” or “No” without using any other words.
Prompt template
Context: Patient reports intermittent episodes of rapid, fluttering heartbeats over the past week, often accompanied by lightheadedness and shortness of breath \n
Is it likely the patient has atrial fibrillation?\n
Constraint: Even if you are uncertain, you must pick either “Yes” or “No” without using any other words.
Example instance (contrived for this talk)
LLM
FLAN-T5 XXL
(Chung et al. 2024)
or
OLMo-7B-Instruct (Groeneveld et al. 2024)
“Yes”
Generated text output
Deterministic answer extraction
If “Yes” in output, W=1
Else W=0
*Same for Z
Simple set-up for our proof-of-concept. Likely could be engineered for improvement.
Our pipeline: proximal causal inference with text
46
MIMIC-III
clinical notes
1
Remove discharge summaries
Split via
metadata
Clinical note categories
2
3
3
Echocardiogram
Nursing notes
LLM-1
LLM-2
FLAN-T5
OLMo
5. Zero-shot
Odds ratio heuristic
Fails
Passes
7
8
6
Proximal g-formula
Estimate of
9
10
5. Zero-shot
4
4
Our pipeline: proximal causal inference with text
47
MIMIC-III
clinical notes
1
Remove discharge summaries
Split via
metadata
Clinical note categories
2
3
3
Echocardiogram
Nursing notes
LLM-1
LLM-2
FLAN-T5
OLMo
5. Zero-shot
Odds ratio heuristic
Fails
Passes
7
8
6
Proximal g-formula
Estimate of
9
10
5. Zero-shot
4
4
Our pipeline: proximal causal inference with text
48
MIMIC-III
clinical notes
1
Remove discharge summaries
Split via
metadata
Clinical note categories
2
3
3
Echocardiogram
Nursing notes
LLM-1
LLM-2
FLAN-T5
OLMo
5. Zero-shot
Odds ratio heuristic
Fails
Passes
7
8
6
Proximal g-formula
Estimate of
9
10
5. Zero-shot
Our pipeline: proximal causal inference with text
49
MIMIC-III
clinical notes
1
Remove discharge summaries
Split via
metadata
Clinical note categories
2
3
3
Echocardiogram
Nursing notes
LLM-1
LLM-2
FLAN-T5
OLMo
5. Zero-shot
Odds ratio heuristic
Fails
Passes
7
8
6
Proximal g-formula
Estimate of
9
10
5. Zero-shot
4
4
Our pipeline: proximal causal inference with text
50
MIMIC-III
clinical notes
1
Remove discharge summaries
Split via
metadata
Clinical note categories
2
3
3
Echocardiogram
Nursing notes
LLM-1
LLM-2
FLAN-T5
OLMo
5. Zero-shot
Odds ratio heuristic
Fails
Passes
7
8
6
Proximal g-formula
Estimate of
10
5. Zero-shot
4
4
Our pipeline: proximal causal inference with text
51
MIMIC-III
clinical notes
1
Remove discharge summaries
Split via
metadata
Clinical note categories
2
3
3
Echocardiogram
Nursing notes
LLM-1
LLM-2
FLAN-T5
OLMo
5. Zero-shot
Odds ratio heuristic
Fails
Passes
7
8
6
Proximal g-formula
Estimate of
9
10
5. Zero-shot
4
4
Results Highlights: Semi-synthetic for ACE estimates
One LLM (FLAN)
Two LLMs (FLAN, Olmo)
Estimated
ACE
W directly in Backdoor
One LLM
(FLAN)
Two LLMs
(FLAN, OLMo)
U: A-sis
(coronary atherosclerosis of the native coronary artery)
T1pre:
Echocardiogram
T2pre:
Radiology
LLM-1
LLM-2
Results Highlights: Semi-synthetic for ACE estimates
One LLM (FLAN)
Two LLMs (FLAN, Olmo)
Estimated
ACE
W directly in Backdoor
One LLM
(FLAN)
Two LLMs
(FLAN, OLMo)
U: A-sis
(coronary atherosclerosis of the native coronary artery)
T1pre:
Echocardiogram
T2pre:
Radiology
95% confidence intervals via bootstrap resampling
LLM-1
LLM-2
Results Highlights: Semi-synthetic for ACE estimates
One LLM (FLAN)
Two LLMs (FLAN, Olmo)
Estimated
ACE
W directly in Backdoor
One LLM
(FLAN)
Two LLMs
(FLAN, OLMo)
U: A-sis
(coronary atherosclerosis of the native coronary artery)
T1pre:
Echocardiogram
T2pre:
Radiology
95% confidence intervals via bootstrap resampling
Blue: Passed odds ratio heuristic
Red: Failed odds ratio heuristic
Takeaways:
LLM-1
LLM-2
LLMs helpful but not sufficient subroutine for causal inference with text
LLMs Helpful:
LLMs Not Sufficient:
LLM-1
LLM-2
✅
✅
✅
🎯 Today’s Learning Objectives