Scikit-learn Visualization Guide: Making Models Speak

Use the Display API to replace complex Matplotlib code

Scikit-learn Visualization Guide: Making Models Speak
Scikit-learn Visualization Guide: Making Models Speak. Image by Author

Introduction

In the journey of machine learning, explaining models with visualization is as important as training them.

A good chart can show us what a model is doing in an easy-to-understand way. Here's an example:

Decision boundaries of two different generalization performances.
Decision boundaries of two different generalization performances. Image by Author

This graph makes it clear that for the same dataset, the model on the right is better at generalizing.

Most machine learning books prefer to use raw Matplotlib code for visualization, which leads to issues:

  1. You have to learn a lot about drawing with Matplotlib.
  2. Plotting code fills up your notebook, making it hard to read.
  3. Sometimes you need third-party libraries, which isn't ideal in business settings.

Good news! Scikit-learn now offers Display classes that let us use methods like from_estimator and from_predictions to make drawing graphs for different situations much easier.

Curious? Let me show you these cool APIs.


Scikit-learn Display API Introduction

Use utils.discovery.all_displays to find available APIs

Scikit-learn (sklearn) always adds Display APIs in new releases, so it's key to know what's available in your version.

Sklearn's utils.discovery.all_displays lets you see which classes you can use.

from sklearn.utils.discovery import all_displays

displays = all_displays()
displays

For example, in my Scikit-learn 1.4.0, these classes are available:

[('CalibrationDisplay', sklearn.calibration.CalibrationDisplay),
 ('ConfusionMatrixDisplay',
  sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay),
 ('DecisionBoundaryDisplay',
  sklearn.inspection._plot.decision_boundary.DecisionBoundaryDisplay),
 ('DetCurveDisplay', sklearn.metrics._plot.det_curve.DetCurveDisplay),
 ('LearningCurveDisplay', sklearn.model_selection._plot.LearningCurveDisplay),
 ('PartialDependenceDisplay',
  sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay),
 ('PrecisionRecallDisplay',
  sklearn.metrics._plot.precision_recall_curve.PrecisionRecallDisplay),
 ('PredictionErrorDisplay',
  sklearn.metrics._plot.regression.PredictionErrorDisplay),
 ('RocCurveDisplay', sklearn.metrics._plot.roc_curve.RocCurveDisplay),
 ('ValidationCurveDisplay',
  sklearn.model_selection._plot.ValidationCurveDisplay)]

Using inspection.DecisionBoundaryDisplay for decision boundaries

Since we mentioned it, let's start with decision boundaries.

If you use Matplotlib to draw them, it's a hassle:

  • Use np.linspace to set coordinate ranges;
  • Use plt.meshgrid to calculate the grid;
  • Use plt.contourf to draw the decision boundary fill;
  • Then use plt.scatter to plot data points.

Now, with inspection.DecisionBoundaryDisplay, you can simplify this process:

from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

iris = load_iris(as_frame=True)
X = iris.data[['petal length (cm)', 'petal width (cm)']]
y = iris.target


svc_clf = make_pipeline(StandardScaler(), 
                        SVC(kernel='linear', C=1))
svc_clf.fit(X, y)

display = DecisionBoundaryDisplay.from_estimator(svc_clf, X, 
                                                 grid_resolution=1000,
                                                 xlabel="Petal length (cm)",
                                                 ylabel="Petal width (cm)")
plt.scatter(X.iloc[:, 0], X.iloc[:, 1], c=y, edgecolors='w')
plt.title("Decision Boundary")
plt.show()

See the final effect in the figure:

Use DecisionBoundaryDisplay to draw a triple classification model.
Use DecisionBoundaryDisplay to draw a triple classification model. Image by Author

Remember, Display can only draw 2D, so make sure your data has only two features or reduced dimensions.

Using calibration.CalibrationDisplay for probability calibration

To compare classification models, probability calibration curves show how confident models are in their predictions.

Note that CalibrationDisplay uses the model's predict_proba. If you use a support vector machine, set probability to True:

from sklearn.calibration import CalibrationDisplay
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.ensemble import HistGradientBoostingClassifier

X, y = make_classification(n_samples=1000,
                           n_classes=2, n_features=5,
                           random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    test_size=0.3, random_state=42)
proba_clf = make_pipeline(StandardScaler(), 
                          SVC(kernel="rbf", gamma="auto", 
                              C=10, probability=True))
proba_clf.fit(X_train, y_train)

