1 of 56

Data Engineering (AI5308/AI4005)

Apr 6: Training Data: Data Imbalance and Data Augmentation (Ch. 4)

Sundong Kim

Course website: https://sundong.kim/courses/dataeng23sp/�Contents from CS 329S (Chip Huyen, 2022) | cs329s.stanford.edu

2 of 56

Eugene Yan’s page

2

3 of 56

Pop Quiz Results

Task: You want to build a model to classify whether a tweet spreads misinformation.

3

4 of 56

Questions 1

Suppose you receive a continuous flow of tweets with an unknown quantity, and you don’t have enough memory to store all of them. How can you sample 10 million tweets in such a way that each tweet has an equal probability of being chosen?

4

5 of 56

Questions 2

Now you have a set of 10 million tweets, and there are from 10,000 different users over a period of 24 months. However, all the tweets are unlabeled and you want to label a portion of them to train a classifier. How would you select a sample of 100,000 tweets to label?

  • # tweets/user follows a long-tail distribution
  • You estimate 1% of tweets are misinformation

5

6 of 56

Questions 3

You have 100K labels from 20 annotators and need to estimate their quality. What is the appropriate number of labels to examine, and how should they be sampled?

6

7 of 56

Ch. 4: Training Data

7

  • Sampling
  • Labeling
  • Class Imbalance
  • Data Augmentation

8 of 56

Class Imbalance

8

9 of 56

Class imbalance is the norm

  • Fraud detection
  • Spam detection
  • Disease screening
  • Churn prediction
  • Resume screening
    • E.g. 2% of resumes pass screening
  • Object detection
    • Most bounding boxes don’t contain any object

9

People are more interested in unusual/potentially catastrophic events

Image from PyImageSearch

10 of 56

Class Imbalance

10

11 of 56

Why is class imbalance hard?

  • Not enough signal to learn about rare classes

11

12 of 56

Why is class imbalance hard?

  • Not enough signal to learn about rare classes
  • Statistically, predicting majority label has higher chance of being right
    • If a majority class accounts 99% of data, always predicting it gives 99% accuracy

12

13 of 56

Why is class imbalance hard?

  • Not enough signal to learn about rare classes
  • Statistically, predicting majority label has higher chance of being right
  • Asymmetric cost of errors: different cost of wrong predictions

13

14 of 56

Why is class imbalance hard?

  • Not enough signal to learn about rare classes
  • Statistically, predicting majority label has higher chance of being right
  • Asymmetric cost of errors: different cost of wrong predictions

14

15 of 56

Asymmetric cost of errors: regression

  • 95th percentile: $10K
  • Median: $250

15

Thanks Eugene Yan for this example!

16 of 56

Asymmetric cost of errors: regression

100% error difference

  • $10K bill: off by $10K
  • $250 bill: off by $250

16

OK

Not OK

Thanks Eugene Yan for this example!

17 of 56

How to deal with class imbalance

  1. Choose the right metrics
  2. Data-level methods
  3. Algorithm-level methods

17

18 of 56

  1. Choose the right metrics

Model A vs. Model B confusion matrices

18

Model A

Actual CANCER

Actual NORMAL

Predicted CANCER

10

10

Predicted NORMAL

90

890

Model B

Actual CANCER

Actual NORMAL

Predicted CANCER

90

90

Predicted NORMAL

10

810

Poll:

Which model would you choose?

19 of 56

Choose the right metrics

Model A vs. Model B confusion matrices

19

Model A

Actual CANCER

Actual NORMAL

Predicted CANCER

10

10

Predicted NORMAL

90

890

Model B

Actual CANCER

Actual NORMAL

Predicted CANCER

90

90

Predicted NORMAL

10

810

Model B has a better chance of telling if you have cancer

Both have the same accuracy: 90%

20 of 56

Symmetric metrics vs. asymmetric metrics

  • TP: True positives
  • TN: True negatives
  • FP: False positives
  • FN: False negatives

20

Symmetric metrics

Asymmetric metrics

Treat all classes the same

Measures a model’s performance w.r.t to a class

Accuracy

F1, recall, precision, ROC

21 of 56

Class imbalance: asymmetric metrics

  • Your model’s performance w.r.t to a class

21

CANCER (1)

NORMAL (0)

Accuracy

Precision

Recall

F1

Model A

10/100

890/900

0.9

0.5

0.1

0.17

Model B

90/100

810/900

0.9

0.5

0.9

0.64

Model A

Actual CANCER

Actual NORMAL

Predicted CANCER

10

10

Predicted NORMAL

90

890

Model B

Actual CANCER

Actual NORMAL

Predicted CANCER

90

90

Predicted NORMAL

10

810

22 of 56

Class imbalance: asymmetric metrics

  • Your model’s performance w.r.t to a class

22

CANCER (1)

NORMAL (0)

Accuracy

Precision

Recall

F1

Model A

10/100

890/900

0.9

0.5

0.1

0.17

Model B

90/100

810/900

0.9

0.5

0.9

0.64

⚠ F1 score for CANCER as 1 is different from F1 score for NORMAL as 1 ⚠

23 of 56

2. Data-level methods: Resampling

23

Undersampling

Oversampling

Remove samples from the majority class

Add more examples to the minority class

https://www.kaggle.com/rafjaa/resampling-strategies-for-imbalanced-datasets#t1

24 of 56

2. Data-level methods: Resampling

24

Undersampling

Oversampling

Remove samples from the majority class

Add more examples to the minority class

Can cause overfitting

Can cause loss of information

