Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Last active December 9, 2024 10:17
Show Gist options
  • Save ogrisel/af21bdc55a2c02671a48c68631ee7294 to your computer and use it in GitHub Desktop.
Save ogrisel/af21bdc55a2c02671a48c68631ee7294 to your computer and use it in GitHub Desktop.
Impact of the use of stratified cross-validation on the assesment of epistemic uncertainty in ML performance metrics
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Impact of the use of stratified cross-validation on the assesment of epistemic uncertainty in ML performance metrics\n",
"\n",
"The goal of this notebook is to explore how the use of cross-validation can severely mislead us when measuring the performance of ML models on imbalanced data.\n",
"\n",
"We generate a large synthetic \"population\" dataset with significant class imbalance and randomly subsample small to medium observed dataset. Everything is i.i.d. in the data generating and subsampling processes.\n",
"\n",
"We run cross-validation on the observed dataset and assess whether our cross-validation strategy can meaningully inform us on the \"true\" performance of the model for various classification metrics, namely: the negative log likelihood, Brier score, ROC AUC and average precision.\n",
"\n",
"In particular, we want to check if the variability of the scores measured accross cross-validation folds can meaningful quantify the epistemic uncertainty of our ML performance evaluation and how stratification can be useful or detrimental in that respect.\n",
"\n",
"Spoiler alert: **stratification on the target variable can be highly detrimental as it can lead to degenerately narrow uncertainty estimates** of the true model performance. Vanilla cross-validation on the other end has no such problem: only operation problems (warnings on undefined metrics) on some CV folds where some metric are undefined if the validation set has no occurence of the minority class(es) represented by chance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data generation and sampling"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.float64(0.0149803)"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"\n",
"population_size = 10_000_000\n",
"observed_sample_size = 5_000\n",
"n_features = 10\n",
"n_informative = int(0.3 * n_features)\n",
"expected_positive_fraction = 0.015\n",
"\n",
"rng = np.random.default_rng(0)\n",
"X_all = rng.uniform(0, 1, size=(population_size, n_features))\n",
"y_all = rng.binomial(\n",
" n=1, p=expected_positive_fraction * X_all[:, 0:n_informative].mean(axis=1) * 2\n",
")\n",
"\n",
"assert np.abs(y_all.mean() - expected_positive_fraction) < 0.001\n",
"y_all.mean() # high imbalance: much fewer 1s than 0s"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_observed, X_unobserved, y_observed, y_unobserved = train_test_split(\n",
" X_all, y_all, train_size=observed_sample_size, random_state=0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.int64(86)"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Check that this study is not completely degenerate. We cannot do any\n",
"# meaningful ML if we have too few positive examples.\n",
"assert y_observed.sum() > 30\n",
"y_observed.sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Measuring performance metrics via cross-validation"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"def neg_log_likelihood(y_true, y_pred_proba):\n",
" \"\"\"log-likelihood for binary classification.\n",
" \n",
" This implementation tolerates y_true containing only one class contrary to\n",
" the log_loss function in scikit-learn.\n",
" \"\"\"\n",
" nll = 0\n",
" pos_mask = y_true == 1\n",
" if pos_mask.sum() > 0:\n",
" nll += -np.log(y_pred_proba[pos_mask]).sum()\n",
" \n",
" neg_mask = ~pos_mask\n",
" if neg_mask.sum() > 0:\n",
" nll += -np.log(1 - y_pred_proba[neg_mask]).sum()\n",
"\n",
" return nll / len(y_true)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import brier_score_loss\n",
"from sklearn.metrics import roc_auc_score\n",
"from sklearn.metrics import average_precision_score\n",
"\n",
"shared_cv_params = dict(n_splits=100, test_size=0.2, random_state=0)\n",
"stratified_cv = StratifiedShuffleSplit(**shared_cv_params)\n",
"vanilla_cv = ShuffleSplit(**shared_cv_params)\n",
"\n",
"model = LogisticRegression()\n",
"\n",
"\n",
"cv_results = []\n",
"for cv_strategy in [\"stratified\", \"vanilla\"]:\n",
" cv = vanilla_cv if cv_strategy == \"vanilla\" else stratified_cv\n",
" for split_idx, (train, test) in enumerate(cv.split(X_observed, y_observed)):\n",
"\n",
" model.fit(X_observed[train], y_observed[train])\n",
" y_pred_proba = model.predict_proba(X_observed[test])[:, 1]\n",
" cv_results.append(\n",
" {\n",
" \"cv_strategy\": cv_strategy,\n",
" \"split_idx\": split_idx,\n",
" \"positive_rate_train\": y_observed[train].mean(),\n",
" \"positive_rate_test\": y_observed[test].mean(),\n",
" \"neg_log_likelihood_test\": neg_log_likelihood(y_observed[test], y_pred_proba),\n",
" \"brier_score_test\": brier_score_loss(y_observed[test], y_pred_proba),\n",
" \"roc_auc_test\": roc_auc_score(y_observed[test], y_pred_proba),\n",
" \"average_precision_test\": average_precision_score(y_observed[test], y_pred_proba),\n",
" }\n",
" )\n",
"cv_results = pd.DataFrame(cv_results)\n",
"cv_results_summary = cv_results.groupby(\"cv_strategy\").describe().round(4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Depending on the sample size, it can happen that some warnings are emitted when computing ROC AUC and average precision with vanilla cross-validation. Those warnings can somewhat safely be ignored as explained in the appendix."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Impact of the stratification on the positive class fraction in CV folds\n",
"\n",
"The empirical fraction of positive data points in the observed data can be quite different because of the high imbalance in the data generating process and finite sample size of the observed data:"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(np.float64(0.0172), np.float64(0.015))"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_observed.mean().round(4), y_unobserved.mean().round(4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The two different CV strategies handle this very differently (as expected):"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>count</th>\n",
" <th>mean</th>\n",
" <th>std</th>\n",
" <th>min</th>\n",
" <th>25%</th>\n",
" <th>50%</th>\n",
" <th>75%</th>\n",
" <th>max</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cv_strategy</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>stratified</th>\n",
" <td>100.0</td>\n",
" <td>0.0172</td>\n",
" <td>0.000</td>\n",
" <td>0.0173</td>\n",
" <td>0.0173</td>\n",
" <td>0.0173</td>\n",
" <td>0.0173</td>\n",
" <td>0.0173</td>\n",
" </tr>\n",
" <tr>\n",
" <th>vanilla</th>\n",
" <td>100.0</td>\n",
" <td>0.0172</td>\n",
" <td>0.001</td>\n",
" <td>0.0145</td>\n",
" <td>0.0165</td>\n",
" <td>0.0173</td>\n",
" <td>0.0178</td>\n",
" <td>0.0190</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"cv_strategy \n",
"stratified 100.0 0.0172 0.000 0.0173 0.0173 0.0173 0.0173 0.0173\n",
"vanilla 100.0 0.0172 0.001 0.0145 0.0165 0.0173 0.0178 0.0190"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cv_results_summary[\"positive_rate_train\"]"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>count</th>\n",
" <th>mean</th>\n",
" <th>std</th>\n",
" <th>min</th>\n",
" <th>25%</th>\n",
" <th>50%</th>\n",
" <th>75%</th>\n",
" <th>max</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cv_strategy</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>stratified</th>\n",
" <td>100.0</td>\n",
" <td>0.0170</td>\n",
" <td>0.0000</td>\n",
" <td>0.017</td>\n",
" <td>0.0170</td>\n",
" <td>0.017</td>\n",
" <td>0.017</td>\n",
" <td>0.017</td>\n",
" </tr>\n",
" <tr>\n",
" <th>vanilla</th>\n",
" <td>100.0</td>\n",
" <td>0.0172</td>\n",
" <td>0.0038</td>\n",
" <td>0.010</td>\n",
" <td>0.0148</td>\n",
" <td>0.017</td>\n",
" <td>0.020</td>\n",
" <td>0.028</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"cv_strategy \n",
"stratified 100.0 0.0170 0.0000 0.017 0.0170 0.017 0.017 0.017\n",
"vanilla 100.0 0.0172 0.0038 0.010 0.0148 0.017 0.020 0.028"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cv_results_summary[\"positive_rate_test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In particular we observe that:\n",
"\n",
"- the stratified CV generates data folds with uniform class imbalance, as expected, albeit with a fixed positive fraction that is not necessarily a good approximation of the true rate that cannot be observed;\n",
"- the vanilla strategy displays a some of variability in the positive class fraction of the train and test splits. Note that the true fraction is not necessarily in the Inter Quartile Range (IQR) of between the 25% percentile and the 75%, but at least that IQR is not degenerately narrow and informs on our uncertainty about the fraction of positive in the unobserved population distribution.\n",
"\n",
"In conclusion, **the IQR of the vanilla strategy better captures the epistemic uncertainty about the positive class rate**.\n",
"\n",
"Let's see how this transfer to cross-validate machine learning metrics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Epistemic uncertainty on the Log-Likelihood estimate\n",
"\n",
"Let's fit a model on all the observed data and measure the \"true\" negative log likelihood of the unobserved population data under this model:"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True negative log-likelihood on unobserved data: 0.0777\n"
]
}
],
"source": [
"model.fit(X_observed, y_observed)\n",
"y_pred_proba = model.predict_proba(X_unobserved)[:, 1]\n",
"print(f\"True negative log-likelihood on unobserved data: {neg_log_likelihood(y_unobserved, y_pred_proba):.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's compare without the estimates found by cross-validation on observed data:"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>count</th>\n",
" <th>mean</th>\n",
" <th>std</th>\n",
" <th>min</th>\n",
" <th>25%</th>\n",
" <th>50%</th>\n",
" <th>75%</th>\n",
" <th>max</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cv_strategy</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>stratified</th>\n",
" <td>100.0</td>\n",
" <td>0.0856</td>\n",
" <td>0.0018</td>\n",
" <td>0.0816</td>\n",
" <td>0.0845</td>\n",
" <td>0.0855</td>\n",
" <td>0.0865</td>\n",
" <td>0.0907</td>\n",
" </tr>\n",
" <tr>\n",
" <th>vanilla</th>\n",
" <td>100.0</td>\n",
" <td>0.0868</td>\n",
" <td>0.0153</td>\n",
" <td>0.0587</td>\n",
" <td>0.0752</td>\n",
" <td>0.0856</td>\n",
" <td>0.0972</td>\n",
" <td>0.1291</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"cv_strategy \n",
"stratified 100.0 0.0856 0.0018 0.0816 0.0845 0.0855 0.0865 0.0907\n",
"vanilla 100.0 0.0868 0.0153 0.0587 0.0752 0.0856 0.0972 0.1291"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cv_results_summary[\"neg_log_likelihood_test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly to the previous conclusion, the IQR of the estimate from the stratified CV method is degenerate (much too narrow): it fails to capture any meaningful uncertainty on the estimate of the negative log likelihood. As a result the true log-likelihood is well outside that range.\n",
"\n",
"On the contrary, the IQR range of the vanilla strategy seems to be informative and the true value lie either within or close enough to the IQR boundary of the vanilla method (depending on the values of the parameters at the beginning of the notebook)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Epistemic uncertainty on the Brier score estimate\n",
"\n",
"Let's do the same for the Brier score:"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True Brier score on unobserved data: 0.0148\n"
]
}
],
"source": [
"print(f\"True Brier score on unobserved data: {brier_score_loss(y_unobserved, y_pred_proba):.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>count</th>\n",
" <th>mean</th>\n",
" <th>std</th>\n",
" <th>min</th>\n",
" <th>25%</th>\n",
" <th>50%</th>\n",
" <th>75%</th>\n",
" <th>max</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cv_strategy</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>stratified</th>\n",
" <td>100.0</td>\n",
" <td>0.0167</td>\n",
" <td>0.0001</td>\n",
" <td>0.0165</td>\n",
" <td>0.0166</td>\n",
" <td>0.0167</td>\n",
" <td>0.0167</td>\n",
" <td>0.0169</td>\n",
" </tr>\n",
" <tr>\n",
" <th>vanilla</th>\n",
" <td>100.0</td>\n",
" <td>0.0169</td>\n",
" <td>0.0036</td>\n",
" <td>0.0100</td>\n",
" <td>0.0145</td>\n",
" <td>0.0167</td>\n",
" <td>0.0196</td>\n",
" <td>0.0273</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"cv_strategy \n",
"stratified 100.0 0.0167 0.0001 0.0165 0.0166 0.0167 0.0167 0.0169\n",
"vanilla 100.0 0.0169 0.0036 0.0100 0.0145 0.0167 0.0196 0.0273"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cv_results_summary[\"brier_score_test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We observe a similar conclusion: degenerately narrow IQR for the stratified CV estimate while the vanilla estimate is more informative."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Epistemic uncertainty on the ROC AUC estimate\n"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True ROC AUC on unobserved data: 0.5781\n"
]
}
],
"source": [
"print(\n",
" \"True ROC AUC on unobserved data: \"\n",
" f\"{roc_auc_score(y_unobserved, y_pred_proba):.4f}\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>count</th>\n",
" <th>mean</th>\n",
" <th>std</th>\n",
" <th>min</th>\n",
" <th>25%</th>\n",
" <th>50%</th>\n",
" <th>75%</th>\n",
" <th>max</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cv_strategy</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>stratified</th>\n",
" <td>100.0</td>\n",
" <td>0.5888</td>\n",
" <td>0.0576</td>\n",
" <td>0.4387</td>\n",
" <td>0.5512</td>\n",
" <td>0.5854</td>\n",
" <td>0.6225</td>\n",
" <td>0.7617</td>\n",
" </tr>\n",
" <tr>\n",
" <th>vanilla</th>\n",
" <td>100.0</td>\n",
" <td>0.5884</td>\n",
" <td>0.0597</td>\n",
" <td>0.3978</td>\n",
" <td>0.5558</td>\n",
" <td>0.5925</td>\n",
" <td>0.6222</td>\n",
" <td>0.7393</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"cv_strategy \n",
"stratified 100.0 0.5888 0.0576 0.4387 0.5512 0.5854 0.6225 0.7617\n",
"vanilla 100.0 0.5884 0.0597 0.3978 0.5558 0.5925 0.6222 0.7393"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cv_results_summary[\"roc_auc_test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This time, both evaluation methods lead to similar results and are consistent with the true value. The uncertainty of the stratified method is probably still underestimated but this is not as dramatic as for the two other metrics.\n",
"\n",
"This is probably a consequence of the fact that ROC AUC only assesses the ranking power of the model and is completely blind to lack of probabilistic calibration of its predictions. As a result, it is less sensitive to the degenerate estimation of the true positive class rate."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Epistemic uncertainty on the average precision estimate\n"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True average precision on unobserved data: 0.0193\n"
]
}
],
"source": [
"print(\n",
" \"True average precision on unobserved data: \"\n",
" f\"{average_precision_score(y_unobserved, y_pred_proba):.4f}\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>count</th>\n",
" <th>mean</th>\n",
" <th>std</th>\n",
" <th>min</th>\n",
" <th>25%</th>\n",
" <th>50%</th>\n",
" <th>75%</th>\n",
" <th>max</th>\n",
" </tr>\n",
" <tr>\n",
" <th>cv_strategy</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>stratified</th>\n",
" <td>100.0</td>\n",
" <td>0.0387</td>\n",
" <td>0.0230</td>\n",
" <td>0.0151</td>\n",
" <td>0.0234</td>\n",
" <td>0.0309</td>\n",
" <td>0.0457</td>\n",
" <td>0.1298</td>\n",
" </tr>\n",
" <tr>\n",
" <th>vanilla</th>\n",
" <td>100.0</td>\n",
" <td>0.0372</td>\n",
" <td>0.0235</td>\n",
" <td>0.0093</td>\n",
" <td>0.0215</td>\n",
" <td>0.0290</td>\n",
" <td>0.0447</td>\n",
" <td>0.1172</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"cv_strategy \n",
"stratified 100.0 0.0387 0.0230 0.0151 0.0234 0.0309 0.0457 0.1298\n",
"vanilla 100.0 0.0372 0.0235 0.0093 0.0215 0.0290 0.0447 0.1172"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cv_results_summary[\"average_precision_test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similar conclusions as for ROC AUC."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusions\n",
"\n",
"Stratified cross-validation on the target variable can be **very misleading when evaluating the performance of machine leaning models on imbalanced data** as this method can **overly underestimate the epistemic uncertainty** in the performance measure that naturally stems from the finite sample size of the observed dataset with imbalanced classes. This effect is highly dependent on the choice of the performance metric though: strictly proper scoring rules such as Brier score and log-likelihood that are sensitive to calibration problems. As a result the uncertainty of their estimation is significantly more sensitive to the use of stratification than metrics that only focus on the ranking power of the model.\n",
"\n",
"To confirm this we should rerun this study many times with different random seeds, observe sample sizes, class imbalance levels and model hyperparameters to quantify the fraction of time the score of the model measured on the (unobserved) population dataset lie within the IQR of the score measured by cross-validation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Appendix: binary classification metrics computed on single class evaluation sets\n",
"\n",
"ROC AUC and average precision can be undefined when the evaluation set only contains examples of one of the two classes. scikit-learn typically warns and returns nan in that case. The nans are then ignored by pandas we summarizing the performance scores accross CV folds."
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ogrisel/code/scikit-learn/sklearn/metrics/_ranking.py:375: UndefinedMetricWarning: Only one class is present in y_true. ROC AUC score is not defined in that case.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"nan"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fake_pred_proba = np.linspace(0.01, 0.99, 3)\n",
"roc_auc_score(np.ones(3), fake_pred_proba)"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ogrisel/code/scikit-learn/sklearn/metrics/_ranking.py:375: UndefinedMetricWarning: Only one class is present in y_true. ROC AUC score is not defined in that case.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"nan"
]
},
"execution_count": 93,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"roc_auc_score(np.zeros(3), fake_pred_proba)"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.float64(1.0)"
]
},
"execution_count": 94,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"average_precision_score(np.ones(3), fake_pred_proba)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ogrisel/code/scikit-learn/sklearn/metrics/_ranking.py:1027: UserWarning: No positive class found in y_true, recall is set to one for all thresholds.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"average_precision_score(np.zeros(3), fake_pred_proba)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.float64(1.7694559008005124)"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"neg_log_likelihood(np.ones(3), fake_pred_proba)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.float64(1.7694559008005124)"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"neg_log_likelihood(np.zeros(3), fake_pred_proba)"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.float64(0.41006666666666663)"
]
},
"execution_count": 98,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"brier_score_loss(np.ones(3), fake_pred_proba)"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.float64(0.41006666666666663)"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"brier_score_loss(np.zeros(3), fake_pred_proba)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that pandas automatically ignores nan values when computing aggregate statitics:"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 2.000000\n",
"mean 0.500000\n",
"std 0.707107\n",
"min 0.000000\n",
"25% 0.250000\n",
"50% 0.500000\n",
"75% 0.750000\n",
"max 1.000000\n",
"dtype: float64"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.Series([0, np.nan, 1]).describe()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment