What is the xgb.plot_tree() function in Python?
XGBoost (eXtreme gradient boosting) is a well-known machine-learning module that employs gradient boosting, a powerful ensemble learning approach that combines the predictions of many weak learners (often decision trees) to produce a strong learner.
The xgb.plot_tree() function
The xgb.plot_tree() function is an invaluable tool that XGBoost provides for visualizing individual decision trees that make the ensemble.
Decision trees can become complex, and visualizing them can help us better comprehend the model's decision-making process, feature relevance, and possible
Syntax
Here, we will show the basic syntax for the xgb.plot_tree() function:
xgb.plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs)
boosteris a required parameter representing the model (XGBRegressor or XGBClassifier) to be visualized.fmapis the name of the feature map file.num_treesrepresents the index of the tree to be plotted. The default; value is 0.rankdiris an optional parameter representing the direction of the graph layout. The value can be"TB"for top-to-bottom or"LR"for left-to-right.axis an optional parameter representing the matplotlib axes object to plot the tree.**kwargsis an optional parameter showing additional keyword arguments that can be passed to the plot function.
Note: Make sure you have the XGBoost library installed. Learn more about the error-free XGBoost installation on your system here.
Code
Let's look at a code example that implements the function xgb.plot_tree() given below:
import xgboost as xgbfrom xgboost import plot_treeimport numpy as npimport matplotlib.pyplot as plt#Creating a synthetic datasetnp.random.seed(42)X = np.random.rand(100, 3)y = np.random.randint(0, 2, 100)#Creating an XGBoost classifiermodel = xgb.XGBClassifier()#Training the model on the datasetmodel.fit(X, y)#Visualizing the first decision tree in the ensembleplot_tree(model, num_trees=0)plt.show()
Code explanation
Line 1–2: Firstly, we import the
xgblibrary and theplot_treefunction to visualize decision trees.Line 3–4: Next, we import the
numpylibrary and thepyplotmodule from thematplotliblibrary.Line 7–9: Now, we create a smaller synthetic dataset with 100 samples and 3 features for our convenience using
random.rand()andrandom.randint()functions. The variableyis binary, having values 0 or 1.Line 12: In this line, we create an XGBoost classifier with default hyperparameters and store it in the variable
model.Line 15: Moving on, we train the model on the entire synthetic dataset
Xandy.Line 18: Now, we visualize the first decision tree in the ensemble using the
plot_treefunction. The parameternum_trees=0specifies to plot the first tree in the ensemble.Line 19: Finally, we display the plot using
plt.show()on the console.
Output
Upon execution, the code will use plot_tree() method to visualize the first decision tree in the XGBoost ensemble model.
The output or decision tree looks like this:
In the plot above, we can see the tree's nodes reflect splitting conditions on certain features, while the leaves provide predicted class labels. This helps in understanding how the model makes decisions based on the features in the dataset.
Conclusion
Therefore, the XGBoost method xgb.plot_tree() is useful for visualizing decision trees in an ensemble model. Using this function, we can learn about the model's decision-making process, feature relevance, and potential overfitting. This improves the XGBoost model's understanding, making communication easier and allowing for enhanced model debugging and tuning.
Free Resources