Image by Christina Brinza on Unsplash.

Classifying Iris Species with K-Means Clustering

Implementing the K-Means Clustering Algorithm on the Iris Dataset

Tenzin Migmar
4 min readMay 20, 2021

--

If you’re anything like me, you might have spent the past couple of months diving into the realm of neural networks only to submerge from the depths of deep learning and realize you’ve neglected classical machine learning by casting it aside.

Although a lot of cutting-edge artificial intelligence research puts the limelight on deep learning, classical machine learning algorithms like support vector machines and logistic regression also have a lot to offer.

In this article, I’ll walk through one of the most popular algorithms: K-Means Clustering and step through how you can implement it on one of machine learning’s “hello world!” datasets: the iconic Iris dataset.

I’ve always been a strong believer in putting the spotlight on the data. Data is the backbone to artificial intelligence and in keeping up with appearances, I’ll quickly give a short introduction to the beloved Iris dataset.

Only interested in the code? That’s okay, I’m not too offended. I’ll save you the hassle of reading through the entire article because to be honest, I’m not the biggest fan of wading through large amounts of information when I only need a specific piece of it too. Here’s the link to the Kaggle Notebook.

The Iris flower dataset is a multiclassification problem first introduced by British statistician and biologist Ronald Fisher which lies in classifying three variations of an Iris — Iris Setosa, Iris Virginica, and Iris Versicolor. The features for this dataset are the length and width of the petals and sepals and the target class of the species. (If I’ve already convinced you to give this dataset a go, you can find it here, otherwise, hang tight and give me a couple more paragraphs and you might find yourself scrolling back up.)

With that dataset knowledge, let’s begin data exploration to see what golden nuggets of information we can uncover about this dataset.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.cluster import KMeans

df = pd.read_csv("../input/iris-flower-dataset/IRIS.csv")

# First look at the dataset table
df.head(5)

This should reveal to you a table of the first 5 values of the dataset. Take a look. You likely won’t find anything too interesting about it. Only some columns and rows of numbers that don’t have much meaning just yet.

# Shape of data, # of samples working with

df.shape

Pretty self-explanatory! Df.shape will fill us in on how many rows and columns we’re dealing with. In this case, we have 150 rows and 5 columns.

# Useful info like mean, min, max for features

df.describe()

Then, we’ll use .describe() to draw out useful information. Again, this is quite simple so I’ll drop the code for the next bit and move on to explaining the K-Means Clustering algorithm.

# Finding null values

df.isna().sum()
# Plot distributions against features

colors = {"Iris-virginica":"purple", "Iris-setosa": "blue", "Iris-versicolor":"green"}

def showDistributions(feature1):
plt.figure(figsize=(30,30))
plt.subplot(6,6,1)
plt.scatter(df[feature1], df['species'], c=df['species'].map(colors))
plt.title("{} distribution".format(feature1))
plt.xlabel(feature1)
plt.ylabel('species')
plt.show()

showDistributions('sepal_length')
showDistributions('petal_length')
showDistributions('sepal_width')
showDistributions('petal_width')

What is the K-Means Clustering Algorithm?

K-Means clustering is an unsupervised classical machine learning algorithm that classifies targets by clusters of aggregated datapoints resulting from certain similarities within the features. As mentioned, K-Means is an unsupervised algorithm hence the model will be trying to understand the data and draw out valuable or informative features and reveal patterns within the dataset.

How K-Means Clustering works: The K in K-Means clustering is actually a variable that stands for the number of centroids — points at the center of clusters. Beginning with initialization, K-Means calculates the distances between the data points and the number of centroids (k), and then groups each data point under the centroid that it is closest to.

Then, the centroid is recalculated, the formula for the new centroids are as follows:

sum of points grouped to centroid / # of points in group

This process continues until all data points are grouped into their respective clusters and there are no changes in the classes of any of the data points.

With the intuition behind how the K-Means Clustering Algorithm now under our belt, let’s build the model.

# creating our target and prediction values x = df.drop(['species'], axis=1)
y = df['species']
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters = 3, init = 'k-means++', max_iter = 500, n_init = 10, random_state = 0)
model = kmeans.fit_predict(x)
# Looking at the centroid values generated

kmeans.cluster_centers_

We can then gauge the performance of the model by comparing the predicted and actual target values.

species = {"Iris-versicolor": 0, "Iris-setosa": 1, "Iris-virginica": 2}

irisdf = df.copy()

irisdf["species"] = irisdf["species"].map(species)
irisdf["predicted"] = model
irisdf

Closing Notes

This is a pretty standard dataset. Not too interesting or too many patterns or insights to extract because it only has 4 features, none of which are too unique from the others. It was nice to be able to reacquaint myself with K-Means clustering though! I’m looking forward to implementing more classical machine learning algorithms on some other datasets.

Check out some of my other notebooks!

Lunar Rocky Landscape Segmentation with U-Net

Designing Convolutional Neural Networks with Fashion MNIST

Galaxy Multi-Image Classification with LeNet-5

--

--

Tenzin Migmar

my personal blog! :D if you're looking for my other medium for math articles: https://medium.com/@t9nz