Skip to content

Commit

Permalink
Merge pull request #612 from guillaume-vignal/feature/refacto_report_…
Browse files Browse the repository at this point in the history
…plot

Enhance Plot Functionality and Consistency for Additional Visualizations
  • Loading branch information
guillaume-vignal authored Dec 9, 2024
2 parents 455fdc3 + 2264563 commit 18ffab6
Show file tree
Hide file tree
Showing 15 changed files with 1,878 additions and 311 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ report = [
"nbconvert>=6.0.7",
"papermill>=2.0.0",
"jupyter-client>=7.4.0",
"seaborn==0.12.2",
"notebook",
"Jinja2>=2.11.0",
"phik",
Expand Down
120 changes: 119 additions & 1 deletion shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
import random
from typing import Optional

import numpy as np
import pandas as pd
Expand All @@ -16,11 +17,12 @@
from shapash.plots.plot_bar_chart import plot_bar_chart
from shapash.plots.plot_contribution import plot_scatter, plot_violin
from shapash.plots.plot_correlations import plot_correlations
from shapash.plots.plot_evaluation_metrics import plot_confusion_matrix, plot_scatter_prediction
from shapash.plots.plot_feature_importance import plot_feature_importance
from shapash.plots.plot_interactions import plot_interactions_scatter, plot_interactions_violin, update_interactions_fig
from shapash.plots.plot_line_comparison import plot_line_comparison
from shapash.plots.plot_scatter_prediction import plot_scatter_prediction
from shapash.plots.plot_stability import plot_amplitude_vs_stability, plot_stability_distribution
from shapash.plots.plot_univariate import plot_distribution
from shapash.style.style_utils import colors_loading, define_style, select_palette
from shapash.utils.sampling import subset_sampling
from shapash.utils.utils import (
Expand Down Expand Up @@ -1852,3 +1854,119 @@ def scatter_plot_prediction(
)

return fig

def confusion_matrix_plot(
self,
width: int = 700,
height: int = 500,
file_name=None,
auto_open=False,
):
"""
Returns a matplotlib figure containing a confusion matrix that is computed using y_true and
y_pred parameters.
Parameters
----------
y_true : array-like
Ground truth (correct) target values.
y_pred : array-like
Estimated targets as returned by a classifier.
colors_dict : dict
dict of colors used
width : int, optional, default=7
The width of the generated figure, in inches.
height : int, optional, default=4
The height of the generated figure, in inches.
Returns
-------
matplotlib.pyplot.Figure
"""

# Classification Case
if self._explainer._case == "classification":
y_true = self._explainer.y_target.iloc[:, 0]
y_pred = self._explainer.y_pred.iloc[:, 0]
if self._explainer.label_dict is not None:
y_true = y_true.map(self._explainer.label_dict)
y_pred = y_pred.map(self._explainer.label_dict)
# Regression Case
elif self._explainer._case == "regression":
raise (ValueError("Confusion matrix is only available for classification case"))

return plot_confusion_matrix(
y_true=y_true,
y_pred=y_pred,
colors_dict=self._style_dict,
width=width,
height=height,
file_name=file_name,
auto_open=auto_open,
)

def distribution_plot(
self,
col: str,
hue: Optional[str] = None,
width: int = 700,
height: int = 500,
nb_cat_max: int = 7,
nb_hue_max: int = 7,
file_name=None,
auto_open=False,
) -> go.Figure:
"""
Generate a Plotly figure displaying the univariate distribution of a feature
(continuous or categorical) in the dataset.
For categorical features with too many unique categories, the least frequent
categories are grouped into a new 'Other' category to ensure the plot remains
readable. Continuous features are visualized using KDE plots.
The input DataFrame must contain the column of interest (`col`) and a second column
(`hue`) used to distinguish between two groups (e.g., 'train' and 'test').
Parameters
----------
col : str
The name of the column of interest whose distribution is to be visualized.
hue : Optional[str], optional
The name of the column used to differentiate between groups.
width : int, optional, default=700
The width of the generated figure, in pixels.
height : int, optional, default=500
The height of the generated figure, in pixels.
nb_cat_max : int, optional, default=7
Maximum number of categories to display. Categories beyond this limit
are grouped into a new 'Other' category (only for categorical features).
nb_hue_max : int, optional, default=7
Maximum number of hue categories to display. Categories beyond this limit
are grouped into a new 'Other' category.
file_name : str, optional
Path to save the plot as an HTML file. If None, the plot will not be saved, by default None.
auto_open : bool, optional
If True, the plot will automatically open in a web browser after being generated, by default False.
Returns
-------
go.Figure
A Plotly figure object representing the distribution of the feature.
"""
if self._explainer.y_target is not None:
data = pd.concat([self._explainer.x_init, self._explainer.y_target], axis=1)
else:
data = self._explainer.x_init

return plot_distribution(
data,
col,
hue=hue,
colors_dict=self._style_dict,
width=width,
height=height,
nb_cat_max=nb_cat_max,
nb_hue_max=nb_hue_max,
file_name=file_name,
auto_open=auto_open,
)
21 changes: 18 additions & 3 deletions shapash/plots/plot_correlations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sch
Expand All @@ -6,12 +8,14 @@
from plotly.subplots import make_subplots

from shapash.manipulation.summarize import compute_corr
from shapash.style.style_utils import define_style, get_palette
from shapash.utils.utils import adjust_title_height, compute_top_correlations_features, suffix_duplicates


def plot_correlations(
df,
style_dict,
style_dict: Optional[dict] = None,
palette_name: str = "default",
features_dict=None,
optimized=False,
max_features=20,
Expand All @@ -35,6 +39,8 @@ def plot_correlations(
DataFrame for which we want to compute correlations.
style_dict: dict
the different styles used in the different outputs of Shapash
palette_name : str, optional, default="default"
The name of the color palette to be used if `colors_dict` is not provided.
features_dict: dict (default: None)
Dictionary mapping technical feature names to domain names.
optimized : boolean, optional
Expand Down Expand Up @@ -123,6 +129,15 @@ def prepare_corr_matrix(df_subset):
list_features_shorten = suffix_duplicates(list_features_shorten)
return corr, list_features, list_features_shorten

if style_dict:
style_dict_default = {}
keys = ["dict_title", "init_contrib_colorscale"]
if any(key not in style_dict for key in keys):
style_dict_default = define_style(get_palette(palette_name))
style_dict_default.update(style_dict)
else:
style_dict_default = define_style(get_palette(palette_name))

if features_dict is None:
features_dict = {}

Expand Down Expand Up @@ -203,10 +218,10 @@ def prepare_corr_matrix(df_subset):
if len(list_features) < len(df.drop(features_to_hide, axis=1).columns):
subtitle = f"Top {len(list_features)} correlations"
title += f"<span style='font-size: 12px;'><br />{subtitle}</span>"
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
dict_t = style_dict_default["dict_title"] | {"text": title, "y": adjust_title_height(height)}

fig.update_layout(
coloraxis=dict(colorscale=["rgb(255, 255, 255)"] + style_dict["init_contrib_colorscale"][5:-1]),
coloraxis=dict(colorscale=["rgb(255, 255, 255)"] + style_dict_default["init_contrib_colorscale"][5:-1]),
showlegend=True,
title=dict_t,
width=width,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Optional, Union

import numpy as np
import pandas as pd
from plotly import graph_objs as go
from plotly.offline import plot

from shapash.style.style_utils import define_style, get_palette
from shapash.utils.sampling import subset_sampling
from shapash.utils.utils import adjust_title_height, truncate_str, tuning_colorscale

Expand Down Expand Up @@ -356,7 +359,6 @@ def _prediction_regression_plot(y_target, y_pred, prediction_error, list_ind, st
fig = go.Figure()

subtitle = None
prediction_error = prediction_error
if prediction_error is not None:
if (y_target == 0).any().iloc[0]:
subtitle = "Prediction Error = abs(True Values - Predicted Values)"
Expand Down Expand Up @@ -458,8 +460,8 @@ def _prediction_regression_plot(y_target, y_pred, prediction_error, list_ind, st
"y": 1.1,
}
range_axis = [
min(min(y_target_values), min(y_pred_flatten)),
max(max(y_target_values), max(y_pred_flatten)),
min(y_target_values.min(), y_pred_flatten.min()),
max(y_target_values.max(), y_pred_flatten.max()),
]
fig.update_xaxes(range=range_axis)
fig.update_yaxes(range=range_axis)
Expand All @@ -479,3 +481,152 @@ def _prediction_regression_plot(y_target, y_pred, prediction_error, list_ind, st
)

return fig, subtitle


def plot_confusion_matrix(
y_true: Union[np.ndarray, list],
y_pred: Union[np.ndarray, list],
colors_dict: Optional[dict] = None,
width: int = 700,
height: int = 500,
palette_name: str = "default",
file_name=None,
auto_open=False,
) -> go.Figure:
"""
Creates an interactive confusion matrix using Plotly.
Parameters
----------
y_true : array-like
Ground truth (correct) target values.
y_pred : array-like
Estimated targets as returned by a classifier.
colors_dict : dict, optional
Custom colors for the confusion matrix.
width : int, optional
The width of the figure in pixels.
height : int, optional
The height of the figure in pixels.
palette_name : str, optional
The color palette to use for the heatmap.
file_name: string, optional
Specify the save path of html files. If None, no file will be saved.
auto_open: bool, optional
Automatically open the plot.
Returns
-------
go.Figure
The generated confusion matrix as a Plotly figure.
"""
# Create a confusion matrix as a DataFrame
labels = sorted(set(y_true).union(set(y_pred)))
se_y_true = pd.Series(y_true, name="Actual")
se_y_pred = pd.Series(y_pred, name="Predicted")
df_cm = pd.crosstab(se_y_true, se_y_pred).reindex(index=labels, columns=labels, fill_value=0)

if colors_dict:
style_dict = {}
keys = ["dict_title", "init_confusion_matrix_colorscale", "dict_xaxis", "dict_yaxis"]
if any(key not in colors_dict for key in keys):
style_dict = define_style(get_palette(palette_name))
style_dict.update(colors_dict)
else:
style_dict = define_style(get_palette(palette_name))

init_colorscale = style_dict["init_confusion_matrix_colorscale"]
linspace = np.linspace(0, 1, len(init_colorscale))
col_scale = [(value, color) for value, color in zip(linspace, init_colorscale)]

# Convert the DataFrame to a NumPy array
x_labels = list(df_cm.columns)
y_labels = list(df_cm.index)
z = df_cm.loc[x_labels, y_labels].values

title = "Confusion Matrix"
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = style_dict["dict_xaxis"] | {"text": se_y_pred.name}
dict_yaxis = style_dict["dict_yaxis"] | {"text": se_y_true.name}

# Determine if labels are numeric
x_numeric = all(str(label).isdigit() for label in x_labels)
y_numeric = all(str(label).isdigit() for label in y_labels)

hv_text = [
[f"Actual: {y}<br>Predicted: {x}<br>Count: {value}" for x, value in zip(x_labels, row)]
for y, row in zip(y_labels, z)
]

if not x_numeric:
if len(x_labels) < 6:
k = 10
else:
k = 6

# Shorten labels that exceed the threshold
x_labels = [x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x for x in x_labels]

if not y_numeric:
if len(y_labels) < 6:
k = 10
else:
k = 6

# Shorten labels that exceed the threshold
y_labels = [x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x for x in y_labels]

# Create the heatmap using go.Heatmap
heatmap = go.Heatmap(
z=z,
x=x_labels,
y=y_labels,
colorscale=col_scale,
hovertext=hv_text,
hovertemplate="%{hovertext}<extra></extra>",
showscale=True,
)

fig = go.Figure(data=[heatmap])

# Add annotations for each cell
annotations = []
for i, y_label in enumerate(y_labels):
for j, x_label in enumerate(x_labels):
annotations.append(
dict(
x=x_label,
y=y_label,
text=str(z[i][j]),
showarrow=False,
font=dict(color="black" if z[i][j] < z.max() / 2 else "white"),
)
)

# Update layout
fig.update_layout(
annotations=annotations,
title=dict_t,
xaxis=dict(
title=dict_xaxis,
tickangle=45,
tickmode="array" if x_numeric else "linear",
tickvals=[int(label) for label in x_labels] if x_numeric else None,
ticktext=x_labels if x_numeric else None,
),
yaxis=dict(
title=dict_yaxis,
autorange="reversed", # Reverse y-axis to match conventional confusion matrix
tickmode="array" if y_numeric else "linear",
tickvals=[int(label) for label in y_labels] if y_numeric else None,
ticktext=y_labels if y_numeric else None,
),
width=width,
height=height,
margin=dict(l=150, r=20, t=100, b=70),
)

if file_name:
plot(fig, filename=file_name, auto_open=auto_open)

return fig
Loading

0 comments on commit 18ffab6

Please sign in to comment.