Classification tree#

Iris dataset#

This is an example from sklearn.

Download and visualize Iris dataset:

import seaborn as sns
iris = sns.load_dataset("iris")
sns.pairplot(iris, hue="species");
../_images/f15343f4f2c2ee6746a647b63b3026e2c342cd3be6b60b934570a2a9bbba6b1b.png

Fit decision tree classifier:

from sklearn import tree

y = iris['species']
X = iris.drop("species", axis=1)
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)
clf.score(X, y)
1.0

Plot the tree:

tree.plot_tree(clf, filled=True);
../_images/5a7944d6f154f0a6e0b1297e04cbd68d63a23e8978340875bd4afc9f252e7c2d.png

A prettier tree can be drawn by graphviz:

import graphviz

dot_data = tree.export_graphviz(clf, out_file=None, 
                     feature_names=iris.columns[:-1],  
                     class_names=['setosa', 'versicolor', 'virginica'],  
                     filled=True, rounded=True,  
                     special_characters=True)  
graph = graphviz.Source(dot_data)  
graph 
../_images/3077fbcd329137d40793a6c460cd2504b90100577ea7048df627e84f74082e70.svg

Depth equal to \(2\) is enough for this toy dataset:

clf = tree.DecisionTreeClassifier(max_depth=2)
clf = clf.fit(X, y)
clf.score(X, y)
0.96
../_images/6de197877034cd57eabc6b799301ea137ee6d6ad6b114df3a31fe4a09b0226bc.svg

MNIST#

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

%config InlineBackend.figure_format = 'svg'

X, Y = fetch_openml('mnist_784', return_X_y=True, parser='auto')

X = X.astype(float).values / 255
Y = Y.astype(int).values

Visualize data:

../_images/4507ebd0a1a468d6eace3d7f500bce5c6669e08e76078b7f34cc8bf66f8658b8.svg

Split into train and test:

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=10000)
X_train.shape, X_test.shape, y_train.shape, y_test.shape
((60000, 784), (10000, 784), (60000,), (10000,))

Fit a decision tree model:

from sklearn.tree import DecisionTreeClassifier

DT = DecisionTreeClassifier()
DT.fit(X_train, y_train)
DecisionTreeClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Accuracy score:

print("Train accuracy:", DT.score(X_train, y_train))
print("Test accuracy:", DT.score(X_test, y_test))
Train accuracy: 1.0
Test accuracy: 0.8735
plt.figure(figsize=(10, 8))
plt.title("Decision tree on MNIST")
sns.heatmap(confusion_matrix(y_test, DT.predict(X_test)), annot=True);
../_images/df778fc52f31a739e51cc3aeef3f47da63ec29e6fca7283d980bc57306a5e332.svg

Limit the tree depth and size of leaves:

DT = DecisionTreeClassifier(max_depth=15, min_samples_leaf=3)
DT.fit(X_train, y_train)
DecisionTreeClassifier(max_depth=15, min_samples_leaf=3)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
print("Train accuracy:", DT.score(X_train, y_train))
print("Test accuracy:", DT.score(X_test, y_test))
Train accuracy: 0.9615166666666667
Test accuracy: 0.8776
../_images/5d8627f29bf3b0a9e00fdee4e25f6b4c7c86eadecba18d5325ac4f7a5b01fe76.svg

Splitting conditions#

Each non-terminal node contains splitting condition, depending on which we are going to the left or to the right subtree. The splitting condition usually consists in comparing value of some feature \(x_j\) with a threshold \(t\):

\[ \mathbb I[x_j \leqslant t], \quad 1\leqslant j \leqslant d. \]

According to the splitting condition, the training sample \(X\) is split into two subsamples \(X_l\) and \(X_r\), \(X = X_l \cup X_r\).

ChatGPT suggestions#

  1. Node Splitting

At each internal node, the tree algorithm selects a feature and a splitting criterion to divide the data into two or more child nodes. The goal is to create splits that maximize the purity or homogeneity of the class labels within each node.

  1. Leaf Nodes

The leaf nodes are the terminal nodes of the tree. Each leaf node contains a predicted class label, representing the majority class of the training samples in that node.

  1. Predictive Modeling

To make predictions for new data, you traverse the tree from the root to a leaf node based on the feature values of the new data point. The class label in the selected leaf node is the predicted class for that data point.

  1. Recursive Partitioning

The process of building a classification tree is recursive. The algorithm starts with the entire dataset and recursively splits it into subsets by choosing the best feature and split criterion at each node, continuing until a stopping condition is met.

  1. Stopping Criteria

Stopping criteria are used to determine when to stop growing the tree. Common stopping criteria include limiting the tree depth, setting a minimum number of samples per leaf, or using a minimum impurity reduction threshold.

  1. Impurity Measures

In classification trees, impurity measures such as Gini impurity, entropy, or misclassification rate are used to evaluate how well a split increases the purity or homogeneity of class labels. The split that minimizes impurity is selected.

  1. Pruning

After building a classification tree, it may be pruned to reduce overfitting. Pruning involves removing nodes that do not significantly improve the tree’s performance on a validation dataset.

  1. Visualization

Classification trees can be visualized graphically, making it easy to interpret and understand the model’s decision-making process.

  1. Ensemble Methods

Classification trees are often used as building blocks in ensemble methods like Random Forests and Gradient Boosting, which combine multiple trees to improve predictive accuracy and reduce overfitting.

  1. Advantages

Classification trees are interpretable, and their decision-making process is easy to understand. They can capture complex decision boundaries and interactions between features.

  1. Limitations

They can be prone to overfitting, especially if the tree is allowed to grow deep. Single trees may not generalize well on certain types of data. Ensembling methods can mitigate these limitations.

Classification trees are widely used in various domains, including healthcare, finance, and natural language processing, for tasks such as spam email detection, disease diagnosis, and sentiment analysis. Proper tuning of hyperparameters and consideration of potential overfitting are essential when working with classification trees.