...
/Dealing with Imbalanced Datasets in Python Programming
Dealing with Imbalanced Datasets in Python Programming
Learn about the fundamentals of imbalanced datasets and explore how to use SMOTE to effectively handle imbalanced datasets.
We'll cover the following...
In this lesson, we will rectify imbalanced datasets using the MNIST dataset, focusing on the digits 0 and 1. We will investigate the impact of some classes having more examples than others and learn how this affects the model’s performance. Then, we will apply SMOTE to balance the dataset. Moreover, we will train balanced and imbalanced datasets using a CNN. Finally, we will compare the performance of these models using metrics such as accuracy, F1 score, precision, and recall.
This lesson is divided into the following three steps:
Step 1: Using a bar chart, we will visualize how many images of the numbers 0 and 1 are in the MNIST dataset.
Step 2: We will apply SMOTE to balance the imbalanced dataset and create a bar chart to show the updated distribution.
Step 3: We will use a CNN model to train the imbalanced and balanced datasets and measure their performance using metrics such as accuracy, F1 score, precision, and recall.
Step 1: Visualizing the MNIST dataset (digits 0 and 1)
The code provided below generates a bar chart that displays the imbalanced distribution of the digits 0 and 1 in the dataset.
Click the “Run” button to observe the imbalanced dataset’s output of the digits 0 and 1.
# Import necessary librariesimport numpy as npimport matplotlib.pyplot as pltfrom tensorflow.keras.datasets import mnist# Load the MNIST dataset(x_train, y_train), (x_test, y_test) = mnist.load_data()# Filter out digit 0 and 1 onlyindex = np.where((y_train == 0) | (y_train == 1))x_filtered = x_train[index]y_filtered = y_train[index]# Delete 3000 samples from digit 0 of the imbalanced datasetsamples_to_delete = 3000index_to_delete = np.where(y_filtered == 0)[0][:samples_to_delete]x_filtered = np.delete(x_filtered, index_to_delete, axis=0)y_filtered = np.delete(y_filtered, index_to_delete, axis=0)# Show the distribution before balancingplt.bar([0, 1], [np.sum(y_filtered == 0), np.sum(y_filtered == 1)])plt.xlabel('Digits')plt.ylabel('Count')plt.title('Digit Distribution (0 and 1) Before Balancing')plt.show()
Lines 1–5: We import the necessary libraries, including
numpy,matplotlib, andmnist.Line 7: We load the MNIST dataset into the training and testing datasets along with their respective labels.
Lines 11–13: We filter the dataset to retain only the samples with the digits
0and1. Thex_filteredvariables hold the corresponding images, and they_filteredvariables contain their respective labels.Lines 15–19: We remove
3000samples of the digit0from the filtered dataset and store the resulting data in thex_filteredarray, which consists of images, and they_filteredarray, which consists of labels.Lines 22–26: We create a bar chart illustrating the number of samples in the digits 0 and 1 before balancing.
Step 2: Using SMOTE
SMOTE works by selecting a point from the minority class and finding its closest points. It then adds new points between the chosen point ...