1 of 62

FIGS: Fast interpretable greedy-tree sums

Tan*, Singh*, Nasseri, Agarwal, & Yu

arXiv, submitted to ICML

2 of 62

3 of 62

X1> 0

X2> 0

X3> 0

X2> 0

X3> 0

X1> 0

+

X2> 0

X3> 0

4 of 62

+

Number of trees?

Depth of each tree?

5 of 62

+

Random Forest & Gradient Boosting

Tree size / ensemble size is not adaptive

+

+

+...

6 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

7 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

8 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

9 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

+

10 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

+

Each split fits the residuals of the other trees

11 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

+

12 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

+

+

13 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

+

+

14 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

+

+

15 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

+

+

16 of 62

CART algorithm: Step-by-step

More precisely….

Node t

Left child tL

Right child tR

X1> b

Split s

Impurity decrease

Generative model

17 of 62

CART algorithm: Step-by-step

Node t

Left child tL

Right child tR

X1> b

Split s

Impurity decrease

Generative model

18 of 62

CART algorithm: Step-by-step

Node t

Left child tL

Right child tR

X1> b

Split s

Impurity decrease

Generative model

19 of 62

CART algorithm: Step-by-step

Impurity decrease

Generative model

20 of 62

CART algorithm: Step-by-step

Impurity decrease

Generative model

21 of 62

CART algorithm: Step-by-step

Impurity decrease

Generative model

22 of 62

FIGS algorithm: Step-by-step

Generative model

23 of 62

FIGS algorithm: Step-by-step

Generative model

24 of 62

FIGS algorithm: Step-by-step

Generative model

25 of 62

FIGS algorithm: Step-by-step

Generative model

26 of 62

FIGS algorithm: Step-by-step

Compute the impurity decreases of this split in terms of the residuals of the other trees

Predict average of residuals over new leaves

How to choose splits?

Generative model

27 of 62

FIGS algorithm: Runtime analysis

Worst-case runtime: O(n2r2d)

n = no. of samples

d = no. of features

r = no. of splits made

CART runtime: O(n2rd)

28 of 62

Real-data experiments

Improved prediction performance

Disentangling additive components

29 of 62

Does FIGS predict well with few splits on real-world data sets?

6

9

30 of 62

Classification results: AU-ROC vs no. of splits

AUC

31 of 62

Classification results: AU-ROC vs no. of splits

32 of 62

Regression results

33 of 62

CART

AUC=0.817

Pima Native American Diabetes classification dataset (n=768, p=8)

Target = Onset of diabetes within five years

FIGS

AUC=0.820

34 of 62

Pima Native American Diabetes classification dataset (n=768, p=8)

Target = Onset of diabetes within five years

FIGS

AUC=0.820

Glucose = 150

BMI = 30

Age = 23

Predicted risk

= 0.22 + 0.26

= 0.48

35 of 62

Theoretical results

(with oracle assumptions)

Generalization upper bound

Disentangling additive components

36 of 62

Generative model for theory and simulations

Generative model

37 of 62

Generative model for theory and simulations

Generative model

indp

I1

I2

Is

38 of 62

Generative model for theory and simulations

Generative model

indp

I1

I2

Is

39 of 62

Defining the ERM tree-sum estimator

Generative model

Given:

+

+

-

=

argmin

2

Trees

,

,

Data

ERM tree- sum

indp

40 of 62

Oracle generalization upper bound

Theorem 2 (TSNAY, Oracle generalization): Suppose

for k=1,...,s. There is a collection of trees such that the ERM tree-sum satisfies

Generative model

indp

41 of 62

Oracle generalization upper bound

Theorem 2 (TSNAY, Oracle generalization): Suppose

for k=1,...,s. There is a collection of trees such that the ERM tree-sum satisfies

Generative model

indp

Corollary (Oracle generalization for additive models): Suppose further that |Ik| = 1 for each k. Then

indp

42 of 62

Comparing to other rates

Corollary (Oracle generalization for additive models): Suppose further that |Ik| = 1 for each k. Then

Tree lower bound from Part 1: Suppose |φj’(t)| ≥ β0 for j = 1,...,s. Then for any tree T,

