Producing a valuable classification model from a heavily imbalanced training dataset can be challenging in machine learning. In this article, I will present a simple but effective method ML engineers can use to produce an effective classification model from an imbalanced dataset. I will also outline some of the most common mistakes ML engineers make while handling heavily imbalanced datasets.
Machine learning engineers often encounter problems requiring classification of an instance into one of two categories based on available features. Common use cases include:
These three examples have, in common, a positive class (spam, fraudulent, or tumor present) that is often heavily underrepresented in the population. For example, much less than one percent of users of an application will ever attempt fraudulent behavior, with the vast majority engaging in normal day-to-day usage.
Example — Imbalanced Dataset Use Case at Weebly
Scenario: Imagine you’re a newly hired machine learning engineer at a company with full access to the training dataset provided below. Your first task is to develop a model that identifies the probability a user will pay as accurately as possible using the two action-based features.
Primary Goal: Predict the probability of recently created freemium accounts to become subscribers to one of Weebly’s paid services in the near future.
Example Dataset: Let’s assume you’ve completed the quantitative analysis of the example dataset, and learn that less than 2% of Weebly’s users are not on paid plans. Figure 1 (below) shows what such an imbalanced dataset might look like. In this figure, paid users as shown as green markers, while users who remain with the free service are shown in red.
Each user, i.e., each data point, is embedded into a two-dimensional plane spanned from two potentially important features used to predict probability of future conversion to a paid service:
Only 1%, or 500 out of the 50,000 data points, are positive, resulting in a heavily imbalanced dataset.
NOTE: the model Weebly actually build takes into account a much more complicated structure of user properties and actual data. The data shown in Figure 1 is an artificial projection for illustration purposes only. We will keep using this simple “minimum projection” in the following sections for simplicity.
The Accuracy Trap
If you’re relatively new to machine learning, you may have only worked with well-balanced datasets in the past, and your first instinct might be to jump right in and throw a standard classifier at the training data without making any modifications.
When you train a random forest classifier on the data in Figure 1 and evaluate its performance on a test set, you find that the model is 99% accurate at predicting whether or not a user will convert. You’re pretty excited at this point, so you show some consumers of the model immediately. A bit later, these model consumers ask you to provide a confusion matrix to help them understand the types of errors your model makes. When you print the matrix out, you see the following:
You realize that the model never predicts a user will convert within the entire test set of nearly 20,000 users. Instead, it takes every user, regardless of that user’s total number of actions taken and time to their last activity, and it lumps them in the ‘Not Converted’ category. Why did this happen? Doesn’t Figure 1 clearly show differences in the distributions of these two features among converters and non-converters?
The behavior of the model in the above scenario is a direct result of the nature of the model’s underlying algorithm. Standard machine learning binary classifiers usually work best when there is a balance between the number of positive and negative instances in the training set. When training occurs, the model tweaks internal parameters or generates decision boundaries in order to optimize the overall predictive accuracy on the training set. Because each instance contributes to the model’s accuracy equally, and because negative instances far outweigh the positive in number, the model finds it can maximize accuracy by simply outputting ‘Not Converted’ for every instance, regardless of its input features. In this way, all the algorithm ‘learns’ is that it’s never worth trying to identify positive instances as positive, because this risks hurting the model’s ability to correctly classify negative instances, which contribute to the overall accuracy nearly 100X more than positive instances.
For any classification model, there is almost always a tradeoff between identifying positive instances accurately, measured by ‘recall’, and identifying negative instances accurately, measured by ‘sensitivity’. A given model will be biased to optimize the accuracy parameter corresponding to the majority class. In the above example, the model’s recall ends up being (0)/(0+169) = 0%, meaning that, when presented with positive instances, the model will always fail to positively classify them. On the other hand, the sensitivity is (16,497)/(0+16,497) = 1.0, or 100%, meaning the model will always correctly classify a negative instance as negative. If the model traded a 1% increase in recall with a 1% decrease in sensitivity during training, the resulting accuracy would be (0.01*200 + 0.99*19,800)/(20,000) = 19,604/20,000 = 98.02%, or a nearly 1% decrease in the overall accuracy.
If you’ve read this far and are still thinking, “Yeah, but 99% accuracy is pretty awesome”but consider the fact the model has provided no information beyond basic statistics. You don’t need ML engineers to effectively divide the number of negative instances by the total number of instances (=0.99) and provide this as the probability of a user being a non-converter. Your analytics team did this on day one, and have already been using this degree of knowledge to inform various business decisions.
"The job of an ML engineer is to go a layer deeper."
A Quick Fix — Undersampling
Before your boss finds out that your first model adds no new value to the company, you need to train a new one that can provide some useful predictions.
In a frantic Google search consisting of the words ‘classifier always outputs zero’, you find a Stack Overflow post mentioning that your binary classifier might perform better if you randomly undersample the majority class so that you end up with equal numbers of positive and negative instances in your training dataset. To make this happen, you know there are 500 converter instances in your full dataset, so you add to these 500 more instances that were randomly sampled from the collection of non-converter instances, forming a new training set. This balanced training set has 1,000 instances, exactly 50% of which are the positive class, as plotted in Figure 3.
You train a second random forest classifier on this new dataset, and through cross validation, you discover your model is 79% accurate at predicting whether a user will convert or not. This number doesn’t sound nearly as good as 99%, but you’re happy with anything above 60% at this point. You inspect the confusion matrix to see the distribution of errors:
You can tell from this matrix that the model is no longer lumping all instances into the ‘Not Converted’ category. Instead, it makes a balanced effort to identify both positive and negative class instances, as evidenced by the nearly symmetric confusion matrix. You also notice that the recall has increased to (128)/(36+128) = 78%, and the sensitivity has decreased to (135)/(42+135) = 76%. In a final test, you look at the accuracy of your model on a hold-out dataset with the original, imbalanced distribution. This seems wise, given that, when you put your model into production, it will see a heavily imbalanced dataset quite unlike the balanced set. The overall accuracy is around 76%, and one last time, you print off the confusion matrix:
From this matrix, you can see that, even though your model can now correctly pick out 414 out of 512 of the converting users, it incorrectly predicts that 10,491 non-converting users will convert. It seems like your model made a pretty bad trade-off. In the effort to correctly identify hundreds of users, the model incorrectly identify thousands! It’s important at this point not to be dismayed, but rather step back and ask yourself, ‘what is the ultimate goal of my model?’ The answer to this question will inform the final step of working with your imbalanced dataset, which is to set the balance between correctly identifying positive cases versus negative cases.
The Inevitable Trade-off
When working on a classification problem, one can always make a trade-off between accurately predicting the positive class and accurately predicting the negative class. In the above example, undersampling the majority class made the model focus more on predicting the positive instances correctly. It did this by effectively over representing the positive case as compared to within the original sample. Undersampling is just one of many ways to facilitate this trade-off. How to set the trade-off balance depends on who will consume the model and for what purposes.
Imagine the consumers of your model suddenly tell you they want to use it to deliver discount coupons to motivate users to buy the product. They specifically mention that they don’t want to send offers to users who already seem interested in purchasing, and that they need a model that gives them a large set of users who are very unlikely to ever enter a paid service.
This is good news for you, because the second model you built using a balanced dataset is almost 80% successful at picking out the set of users who will convert. As a result of your work, you can now provide your stakeholders with a list of nearly 40,000 users that are unlikely to purchase the product. In fact, only around 96/40,000 = 0.24% of these users will end up converting. Even though your model incorrectly identifies around 24% of non-converters as converters, the model consumers don’t care about this type of error. They’re not worried about sending coupons to everyone possible, just a large group who are very unlikely to pay without external motivation. Had you delivered your initial model trained on an imbalanced dataset, you would have caused coupons to be sent to all users, even the 1% that were going to pay anyway, leading to some pretty angry stakeholders!
Working with heavily imbalanced data presents additional challenges as compared to balanced data, but an understanding of a few straightforward concepts can help one avoid common mistakes.
A standard model trained on an unmodified imbalanced dataset will likely lead to a classifier that focuses almost entirely on getting the majority class correct. A simple way to push the model to more accurately identify the minority class is to undersample the majority class in the training set.
There always will be a trade-off between overall accuracy and the ability of the classifier to identify the minority class, and the balance between these two goals should be dictated by the priorities of the use case for the model.