FIGS: Fast interpretable greedy-tree sums
Tan*, Singh*, Nasseri, Agarwal, & Yu
arXiv, submitted to ICML
X1> 0
X2> 0
X3> 0
X2> 0
X3> 0
X1> 0
+
X2> 0
X3> 0
+
Number of trees?
Depth of each tree?
+
Random Forest & Gradient Boosting
Tree size / ensemble size is not adaptive
+
+
+...
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
+
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
+
Each split fits the residuals of the other trees
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
+
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
+
+
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
+
+
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
+
+
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
+
+
CART algorithm: Step-by-step
More precisely….
Node t
Left child tL
Right child tR
X1> b
Split s
Impurity decrease
Generative model
CART algorithm: Step-by-step
Node t
Left child tL
Right child tR
X1> b
Split s
Impurity decrease
Generative model
CART algorithm: Step-by-step
Node t
Left child tL
Right child tR
X1> b
Split s
Impurity decrease
Generative model
CART algorithm: Step-by-step
Impurity decrease
Generative model
CART algorithm: Step-by-step
Impurity decrease
Generative model
CART algorithm: Step-by-step
Impurity decrease
Generative model
FIGS algorithm: Step-by-step
Generative model
FIGS algorithm: Step-by-step
Generative model
FIGS algorithm: Step-by-step
Generative model
FIGS algorithm: Step-by-step
Generative model
FIGS algorithm: Step-by-step
Compute the impurity decreases of this split in terms of the residuals of the other trees
Predict average of residuals over new leaves
How to choose splits?
Generative model
FIGS algorithm: Runtime analysis
Worst-case runtime: O(n2r2d)
n = no. of samples
d = no. of features
r = no. of splits made
CART runtime: O(n2rd)
Real-data experiments
Improved prediction performance
Disentangling additive components
Does FIGS predict well with few splits on real-world data sets?
6
9
Classification results: AU-ROC vs no. of splits
AUC
Classification results: AU-ROC vs no. of splits
Regression results
CART
AUC=0.817
Pima Native American Diabetes classification dataset (n=768, p=8)
Target = Onset of diabetes within five years
FIGS
AUC=0.820
Pima Native American Diabetes classification dataset (n=768, p=8)
Target = Onset of diabetes within five years
FIGS
AUC=0.820
Glucose = 150
BMI = 30
Age = 23
Predicted risk
= 0.22 + 0.26
= 0.48
Theoretical results
(with oracle assumptions)
Generalization upper bound
Disentangling additive components
Generative model for theory and simulations
Generative model
Generative model for theory and simulations
Generative model
indp
I1
I2
Is
Generative model for theory and simulations
Generative model
indp
I1
I2
Is
Defining the ERM tree-sum estimator
Generative model
Given:
+
+
-
=
argmin
2
Trees
,
,
Data
ERM tree- sum
indp
Oracle generalization upper bound
Theorem 2 (TSNAY, Oracle generalization): Suppose
for k=1,...,s. There is a collection of trees such that the ERM tree-sum satisfies
Generative model
indp
Oracle generalization upper bound
Theorem 2 (TSNAY, Oracle generalization): Suppose
for k=1,...,s. There is a collection of trees such that the ERM tree-sum satisfies
Generative model
indp
Corollary (Oracle generalization for additive models): Suppose further that |Ik| = 1 for each k. Then
indp
Comparing to other rates
Corollary (Oracle generalization for additive models): Suppose further that |Ik| = 1 for each k. Then
Tree lower bound from Part 1: Suppose |φj’(t)| ≥ β0 for j = 1,...,s. Then for any tree T,
Theorem (Raskutti et al., 12)*: The minimax rate for C1 additive models is . This is achievable by penalized kernel methods.
MSE for CART, RF, FIGS on a sparse linear generative model (d=50, s=10)
FIGS can disentangle polynomial interactions
Observed at n=2500
Ground truth
For a given feature, count number of times it gets split upon in each tree.
Compute cosine similarities between different count vectors.
Theory: Provable disentanglement if FIGS makes use of population impurity decrease and average values (large sample limit).
MSE for CART, RF, XGB, GAM, FIGS on a linear model (left), sum of poly interactions (right)
⭐ 600+
⬇ 24k+
Appendix
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
FIGS
Trees compete with each other to predict the outcome
Unexplained variance
A decision tree is a piecewise constant model obtained from recursive partitioning of the covariate space
0
1
1
X2
X1
A decision tree is a piecewise constant model obtained from recursive partitioning of the covariate space
a
0
1
1
X2> a
X2
X1
A decision tree is a piecewise constant model obtained from recursive partitioning of the covariate space
a
b
0
1
1
X1> b
X2> a
X2
X1
Trees are “inefficient” at expressing additive structures
Idea: Grow multiple trees…
X1> 0
X2> 0
X3> 0
X2> 0
X3> 0
X1> 0
+
X2> 0
X3> 0
We want to
Overcome CART’s limitation
Able to disentangle additive components of a model (single tree for each component).
Able to do this without knowing which features correspond to each component.
Preserve CART’s strengths
Able to adapt to higher order interactions if present.
Able to do this without knowing the true order of the interaction, or the optimal tree depth.
Able to easily constrain model complexity (total number of splits).
By performing disentanglement while preserving adaptivity, we hope to have
Improved prediction performance.
Fewer splits, hence a simpler model.
Fewer false discoveries for interactions.
Existing tree-sum methods not able to fulfill these simultaneously
Random forests
Each tree grown in exactly the same way, tend to split on all features.
Share’s CART weakness (lack of disentanglement).
Gradient boosting
Needs each tree to have pre-specified depth.
Unable to adapt to “true” order of interaction.
Loses CART’s strengths (adaptivity).
Both methods not designed for their complexity to be constrained.
More precisely….
Fast Interpretable Greedy-Tree Sums (FIGS)
Modifies CART split rule to grow a flexible number of trees simultaneously.
FIGS algorithm: Step-by-step
Generative model
Compute the impurity decreases of this split in terms of the residuals of the the other trees
Predict average of
over new leaves
How to choose splits?
Decision tree algorithms
Greedy tree methods for regression & classification
AID [Morgan & Sonquist, 63]
THAID [Messenger & Mandell, 73]
CHAID [Kass, 80]
CART [Breiman et al., 84]
ID3 [Quinlan, 86]
C4.5 [Quinlan, 93]
Dealing with missing data, de-biasing splits, etc.
GUIDE [Loh, 09]
CRUISE [Kim & Loh, 01]
Algorithms differ in their search strategy for splits
Decision tree algorithms
Greedy tree methods for regression & classification
Greedy trees for other problems
Globally optimal trees
AID [Morgan & Sonquist, 63]
CHAID [Kass, 80]
CART [Breiman et al., 84]
ID3 [Quinlan, 86]
C4.5 [Quinlan, 93]
THAID [Messenger & Mandell, 73]
Dealing with missing data, de-biasing splits, etc.
Ensemble methods
Random forests [Breiman, 01]
Gradient boosting [Friedman, 01]
BART [Chipman et al., 10]
Quantile regression [Meinshausen, 06]
Survival analysis [Ishwaran, 08]
Ranking [Clemençon, 13]
Causal inference [Athey & Imbens, 16]
GUIDE [Loh, 09]
CRUISE [Kim & Loh, 01]
[Bennett, 94]
GOSDT (dynamic programming) [Lin et al., 20]
Mixed integer optimization [Aghaei et al. 21]