Theorem (Raskutti et al., 12)*: The minimax rate for C1 additive models is . This is achievable by penalized kernel methods.

43 of 62

MSE for CART, RF, FIGS on a sparse linear generative model (d=50, s=10)

44 of 62

FIGS can disentangle polynomial interactions

Observed at n=2500

Ground truth

For a given feature, count number of times it gets split upon in each tree.

Compute cosine similarities between different count vectors.

Theory: Provable disentanglement if FIGS makes use of population impurity decrease and average values (large sample limit).

45 of 62

MSE for CART, RF, XGB, GAM, FIGS on a linear model (left), sum of poly interactions (right)

46 of 62

600+

24k+

47 of 62

Appendix

48 of 62

49 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

50 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

51 of 62

FIGS

Trees compete with each other to predict the outcome

Unexplained variance

52 of 62

A decision tree is a piecewise constant model obtained from recursive partitioning of the covariate space

0

1

1

X2

X1

53 of 62

A decision tree is a piecewise constant model obtained from recursive partitioning of the covariate space

a

0

1

1

X2> a

X2

X1

54 of 62

A decision tree is a piecewise constant model obtained from recursive partitioning of the covariate space

a

b

0

1

1

X1> b

X2> a

X2

X1

55 of 62

Trees are “inefficient” at expressing additive structures

Idea: Grow multiple trees…

X1> 0

X2> 0

X3> 0

X2> 0

X3> 0

X1> 0

+

X2> 0

X3> 0

56 of 62

We want to

Overcome CART’s limitation

Able to disentangle additive components of a model (single tree for each component).

Able to do this without knowing which features correspond to each component.

Preserve CART’s strengths

Able to adapt to higher order interactions if present.

Able to do this without knowing the true order of the interaction, or the optimal tree depth.

Able to easily constrain model complexity (total number of splits).

57 of 62

By performing disentanglement while preserving adaptivity, we hope to have

Improved prediction performance.

Fewer splits, hence a simpler model.

Fewer false discoveries for interactions.

58 of 62

Existing tree-sum methods not able to fulfill these simultaneously

Random forests

Each tree grown in exactly the same way, tend to split on all features.

Share’s CART weakness (lack of disentanglement).

Gradient boosting

Needs each tree to have pre-specified depth.

Unable to adapt to “true” order of interaction.

Loses CART’s strengths (adaptivity).

Both methods not designed for their complexity to be constrained.

59 of 62

More precisely….

Fast Interpretable Greedy-Tree Sums (FIGS)

Modifies CART split rule to grow a flexible number of trees simultaneously.

60 of 62

FIGS algorithm: Step-by-step

Generative model

Compute the impurity decreases of this split in terms of the residuals of the the other trees

Predict average of

over new leaves

How to choose splits?

61 of 62

Decision tree algorithms

Greedy tree methods for regression & classification

AID [Morgan & Sonquist, 63]

THAID [Messenger & Mandell, 73]

CHAID [Kass, 80]

CART [Breiman et al., 84]

ID3 [Quinlan, 86]

C4.5 [Quinlan, 93]

Dealing with missing data, de-biasing splits, etc.

GUIDE [Loh, 09]

CRUISE [Kim & Loh, 01]

Algorithms differ in their search strategy for splits

62 of 62

Decision tree algorithms

Greedy tree methods for regression & classification

Greedy trees for other problems

Globally optimal trees

AID [Morgan & Sonquist, 63]

CHAID [Kass, 80]

CART [Breiman et al., 84]

ID3 [Quinlan, 86]

C4.5 [Quinlan, 93]

THAID [Messenger & Mandell, 73]

Dealing with missing data, de-biasing splits, etc.

Ensemble methods

Random forests [Breiman, 01]

Gradient boosting [Friedman, 01]

BART [Chipman et al., 10]

Quantile regression [Meinshausen, 06]

Survival analysis [Ishwaran, 08]

Ranking [Clemençon, 13]

Causal inference [Athey & Imbens, 16]

GUIDE [Loh, 09]

CRUISE [Kim & Loh, 01]

[Bennett, 94]

GOSDT (dynamic programming) [Lin et al., 20]

Mixed integer optimization [Aghaei et al. 21]