6.2.1.5. Decision Tree Regression#

A decision tree makes predictions by asking a series of yes/no questions about the input features, following a flowchart from the root to a leaf node. At every internal node the tree asks: “Is feature \(j\) above threshold \(t\)?” It routes the sample left or right, and the process repeats until a leaf is reached. The prediction at that leaf is the mean of all training targets that fell into the same region.

Think of it as partitioning the feature space into axis-aligned rectangles. Each rectangle gets one constant prediction - the average of the training points inside it.

Unlike linear models, a decision tree can capture any non-linear relationship without feature engineering. This power comes at a cost: deep trees memorise the training data perfectly but generalise poorly - the classic variance problem. Controlling tree depth is therefore the central challenge.


The Math#

At each node the algorithm searches over all features \(j\) and all possible thresholds \(t\) to find the split that most reduces the MSE over the training samples in that node:

\[\Delta\text{MSE} = \text{MSE}_{\text{parent}} - \left(\frac{n_L}{n}\,\text{MSE}_L + \frac{n_R}{n}\,\text{MSE}_R\right)\]

The prediction at a leaf containing region \(\mathcal{R}_m\) is:

\[\hat{y}_{\text{leaf}} = \frac{1}{|\mathcal{R}_m|}\sum_{i \in \mathcal{R}_m} y_i\]

Key hyperparameters:

Hyperparameter

Effect

max_depth

Maximum levels of splits - primary control over complexity

min_samples_leaf

Minimum training samples per leaf - higher → smoother model

min_samples_split

Minimum samples required to attempt any split

max_features

Number of features considered at each split


In scikit-learn#

from sklearn.tree import DecisionTreeRegressor

tree = DecisionTreeRegressor(max_depth=5, min_samples_leaf=5, random_state=42)
tree.fit(X_train, y_train)

No feature scaling is required - trees split on rank-order, not magnitude.


Example#

Hide code cell source

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from myst_nb import glue
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.datasets import make_regression

np.random.seed(42)

X, y = make_regression(n_samples=300, n_features=10, n_informative=6,
                        noise=25, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=42)
tree = DecisionTreeRegressor(max_depth=5, random_state=42)
tree.fit(X_train, y_train)

train_r2  = r2_score(y_train, tree.predict(X_train))
test_r2   = r2_score(y_test,  tree.predict(X_test))
test_rmse = np.sqrt(mean_squared_error(y_test, tree.predict(X_test)))

print(f"Train R²  : {train_r2:.3f}")
print(f"Test  R²  : {test_r2:.3f}")
print(f"Test  RMSE: {test_rmse:.1f}")
Train R²  : 0.882
Test  R²  : 0.300
Test  RMSE: 148.7

With max_depth=5 the tree achieves a test \(R^2\) of 0.3 and RMSE of 148.7. Train \(R^2\) of 0.882 is notably higher, signalling some overfitting - this is typical of decision trees.

Visualising the Tree Structure#

scikit-learn’s plot_tree renders the full decision flowchart. A max_depth=3 tree is used here so the diagram stays readable - each node shows the split rule, the MSE at that node, the number of samples, and the predicted value.

Hide code cell source

from sklearn.tree import plot_tree

# Train a shallow tree purely for visualisation
tree_viz = DecisionTreeRegressor(max_depth=2, random_state=42)
tree_viz.fit(X_train, y_train)

feature_names = [f"Feature {i}" for i in range(X.shape[1])]

fig, ax = plt.subplots(figsize=(20, 7))
plot_tree(
    tree_viz,
    feature_names=feature_names,
    filled=True,           # colour nodes by predicted value
    rounded=True,
    fontsize=9,
    ax=ax,
    precision=2,
)
ax.set_title("Decision Tree (max_depth=2) - learned structure",
             fontsize=14, fontweight="bold", pad=12)
plt.tight_layout()
plt.show()
../../../../_images/eb7d79ea97eaa3edf1bff9f314ca706e981fd51d03809ec2c973c427eaa54b47.png

Each node is colour-coded by its mean prediction (darker = higher value). The root split always uses the single most informative feature; subsequent splits refine the prediction in each sub-region. Leaf nodes display the constant prediction that will be returned for any sample that reaches them.

The Bias–Variance Trade-off Across Depths#

depths = [1, 2, 3, 5, 8, 12, None]

rows = []
for d in depths:
    m = DecisionTreeRegressor(max_depth=d, random_state=42)
    m.fit(X_train, y_train)
    rows.append({
        "max_depth":   str(d),
        "Train R²":    round(r2_score(y_train, m.predict(X_train)), 3),
        "Test R²":     round(r2_score(y_test,  m.predict(X_test)),  3),
        "Test RMSE":   round(np.sqrt(mean_squared_error(y_test, m.predict(X_test))), 1),
    })