CalibrationDisplay.from_estimator(proba_clf, 
                                            X_test, y_test)

hist_clf = HistGradientBoostingClassifier()
hist_clf.fit(X_train, y_train)

ax = plt.gca()
CalibrationDisplay.from_estimator(hist_clf,
                                  X_test, y_test,
                                  ax=ax)
plt.show()
Charts drawn by CalibrationDisplay.
Charts drawn by CalibrationDisplay. Image by Author

Using metrics.ConfusionMatrixDisplay for confusion matrices

When assessing classification models and dealing with imbalanced data, we look at precision and recall.

These break down into TP, FP, TN, and FN – a confusion matrix.

To draw one, use metrics.ConfusionMatrixDisplay. It's well-known, so I'll skip the details.

from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import ConfusionMatrixDisplay

digits = fetch_openml('mnist_784', version=1)
X, y = digits.data, digits.target
rf_clf = RandomForestClassifier(max_depth=5, random_state=42)
rf_clf.fit(X, y)

ConfusionMatrixDisplay.from_estimator(rf_clf, X, y)
plt.show()
Charts drawn with ConfusionMatrixDisplay.
Charts drawn with ConfusionMatrixDisplay. Image by Author

metrics.RocCurveDisplay and metrics.DetCurveDisplay

These two are together because they're often used to evaluate side by side.

RocCurveDisplay compares TPR and FPR for the model.

For binary classification, you want low FPR and high TPR, so the upper left corner is best. The Roc curve bends towards this corner.

Because the Roc curve stays near the upper left, leaving the lower right empty, it's hard to see model differences.

So, we also use DetCurveDisplay to draw a Det curve with FNR and FPR. It uses more space, making it clearer than the Roc curve.

The perfect point for a Det curve is the lower left corner.

from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import DetCurveDisplay

X, y = make_classification(n_samples=10_000, n_features=5,
                           n_classes=2, n_informative=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    test_size=0.3, random_state=42,
                                                    stratify=y)


classifiers = {
    "SVC": make_pipeline(StandardScaler(), SVC(kernel="linear", C=0.1, random_state=42)),
    "Random Forest": RandomForestClassifier(max_depth=5, random_state=42)
}

fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(10, 4))
for name, clf in classifiers.items():
    clf.fit(X_train, y_train)
    
    RocCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_roc, name=name)
    DetCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_det, name=name)
Comparison Chart of RocCurveDisplay and DetCurveDisplay.
Comparison Chart of RocCurveDisplay and DetCurveDisplay. Image by Author

Using metrics.PrecisionRecallDisplay to adjust thresholds

With imbalanced data, you might want to shift recall and precision.

  • For email fraud, you want high precision.
  • For disease screening, you want high recall to catch more cases.

You can adjust the threshold, but what's the right amount?

Here, metrics.PrecisionRecallDisplay can help.

from xgboost import XGBClassifier
from sklearn.datasets import load_wine
from sklearn.metrics import PrecisionRecallDisplay

wine = load_wine()
X, y = wine.data[wine.target<=1], wine.target[wine.target<=1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,
                                                    stratify=y, random_state=42)

xgb_clf = XGBClassifier()
xgb_clf.fit(X_train, y_train)

PrecisionRecallDisplay.from_estimator(xgb_clf, X_test, y_test)
plt.show()
Charting xgboost model evaluation using PrecisionRecallDisplay.
Charting xgboost model evaluation using PrecisionRecallDisplay. Image by Author

This shows that models following Scikit-learn's design can be drawn, like xgboost here. Handy, right?

Using metrics.PredictionErrorDisplay for regression models

We've talked about classification, now let's talk about regression.

Scikit-learn's metrics.PredictionErrorDisplay helps assess regression models.

from sklearn.svm import SVR
from sklearn.metrics import PredictionErrorDisplay

rng = np.random.default_rng(42)
X = rng.random(size=(200, 2)) * 10
y = X[:, 0]**2 + 5 * X[:, 1] + 10 + rng.normal(loc=0.0, scale=0.1, size=(200,))

reg = make_pipeline(StandardScaler(), SVR(kernel='linear', C=10))
reg.fit(X, y)

fig, axes = plt.subplots(1, 2, figsize=(8, 4))
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[0], kind="actual_vs_predicted")
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[1], kind="residual_vs_predicted")
plt.show()
Two charts were drawn by PredictionErrorDisplay.
Two charts were drawn by PredictionErrorDisplay. Image by Author

