August 23, 2022

How Does Data Augmentation Reduce Image Classification Overfitting?

Learn how data augmentation can reduce image classification overfitting & how VisionERA can help craft unique solutions for you.

There is always some valuable information that can be extracted from any data type. Images containing objects can be used to teach machines to identify the image's label or category correctly. Image classification is one of the popular techniques in computer vision, which aims to enable computers to extract meaningful information from digital media like images or videos. This is essential in specialized applications requiring image recognition, such as

  • Self-driving cars for detecting obstacles or identifying signals.
  • Geospatial analysis to understand the geography of an area.
  • Medical imaging to understand the presence of a disease.
  • Media analytics for identifying objectionable content.
  • And many more.

Image recognition or Image classification is the basis for all the above applications. Hence, an image classification model with higher accuracy is desired for the success of these applications in real-world scenarios.

What Are The Issues Faced By Image Classification Models?

Image classification models are deep learning models with specified neural networks that can analyze and extract the important features of an image. Based on these attributes, the algorithm enables the model to classify the image into a specific category. In general, if we train the model with a sufficient variety and number of images, the model will be better at identifying the correct category.

However, such a diverse and huge amount of relevant data is hard to find. It would be expensive and time-consuming, even if companies strive to acquire such data. So, it is evident that these models frequently face challenges due to the scarcity of larger datasets and a lack of adequately labeled data. Additionally, variations in intra-class, scale, illumination, and presence of noise or clutter in the images affect the accuracy as the model cannot correctly label the images. This is likely to occur when the deep learning model is fed with real images that can have variations in brightness, contrast, zoom, or rotation. One way to ensure that a model can process these differences in the same object is to train the model with an expanded dataset with augmented images. The existing dataset can be augmented using Data Augmentation techniques. These augmented images can add much-needed variety to the existing dataset, and the model generalization can improve.

Concept Of Overfitting And Underfitting Of Machine Learning Models

Each machine learning model's primary goal is to generalize well. In this context, generalization refers to a Machine Learning (ML) model's ability to provide a suitable output by adapting the given set of unknown inputs. It means that after training on the dataset, it is expected the model should produce reliable and accurate results. However, when we evaluate the model performance, there are two major scenarios commonly encountered while training Machine Learning models: Underfitting and Overfitting. The occurrence of any of these can degrade the performance of machine learning models.

In a simple sense, underfitting implies that the trained model makes a few correct and many incorrect predictions. On the other hand, overfitting occurs when the trained model fails to make accurate predictions, i.e., the training accuracy is relatively high, but the validation accuracy is poor. A higher training accuracy indicates that the training error is very small, while a poor validation accuracy means the validation error is very large. Both of these should ideally not be present in models, although they are often challenging to remove. Underfitting is less common in ML models than overfitting, but it should not be neglected. Before delving into the details of overfitting, let us first understand some key terms used for evaluating model performance.

Key Terms To Understand Overfitting

Following are four important concepts essential to understanding the machine learning model performance.

  • Signal: This refers to the true underlying pattern of the data that allows the machine learning model to learn from it.
  • Noise: Unnecessary and irrelevant data is called 'noise' that degrades an ML model's performance.
  • Bias: When ML techniques are oversimplified, a prediction inaccuracy can be introduced in the model, making the model heavily dependent on the training data. This is called bias. Also, such a model is likely to fail on test data.
  • Variance: Variance occurs when the machine learning model performs well with the training dataset, i.e., has high accuracy but fares poorly with the test dataset by providing too many incorrect predictions.

’variance

So, when a model appears to be overfitting, it means that the model 'bias' is very low and 'variance' is high, whereas vice-versa in the case of underfitting.

The "goodness of fit" is an optimum fit for the model that indicates the model performs well on unknown data. This can be achieved when variance and bias are kept as low as possible. In reality, there is always a slight trade-off between bias and variance that needs to be considered to make the model acceptable. But then, how do we know if a model is overfitting? Let us understand this in the next section.

How To Identify & Avoid Model Overfitting?

Only a trained model used on test data can be evaluated for overfitting. The model needs to be trained on a split dataset (generally an 80:20 split) with distinct training and testing sets. Next, a plot (often called a learning curve plot) of the model performance can be plotted at each epoch which shows both curves for model performance on the training and validation/test sets for each step of model learning. Now, if the model performance on the training set was exponentially better than on the test set, it is clearly overfitting the training data. It can be often seen in the case of learning curve plots that the model performance on the training dataset continues to improve, i.e., loss continues to reduce or accuracy continues to increase, whereas, for validation/test set, it seems to improve only up to a certain point and then begins to degrade. The training should be stopped whenever such a pattern is observed in order to avoid model overfitting. After understanding overfitting, let us explore some techniques to reduce the model overfitting.

Techniques To Reduce Model Overfitting

