Right for the Right Reasons:
Training Differentiable Models by Constraining their Explanations
Andrew Slavin Ross, Michael C. Hughes, and Finale Doshi-Velez
August 24, 2017, IJCAI, Melbourne
doi.org/10.24963/ijcai.2017/371
Code & data: github.com/dtak/rrr
These slides: goo.gl/fMZiRu
Models don’t always learn what you think they learn
Models don’t always learn what you think they learn
The picture in my head
Parameter space
Loss function
Easily-found
local minimum
as snow detector
Much subtler
global minimum
as husky vs. wolf
face distinguisher
What are explanations?
What are explanations?
(let me try to explain...)
What are explanations?
(let me try to explain...)
What are explanations?
One approach: interpretable surrogates
(let me try to explain...)
[Ribiero et al., ACM 2016, again]
What are explanations?
One approach: interpretable surrogates
Another: gradients of output probabilities
with respect to input features
(let me try to explain...)
[Ribiero et al., ACM 2016, again]
What are explanations?
[Ribiero et al., ACM 2016, again]
One approach: interpretable surrogates
Another: gradients of output probabilities
with respect to input features
Actually quite similar!
(let me try to explain...)
Input gradients for image classifications
This kind of works!
So, we’re done, right?
This kind of works!
So, we’re done, right?
...what do we do if the explanations are wrong?
Optimizing for the right reason
Optimizing for the right reason
X, y
Training Examples X, y
Loss Function
θ
Model Parameters θ
θ′
L(θ|X, y)
Model that predicts y given X
Traditional ML
Optimizing for the right reason
X, y
Training Examples X, y
Loss Function
θ
Model Parameters θ
θ′
L(θ|X, y, A)
Model that predicts y given X because A
Explanation Annotations A
A
Traditional ML + explanation regularization
(Ideally, only need a small subset of X to have extra annotations)
Case 1: Annotations are given
How we encode domain knowledge
Features
Examples
Signifies that second feature of first example should be irrelevant to model’s prediction
Annotation Example
A = 1
A = 0
Our loss function
Our loss function
Our loss function
With some overall strength,
Our loss function
With some overall strength,
if a particular example
Our loss function
With some overall strength,
if a particular example’s feature
Our loss function
With some overall strength,
if a particular example’s feature
is marked irrelevant,
Our loss function
With some overall strength,
if a particular example’s feature
is marked irrelevant,
then penalize,
Our loss function
With some overall strength,
if a particular example’s feature
is marked irrelevant,
then penalize,
when the feature changes,
Our loss function
With some overall strength,
if a particular example’s feature
is marked irrelevant,
then penalize,
how much the prediction changes.
when the feature changes,
Experiments
Basic philosophy:
Models we used:
Experiments: Toy Colors
Class 1
Class 2
All colors
shared
At least two different
Experiments: Toy Colors
Class 1
Class 2
All colors
different
At least two shared
Learning an otherwise-unreachable rule
By default, model appears to learn corner rule.
If we penalize corners, model discovers top-mid rule!
Pixels w/ largest magnitude gradients
How regularization strength affects what we learn
Smooth transition between model learning each rule!
(Can transition with 10s of annotations if we oversample in minibatches)
Experiments: Decoy MNIST
Swatch shades a simple function of y in train, but not in test.
Experiments: Decoy MNIST
Normal model has low accuracy; gradients focus on swatches
+
= increasing pixel increases predicted label prob
-
= increasing pixel decreases predicted label prob
Experiments: Decoy MNIST
Model with gradient regularization recovers baseline accuracy!
+
= increasing pixel increases predicted label prob
-
= increasing pixel decreases predicted label prob
Case 2: What if we don’t have annotations?
Find-another-explanation
Find-another-explanation
Overall goal: obtain an ensemble of models that are all accurate but for different reasons.
Find-another-explanation
X, y
Training examples
Overall goal: obtain an ensemble of models that are all accurate but for different reasons.
Find-another-explanation
X, y
Training examples
θ1
X, y
Overall goal: obtain an ensemble of models that are all accurate but for different reasons.
Find-another-explanation
X, y
Training examples
θ1
X, y
Overall goal: obtain an ensemble of models that are all accurate but for different reasons.
A1
Find-another-explanation
Penalize 2nd model from having large gradients where 1st model did!
X, y
Training examples
θ1
θ2
X, y
X, y
A1
Overall goal: obtain an ensemble of models that are all accurate but for different reasons.
A1
Find-another-explanation
Penalize 2nd model from having large gradients where 1st model did!
X, y
Training examples
θ1
θ2
X, y
X, y
A1
Overall goal: obtain an ensemble of models that are all accurate but for different reasons.
A2
A1
Find-another-explanation
Penalize 2nd model from having large gradients where 1st model did!
3rd model should be dissimilar from both 1st and 2nd!
X, y
Training examples
θ1
θ2
θ3
X, y
X, y
X, y
A1
Overall goal: obtain an ensemble of models that are all accurate but for different reasons.
A1 ∪ A2
A2
A1
Find-another-explanation
Penalize 2nd model from having large gradients where 1st model did!
3rd model should be dissimilar from both 1st and 2nd!
X, y
Training examples
θ1
θ2
θ3
X, y
X, y
X, y
A1
A1 ∪ A2
A2
Overall goal: obtain an ensemble of models that are all accurate but for different reasons.
And so on...
A1
Back to the picture in my head
Parameter space
Loss function
Iterate through local minima
Find-another-explanation: Toy Colors
Model initially learns corner rule, falls back to top-three rule, then fails to learn anything.
Find-another-explanation: Decoy MNIST
Models initially learn decoy rule, then use other features.
Accuracy falls, but very slowly (MNIST is redundant)
Summary / Contributions
For when learning from X, y alone is insufficient:
Summary / Contributions
For when learning from X, y alone is insufficient:
May be more common than we think!
Future Work
Future Work
These slides again: goo.gl/fMZiRu
Learning with less data?
Best if “right answers” term ≈ “right reasons” term
Gradients are consistent with LIME but less sparse
Setting A = 1 for all features (“certainty regularization”)
...for certainty-regularized CNN:
Sum-of-log-prob input gradients for normal CNN:
Certainty-regularized CNN also much more resistant to adversarial perturbations
Input images:
Information-Theoretic Interpretation of Loss Function
= K * Cross-entropy of prediction w/ uniformly random guess
“distance from total uncertainty”