Trusted answers to developer questions
Trusted Answers to Developer Questions

Related Tags


How to build a decision tree with the IRIS dataset in Python

Harsh Jain

A decision tree is a machine learning algorithm that uses a tree-like model of decisions and their subsequent consequences to arrive at a particular decision. It is a Supervised Machine Learning model, where the data is continuously split according to a certain parameter, and finally, a decision is made.

Usually, a decision tree is drawn upside down, with the root node at the top and the leaf nodes at the bottom. A decision tree usually contains 3 types of nodes.

  1. Root node: The very top node that represents the entire population or sample.
  2. Decision nodes: Sub-nodes that split from the root node.
  3. Leaf nodes: Nodes with no children, also known as terminal nodes.
Structure of a Decision Tree


In Machine Learning, we have two types of models:

  • Regression
  • Classification

You can use decision trees in Regression and Classification problems.

  • Regression tree: These are used to predict continuous variables. For example, predicting rainfall in a region or predicting the revenue that a company might generate in the future.

  • Classification tree: These are used to classify discrete variables. For example, classifying if the temperature of a day will be high or low, or predicting if a team will win the match or not.

How decision trees work

Decision trees work in a step-wise manner, meaning that they perform a step-by-step process instead of following a continuous process. Decision trees follow a tree-like structure, where the nodes of a tree are split using the features based on defined criteria. The main criteria based on which decision trees split are:

  • Gini impurity: Measures the impurity in a node.

  • Entropy: Measures the randomness of the system.

  • Variance: This is normally used in the Regression model, which is a measure of the variation of each data point from the mean.

Practical implementation

Let’s use a real-world dataset to apply decision tree algorithms in Python. You can follow the steps below to create a feasible and useful decision tree:

  • Gather the data.

  • Import the required Python libraries and build a data frame.

  • Create the model in Python (we will use decision trees).

  • Use the test dataset to make a prediction and check the accuracy score of the model.

We will be using the IRIS dataset to build a decision tree classifier. The dataset contains information for three classes of the IRIS plant, namely IRIS Setosa, IRIS Versicolour, and IRIS Virginica, with the following attributes: sepal length, sepal width, petal length, and petal width.

Our aim is to predict the class of the IRIS plant based on the given attributes.


Let’s take a look at the code.

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

# Reading the Iris.csv file
data = load_iris()

# Extracting Attributes / Features
X =

# Extracting Target / Class Labels
y =

# Import Library for splitting data
from sklearn.model_selection import train_test_split

# Creating Train and Test datasets
X_train, X_test, y_train, y_test = train_test_split(X,y, random_state = 50, test_size = 0.25)

# Creating Decision Tree Classifier
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(),y_train)

# Predict Accuracy Score
y_pred = clf.predict(X_test)
print("Train data accuracy:",accuracy_score(y_true = y_train, y_pred=clf.predict(X_train)))
print("Test data accuracy:",accuracy_score(y_true = y_test, y_pred=y_pred))
Build a Decision Tree using IRIS dataset in Python


  • In lines 1 to 4, we import the necessary libraries to read and analyze the dataset.

  • In line 7, we store the IRIS dataset in the variable data. Since the sklearn library contains the IRIS dataset by default, you do not need to upload it again.

  • In line 10, we extract all of the attributes in variable X.

  • In line 13, we extract the target, i.e., the labels in variable y.

  • In line 16, we import the train_test_split function.

  • In line 19, we implement the train_test_split() function. The parameter random_state can be randomly set to any value, but the same needs to be maintained in order to produce reproducible splits. The parameter test_size can also be manipulated based on need. Here, we use a test_size of 0.25, which indicates that we want to split the test data as 25% of the total dataset, and the remaining 75% will be assigned as training data.

  • From lines 22 to 24, we create a decision tree classifier and fit it against the training dataset. By default, the criterion parameter is set to gini. From lines 27 to 30, we import the “accuracy_score” module and implement the same to find the accuracy of both the training and test data.

  • In lines 28 and 29, we get the output as 1, i.e., 100% for training data and 0.947, which is approximately 95%, for the test dataset.



View all Courses

Keep Exploring