KNeighborsClassifier in scikit-learn
Scikit-learn is a Python-focused library on machine learning whose algorithms we can utilize for various machine learning tasks, including classification, regression, clustering, and model selection. In this Answer, our main focus will be on the classification domain of scikit-learn.
Note: To get hands-on practice in Scikit-learn, you can explore the course
Classification
Classification is a fundamental ML task that involves assigning predefined labels or categories to input data based on patterns or features. The main goal here is to create a mapping between input features and the corresponding target labels.
The algorithm to be discussed in this answer, KNeighborsClassifier, is a classification algorithm.
KNeighborsClassifier algorithm
KNeighborsClassifier is an algorithm that effectively categorizes data points according to the trends found in that said point's nearest data points or neighbors. Let's take the already labeled data points below.
Now, if a new data point, highlighted using the arrow, is to be added to the set, its neighbors will be explored first. Based on the above example, it is highly likely that it will be assigned the beige class.
Fun astronomical scenario
Imagine we have a dataset of stars in the night sky, each represented by features such as brightness, distance from Earth, and other spectral characteristics. Further, each star has a label indicating its type, e.g., main sequence, red giant, or white dwarf.
The question at hand
Suppose we want to classify a newly discovered star based on its features. For this, we can use the KNeighborsClassifier algorithm. How? Let's see ahead.
The solution
First, we'll gather a training dataset with labeled stars, including their features and corresponding types.
A crucial step is to select a value for k, representing the number of nearest neighbors to consider. Let's say we keep it to five.
When a new star is observed, and its features are measured, the algorithm calculates the distances between this star and all the already existing stars based on their feature values. It then identifies the five nearest neighbors in the training data.
To see KNeighborsClassifier in action, let's check the labels of these five nearest neighbors. Assuming four are labeled as main sequence stars and one as a red giant, we can begin with our prediction phase.
Note: Based on this, the algorithm will predict that the new star is likely a main sequence star.
Predictions
To depict a bird's eye view of the scenario, we can say that stars with similar features, like brightness, tend to belong to the same type.
By finding the nearest labeled stars in the training data and letting this observation decide the new star type, the algorithm uses patterns observed in the known stars to predict the unknown star.
What if the count of the neighbors is found to be equal?
Code sample
To demonstrate the usage of this algorithm, let's take the following code. This code aims to train a KNeighborsClassifier on an existing dataset and then visualize the classification using a scatter plot and a pseudocolor plot.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.datasets import make_blobs
from sklearn.neighbors import KNeighborsClassifier
X,y = make_blobs(n_samples = 550, centers = 4, random_state = 50)
knn = KNeighborsClassifier(n_neighbors = 5)
knn.fit(X, y)
h = 0.05
minimumX = X[:, 0].min() - 1
maximumX = X[:, 0].max() + 1
minimumY = X[:, 1].min() - 1
maximumY = X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(minimumX, maximumX, h), np.arange(minimumY, maximumY, h))
Z = knn.predict(np.c_[xx.ravel(), yy.ravel()])
cmapOne = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF', '#FFCCAA'])
cmapTwo = ListedColormap(['#FF0000', '#00FF00', '#0000FF', '#FF8800'])
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx, yy, Z, cmap=cmapOne)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmapTwo, edgecolor = 'k', s = 20)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("KNeighborsClassifier")
plt.show()Note: Click the "Run button to see the code in action. You can also experiment with the code to understand it better!
Code explanation
Lines 1–2: Prior to writing our main code, we import the necessary libraries. We import
numpyasnp, which provides numerical operations. Then we importmatplotlib.pyplotasplt, which is a plotting library.Line 3: We import the
ListedColormapclass from thematplotlib.colorsmodule. This class is used to create a custom colormap for our plot.Line 4: We import the
make_blobsfunction from thesklearn.datasetsmodule. It's a built-in dataset we can make use of. The following code shows the results of this dataset.
Line 5: We import the
KNeighborsClassifierclass from thesklearn.neighborsmodule for classification tasks.Line 7: We generate a dataset using the
make_blobsfunction. We set the number of samplesn_samplesto 550, the number of centerscentersto 4, and the random staterandom_stateto 50. The generated dataset is assigned to the variablesXandy, representing the input features and corresponding labels, respectively.Line 9: We create an instance of the
KNeighborsClassifierclassknn. Here we can customize the number of neighborsn_neighborsaccording to our own needs. Let's say we set it to 5.Line 11: We fit the model using the
fitmethod of theknnobject so that the dataset(X, y)can be trained with theKNeighborsClassifiermodel.Line 13: Moving on to the data visualization part, we set the value of
hto 0.05. This variable will be used to define the step size for the meshgrid.Lines 15–18: We calculate the minimum and maximum values for the x and y coordinates of our plot using one point lesser and greater than the minimum and maximum values in the original dataset. These are saved in
minimumX,minimumY,maximumX, andmaximumY.Line 20: We use the
np.meshgridfunction to create a grid of coordinates. For this purpose, we pass the range of x-coordinates fromminimumXtomaximumXwith a step size ofh, and the range of y-coordinates fromminimumYtomaximumYwith a step size ofh. The final grid is stored inxxandyy.Line 22: We use the
predictmethod ofknnto classify each point on the meshgrid. The flattened coordinates of the meshgrid are obtained and passed as parameters. The predicted labels are then stored in the variableZ.Line 24: We create a
ListedColormapobject calledcmapOnefor the different classes.Line 25: For our data points, we create
cmapTwowith different colors.Line 27: We reshape the predicted labels
Zto match the shape of the meshgridxx.Line 28: We create a new figure for the plot using
plt.figure().Line 29: We use the
plt.pcolormeshfunction to create a pseudocolor plot of the predicted labelsZon the meshgrid. For customizations, we pass the meshgrid coordinatesxxandyy, the predicted labelsZ, and the colormapcmapOneas parameters.Line 31: We use the
plt.scatterfunction to create a scatter plot of the input featuresX. For customizations, we pass the x-coordinatesX[:, 0], y-coordinatesX[:, 1], the labelsy, the colormapcmapTwo, theedgecoloras 'k', i.e., black, and the marker sizesas 20 pixels.Lines 33–34: We set the x-axis limits of the plot to the minimum and maximum values of
xxand the y-axis limits to ofyy.Line 36: We set the plot title to "KNeighborsClassifier".
Line 38: Yay, our plot is now ready! We finally display the plot using the
plt.show()function.
Output
Our model is now trained to check the 5 closest neighbors and classify the given data points accordingly. That's the magic of the KNeighborsClassifier model!
Food for thought
Considering the above graph, if a point were to have three green neighbors as the closest ones, it would be colored green too.
Test your KNeighborsClassifier knowledge!
How do we specify the neighbor count to be considered?
Yellow class
If a data point has 3 yellow neighbors, and 2 blue neighbors, what is the class that this new point will be assigned with?
We make use of n_neighbors
Blue class
Free Resources