As shown, it can draw two kinds of graphs. The left shows predicted vs. actual values – good for linear regression.

However, not all data is perfectly linear. For that, use the right graph.

It compares real vs. predicted differences, a residuals plot.

This plot's banana shape suggests our data might not fit linear regression.

Switching from a linear to an rbf kernel can help.

reg = make_pipeline(StandardScaler(), SVR(kernel='rbf', C=10))
A visual demonstration of the improved model performance.
A visual demonstration of the improved model performance. Image by Author

See, with rbf, the residual plot looks better.

Using model_selection.LearningCurveDisplay for learning curves

After assessing performance, let's look at optimization with LearningCurveDisplay.

First up, learning curves – how well the model generalizes with different training and testing data, and if it suffers from variance or bias.

As shown below, we compare a DecisionTreeClassifier and a GradientBoostingClassifier to see how they do as training data changes.

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import LearningCurveDisplay

X, y = make_classification(n_samples=1000, n_classes=2, n_features=10,
                           n_informative=2, n_redundant=0, n_repeated=0)

tree_clf = DecisionTreeClassifier(max_depth=3, random_state=42)
gb_clf = GradientBoostingClassifier(n_estimators=50, max_depth=3, tol=1e-3)

train_sizes = np.linspace(0.4, 1.0, 10)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
LearningCurveDisplay.from_estimator(tree_clf, X, y,
                                    train_sizes=train_sizes,
                                    ax=axes[0],
                                    scoring='accuracy')
axes[0].set_title('DecisionTreeClassifier')
LearningCurveDisplay.from_estimator(gb_clf, X, y,
                                    train_sizes=train_sizes,
                                    ax=axes[1],
                                    scoring='accuracy')
axes[1].set_title('GradientBoostingClassifier')
plt.show()
Comparison of the learning curve of two different models. Image by Author

The graph shows that although the tree-based GradientBoostingClassifier maintains good accuracy on the training data, its generalization capability on test data does not have a significant advantage over the DecisionTreeClassifier.

Using model_selection.ValidationCurveDisplay for visualizing parameter tuning

So, for models that don't generalize well, you might try adjusting the model's regularization parameters to tweak its performance.

The traditional approach is to use tools like GridSearchCV or Optuna to tune the model, but these methods only give you the overall best-performing model and the tuning process is not very intuitive.

For scenarios where you want to adjust a specific parameter to test its effect on the model, I recommend using model_selection.ValidationCurveDisplay to visualize how the model performs as the parameter changes.

from sklearn.model_selection import ValidationCurveDisplay
from sklearn.linear_model import LogisticRegression

param_name, param_range = "C", np.logspace(-8, 3, 10)
lr_clf = LogisticRegression()

ValidationCurveDisplay.from_estimator(lr_clf, X, y,
                                      param_name=param_name,
                                      param_range=param_range,
                                      scoring='f1_weighted',
                                      cv=5, n_jobs=-1)
plt.show()
Fine-tuning of model parameters plotted with ValidationCurveDisplay.
Fine-tuning of model parameters plotted with ValidationCurveDisplay. Image by Author

Some regrets

After trying out all these Displays, I must admit some regrets:

  • The biggest one is that most of these APIs lack detailed tutorials, which is probably why they're not well-known compared to Scikit-learn's thorough documentation.
  • These APIs are scattered across various packages, making it hard to reference them from a single place.
  • The code is still pretty basic. You often need to pair it with Matplotlib's APIs to get the job done. A typical example is DecisionBoundaryDisplay, where after plotting the decision boundary, you still need Matplotlib to plot the data distribution.
  • They're hard to extend. Besides a few methods validating parameters, it's tough to simplify my model visualization process with tools or methods; I end up rewriting a lot.

I hope these APIs get more attention, and as versions upgrade, visualization APIs become even easier to use.


Conclusion

In the journey of machine learning, explaining models with visualization is as important as training them.

This article introduced various plotting APIs in the current version of scikit-learn.

With these APIs, you can simplify some Matplotlib code, ease your learning curve, and streamline your model evaluation process.

Due to length, I didn't expand on each API. If interested, you can check the official documentation for more details.

Now it's your turn. What are your expectations for visualizing machine learning methods? Feel free to leave a comment and discuss.


🎉
Enjoyed this read? Subscribe now to get more cutting-edge data science tips straight to your inbox! Your feedback and questions are welcome—let's discuss in the comments below!