Decision Trees are supervised machine learning algorithms used for both classification and regression. In scikit-learn, the main classes are DecisionTreeClassifier and DecisionTreeRegressor. A tree works by recursively splitting the data into smaller regions using feature-based rules until it reaches a prediction at a leaf node.
A Decision Tree learns a sequence of if-then rules from data.
Example:
age < 30, go leftsalary > 5000, go rightexperience < 2, go left againAt the end of the path, the model reaches a leaf and makes a prediction.
So a tree is built from:
This makes Decision Trees one of the easiest machine learning models to understand visually. Scikit-learn also provides plot_tree to visualize a fitted tree.
Decision Trees are popular because they are:
They are also a foundation for more advanced ensemble models such as Random Forests and Gradient Boosting. Scikit-learn describes Random Forests as ensembles of decision trees built on subsamples to improve predictive accuracy and control overfitting.
At each node, the tree searches for the split that best separates the data.
For classification, DecisionTreeClassifier supports criteria such as:
ginientropylog_loss For regression, DecisionTreeRegressor supports criteria such as:
squared_errorfriedman_mseabsolute_errorpoisson A good split is one that makes the child nodes more “pure”.
Example:
That is better than a split where both branches remain mixed.
A good split reduces prediction error inside each child node.
Example:
In classification, each leaf predicts a class. Scikit-learn’s tree guide notes that DecisionTreeClassifier supports binary and multiclass classification, and it can also output class probabilities with predict_proba, where the probability is based on the fraction of training samples of each class in the leaf.
Suppose the tree is predicting whether a student passes:
Otherwise:
This is why trees are very interpretable.
In regression, each leaf predicts a number.
Example:
Otherwise:
Unlike linear regression, a decision tree regressor does not fit one global equation. It splits the feature space into regions and predicts a value in each region. Scikit-learn’s regression example shows that tree regression can approximate a nonlinear sine curve, but can overfit noise when max_depth is too high.
Unlike SVM and KNN, Decision Trees do not depend on geometric distance. They split on thresholds like:
feature <= valuefeature > valueBecause of that, feature scaling is usually not necessary for tree models. This follows from how scikit-learn decision trees operate: they choose threshold-based splits feature by feature rather than optimizing a distance-based objective.
So this is one major practical advantage of trees.
pip install numpy pandas matplotlib scikit-learnWe will use the Breast Cancer dataset from scikit-learn.
import numpy as np
import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
# Load dataset
data = load_breast_cancer()
X = data.data
y = data.target
# Split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Build model
model = DecisionTreeClassifier(random_state=42)
# Train
model.fit(X_train, y_train)
# Predict
y_pred = model.predict(X_test)
# Evaluate
print("Accuracy:", accuracy_score(y_test, y_pred))
print("\nConfusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred))DecisionTreeClassifier uses criterion, splitter, max_depth, and other parameters to control how the tree is built.
A Decision Tree can also estimate probabilities.
proba = model.predict_proba(X_test[:5])
print(proba)This returns the class probabilities for the first five samples. In scikit-learn, these probabilities are based on the class proportions in the leaf reached by each sample.
Scikit-learn provides plot_tree to visualize a fitted decision tree. It can show node labels, class names, feature names, impurity, and sample counts.
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
plt.figure(figsize=(20, 10))
plot_tree(
model,
filled=True,
feature_names=data.feature_names,
class_names=data.target_names,
rounded=True,
fontsize=8
)
plt.show()Each node typically shows:
This is one of the best parts of using Decision Trees: you can inspect the learned logic directly.
Very large trees become hard to read.
plt.figure(figsize=(18, 8))
plot_tree(
model,
max_depth=3,
filled=True,
feature_names=data.feature_names,
class_names=data.target_names,
rounded=True,
fontsize=9
)
plt.show()The plot_tree function supports a max_depth argument, which is useful for visualization even if the underlying model is deeper.
DecisionTreeClassifierFrom the scikit-learn API, important parameters include:
criterionsplittermax_depthmin_samples_splitmin_samples_leafmax_featuresrandom_stateccp_alpha criterionMeasures split quality.
Common choices:
ginientropylog_loss max_depthMaximum depth of the tree.
min_samples_splitMinimum number of samples required to split a node.
min_samples_leafMinimum number of samples required in a leaf.
splitterCan be:
"best""random" ccp_alphaControls minimal cost-complexity pruning. Scikit-learn added pruning controlled by ccp_alpha for tree estimators; increasing it prunes the tree more aggressively.
Decision Trees are flexible. That is useful, but it also means they can memorize training data.
Signs of overfitting:
Scikit-learn’s tree regression example explicitly notes that high max_depth can make the model learn the noise and overfit.
You can regularize a tree by limiting its growth:
max_depthmin_samples_splitmin_samples_leafmax_leaf_nodesccp_alpha Example:
model = DecisionTreeClassifier(
max_depth=4,
min_samples_split=10,
min_samples_leaf=5,
random_state=42
)
This usually generalizes better than an unconstrained tree.
GridSearchCV is the standard scikit-learn tool for exhaustive hyperparameter search with cross-validation.
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier
param_grid = {
"criterion": ["gini", "entropy", "log_loss"],
"max_depth": [3, 5, 7, 10, None],
"min_samples_split": [2, 5, 10, 20],
"min_samples_leaf": [1, 2, 4, 8]
}
grid = GridSearchCV(
estimator=DecisionTreeClassifier(random_state=42),
param_grid=param_grid,
cv=5,
scoring="accuracy",
n_jobs=-1
)
grid.fit(X_train, y_train)
print("Best Parameters:", grid.best_params_)
print("Best CV Score:", grid.best_score_)
best_model = grid.best_estimator_
y_pred = best_model.predict(X_test)
print("Test Accuracy:", accuracy_score(y_test, y_pred))
Instead of guessing the best depth or leaf size, you let cross-validation choose a stronger configuration.
ccp_alphaCost-complexity pruning is one of the most important practical tools for trees in scikit-learn. The pruning control parameter is ccp_alpha. Larger values prune more nodes.
Example:
pruned_model = DecisionTreeClassifier(
random_state=42,
ccp_alpha=0.01
)
pruned_model.fit(X_train, y_train)
y_pred = pruned_model.predict(X_test)
print("Pruned Accuracy:", accuracy_score(y_test, y_pred))Pruning removes branches that add complexity but do not improve generalization enough.
We will use make_moons to see a nonlinear decision boundary.
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
# Create nonlinear dataset
X, y = make_moons(n_samples=300, noise=0.25, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
model = DecisionTreeClassifier(max_depth=5, random_state=42)
model.fit(X_train, y_train)
print("Accuracy:", model.score(X_test, y_test))def plot_decision_boundary(model, X, y):
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(
np.linspace(x_min, x_max, 400),
np.linspace(y_min, y_max, 400)
)
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, alpha=0.3)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors="k")
plt.title("Decision Tree Boundary")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()
plot_decision_boundary(model, X, y)Decision Tree boundaries are made of axis-aligned splits, so they often look like rectangular step-like regions rather than smooth curves.

