1 of 59

Right for the Right Reasons:

Training Differentiable Models by Constraining their Explanations

Andrew Slavin Ross, Michael C. Hughes, and Finale Doshi-Velez

August 24, 2017, IJCAI, Melbourne

doi.org/10.24963/ijcai.2017/371

Code & data: github.com/dtak/rrr

These slides: goo.gl/fMZiRu

2 of 59

Models don’t always learn what you think they learn

3 of 59

Models don’t always learn what you think they learn

4 of 59

The picture in my head

Parameter space

Loss function

Easily-found

local minimum

as snow detector

Much subtler

global minimum

as husky vs. wolf

face distinguisher

5 of 59

What are explanations?

6 of 59

What are explanations?

(let me try to explain...)

7 of 59

What are explanations?

(let me try to explain...)

8 of 59

What are explanations?

One approach: interpretable surrogates

(let me try to explain...)

[Ribiero et al., ACM 2016, again]

9 of 59

What are explanations?

One approach: interpretable surrogates

Another: gradients of output probabilities

with respect to input features

(let me try to explain...)

10 of 59

What are explanations?

One approach: interpretable surrogates

Another: gradients of output probabilities

with respect to input features

Actually quite similar!

(let me try to explain...)

11 of 59

Input gradients for image classifications

12 of 59

This kind of works!

So, we’re done, right?

13 of 59

This kind of works!

So, we’re done, right?

...what do we do if the explanations are wrong?

14 of 59

Optimizing for the right reason

15 of 59

Optimizing for the right reason

X, y

Training Examples X, y

Loss Function

θ

Model Parameters θ

θ

L(θ|X, y)

Model that predicts y given X

Traditional ML

16 of 59

Optimizing for the right reason

X, y

Training Examples X, y

Loss Function

θ

Model Parameters θ

θ

L(θ|X, y, A)

Model that predicts y given X because A

Explanation Annotations A

A

Traditional ML + explanation regularization

(Ideally, only need a small subset of X to have extra annotations)

17 of 59

Case 1: Annotations are given

18 of 59

How we encode domain knowledge

Features

Examples

Signifies that second feature of first example should be irrelevant to model’s prediction

19 of 59

Annotation Example

A = 1

A = 0

20 of 59

Our loss function

21 of 59

Our loss function

22 of 59

Our loss function

With some overall strength,

23 of 59

Our loss function

With some overall strength,

if a particular example

24 of 59

Our loss function

With some overall strength,

if a particular example’s feature

25 of 59

Our loss function

With some overall strength,

if a particular example’s feature

is marked irrelevant,

26 of 59

Our loss function

With some overall strength,

if a particular example’s feature

is marked irrelevant,

then penalize,

27 of 59

Our loss function

With some overall strength,

if a particular example’s feature

is marked irrelevant,

then penalize,

when the feature changes,

28 of 59

Our loss function

With some overall strength,

if a particular example’s feature

is marked irrelevant,

then penalize,

how much the prediction changes.

when the feature changes,

29 of 59

Experiments

Basic philosophy:

  • use or create datasets that we know can by classified with qualitatively different rules
  • see if we can use explanations to “select” which implicit rule the model learns

Models we used:

  • 2 hidden layer fully connected network, but method works for CNNs and larger models

30 of 59

Experiments: Toy Colors

Class 1

Class 2

All colors

shared

At least two different

31 of 59

Experiments: Toy Colors

Class 1

Class 2

All colors

different

At least two shared

32 of 59

Learning an otherwise-unreachable rule

By default, model appears to learn corner rule.

If we penalize corners, model discovers top-mid rule!

Pixels w/ largest magnitude gradients

33 of 59

How regularization strength affects what we learn

Smooth transition between model learning each rule!

(Can transition with 10s of annotations if we oversample in minibatches)

34 of 59

Experiments: Decoy MNIST

Swatch shades a simple function of y in train, but not in test.

35 of 59

Experiments: Decoy MNIST

Normal model has low accuracy; gradients focus on swatches

+

= increasing pixel increases predicted label prob

-

= increasing pixel decreases predicted label prob

36 of 59

Experiments: Decoy MNIST

Model with gradient regularization recovers baseline accuracy!

+

= increasing pixel increases predicted label prob

-

= increasing pixel decreases predicted label prob

37 of 59

Case 2: What if we don’t have annotations?

38 of 59

Find-another-explanation

39 of 59

Find-another-explanation

Overall goal: obtain an ensemble of models that are all accurate but for different reasons.

40 of 59

Find-another-explanation

X, y

Training examples

Overall goal: obtain an ensemble of models that are all accurate but for different reasons.

41 of 59

Find-another-explanation

X, y

Training examples

θ1

X, y

Overall goal: obtain an ensemble of models that are all accurate but for different reasons.

42 of 59

Find-another-explanation

  • Take 1st model gradients of full training set
  • Set An,d = 1 where |∂f/∂Xnd| is relatively large

X, y

Training examples

θ1

X, y

Overall goal: obtain an ensemble of models that are all accurate but for different reasons.

A1