https://www.kaggle.com/rafjaa/resampling-strategies-for-imbalanced-datasets#t1

25 of 56

Undersampling: Tomek Links

  • Find pairs of close samples of opposite classes
  • Remove the sample of majority class in each pair
    • Pros: Make decision boundary more clear
    • Cons: Make model less robust

25

Image from https://www.kaggle.com/rafjaa/resampling-strategies-for-imbalanced-datasets

26 of 56

Oversampling: SMOTE

  • Synthesize samples of minority class as convex (linear) combinations of existing points and their nearest neighbors of same class.

26

Image from Analytics Vidhya

27 of 56

Oversampling: SMOTE

  • Synthesize samples of minority class as convex (linear) combinations of existing points and their nearest neighbors of same class.

27

Image from Analytics Vidhya

Both SMOTE and Tomek links only work on low-dimensional data!

28 of 56

3. Algorithm-level methods

  • Naive loss: all samples contribute equally to the loss
  • Idea: training samples we care about should contribute more to the loss

28

29 of 56

3. Algorithm-level methods

  • Cost-sensitive learning
  • Class-balanced loss
  • Focal loss

29

30 of 56

Cost-sensitive learning

  • Cij: the cost if class i is classified as class j

  • The loss caused by instance x of class i will become the weighted average of all possible classifications of instance x.

30

31 of 56

Class-balance loss

  • Give more weight to rare classes

31

Non-weighted loss

Weighted loss

model.fit(features, labels, epochs=10, batch_size=32, class_weight={“fraud”: 0.9, “normal”: 0.1})

32 of 56

Focal loss

  • Give more weight to difficult samples:
    • downweighs well-classified samples

32

33 of 56

Focal loss

33

34 of 56

1. Data Augmentation

34

“Data augmentation is the new feature engineering”

- Josh Wills, prev Director of Data Engineering @ Slack

35 of 56

Data augmentation: Goals

  • Improve model’s performance overall or on certain classes
  • Generalize better
  • Enforce certain behaviors

35

36 of 56

Data augmentation

  1. Simple label-preserving transformation
  2. Perturbation
  3. Data synthesis

36

37 of 56

Label-preserving:�Computer Vision

Random cropping, flipping, erasing, etc.

37

38 of 56

Label-preserving: NLP

38

Original sentences

I’m so happy to see you.

Generated sentences

I’m so glad to see you.

I’m so happy to see y’all.

I’m very happy to see you.

39 of 56

Perturbation: Neural networks can be sensitive to noise

  • 67.97% Kaggle CIFAR-10 test images
  • 16.04% ImageNet test images

can be misclassified by changing just one pixel�(Su et al., 2017)

39

40 of 56

Perturbation:�Computer Vision

  • Random noise
  • Search strategy
    • DeepFool (Moosavi-Dezfooli et al., 2016): find the minimal noise injection needed to cause a misclassification with high confidence.

40

Whale

Turtlenoise by DeepFool

Turtlenoise by fast gradient sign

41 of 56

Perturbation: NLP

  • Random replacement
    • e.g. BERT (10% * 15% = 1.5%)

41

42 of 56

Data Synthesis: NLP

  • Template-based
    • Very common in conversational AI
  • Language model-based

42

Template

Find me a [CUISINE] restaurant within [NUMBER] miles of [LOCATION].

Generated queries

  • Find me a Vietnamese restaurant within 2 miles of my office.
  • Find me a Thai restaurant within 5 miles of my home.
  • Find me a Mexican restaurant within 3 miles of Google headquarters.

43 of 56

Data Synthesis: Computer Vision

  • Mixup
    • Create convex combination of samples of different classes
      • Labels: cat [3], dog [4]
      • Mixup: 30% dog, 70% cat [0.3 * 3 + 0.7 * 4 = 3.7]

43

44 of 56

Data Synthesis: Computer Vision

  • Mixup
    • Incentivize models to learn linear relationships
    • Improves generalization on speech and tabular data
    • Can be used to stabilize the training of GANs

44

https://forums.fast.ai/t/mixup-data-augmentation/22764

45 of 56

Data Augmentation: GAN

Example: kidney segmentation with�data augmentation by CycleGAN

45

46 of 56

Data Augmentation: Tabular Dataset

46

  1. Limited data: Tabular datasets often contain a limited number of examples, making it difficult to perform data augmentation without introducing too much noise.

  • High dimensionality: Tabular data can have a high number of features, and the interaction between these features can be complex. As a result, it can be challenging to find meaningful ways to augment the data without introducing unrealistic or irrelevant information.

  • Categorical features: Tabular data often contains categorical features, such as gender or occupation, which cannot be easily transformed into numerical values for augmentation.

47 of 56

Data Augmentation: Tabular Dataset

47

48 of 56

48

49 of 56

49

50 of 56

Customs Import Declaration Datasets �→ CTGAN + Maintaining Correlations �

50

51 of 56

WCO BACUDA Conference 2022

51

52 of 56

Group of 4, 20 minutes

  • Imagine you’re building a ML model to imitate how human reason through a intelligence test problem. What augmentation strategy would you use?

52

53 of 56

53

54 of 56

54

A survey on Image Data Augmentation for Deep Learning (Connor Shorten & Taghi M. Khoshgoftaar, 2019)

55 of 56

MLOps at Naver Shopping (Apr 6, 16:00, S7, 2F)

Video: will be uploaded�Slides: will be uploaded

56 of 56

Data Engineering

Next class: Feature Engineering (Ch.5)