Almost all image classification models exhibit a tendency to overfit training data. If a classification model appears to be overfitting, here’s what can be done to achieve the goodness of fit. These are a few strategies that can help reduce overfitting and improve the model generalization.

  1. Data Augmentation: Training with more data can reduce overfitting. Data Augmentation is a technique often used by Data Scientists to increase the size of an existing dataset by applying image transformations like flipping, cropping, zooming, etc., when the available dataset is small, lacks diversity, and acquiring new data could be expensive and time-consuming.
  2. L1 / L2 Regularization: Regularization prevents the neural network (DL model) from learning a highly complex model that could result in an overfit. L1 Regularization (Lasso regression) adds a penalty term “absolute value of magnitude” of the coefficient to the loss function whereas L2 Regularization (Ridge regression) adds the “squared magnitude” of the coefficient as the penalty term to the loss function.
  3. Feature Selection: To prevent the model from learning too many features and eventually being an overfit, it is essential to choose only the most crucial attributes for training when only a small number of training samples are available. Experimenting with different model training features can help assess the model generalization and reduce the overfit.
  4. Hold-out & Cross-validation: In the ‘Hold-out’ strategy, the data is split into multiple splits; use one split for model training and the remaining splits for validating and testing the models. In cross-validation strategy, the dataset is divided into multiple (k) categories called k-folds. This process of dividing the dataset is repeated until each group has been utilized as the testing set once and the rest as the training sets. One important point to note here is that cross-validation can be more computationally expensive than the hold-out strategy since it permits all data to be used for training in the end, unlike hold-out.
  5. Simplifying the Data: Overfitting can arise due to a model's complexity. It is possible that even with a large dataset, the model manages to overfit the training dataset. The data simplification approach can be used to reduce overfitting by reducing the model's complexity to make it simple enough that it does not overfit. This includes techniques like pruning a decision tree, lowering the number of parameters in a neural network, utilizing dropout on a neural network, and so on.
  6. Early Stopping: Plotting the validation loss curve for training epochs provides clarity on when the validation loss starts to worsen (e.g., it starts rising instead of reducing). The model training can be halted at this point by monitoring the loss graph and setting an early halting trigger. The stored model can be improved by retraining it using the other techniques mentioned in this section.

All the above strategies can be used to address the issue of overfitting in classification models. In the next section, let us explore how “Data Augmentation” reduces overfitting.

Applying Data Augmentation To Reduce Overfitting

To understand how Data Augmentation reduces overfitting, we will use the filtered version of the Kaggle dataset 'cats_and_dogs' (original dataset provided by Microsoft) to build an image classifier. There are 2000 training and 1000 testing images for two labels, ‘cats’ and ‘dogs’. Here are some sample images from the dataset.

’sample

Since this is a relatively small-sized image dataset, the trained model can experience overfitting. We will build the image classifier using CNN with the following code -

model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(150, 150, 3)),
tf.keras.layers.MaxPooling2D(2, 2),

tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(1,activation="sigmoid")
])
model.compile(optimizer="adam",
loss='binary_crossentropy',
metrics=['accuracy'])

Training the above model for 100 epochs, we get a training accuracy of 100%, while the validation accuracy is 67.6%. After plotting the curves for these two accuracies, as shown in the below figure, we can clearly see that the model is overfitting as the validation accuracy is not increasing after the first few while the validation loss does not reduce; instead, it increases.

’Training

Next, we will augment the dataset using selective transformations such as width_shift_range, height_shift_range, rotation, zoom, flipping and rotation. Now we can retrain the model with these augmented images. Here are some sample augmented images -

’augmented

The training accuracy reduces to 69.4% while the validation accuracy increases to 71%. Plotting the training and validation accuracies for the model trained on augmented images, we get -

’training

The model does a better job at training with augmented images as both training and validation accuracies overlap well. Also, the training and validation losses keep on reducing (except for one epoch). This indicates that Data Augmentation has helped to reduce overfitting for this image classification model.

Conclusion

In this article, we saw an overview of overfitting and how Data Augmentation can reduce overfitting in an image classification model.

Here are some key takeaways from this article -

  • Underfitting and Overfitting are two important unwanted scenarios affecting Machine Learning (ML) models as both tend to degrade their performance.
  • Deep Learning (DL) models for image classification are likely to overfit the training data as indicated by a higher training accuracy.
  • Several techniques are available to reduce overfitting in ML models, especially image classification models.
  • Data Augmentation allows to increase the training dataset size and adds the much-needed variety for model generalization. However, the Data Augmentation technique needs to be used wisely, as adding irrelevant samples to the training dataset can degrade the performance of the DL model instead.

About us: VisionERA is an Intelligent Document Processing (IDP) platform capable of handling various types of documents because of Data Augmentation for Image Classification. It has the capacity to extract and validate data for bulk volumes with minimal intervention. Also, the platform can be molded as per requirements for any industry and use case because of its custom DIY workflow feature. It is a scalable and flexible platform providing end-to-end document automation for any organization.

Looking for a document processing solution that uses the enhanced capabilities of image classification using deep learning? Setup a demo today by clicking the CTA below or simply send us a query through the contact us page!

Get Started with your Document Automation Journey

$0 Implementation cost | $0 monthly payments -> No Risk, No Headaches

Pay only for Satisfactory Results!

Sign up for Free Trial