43 of 59

Find-another-explanation

Penalize 2nd model from having large gradients where 1st model did!

  • Take 1st model gradients of full training set
  • Set An,d = 1 where |∂f/∂Xnd| is relatively large

X, y

Training examples

θ1

θ2

X, y

X, y

A1

Overall goal: obtain an ensemble of models that are all accurate but for different reasons.

A1

44 of 59

Find-another-explanation

Penalize 2nd model from having large gradients where 1st model did!

  • Take 1st model gradients of full training set
  • Set An,d = 1 where |∂f/∂Xnd| is relatively large

X, y

Training examples

θ1

θ2

X, y

X, y

A1

Overall goal: obtain an ensemble of models that are all accurate but for different reasons.

A2

A1

45 of 59

Find-another-explanation

Penalize 2nd model from having large gradients where 1st model did!

  • Take 1st model gradients of full training set
  • Set An,d = 1 where |∂f/∂Xnd| is relatively large

3rd model should be dissimilar from both 1st and 2nd!

X, y

Training examples

θ1

θ2

θ3

X, y

X, y

X, y

A1

Overall goal: obtain an ensemble of models that are all accurate but for different reasons.

A1 A2

A2

A1

46 of 59

Find-another-explanation

Penalize 2nd model from having large gradients where 1st model did!

  • Take 1st model gradients of full training set
  • Set An,d = 1 where |∂f/∂Xnd| is relatively large

3rd model should be dissimilar from both 1st and 2nd!

X, y

Training examples

θ1

θ2

θ3

X, y

X, y

X, y

A1

A1 A2

A2

Overall goal: obtain an ensemble of models that are all accurate but for different reasons.

And so on...

A1

47 of 59

Back to the picture in my head

Parameter space

Loss function

Iterate through local minima

48 of 59

Find-another-explanation: Toy Colors

Model initially learns corner rule, falls back to top-three rule, then fails to learn anything.

49 of 59

Find-another-explanation: Decoy MNIST

Models initially learn decoy rule, then use other features.

Accuracy falls, but very slowly (MNIST is redundant)

50 of 59

Summary / Contributions

For when learning from X, y alone is insufficient:

  • Introduced a novel method of injecting domain knowledge into NN training
    • Works for any differentiable model, no need to modify architecture
    • Can start using it with a small number of annotated examples
  • Demonstrated how it can be used to obtain otherwise unreachable models
    • If we have domain knowledge, we can use it to avoid fitting to spurious correlations
    • If we don’t, we can obtain a diverse ensemble of models

51 of 59

Summary / Contributions

For when learning from X, y alone is insufficient:

  • Introduced a novel method of injecting domain knowledge into NN training
    • Works for any differentiable model, no need to modify architecture
    • Can start using it with a small number of annotated examples
  • Demonstrated how it can be used to obtain otherwise unreachable models
    • If we have domain knowledge, we can use it to avoid fitting to spurious correlations
    • If we don’t, we can obtain a diverse ensemble of models

May be more common than we think!

52 of 59

Future Work

  • Human-in-the-loop
    • Interactively select the best explanations, train new models
  • Bridging features and concepts
    • E.g. for images, “concepts” are only emergent at upper layers
    • If we can identify concepts like in [Bau et al. CVPR 2017], regularize wrt concepts?
  • Explore more options for loss functions and annotations
    • Use non-binary A, L1 regularization, class specific positive/negative penalties rather than sum
  • Much bigger networks
    • Have already validated the approach for mid-size CNNs, but I’m a newbie
  • Defending against adversarial perturbations
    • Have results that setting A=1 universally builds robustness to FGSM and JSMA attacks
  • Applications to medical domain
    • Many types of medical knowledge are easily encodable as annotations

53 of 59

Future Work

  • Human-in-the-loop
    • Interactively select the best explanations, train new models
  • Bridging features and concepts
    • E.g. for images, “concepts” are only emergent at upper layers
    • If we can identify concepts like in [Bau et al. CVPR 2017], regularize wrt concepts?
  • Explore more options for loss functions and annotations
    • Use non-binary A, L1 regularization, class specific positive/negative penalties rather than sum
  • Much bigger networks
    • Have already validated the approach for mid-size CNNs, but I’m a newbie
  • Defending against adversarial perturbations
    • Have results that setting A=1 universally builds robustness to FGSM and JSMA attacks
  • Applications to medical domain
    • Many types of medical knowledge are easily encodable as annotations

These slides again: goo.gl/fMZiRu

54 of 59

55 of 59

Learning with less data?

56 of 59

Best if “right answers” term ≈ “right reasons” term

57 of 59

Gradients are consistent with LIME but less sparse

58 of 59

Setting A = 1 for all features (“certainty regularization”)

...for certainty-regularized CNN:

Sum-of-log-prob input gradients for normal CNN:

Certainty-regularized CNN also much more resistant to adversarial perturbations

Input images:

59 of 59

Information-Theoretic Interpretation of Loss Function

= K * Cross-entropy of prediction w/ uniformly random guess

“distance from total uncertainty”