1 of 29

Get the best from your scikit-learn classifier

EuroSciPy 2023

Guillaume Lemaitre - August 16, 2023

2 of 29

About me

Research Engineer

@glemaitre

@glemaitre@fosstodon.org

2017

2019

2023

3 of 29

Problem statement

Imbalanced classification

The number of cancer voxels is much smaller than the number of healthy voxels.

20:1

4 of 29

Problem statement

Imbalanced classification

5 of 29

from imblearn.pipeline import make_pipeline

from imblearn.over_sampling import SMOTE

from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import cross_validate

model = make_pipeline(SMOTE(), LogisticRegression())

cv_results = cross_validate(model, X, y)

Problem statement

Imbalanced classification

6 of 29

~6 years…

7 of 29

Learning from imbalanced data:

I was wrong but I was not the only one

EuroSciPy 2023

Guillaume Lemaitre - August 16, 2023

8 of 29

Strong claim

There is no problem learning from imbalanced data

“No resampling technique will magically generate more information out of the few cases with the rare class” [1]

9 of 29

A “typical” use-case

Adult census dataset

<class 'pandas.core.frame.DataFrame'>

Index: 38393 entries, 0 to 6008

Data columns (total 12 columns):

# Column Non-Null Count Dtype

--- ------ -------------- -----

0 age 38393 non-null int64

1 workclass 35821 non-null category

2 education 38393 non-null category

3 marital-status 38393 non-null category

4 occupation 35811 non-null category

5 relationship 38393 non-null category

6 race 38393 non-null category

7 sex 38393 non-null category

8 capital-gain 38393 non-null int64

9 capital-loss 38393 non-null int64

10 hours-per-week 38393 non-null int64

11 native-country 37731 non-null category

dtypes: category(8), int64(4)

memory usage: 1.8 MB

from sklearn.datasets import fetch_openml

data, target = fetch_openml(

"Adult", as_frame=True, return_X_y=True

)

target

<=50K 37155

>50K 1238

Name: count, dtype: int64

10 of 29

A “typical” use-case

Experimental setup

from sklearn.model_selection import cross_validate

cv_results = cross_validate(model, data, target, scoring="balanced_accuracy")

Vanilla Random Forest

Vanilla Logistic Regression

11 of 29

A “typical” use-case

Results of vanilla strategy

12 of 29

A “typical” use-case

“Over”-fighting the imbalance

SMOTE Logistic Regression

Balanced Random Forest

13 of 29

A “typical” use-case

“Over”-fighting the imbalance

Synthetic Minority Oversampling TEchnique

(SMOTE)

Balanced Random Forest

Each tree in the forest is given a dataset derived from the original one by:

  • Taking a bootstrap from the minority class
  • Drawing at random with replacement from the majority class

14 of 29

A “typical” use-case

“Over”-fighting the imbalance

15 of 29

Resampling for fighting class imbalance

The potential caveats

  • Is the resampling strategy optimal for the metric of interest?
  • How the classifiers are affected by the resampling?

16 of 29

Resampling breaks calibration!

Resampling is used to alleviate the inflexibility of decision threshold (0.5 by default) but it renders the interpretation of the values returned by model.predict_proba meaningless!

17 of 29

A new scikit-learn estimator [1]

Post-tuning the estimator decision threshold given a metric

from sklearn.model_selection import TunedThresholdClassifier

tuned_model = TunedThresholdClassifier(

estimator=model, objective_metric="balanced_accuracy"

)

Post-tuned threshold

Logistic Regression

18 of 29

A new scikit-learn estimator [1]

Post-tuning the estimator decision threshold given a metric

19 of 29

A new scikit-learn estimator [1]

Post-tuning the estimator decision threshold given a metric

Post-tuned threshold

Logistic Regression

Post-tuned threshold

Random Forest

20 of 29

A new scikit-learn estimator [1]

Post-tuning the estimator decision threshold given a metric

21 of 29

Tune your hyperparameters: which metric?

Unthresholded, probabilistic metric

computed on

model.predict_proba(X_test)

Thresholded metric

computed on

model.predict(X_test)

vs.

22 of 29

Tuning hyperparameters on the probabilistic metric

23 of 29

Effect of hyper-parameter tuning on final metric

Before tuning

24 of 29

Effect of hyper-parameter tuning on final metric

After tuning

25 of 29

Optimum decision threshold

Cost-sensitive learning [1]

Fraudulent

Legitimate

Refused

50€ + amount

-5€

Accepted

-amount

0.02 Ă— amount

Credit card frauds example

credit_card = fetch_openml(data_id=1597, as_frame=True)

columns_to_drop = ["Class", "Amount"]

data = credit_card.frame.drop(columns=columns_to_drop)

target = credit_card.frame["Class"].astype(int)

amount = credit_card.frame["Amount"].to_numpy()

26 of 29

Cost-sensitive learning

Metadata routing (SLEP006 [1])

def business_metric(y_true, y_pred, amount):

mask_tp, mask_tn = (y_true == 1) & (y_pred == 1), (y_true == 0) & (y_pred == 0)

mask_fp, mask_fn = (y_true == 0) & (y_pred == 1), (y_true == 1) & (y_pred == 0)

fraudulent_refuse = (mask_tp.sum() * 50) + amount[mask_tp].sum()

fraudulent_accept = -amount[mask_fn].sum()

legitimate_refuse = mask_fp.sum() * -5

legitimate_accept = (amount[mask_tn] * 0.02).sum()

return fraudulent_refuse + fraudulent_accept + legitimate_refuse + legitimate_accept

sklearn.set_config(enable_metadata_routing=True)

business_scorer = make_scorer(business_metric).set_score_request(amount=True)

business_scorer(model, data, target, amount=amount)

27 of 29

tuned_model = TunedThresholdClassifier(

estimator=model, objective_metric=business_scorer

)

tuned_model.fit(data_train, target_train, amount=amount_train)

business_score = business_metric(

tuned_model, data_test, target_test, amount=amount_test

)

Cost-sensitive learning

Metadata routing (SLEP006 [1])

28 of 29

cv_results_tuned_model = cross_validate(

model, data, target, params={"amount": amount}, scoring=business_scorer

)

Cost-sensitive learning

Metadata routing (SLEP006 [1])

29 of 29

Conclusion

Take-away

  • Stop resampling to fight class imbalance
  • Tune the hyperparameters of your predictive model using a unthresholded probabilistic metric
  • Tune the decision threshold of your predictive model using the optimal decision-making metric
  • Always cross-validate your predictive model