depth_df = pd.DataFrame(rows)
depth_df
max_depth Train R² Test R² Test RMSE
0 1 0.261 0.102 168.6
1 2 0.447 0.090 169.7
2 3 0.626 0.347 143.7
3 5 0.882 0.300 148.7
4 8 0.988 0.445 132.4
5 12 1.000 0.464 130.2
6 None 1.000 0.475 128.9

Hide code cell source

fig, axes = plt.subplots(1, 2, figsize=(13, 4))

depth_labels = depth_df["max_depth"].tolist()
x = np.arange(len(depth_labels))

axes[0].plot(x, depth_df["Train R²"], "o-", linewidth=2, label="Train R²")
axes[0].plot(x, depth_df["Test R²"],  "s--", linewidth=2, label="Test R²")
axes[0].set_xticks(x)
axes[0].set_xticklabels(depth_labels)
axes[0].set_xlabel("max_depth", fontsize=12)
axes[0].set_ylabel("R²", fontsize=12)
axes[0].set_title("R² vs Tree Depth", fontsize=12, fontweight="bold")
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

axes[1].plot(x, depth_df["Test RMSE"], "o-", linewidth=2, color="tomato")
axes[1].set_xticks(x)
axes[1].set_xticklabels(depth_labels)
axes[1].set_xlabel("max_depth", fontsize=12)
axes[1].set_ylabel("Test RMSE", fontsize=12)
axes[1].set_title("Test RMSE vs Tree Depth", fontsize=12, fontweight="bold")
axes[1].grid(True, alpha=0.3)

plt.suptitle("Decision Tree - Bias–Variance Trade-off", fontsize=13, fontweight="bold")
plt.tight_layout()
plt.show()
../../../../_images/f934691399fd4a663b4a641f2b0742ba76aa5a1a63ec67b522e6802907c35455.png

As depth increases, training error falls toward zero (the tree memorises the data) while test error first decreases, then rises again. The sweet spot is usually somewhere in the middle - here around max_depth=5.

Visualising Splits on 1-D Data#

Hide code cell source

np.random.seed(0)
X_1d = np.sort(np.random.uniform(0, 10, 100)).reshape(-1, 1)
y_1d = 3 * np.sin(X_1d.ravel()) + np.random.normal(0, 0.5, 100)
X_1d_tr, X_1d_te, y_1d_tr, y_1d_te = train_test_split(
    X_1d, y_1d, test_size=0.3, random_state=0)

chosen_depths = [1, 3, 6, 12]
Xp = np.linspace(0, 10, 500).reshape(-1, 1)

fig, axes = plt.subplots(1, 4, figsize=(16, 4), sharey=True)
for ax, d in zip(axes, chosen_depths):
    m = DecisionTreeRegressor(max_depth=d, random_state=0)
    m.fit(X_1d_tr, y_1d_tr)
    tr_r2 = m.score(X_1d_tr, y_1d_tr)
    te_r2 = m.score(X_1d_te, y_1d_te)
    ax.scatter(X_1d_tr, y_1d_tr, s=18, alpha=0.5, label='Train')
    ax.plot(Xp, m.predict(Xp), 'r-', linewidth=2)
    ax.set_title(f'depth={d}\nTrain={tr_r2:.2f} / Test={te_r2:.2f}',
                 fontsize=10, fontweight='bold')
    ax.grid(True, alpha=0.3)

axes[0].set_ylabel('y', fontsize=11)
plt.suptitle('Decision Tree - Depth vs Fit Complexity',
             fontsize=13, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()
../../../../_images/e114d886d66d8e09f87320841cd20056c5dadee5d9f4b9e8511189ec85d800c0.png

The step-function predictions are clearly visible: shallow trees produce coarse approximations (high bias), while deep trees follow every wiggle in the training data (high variance).

Feature Importance#

Hide code cell source

imp = pd.Series(tree.feature_importances_,
                index=[f'Feature {i}' for i in range(X.shape[1])])
imp = imp.sort_values(ascending=False)

plt.figure(figsize=(9, 4))
imp.plot(kind='bar', edgecolor='black', alpha=0.85, color='steelblue')
plt.xticks(rotation=45, ha='right')
plt.ylabel('Impurity Decrease (importance)', fontsize=11)
plt.title('Decision Tree - Feature Importances', fontsize=13, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
../../../../_images/7304dd88efa0398022bc4b62a432277ba45022e07e7e0fa31994b9303584cafe.png

Strengths and Weaknesses#

Strengths

No feature scaling needed; handles non-linearity; interpretable (visualise the tree); mixed feature types

Weaknesses

High variance - sensitive to small data changes; prone to overfitting at large depths

Tip

Single decision trees are rarely used in isolation for competitive performance. They are the core building block for Random Forest (bagging) and Boosting (boosting) - ensemble methods that correct their variance problem.