Decision trees are powerful and widely used supervised learning algorithms that are popular for both classification and regression tasks.
They possess a simplicity that makes them valuable aids for data analysis and decision-making due to their ease of comprehension and interpretation. Apart from their standalone usage, they often serve as a default choice for a
Decision trees are tree-like structures where every inner node represents a choice based on a particular feature, while each terminal node represents the result (class label for classification or predicted value for regression). Their operation involves iteratively dividing the dataset into smaller groups using the most significant attribute (feature) at each stage, with the goal of creating leaf (terminal) nodes that represent the final predicted class or value. One of the key concepts in decision trees is the calculation of impurity to determine how heterogeneous (mixed) a dataset is. One common impurity measure is the Gini impurity. Its range is from to .
Note: Apart from Gini impurity, we can also use entropy,
, information gain, misclassification error Misclassification error measures the proportion of incorrect predictions in a node. A smaller misclassification error indicates a better split. , gain ratio Gain Ratio is an alternative to Information Gain that is used to select the attribute for splitting in a decision tree. It is used to overcome the problem of bias towards the attribute with many outcomes. , and mean squared-error (MSE) to create decision trees. chi-square statistic Chi-square statistic assess the significance of splits based on categorical features. It measures the association between a feature and the target variable, aiding in the selection of attributes that most significantly influence the outcome. Higher chi-square values indicate stronger associations between variables, guiding the decision-making process in partitioning the data effectively.
The code below implements to a decision tree plot. We use the sklearn
library for importing make_classification
, train_test_split
, DecisionTreeClassifier
, and accuracy_score
. We first create a synthetic dataset and then implement a decision tree.
import numpy as npimport matplotlib.pyplot as pltfrom sklearn.datasets import make_classificationfrom sklearn.model_selection import train_test_splitfrom sklearn.tree import DecisionTreeClassifier, plot_treefrom sklearn.metrics import accuracy_score# Generate a synthetic dataset for classificationn_features = 2X, y = make_classification(n_samples=100, n_features=n_features,n_classes=2, n_clusters_per_class=2,n_informative=n_features, n_redundant=0,n_repeated=0,random_state=100)feature_names = [f'$x_{str(i+1)}$' for i in range(n_features)]# Split the dataset into training and testing setsX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# Create a decision tree classifierclf = DecisionTreeClassifier()# Train the classifier on the training dataclf.fit(X_train, y_train)# Make predictions on the test datay_pred = clf.predict(X_test)# Calculate accuracy of the modelaccuracy = accuracy_score(y_test, y_pred)print(f"Accuracy: {accuracy:.2f}")# Plot the decision treeplt.figure(figsize=(20, 10))plot_tree(clf, filled=True, rounded=True, feature_names=feature_names, class_names=['Class 0', 'Class 1'])
Lines 1–6: We import the necessary libraries and modules.
Line 10: We define the number of features n_features
for the synthetic dataset. In this case, it’s set to 2
.
Lines 12–14: We generate a synthetic dataset for classification using the make_classification
function. The generated dataset contains 100
samples, 2
features, 2
classes, 2
clusters per class, and informative features equal to the number of features. The random seed is set to 100
for reproducibility.
Line 16: We create a list of feature names feature_names
using list comprehension. The feature names are in the format $x_i
, where i
is the feature index.
Line 19: We split the generated dataset into training and testing sets using the train_test_split
function.
Lines 22–25: We train the decision tree classifier clf
on the training data X_train
and y_train
using the fit
method.
Line 28: We make predictions on the test data X_test
using the trained classifier clf
and store the predictions in the variable y_pred
.
Lines 31–32: We calculate the accuracy of the model’s predictions by comparing them to the true labels y_test
using the accuracy_score
function. The accuracy score is then printed to the console.
Lines 35–36: We create a large figure for plotting the decision tree visualization. The plot_tree
function from sklearn.tree
is used to generate the visualization of the decision tree. The filled=True
and rounded=True
arguments make the tree nodes filled with color and display rounded boxes. The feature_names
argument specifies the names of the features, and class_names
specifies the class labels.
Free Resources