Dealing with imbalanced datasets

Real world datasets are often imbalanced - some of the classes appear much more often in your data than others.
The problem? You ML model will likely learn to only predict the dominant classes.
What can you do about it?
Thread



Real world datasets are often imbalanced - some of the classes appear much more often in your data than others.
The problem? You ML model will likely learn to only predict the dominant classes.
What can you do about it?

Thread

Example 
We will be dealing with a ML model to detect traffic lights for a self-driving car

Traffic lights are small so you will have much more parts of the image that are not traffic lights.
Furthermore, yellow lights
are much rarer than green
or red
.

We will be dealing with a ML model to detect traffic lights for a self-driving car


Traffic lights are small so you will have much more parts of the image that are not traffic lights.
Furthermore, yellow lights



The problem 
Imagine we train a model to classify the color of the traffic light. A typical distribution will be:
- 56%
- 3%
- 41%
So, your model can get to 97% accuracy just by learning to distinguish red from green.
How can we deal with this?

Imagine we train a model to classify the color of the traffic light. A typical distribution will be:



So, your model can get to 97% accuracy just by learning to distinguish red from green.
How can we deal with this?

Evaluation measures 
First, you need to start using a different evaluation measure than accuracy:
- Precision per class
- Recall per class
- F1 score per class
I also like to look at the confusion matrix to get an overview. Always look at examples from the data as well!

First, you need to start using a different evaluation measure than accuracy:
- Precision per class
- Recall per class
- F1 score per class
I also like to look at the confusion matrix to get an overview. Always look at examples from the data as well!
In the traffic lights example above, we will see very poor recall for
(most real examples were not recognized), while precision will likely be high.
At the same time, the precision of
and
will be lower (
will be classified as
or
).

At the same time, the precision of





Get more data 
The best thing you can do is to collect more data of the underrepresented classes. This may be hard or even impossible...
You can imagine ways to record more yellow lights, but imagine you want to detect a very rare disease in CT images?

The best thing you can do is to collect more data of the underrepresented classes. This may be hard or even impossible...
You can imagine ways to record more yellow lights, but imagine you want to detect a very rare disease in CT images?
Balance your data 
The idea is to resample your dataset so it is better balanced.
Undersampling - throw away some examples of the dominant classes
Oversampling - get more samples of the underrepresented class

The idea is to resample your dataset so it is better balanced.


Undersampling 
The easiest way is to just randomly throw away samples from the dominant class.
Even better, you can use some unsupervised clustering method and throw out only samples from the big clusters.
The problem of course is that you are throwing out valuable data...

The easiest way is to just randomly throw away samples from the dominant class.
Even better, you can use some unsupervised clustering method and throw out only samples from the big clusters.
The problem of course is that you are throwing out valuable data...
Oversampling 
This is more difficult. You can just repeat sample, but it won't work very good.
You can use methods like SMOTE (Synthetic Minority Oversampling Technique) to generate new samples interpolating between existing ones. This may not be easy for complex images.

This is more difficult. You can just repeat sample, but it won't work very good.
You can use methods like SMOTE (Synthetic Minority Oversampling Technique) to generate new samples interpolating between existing ones. This may not be easy for complex images.
Oversampling 
If you are dealing with images, you can use data augmentation techniques to create new samples by modifying the existing ones (rotation, flipping, skewing, color filters...)
You can also use GANs or simulation the synthesize completely new images.

If you are dealing with images, you can use data augmentation techniques to create new samples by modifying the existing ones (rotation, flipping, skewing, color filters...)
You can also use GANs or simulation the synthesize completely new images.
Adapting your loss 
Another strategy is to modify your loss function to penalize misclassification of the underrepresented classes more than the dominant ones.
In the
example we can set them like this (proportionally to the distribution)
- 1.8
- 33.3
- 2.4

Another strategy is to modify your loss function to penalize misclassification of the underrepresented classes more than the dominant ones.
In the




If you are training a neural network with TensorFlow or PyTorch you can do this very easily:
TensorFlow - use the class_weights parameter in the fit() function ( https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model)
PyTorch - use the weight parameter in the CrossEntropyLoss ( https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)


Summary 
In practice, you will likely need to combine all of the strategies above to achieve good performance.
Look at different evaluation metrics and start playing with the parameters to find a good balance (pun intended)

In practice, you will likely need to combine all of the strategies above to achieve good performance.
Look at different evaluation metrics and start playing with the parameters to find a good balance (pun intended)
