Skip to content

Commit

Permalink
Improve plot_compare (#66)
Browse files Browse the repository at this point in the history
* improve plots

* improve plots
  • Loading branch information
aloctavodia authored Dec 10, 2024
1 parent a3916e8 commit fc82d74
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 43 deletions.
60 changes: 45 additions & 15 deletions docs/examples/quick-start.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion kulprit/data/submodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def __repr__(self) -> str:
else:
intercept = []

return f"model_size {self.size}, terms {intercept + self.term_names}"
return f"{intercept + self.term_names}"
44 changes: 23 additions & 21 deletions kulprit/plots/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np


def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None):
def plot_compare(cmp_df, label_terms=None, legend=True, title=True, figsize=None, plot_kwargs=None):
"""
Plot model comparison.
Expand All @@ -14,6 +14,8 @@ def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None
Dataframe containing the comparison data. Should have columns
`elpd_loo` and `elpd_diff` containing the ELPD values and the
differences to the reference model.
label_terms : list
List of the labels for the submodels.
legend : bool
Flag for plotting the legend, default True.
title : bool
Expand All @@ -31,20 +33,18 @@ def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None

figsize, ax_labelsize, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, None, 1, 1)

xticks_pos, step = np.linspace(0, -1, ((cmp_df.shape[0]) * 2) - 2, retstep=True)
xticks_pos[1::2] = xticks_pos[1::2] - step * 1.5

labels = cmp_df.index.values[1:]
xticks_labels = [""] * len(xticks_pos)
xticks_labels[0] = labels[0]
xticks_labels[2::2] = labels[1:]
xticks_pos = np.linspace(0, 1, cmp_df.shape[0] - 1)[::-1]
xticks_num_labels = cmp_df.index.values[1:]
xticks_name_labels = [f"\n\n{term}" for term in label_terms[::-1]]
elpd_loo = cmp_df["elpd_loo"][1:].values
elpd_se = cmp_df["se"][1:].values

fig, axes = plt.subplots(1, figsize=figsize)

axes.errorbar(
y=cmp_df["elpd_loo"][1:],
x=xticks_pos[::2],
yerr=cmp_df.se[1:],
y=elpd_loo,
x=xticks_pos,
yerr=elpd_se,
label="Submodels",
color=plot_kwargs.get("color_eldp", "k"),
fmt=plot_kwargs.get("marker_eldp", "o"),
Expand All @@ -63,7 +63,7 @@ def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None
)

axes.fill_between(
[-2, 1],
[-0.15, 1.15],
cmp_df["elpd_loo"].iloc[0] + cmp_df["se"].iloc[0],
cmp_df["elpd_loo"].iloc[0] - cmp_df["se"].iloc[0],
alpha=0.1,
Expand All @@ -72,7 +72,7 @@ def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None

if legend:
fig.legend(
bbox_to_anchor=(0.9, 0.1),
bbox_to_anchor=(0.9, 0.3),
loc="lower right",
ncol=1,
fontsize=ax_labelsize * 0.6,
Expand All @@ -84,15 +84,18 @@ def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None
fontsize=ax_labelsize * 0.6,
)

# remove double ticks
xticks_pos, xticks_labels = xticks_pos[::2], xticks_labels[::2]
sec0 = axes.secondary_xaxis(location=0)
sec0.set_xticks(xticks_pos, xticks_num_labels)
sec0.tick_params("x", length=0, labelsize=xt_labelsize * 0.6)

sec1 = axes.secondary_xaxis(location=0)
sec1.set_xticks(xticks_pos, xticks_name_labels, rotation=plot_kwargs.get("xlabel_rotation", 0))
sec1.tick_params("x", length=0, labelsize=xt_labelsize * 0.6)
sec1.set_xlabel("Submodels", fontsize=ax_labelsize * 0.6)

# set axes
axes.set_xticks(xticks_pos)
axes.set_xticks([])
axes.set_ylabel("ELPD", fontsize=ax_labelsize * 0.6)
axes.set_xlabel("Submodel size", fontsize=ax_labelsize * 0.6)
axes.set_xticklabels(xticks_labels)
axes.set_xlim(-1 + step, 0 - step)
axes.set_xlim(-0.1, 1.1)
axes.tick_params(labelsize=xt_labelsize * 0.6)

return axes
Expand Down Expand Up @@ -130,7 +133,6 @@ def plot_densities(
if include_reference:
data = [idata]
l_labels = ["Reference"]
var_names.append(f"~{model.family.likelihood.parent}")
else:
data = []
l_labels = []
Expand Down
2 changes: 1 addition & 1 deletion kulprit/projection/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def project_names(self, term_names: Sequence[str]) -> SubModel:
model=new_model,
idata=new_idata,
loss=loss,
size=len(new_model.free_RVs) - len(self.base_terms),
size=len(term_names),
term_names=term_names,
has_intercept=self.has_intercept,
)
Expand Down
14 changes: 9 additions & 5 deletions kulprit/reference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=undefined-loop-variable
"""Core reference model class."""

from typing import Tuple, Union, Optional, List
Expand Down Expand Up @@ -197,7 +198,8 @@ def compare(
If None, size is (10, num of submodels) inches
plot_kwargs : dict
Optional arguments for plot elements. Currently accepts 'color_elpd', 'marker_elpd',
'marker_fc_elpd', 'color_dse', 'marker_dse', 'ls_reference', 'color_ls_reference'.
'marker_fc_elpd', 'color_dse', 'marker_dse', 'ls_reference', 'color_ls_reference',
'xlabel_rotation'.
Returns:
--------
Expand Down Expand Up @@ -237,9 +239,11 @@ def compare(
for k, submodel in self.searcher_path.k_submodel.items():
if k >= min_model_size:
self.searcher_idatas[k] = submodel.idata
self.searcher_idatas[
k + 1 # pylint: disable=undefined-loop-variable
] = self.projector.idata

self.searcher_idatas[k + 1] = self.projector.idata

label_terms = ["Intercept"] if self.has_intercept else []
label_terms.extend(submodel.term_names)

# compare the submodels using loo (other criteria may be added in the future)
comparison = az.compare(self.searcher_idatas)
Expand All @@ -248,7 +252,7 @@ def compare(
# plot the comparison if requested
axes = None
if plot:
axes = plot_compare(comparison, legend, title, figsize, **plot_kwargs)
axes = plot_compare(comparison, label_terms, legend, title, figsize, plot_kwargs)

return comparison, axes

Expand Down

0 comments on commit fc82d74

Please sign in to comment.