From 6c7355271f77534758e7f6f9c3529093021a237f Mon Sep 17 00:00:00 2001 From: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com> Date: Fri, 17 May 2024 10:48:12 +0200 Subject: [PATCH] ENH Improve wording in group-aware cross-validation notebook (#776) Co-authored-by: ArturoAmorQ Co-authored-by: Guillaume Lemaitre --- python_scripts/cross_validation_grouping.py | 76 +++++++++++++-------- 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/python_scripts/cross_validation_grouping.py b/python_scripts/cross_validation_grouping.py index 3c473ecdf..20347b0af 100644 --- a/python_scripts/cross_validation_grouping.py +++ b/python_scripts/cross_validation_grouping.py @@ -7,9 +7,8 @@ # %% [markdown] # # Sample grouping -# We are going to linger into the concept of sample groups. As in the previous -# section, we will give an example to highlight some surprising results. This -# time, we will use the handwritten digits dataset. +# In this notebook we present the concept of **sample groups**. We use the +# handwritten digits dataset to highlight some surprising results. # %% from sklearn.datasets import load_digits @@ -18,8 +17,17 @@ data, target = digits.data, digits.target # %% [markdown] -# We will recreate the same model used in the previous notebook: a logistic -# regression classifier with a preprocessor to scale the data. +# We create a model consisting of a logistic regression classifier with a +# preprocessor to scale the data. +# +# ```{note} +# Here we use a `MinMaxScaler` as we know that each pixel's gray-scale is +# strictly bounded between 0 (white) and 16 (black). This makes `MinMaxScaler` +# more suited in this case than `StandardScaler`, as some pixels consistently +# have low variance (pixels at the borders might almost always be zero if most +# digits are centered in the image). Then, using `StandardScaler` can result in +# a very high scaled value due to division by a small number. +# ``` # %% from sklearn.preprocessing import MinMaxScaler @@ -29,8 +37,10 @@ model = make_pipeline(MinMaxScaler(), LogisticRegression(max_iter=1_000)) # %% [markdown] -# We will use the same baseline model. We will use a `KFold` cross-validation -# without shuffling the data at first. +# The idea is to compare the estimated generalization performance using +# different cross-validation techniques and see how such estimations are +# impacted by underlying data structures. We first use a `KFold` +# cross-validation without shuffling the data. # %% from sklearn.model_selection import cross_val_score, KFold @@ -59,9 +69,9 @@ ) # %% [markdown] -# We observe that shuffling the data improves the mean accuracy. We could go a -# little further and plot the distribution of the testing score. We can first -# concatenate the test scores. +# We observe that shuffling the data improves the mean accuracy. We can go a +# little further and plot the distribution of the testing score. For such +# purpose we concatenate the test scores. # %% import pandas as pd @@ -72,29 +82,29 @@ ).T # %% [markdown] -# Let's plot the distribution now. +# Let's now plot the score distributions. # %% import matplotlib.pyplot as plt -all_scores.plot.hist(bins=10, edgecolor="black", alpha=0.7) +all_scores.plot.hist(bins=16, edgecolor="black", alpha=0.7) plt.xlim([0.8, 1.0]) plt.xlabel("Accuracy score") plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left") _ = plt.title("Distribution of the test scores") # %% [markdown] -# The cross-validation testing error that uses the shuffling has less variance -# than the one that does not impose any shuffling. It means that some specific -# fold leads to a low score in this case. +# Shuffling the data results in a higher cross-validated test accuracy with less +# variance compared to when the data is not shuffled. It means that some +# specific fold leads to a low score in this case. # %% print(test_score_no_shuffling) # %% [markdown] -# Thus, there is an underlying structure in the data that shuffling will break -# and get better results. To get a better understanding, we should read the -# documentation shipped with the dataset. +# Thus, shuffling the data breaks the underlying structure and thus makes the +# classification task easier to our model. To get a better understanding, we can +# read the dataset description in more detail: # %% print(digits.DESCR) @@ -165,7 +175,7 @@ groups[lb:up] = group_id # %% [markdown] -# We can check the grouping by plotting the indices linked to writer ids. +# We can check the grouping by plotting the indices linked to writers' ids. # %% plt.plot(groups) @@ -176,8 +186,9 @@ _ = plt.title("Underlying writer groups existing in the target") # %% [markdown] -# Once we group the digits by writer, we can use cross-validation to take this -# information into account: the class containing `Group` should be used. +# Once we group the digits by writer, we can incorporate this information into +# the cross-validation process by using group-aware variations of the strategies +# we have explored in this course, for example, the `GroupKFold` strategy. # %% from sklearn.model_selection import GroupKFold @@ -191,10 +202,12 @@ ) # %% [markdown] -# We see that this strategy is less optimistic regarding the model -# generalization performance. However, this is the most reliable if our goal is -# to make handwritten digits recognition writers independent. Besides, we can as -# well see that the standard deviation was reduced. +# We see that this strategy leads to a lower generalization performance than the +# other two techniques. However, this is the most reliable estimate if our goal +# is to evaluate the capabilities of the model to generalize to new unseen +# writers. In this sense, shuffling the dataset (or alternatively using the +# writers' ids as a new feature) would lead the model to memorize the different +# writer's particular handwriting. # %% all_scores = pd.DataFrame( @@ -207,13 +220,18 @@ ).T # %% -all_scores.plot.hist(bins=10, edgecolor="black", alpha=0.7) +all_scores.plot.hist(bins=16, edgecolor="black", alpha=0.7) plt.xlim([0.8, 1.0]) plt.xlabel("Accuracy score") plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left") _ = plt.title("Distribution of the test scores") # %% [markdown] -# As a conclusion, it is really important to take any sample grouping pattern -# into account when evaluating a model. Otherwise, the results obtained will be -# over-optimistic in regards with reality. +# In conclusion, accounting for any sample grouping patterns is crucial when +# assessing a model’s ability to generalize to new groups. Without this +# consideration, the results may appear overly optimistic compared to the actual +# performance. +# +# The interested reader can learn about other group-aware cross-validation +# techniques in the [scikit-learn user +# guide](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data).