- Install CatBoost: Use
pip install catboostorconda install catboost. - Import libraries: Import
CatBoostClassifierfromcatboostand other required libraries likesklearn. - Load dataset: Use any dataset (e.g., breast cancer dataset from
sklearn). - Split data: Divide data into training and testing sets using
train_test_split. - Create model: Initialize
CatBoostClassifierwith parameters likeiterations,depth, andlearning_rate. - Train model: Fit the model using
fit()on the training dataset. - Make predictions: Use
predict()on the test data. - Evaluate: Calculate accuracy and generate a classification report.
How to use CatBoostClassifier in Python
Key takeaways:
CatBoost is a machine learning library that excels at handling categorical data using techniques like ordered boosting and gradient-based optimization.
It supports both regression and classification tasks efficiently.
Installation is simple using
piporcondacommands.The process involves importing necessary libraries, loading a dataset, and understanding key parameters like
iterations,depth, andlearning_rate.Categorical features are handled through the
cat_featuresparameter.Train the model using
CatBoostClassifier.fit()and make predictions withpredict().Model evaluation is done by calculating accuracy and generating a classification report.
CatBoost provides high performance and flexibility for working with diverse datasets.
Installation
The CatBoost module can be easily installed using the pip command provided below:
pip install catboost
We can also use the conda command provided below to install the CatBoost module.
conda install -c conda-forge catboost
Import the libraries
The first step is to import the required libraries.
from sklearn.datasets import load_breast_cancerfrom sklearn.model_selection import train_test_splitfrom sklearn import metricsfrom sklearn.metrics import accuracy_score, classification_reportfrom catboost import CatBoostClassifier
Load the dataset
The next step is to load the dataset. We will use the breast cancer dataset provided by the sklearn library.
Disclaimer: The examples provided in this Answer, including any references to the breast cancer dataset, are purely for demonstration purposes to showcase the capabilities of
CatBoostClassifier. Educative does not endorse or associate with any external data, organizations, or individuals mentioned.
data = load_breast_cancer()X = data.datay = data.target
Understand the parameters
The CatBoostClassifier function takes in several parameters. There are five necessary parameters, along with several optional parameters, that can be used for further customization.
Argument | Description |
| The number of boosting trees to build. |
| The maximum depth of the trees in the ensemble. |
| The step size used for gradient descent during training. |
| The loss function to optimize during training, such as log loss for binary classification. |
| A list of indices or names of categorical features in the dataset. It is only used if there are categorical features in the dataset. |
Train the model
Now, we will use CatBoostClassifier to fit the dataset for training the model.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20)model = CatBoostClassifier(iterations=100, depth=6, learning_rate=0.1, loss_function='Logloss')model.fit(X_train, y_train)
Make a prediction
Now, we will use our trained classifier to make a prediction using X_test.
y_pred = model.predict(X_test)
Evaluate the model
Finally, let's evaluate the performance of our classifier.
accuracy = accuracy_score(y_test, y_pred)print("Accuracy: {:.2f}%".format(accuracy * 100))report = classification_report(y_test, y_pred)print("Classification Report:\n", report)
Example
The following code shows how we can use the CatBoostClassifier in Python:
from sklearn.datasets import load_breast_cancerfrom sklearn.model_selection import train_test_splitfrom sklearn import metricsfrom sklearn.metrics import accuracy_score, classification_reportfrom catboost import CatBoostClassifier# Load the breast cancer datasetdata = load_breast_cancer()# Extract the features (X) and target (y)X = data.datay = data.target# Splitting the datasetX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20)# Training the modelmodel = CatBoostClassifier(iterations=100, depth=6, learning_rate=0.1, loss_function='Logloss', verbose=False)# Fit the model on the training datamodel.fit(X_train, y_train)# Make predictions on the test datay_pred = model.predict(X_test)# Calculate and print the accuracy of the modelaccuracy = accuracy_score(y_test, y_pred)print("Accuracy: {:.2f}%".format(accuracy * 100))# Print the classification report of the modelreport = classification_report(y_test, y_pred)print("Classification Report:\n", report)
Explanation
The code above is explained in detail below:
Lines 1–5: We import the required libraries.
Line 8: We load the breast cancer dataset from
sklearnand store it in thedatavariable.Lines 11–12: We extract the feature matrix
Xand the target vectoryfrom the loaded dataset.Xcontains the input data, andycontains the binary classification labels.Line 15: The dataset is split into training (
X_trainandy_train) and testing (X_testandy_test) sets using thetrain_test_splitfunction. Here, 20% of the data is reserved for testing, and 80% is used for training.Line 18: An instance of the
CatBoostClassifieris created with specified parameters.Line 21: The model is trained on the training data using the
fitmethod.Line 24: The trained model is used to make predictions on the test data.
Lines 27–28: We calculate the accuracy of the model’s predictions by comparing them with the true labels in the test set. The accuracy is printed as a percentage.
Lines 31–32: We generate and print the classification report for the model.
Conclusion
To sum up, CatBoost stands out as a powerful machine learning library. With its unique ordered boosting and gradient-based optimization techniques, it is quite good at managing categorical data. Its ability to deliver high performance in both regression and classification tasks across diverse datasets makes it a great tool.
Frequently asked questions
Haven’t found what you were looking for? Contact Us
How to use CatBoost classifier?
How to import CatBoost in Python?
Free Resources