Overfitting in Deep Neural Networks & how to prevent it. | Analytics Vidhya
Mục Lục
The Perfect Fit for a DNN.
The primary objective in deep learning is to have a network that performs its best on both training data & the test data/new data it hasn’t seen before. However, in the case of overfitting & underfitting, this primary objective is not achieved. Overfitting & Underfitting is a common occurrence encountered while training a deep neural network.
Deep neural networks aim’s to learn & generalize the pattern found in the training data so that it can perform similarly on the test data or new data. Although this is the ideal condition or scenario one expects from deep neural networks, this is easier said than done especially to networks like Convolutional Neural Networks, Recurrent Neural Networks etc which consists of millions or billions of tunable parameters which are vulnerable to overfitting.
What is overfitting & why does it occur?
When the network tries to learn too much or too many details in the training data along with the noise from the training data which results in poor performance on unseen or test dataset. When this happens the network fails to generalize the features/pattern found in the training data.
Overfitting during training can be spotted when the error on training data decreases to a very small value but the error on the new data or test data increases to a large value. The error vs iteration graph shows how a deep neural network overfits on training data. The blue curve indicates the error on training data & the red curve the error on test data. The point where the green line intersects is the instance the network begins to overfit. As you can see, the error on test data increases sharply while error on training data decreases.
The above fig illustrates how the model/network (just a simple linear regression model) tries to accommodate every single data point in the training set. A new set of data points will result in the model/network performing poorly as it is very close to all the training points which are noise & outliers. The error on the training points is minimum or very small but the error on the new data points will be high.
One of the main reasons for the network to overfit is if the size of the training dataset is small. When the network tries to learn from a small dataset it will tend to have greater control over the dataset & will make sure to satisfy all the datapoints exactly. It can be thought of as the network trying to memorize every single datapoint failing to capture the general trend in the data.
Underfitting
Underfitting happens when the network can neither model the training or test data which results in overall bad performance. By looking at the graph on the left, the model doesn’t cover all the data points & has a high error on both training & test data.
The reason for underfitting can be because of the limited capacity of the network, a limited number of features provided as input to the network, noisy data etc. Underfitting is not a widely discussed as it is easy to detect & the remedy is to try different machine learning algorithm, provide more capacity to a deep neural network, remove noise from the input data, increasing the training time etc.
Measures to prevent overfitting
1. Decrease the network complexity
Deep neural networks like CNN are prone to overfitting because of the millions or billions of parameters it encloses. A model with these many parameters can overfit on the training data because it has sufficient capacity to do so.
By removing certain layers or decreasing the number of neurons (filters in CNN) the network becomes less prone to overfitting as the neurons contributing to overfitting are removed or deactivated. The network also has a reduced number of parameters because of which it cannot memorize all the data points & will be forced to generalize.
There is no general rule as to how many layers are to be removed or how many neurons must be in a layer before the network can overfit. The popular approach for reducing the network complexity is
- Grid search can be applied to find out the number of neurons and/or layers to reduce or remove overfitting.
- The overfit model can be pruned (trimmed) by removing nodes or connections until it reaches suitable performance on test data.
2. Data Augmentation
One of the best strategies to avoid overfitting is to increase the size of the training dataset. As discussed, when the size of the training data is small the network tends to have greater control over the training data. But in real-world scenarios gathering of large amounts of data is a tedious & time-consuming task, hence the collection of new data is not a viable option.
Data augmentation provides techniques to increase the size of existing training data without any external addition. If our training data consists of images, image augmentation techniques like rotation, horizontal & vertical flipping, translation, increasing or decreasing the brightness or adding noise, cutouts etc can be applied to the existing training images to increase the number of instances.
By applying the above-mentioned data augmentation strategies, the network is trained on multiple instances of the same class of object in different perspectives. An augmented result of a lion’s photograph will have an instance of a lion being viewed in a rotated manner, a lion being viewed up-side-down or cutting out the portion of an image which encloses the mane of a lion. By applying the last augmentation (cutout) the network learns to associate the feature that male lions have a mane with its class.
Performance before & after Data Augmentation.
The concept behind data augmentation is that by increasing the size of the training dataset, the network is unable to overfit on all the input samples (original images + augmented images) & thus is forced to generalize. The overall training loss increases because the network doesn’t predict accurately on the augmented images thus increasing the training loss & the optimizer (optimization algorithm) tunes the network to capture the generalized trend in the training data.
3. Weight Regularization
Weight regularization is a technique which aims to stabilize an overfitted network by penalizing the large value of weights in the network. An overfitted network usually presents with problems with a large value of weights as a small change in the input can lead to large changes in the output. For instance, when the network is given new or test data, it results in incorrect predictions.
Weight regularization penalizes the network’s large weights & forcing the optimization algorithm to reduce the larger weight values to smaller weights, and this leads to stability of the network & presents good performance. In weight regularization, the network configuration remains unchanged only modifying the value of weights.
How does it work?
Weight Regularization reduces overfitting by penalizing or adding a constraint to the loss function. Regularization terms are constraints the optimization algorithm (like Stochastic Gradient Descent) must adhere to when minimizing loss function apart from minimizing the error between predicted value & actual value.
The above two equations represent two types of weight regularization L1 & L2. There are two parts to the equation, the first part is the error between the actual target vs the predicted target (loss function). The second part is the weight penalty or the regularization term.
Without the regularization term
Without the regularization term, the overall loss of the network is the output value of the loss function. As discussed, when the network overfits on training data, the error between predicted & the actual value is very small.
If the training error is very small, then the error gradient is also very small. Then the change in weights is very small as
new-weight = old-weight — lr*(error gradient)
As the updated weight value is close to the old weight values, the network still remains in overfit condition.
With the regularization term
Weight update equation. (SGD)
By adding a weight penalty to the loss function the overall loss/cost of the network increases. The optimizer will now be forced to minimize the weights of the network as that is contributing more to the overall loss.
By increasing the error/loss the error gradient wrt weights increases, which in turn results in a bigger change in weight update. Without the weight penalty, the gradient value remains very small & thus the change in weights also remains small.
With an increase in error gradient, the large weight values are reduced to a smaller value in the weight update rule. Larger weights result in a larger penalty to the loss function, thus pushing the network towards smaller & stabilized weight values.
L1 regularization adds the sum of absolute values of the weights in the network as the weight penalty. L2 regularization adds the squared values of weights as the weight penalty.
The lambda term is a hyperparameter which defines how much of the network’s weights must be reflected on the loss function or simply the term which controls the influence of weight penalty on the loss function.
If the data is too complex, L2 regularization is a better choice as it can model the inherent pattern in the data. If the data is simple, L1 regularization can be used. For most computer vision L2 regularization aka weight decay is applied.
4. Dropouts
Dropout is a regularization strategy that prevents deep neural networks from overfitting. While L1 & L2 regularization reduces overfitting by modifying the loss function, dropouts, on the other hand, deactivate a certain number of neurons at a layer from firing during training.
At each iteration different set of neurons are deactivated & this results in a different set of results. Many deep learning frameworks implement dropouts as a layer which receives inputs from the previous layer, the dropout layer randomly selects neurons which are not fired to the next layer. By deactivating certain neurons which might contribute to overfitting the performance of the network on test data improves.
Dropouts reduce overfitting in a variety of problems like image classification, image segmentation, word embedding etc.
5. Early Stopping
While training a neural network using an optimization algorithm like Gradient Descent, the model parameters (weights) are updated to reduce the training error. At the end of each forward propagation, the network parameters are updated to reduce error in the next iteration.
Too much training can result in network overfitting on the training data. Early stopping provides guidance as to how many iterations can be run before the network begins to overfit.
The above graph indicates the point after which the network begins to overfit. The network parameters at the point of early termination are the best fit for the model. To decrease the test error beyond the point of early termination can be done by
- Decreasing the learning rate. Applying a learning rate scheduler algorithm would be recommended.
- Applying a different optimization algorithm.
- Applying L1 or L2 regularization.
Conclusion
Implementing machine learning algorithms on a dataset directly will not yield the desired results and it may be full of overfitting or underfitting representation of the training data.
This blog provides insight into how to identify overfitting & measures to apply to reduce overfitting & improve overall performance on test data.