
7 Scikit-learn Tricks for Optimized Cross-Validation
Image by Editor | ChatGPT
Introduction
Validating machine learning models requires careful testing on unseen data to ensure robust, unbiased estimates of their performance. One of the most well-established validation approaches is cross-validation, which splits the dataset into several subsets, called folds, and iteratively trains on some of them while testing on the rest. While scikit-learn offers standard components and functions to perform cross-validation the traditional way, several additional tricks can make the process more efficient, insightful, or flexible.
This article reveals seven of these tricks, along with code examples of their implementation. The code examples below use the scikit-learn library, so make sure it is imported.
I recommend that you first acquaint yourself with the basics of cross-validation by checking out this article. Also, for a quick refresher, a basic cross-validation implementation (no tricks yet!) in scikit-learn would look like this:
from sklearn.datasets import load_iris from sklearn.model_selection import cross_val_score from sklearn.linear_model import LogisticRegression
X, y = load_iris(return_X_y=True)
model = LogisticRegression(max_iter=200)
# Basic cross-validation strategy with k=5 folds scores = cross_val_score(model, X, y, cv=5)
# Cross validation results: per iteration + aggregated print(“Cross-validation scores:”, scores) print(“Mean score:”, scores.mean()) |
The following examples assume that the basic libraries and functions, like cross_val_score
, have already been imported.
1. Stratified cross-validation for imbalanced classification
In classification tasks involving imbalanced datasets, standard cross-validation may not guarantee that the class proportions are represented in each fold. Stratified k-fold cross-validation addresses this challenge by preserving class proportions in each fold. It is implemented as follows:
from sklearn.model_selection import cross_val_score, StratifiedKFold
cv = StratifiedKFold(n_splits=5) scores = cross_val_score(model, X, y, cv=cv) |
2. Shuffled K-fold for Robust Splits
By using a KFold
object along with the shuffle=True
option, we can shuffle instances in the dataset to create more robust splits, thereby preventing accidental bias, especially if the dataset is ordered according to some criterion or the instances are grouped by class label, time, season, etc. It is very simple to apply this strategy:
from sklearn.model_selection import KFold
cv = KFold(n_splits=5, shuffle=True, random_state=42) scores = cross_val_score(model, X, y, cv=cv) |
3. Parallelized cross-validation
This trick improves computational efficiency by using an optional argument in the cross_val_score
function. Simply assign n_jobs=-1
to run the process at the fold level on all available CPU cores. This can result in a significant speed boost, especially when the dataset is large.
scores = cross_val_score(model, X, y, cv=5, n_jobs=–1) |
4. Cross-Validated Predictions
By default, using cross-validation in scikit-learn yields the accuracy scores per fold, which are then aggregated into the overall score. If instead we wanted to get predictions for every instance to later build a confusion matrix, ROC curve, etc., we can use cross_val_predict
as a substitute for cross_val_score
, as follows:
from sklearn.model_selection import cross_val_predict
y_pred = cross_val_predict(model, X, y, cv=5) |
5. Beyond Accuracy: Custom Scoring
It is also possible to replace the default accuracy metric used in cross-validation with other metrics like recall or F1-score. It all depends on the nature of your dataset and your predictive problem’s needs. The make_scorer()
function, along with the specific metric (which must also be imported), achieves this:
from sklearn.metrics import make_scorer, f1_score, recall_score
f1 = make_scorer(f1_score, average=“macro”) # You can use recall_score too scores = cross_val_score(model, X, y, cv=5, scoring=f1) |
6. Leave One Out (LOO) Cross-Validation
This strategy is essentially k-fold cross-validation taken to the extreme, providing an exhaustive evaluation for very small datasets. It is a useful strategy mostly for building simpler models on small datasets like the iris one we showed at the beginning of this article, and is generally not advisable for larger datasets or complex models like ensembles, mainly due to the computational cost. For a little extra boost, it can be optionally used combined with trick number #3 shown earlier:
from sklearn.model_selection import LeaveOneOut
cv = LeaveOneOut() scores = cross_val_score(model, X, y, cv=cv) |
7. Cross-validation Inside Pipelines
The last strategy consists of applying cross-validation to a machine learning pipeline that encapsulates model training with prior data preprocessing steps, such as scaling. This is done by first using make_pipeline()
to build a pipeline that includes preprocessing and model training steps. This pipeline object is then passed to the cross-validation function:
from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler
pipeline = make_pipeline(StandardScaler(), LogisticRegression(max_iter=200)) scores = cross_val_score(pipeline, X, y, cv=5) |
Integrating preprocessing within the cross-validation pipeline is crucial for preventing data leakage.
Wrapping Up
Applying the seven scikit-learn tricks from this article helps optimize cross-validation for different scenarios and specific needs. Below is a quick recap of what we learned.
Trick | Explanation |
---|---|
Stratified cross-validation | Preserves class proportions for imbalanced datasets in classification scenarios. |
Shuffled k-fold | By shuffling data, splits are made more robust against possible bias. |
Parallelized cross-validation | Uses all available CPUs for boosting efficiency. |
Cross-validated predictions | Returns instance-level predictions instead of scores by fold, useful for calculating other metrics like confusion matrices. |
Custom scoring | Allows using custom evaluation metrics like F1-score or recall instead of accuracy. |
Leave One Out (LOO) | Thorough evaluation suitable for smaller datasets and simpler models. |
Cross-validation on pipelines | Integrates data preprocessing steps into the cross-validation process to prevent data leakage. |