for depth in [1, 3, 5, 10, None]:
model = DecisionTreeClassifier(max_depth=depth, random_state=42)
model.fit(X_train, y_train)
print(f"max_depth={depth}, accuracy={model.score(X_test, y_test):.4f}")This matches scikit-learn’s warning from the tree regression example that too much depth can fit noise rather than signal.
DecisionTreeRegressorNow let us use a Decision Tree for regression.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
# Synthetic data
rng = np.random.RandomState(42)
X = np.sort(5 * rng.rand(200, 1), axis=0)
y = np.sin(X).ravel()
# Add noise
y[::5] += 0.5 - rng.rand(40)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
model = DecisionTreeRegressor(max_depth=4, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print("MSE:", mean_squared_error(y_test, y_pred))
print("R²:", r2_score(y_test, y_pred))X_plot = np.linspace(X.min(), X.max(), 500).reshape(-1, 1)
y_plot = model.predict(X_plot)
plt.figure(figsize=(8, 6))
plt.scatter(X, y, label="Data")
plt.plot(X_plot, y_plot, linewidth=2, label="Tree prediction")
plt.xlabel("X")
plt.ylabel("y")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()Scikit-learn’s official tree regression example uses a similar sine-style dataset and shows how tree depth changes the fit quality.
DecisionTreeRegressor shares many structure-control parameters with the classifier version, including:
max_depthmin_samples_splitmin_samples_leafmax_featuresccp_alpha Its split criterion options include:
squared_errorfriedman_mseabsolute_errorpoisson Example:
model = DecisionTreeRegressor(
criterion="squared_error",
max_depth=5,
min_samples_leaf=4,
random_state=42
)
import pandas as pd
df = pd.read_csv("your_data.csv")X = df.drop("target", axis=1)
y = df["target"]from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier(random_state=42)from sklearn.model_selection import GridSearchCV
param_grid = {
"criterion": ["gini", "entropy", "log_loss"],
"max_depth": [3, 5, 7, 10, None],
"min_samples_split": [2, 5, 10],
"min_samples_leaf": [1, 2, 4]
}
grid = GridSearchCV(
model,
param_grid=param_grid,
cv=5,
scoring="accuracy",
n_jobs=-1
)
grid.fit(X_train, y_train)from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
best_model = grid.best_estimator_
y_pred = best_model.predict(X_test)
print("Best Params:", grid.best_params_)
print("Accuracy:", accuracy_score(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))new_samples = X_test.iloc[:5]
predictions = best_model.predict(new_samples)
print(predictions)Scikit-learn’s example on understanding tree structure explains that a fitted tree has a tree_ attribute containing low-level information like node_count and max_depth.
Example:
print("Node count:", model.tree_.node_count)
print("Max depth:", model.tree_.max_depth)This can be useful for diagnosing model complexity.
Decision Trees in scikit-learn expose feature_importances_, which summarizes the relative importance of each feature in the fitted tree.
import pandas as pd
importance = pd.Series(model.feature_importances_, index=data.feature_names)
print(importance.sort_values(ascending=False))This is often useful, but interpret it carefully: tree-based feature importance can be biased in some situations, so it is best used as an exploratory signal rather than absolute truth. This caution is an inference based on how impurity-based splitting works in trees.
For classification, common metrics are:
from sklearn.metrics import confusion_matrix, classification_report
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))
Decision trees support multiclass classification as well as binary classification, so these metrics work naturally in many common settings.
For regression, common metrics are:
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np
mae = mean_absolute_error(y_test, y_pred)
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
r2 = r2_score(y_test, y_pred)
print("MAE:", mae)
print("MSE:", mse)
print("RMSE:", rmse)
print("R²:", r2)Decision Trees are strong because they are:
They also form the basis of important ensemble methods such as Random Forests and Gradient Boosting.
Their main weaknesses are:
This is one reason scikit-learn highlights ensemble tree methods alongside single trees.
Bad:
model = DecisionTreeClassifier(random_state=42)This can work, but it may overfit badly.
Better:
model = DecisionTreeClassifier(
max_depth=5,
min_samples_leaf=4,
random_state=42
)Many users tune depth but forget ccp_alpha. Minimal cost-complexity pruning is built into scikit-learn trees and can be very effective for reducing overfitting.
Feature importance in a single tree can be useful, but it is not the same thing as causal importance. Treat it as a model-based summary, not proof of real-world causation.
Use a Decision Tree when:
Be careful with a single tree when:
In those cases, ensemble models like Random Forests or Gradient Boosting are often stronger follow-up options.
For classification:
DecisionTreeClassifier(
criterion="gini",
max_depth=5,
min_samples_leaf=4,
random_state=42
)For regression:
DecisionTreeRegressor(
criterion="squared_error",
max_depth=5,
min_samples_leaf=4,
random_state=42
)These are not universal best settings, but they are sensible starting points to reduce overfitting compared with a completely unconstrained tree. The available criteria and structure-control parameters are documented in the scikit-learn API.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
# Load
data = load_iris()
X, y = data.data, data.target
# Split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42
)
# Model
model = DecisionTreeClassifier(random_state=42)
# Search space
param_grid = {
"criterion": ["gini", "entropy", "log_loss"],
"max_depth": [2, 3, 4, 5, None],
"min_samples_split": [2, 5, 10],
"min_samples_leaf": [1, 2, 4]
}
# Grid search
grid = GridSearchCV(
model,
param_grid=param_grid,
cv=5,
scoring="accuracy",
n_jobs=-1
)
grid.fit(X_train, y_train)
# Evaluate
best_model = grid.best_estimator_
y_pred = best_model.predict(X_test)
print("Best Parameters:", grid.best_params_)
print("Accuracy:", accuracy_score(y_test, y_pred))
print("\nConfusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred))Decision Trees are one of the easiest machine learning models to understand.
The core idea is:
For classification, the tree predicts a class or class probabilities. For regression, it predicts a numeric value. Scikit-learn supports both through DecisionTreeClassifier and DecisionTreeRegressor.
The most important practical rules are:
max_depth, min_samples_leaf, and min_samples_splitccp_alpha for pruningThese recommendations align with the current scikit-learn documentation for decision trees and tree visualization.
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
# X, y = your data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
model = DecisionTreeClassifier(random_state=42)
param_grid = {
"criterion": ["gini", "entropy", "log_loss"],
"max_depth": [3, 5, 7, None],
"min_samples_split": [2, 5, 10],
"min_samples_leaf": [1, 2, 4]
}
grid = GridSearchCV(
model,
param_grid=param_grid,
cv=5,
scoring="accuracy",
n_jobs=-1
)
grid.fit(X_train, y_train)
print("Best params:", grid.best_params_)
print("Test score:", grid.best_estimator_.score(X_test, y_test))
y_pred = grid.best_estimator_.predict(X_test)
print(classification_report(y_test, y_pred))Train a DecisionTreeClassifier on the Iris dataset and report accuracy.
Train a Decision Tree on make_moons and visualize the decision boundary.
Tune max_depth, min_samples_split, and min_samples_leaf using GridSearchCV.
Compare an unconstrained tree with a pruned or depth-limited tree.
Train a DecisionTreeRegressor on a nonlinear synthetic regression dataset.