HOW TO AVOID OVERFITTING YOUR MODEL
Overfitting is a very common problem in ML and Data Science. The cause of the poor performance of the model is either overfitting or underfitting the data. A good machine learning model aims to generalize well from the training data to any data from the problem domain. In this post, you will discover the concept of overfitting and how to deal with the same.
The Goodness of fit. / Signal Vs Noise
What is overfitting in machine learning?
Overfitting refers to a model that models the training data too well that having low training error and high test error. Overfitting occurs when a model learns the noise in the training data instead of finding the signal and due to an overly complex model with too many features which leads to negative impacts on the model’s ability to generalize. As a result, the efficiency and accuracy of the test data decrease.
A Goodness of Fit in Machine Learning
Our goal is to select a model between underfitting and overfitting with low bias and low variance. If the actual(observed) data points fit into a Machine Learning model well, we can say it is Goodness of fit in the statistical term.
Noice and Signal
Separating the signal from noise is a very important task for data scientists because it can cause performance issues, including overfitting. An algorithm can think of noise as a pattern and can start generalizing from it. Noise is unwanted data items, features, or records which don’t help in explaining the feature itself. We want to find the “signal” in the data, rather than fitting the noise.
Example of overfitting
Take an example of Simple Linear Regression. Training the data is all about finding out the minimum cost between the best fit line and data points. To find out the optimum best fit it goes to several iterations. This is where overfitting comes into the picture. In the case of overfitting when we run the training algorithm on the data set. we allow the cost to reduce with each number of alterations. So running this algorithm for too long will mean a reduced cost but it will also fit the noisy data from the data set.
Let’s understand the overfitting concept with A Real World Example,
Let say we have a problem statement that says we want to predict if a player is selected for the team or not based on his current performance. Now imagine we train and fit the model with 10000 such players with the outcome as well and then we try to predict the outcome of the original data set. Let say we got 99% accuracy on the training set but accuracy on the test set comes around 55% this means model is not generalized well from our training data and unseen data.
How to detect overfitting?
This can be judged if the model produces too good results on the training data but performs poorly on the test data. We can also detect overfitting if the curve fitting on data points looks too complex in the prediction graph.
How to prevent Overfitting?
Below are some of the ways to prevent overfitting:
1. Hold back a validation dataset.
We can simply split our dataset into training and testing sets(validation dataset)instead of using all data for training purposes. A common split ratio is 80:20 for training and testing. We train our model until it performs well on the training set and the testing set. This indicates good generalization capability since the testing set represents unseen data that were not used for training.
2. Training with more data
splitting the dataset requires a sufficiently large dataset to train. Such an option makes it easy for the model to detect the signal better to minimize errors. As the user feeds more training data into the model, it will be unable to overfit all the samples and will be forced to generalize to obtain results.
3. Simplifying the Model
One of the ways to prevent overfitting is by Simplifying the model. we can reduce overfitting by decreasing the complexity of the model to make it simple enough that it does not overfit and make the model run faster. Some of the actions that can be implemented include pruning a decision tree, reducing the number of parameters in a neural network, and simply remove layers or reduce the number of neurons to make the network smaller. This reduced the complexity of the model.
4. Early Stopping
Early stopping is a form of regularization while training a model with an iterative method. early stopping is a technique applicable to all the problems. This method updates the model to make it better fit the training data with each iteration. As we can see, after some iterations, the test error has started to increase while the training error is still decreasing. we stop the training and save the current model where we get low bias and low variance. Early stopping rules provide guidance as to how many iterations can be run before the model begins to overfit.
5. Data Augmentation
In the case of neural networks, data augmentation simply means increasing the size of the data and thus reduces overfitting that is a lot of similar images can be generated present in the dataset. Some popular image augmentation techniques are flipping, translation, rotation, scaling, changing brightness, adding noise, etc. Using data augmentation.
6. Ensembling
Ensembling is a machine learning technique that works by combining predictions from two or more separate models. The most popular ensembling methods include boosting and bagging.
Boosting works by using simple base models to increase their aggregate complexity. Boosting combines all the weak learners in the sequence to bring out one strong learner. eg. AdaBoost, gradient Boosting, xgboost
Bagging(Bootstrap Aggregation)is the opposite of boosting. Bagging works by training a large number of strong learners arranged in a parallel pattern and then combining them to optimize their predictions. eg. Random Forest
7. Regularization
Regularization is a technique to reduce the complexity of the model. It does so by adding a penalty term to the loss function. Dropout is a regularization technique that prevents neural networks from overfitting. It randomly drops neurons from the neural network during training in each iteration. Also, the most common techniques are known as L1 and L2 regularization. If the data is too complex to be modeled accurately then L2 is a better choice as it can learn inherent patterns present in the data. While L1 is better if the data is simple enough to be modeled accurately.
8.Cross-validation
The idea of cross-validation is to use initial training data to generate multiple mini train-test splits. The most popular resampling technique is k-fold cross-validation. It allows you to train and test your model k-times on different subsets of training data and build up an estimate of the performance of a machine learning model on unseen data. Based on the model performance on the test fold, we tune its hyperparameters.
Conclusion
In this article, we tried to understand what overfitting is, how to identify it and how to prevent it.
If you want to learn more about Machine Learning projects, check out the following link: