On different interpretations of the class weighting and notion of unbalanced classes
Class weighting is often referred to as a simple and powerful technique to use when solving classification problems with unbalanced classes. But how can it be interpreted in a probabilistic sense? We are going to answer it using binary classification problem with the logistic regression model as an example. However, the more general question of the article is actually why unbalanced classes is the issue in the first place? Is it the issue by itself or is it a conseqence of some underlying property of the data? We will figure this out by the course of the narrative.
Logistic regresion with unbalanced classes
Imagine we have binary classification problem. The logistic regression model defines probabilty of class \(y = 1\) given object's features \(x_i\) as
where $ \langle w, x \rangle $ is the scalar product, \(\sigma(\cdot)\) is the sigmoid function. The \(w\) is the model's parameter vector. It defines the separating hyperplane in the space of \(x\). The further you go from this hyperplane to one half-space or another the more probable label \(1\) or label \(0\) will be for a given object. The label's probabilities on the line are equal. This is illustrated on the figure below where $ x = \begin{pmatrix} x_1 \ x_2 \ \end{pmatrix} \in \mathbb{R}^2 $ and $ w = \begin{pmatrix} w_1 \ w_2 \ \end{pmatrix} = \begin{pmatrix} 0 \ 1 \ \end{pmatrix} $. The color intensity is proportional to $ p(y = 1 | w, x) $.
Imagine, we don't know true \(w\). Assume we are given the following dataset
We can see that all objects are quite far from the (true) separation line. Their classes turned out to be the most probable under our model. We also see that the number of points in the upper plane is much less than in the lower. Consequently, we have the dataset with unbalanced classes - class \(1\) is significantly under-represented relative to class \(0\).
When we know true \(w\), we understand that this situation does not arise from the initial model \(p(y|w, x)\) by itself but rather from the positions of the objects in the dataset. Imagine the object were chosen closer to the separation line, then the dataset would be far more likely balanced because the classes' probabilities would be much closer to \(0.5\).
But now we are in the situation where we don't know anything about true \(w\) and how objects were chosen - they were simply given. Should we worry about unbalanced classes? To begin with, let's just find the most probable model for the dataset.
We see our estimation is not very close to the truth, even though the sample size is enough to learn such a simple model. It seems that the unbalanced classes are the cause of the problem.
We can deal with the issue using technique called class weights. Before a full explanation, let's just see the magic of it.
So in general, the technique adds custom class weight to every entry of the corresponding class in the loss function. Then the model is estimated using the modified loss. In the case of logistic regression, the regular loss is the negative log-likelihood:
If we weight class 1 with \(\omega_1 > 0\) and class 0 with \(\omega_0 > 0\), the modified loss is:
Without loss of generality, we set \(\omega_0 = 1\), \(\omega_1 = \omega\). Then \(\omega\) is the indicator of how many times class 1 is more "important" to us than class 0. Let's give several prespectives to interpret the change in the loss.
- From the optimisation point of view, terms with higher weight now contribute greater to the function increase. Now the optimiser has to give higher probability to the higher weighted class to compensate for this increase through the \(\log p\) term.
- Assume \(\omega \ge 2\) is a positive integer. Then, by multiplying objects' losses by \(\omega\) we actually increase the sample size of the weighted class in \(\omega\) times. Therefore, we directly address the issue of the unbalanced classes by balancing their frequencies in the dataset. If \(\omega\) is any positive real, then the effect is essentially the same (note that if \(\omega < 1\) then we increase the sample size of the counter-class).
- Change of the loss always leads to the change of the probability model of the data. Let's rewrite the new loss function:
Now the values under the logarithm do not sum up to one; let's norm them:
We denoted new class probabilities as \(\tilde{p}\). When we normed the former probabilities we also had to add a new term $ \log \big( \sigma (\langle w, x \rangle) + (1 - \sigma (\langle w, x \rangle))^{\omega} \big) $. In our primary case of \(\omega > 1\) this term is nonpositive so the last inequality is valid. What we got at the end is the negative log likelihood \(\tilde{L}\) of the new model with the class probabilities proportional to \(p(y=1|x, w)\) and \(p(y=0|x, w)^{\omega}\). Hence \(L_{\text{cw}}\) is the upper bound on \(\tilde{L}\) and minimising \(L_{\text{cw}}\) also gives us an estimation of the new model's optimal parameters. The new model is not logistic regression in general.
Overall, we can interpret the situation as follows: we had our initial logistic model but then someone decreased the probability of class \(0\) and generated the labels with new probabilities.
The decrease of class 0 probability leads to the change of the separation line. We can find its new equation from
This equation has a unique solution for \(\langle w, x \rangle\) when \(\omega > 0\) (but cannot be solved analytically in general). Therefore, the separation line will still be a line but shifted towards the less weighted class as we saw from the example above.
Info
On practice, the common choice of the weights is to take them proportional to the inverse frequences of the classes. That's what class_weight='balanced' does in some of the sklearn's estimators.
Do class weights always help with unbalanced classes?
Consider the following example under the previous model.
So several points of class 1 appeared in the lower half-plane which is possible under the logistic regression. The classes are still unbalanced. Now let's find maximum likelihood estimation (MLE) with and without class weights.
So we see that using class weights doesn't do a good job here.
Is class imbalance evil?
To be honest, we can always find examples of the dataset where class weights give better or worse estimation of \(w\). The point of the demonstration is different - choose any given method according to the one's assumptions and knowledge about the data. If you are given the dataset with unbalanced classes and have no idea why they are unbalanced, just use MLE and rely on its asymptotic properties. Using class weights in this situation becomes a gamble. But if you have some prior knowledge about data generation, then its incorporation in the loss will be healthy. We have already figured out what it means in the case of the class weights. But the knowledge can be different. For example, along with the given data I could say that the \(\| w \|\) is small and the second component is closer to zero more probably than the second. This knowledge could be easily transformed into regularising terms on \(w\) in the loss function. Or I could say that objects \(x_i\) were actually chosen randomly but with very high probability that their distance to the separation line is some given constant \(d\) (that's what exactly happened in our demonstration). That is equal to the knowing \(p(x) = p(x | w, d)\). Then the full log liklehood of the dataset would be
and again we would come up with some regularisation of the inital loss.
Conclusion
The issue of unbalanced classes does not exist by itself. Otherwise the balanced classes would be an issue too, wouldn't they? Any pattern in the data should be examined according to some expectations and prior knowledge of the field you're analyzing or anything else connected to the process of obtaining original data. Only then will the incorporation of associated techniques be justified and productive.