Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor, leaning more on PyMC and less on Bambi #64

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 32 additions & 45 deletions docs/examples/quick-start.ipynb

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions kulprit/data/submodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import arviz
import bambi
import bambi as bmb


@dataclass
Expand All @@ -15,19 +15,22 @@ class SubModel:
extract a built pymc model.
idata (InferenceData): The inference data object of the submodel containing the
projected posterior draws and log-likelihood.
loss (float): The final loss (negative log-likelihood) of the submodel following
projection predictive inference
size (int): The number of common terms in the model, not including the intercept
term_names (list): The names of the terms in the model, including the intercept
"""

model: bambi.models.Model
model: bmb.models.Model
idata: arviz.InferenceData
loss: float
size: int
term_names: list
has_intercept: bool

def __repr__(self) -> str:
"""String representation of the submodel."""
if self.has_intercept:
intercept = ["Intercept"]
else:
intercept = []

return ", ".join([self.model.formula.main] + list(self.model.formula.additionals))
return f"model_size {self.size}, terms {intercept + self.term_names}"
32 changes: 9 additions & 23 deletions kulprit/plots/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None
plot_kwargs = {}

if figsize is None:
figsize = (len(cmp_df) - 1, 10)
figsize = (10, 4)

figsize, ax_labelsize, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, None, 1, 1)

Expand Down Expand Up @@ -105,26 +105,27 @@ def plot_densities(
var_names=None,
submodels=None,
include_reference=True,
labels="formula",
labels="size",
kind="density",
figsize=None,
plot_kwargs=None,
):
"""Compare the projected posterior densities of the submodels"""

if plot_kwargs is None:
plot_kwargs = {}
plot_kwargs.setdefault("figsize", figsize)

if kind not in ["density", "forest"]:
raise ValueError("kind must be one of 'density' or 'forest'")

if submodels is None:
submodels = path.values()
else:
submodels = [path[key] for key in submodels]

# set default variable names to the reference model terms
if not var_names:
var_names = list(
set(model.components[model.family.likelihood.parent].common_terms.keys())
- set([model.response_component.term.name])
)
var_names = [fvar.name for fvar in model.backend.model.free_RVs]

if include_reference:
data = [idata]
Expand All @@ -134,13 +135,8 @@ def plot_densities(
data = []
l_labels = []

if submodels is None:
submodels = path.values()
else:
submodels = [path[key] for key in submodels]

if labels == "formula":
l_labels.extend([submodel.model.formula for submodel in submodels])
l_labels.extend([",".join(submodel.term_names) for submodel in submodels])
else:
l_labels.extend([submodel.size for submodel in submodels])

Expand All @@ -167,13 +163,3 @@ def plot_densities(
)

return axes


def align_yaxis(axes, v_1, ax2, v_2):
"""adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in axes"""
_, y_1 = axes.transData.transform((0, v_1))
_, y_2 = ax2.transData.transform((0, v_2))
inv = ax2.transData.inverted()
_, d_y = inv.transform((0, 0)) - inv.transform((0, y_1 - y_2))
miny, maxy = ax2.get_ylim()
ax2.set_ylim(miny + d_y, maxy + d_y)
210 changes: 74 additions & 136 deletions kulprit/projection/projector.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
# pylint: disable=too-many-instance-attributes
"""Base projection class."""

import copy
from typing import Optional, List, Union, Sequence, Dict
import collections
from typing import Optional, Sequence, Union

import arviz as az
import bambi as bmb
from xarray_einstats.stats import XrContinuousRV, XrDiscreteRV

from scipy import stats

import numpy as np

from kulprit.data.submodel import SubModel
from kulprit.projection.solver import Solver
from kulprit.projection.pymc_io import (
compile_mllk,
compute_llk,
compute_new_model,
get_model_information,
)
from kulprit.projection.solver import solve


class Projector:
def __init__(
self,
model: bmb.Model,
idata: az.InferenceData,
has_intercept: bool,
num_samples: int,
path: Optional[dict] = None,
) -> None:
"""Reference model builder for projection predictive model selection.
Expand All @@ -44,14 +48,22 @@ def __init__(
# log reference model and reference inference data object
self.model = model
self.idata = idata
self.has_intercept = has_intercept
self.num_samples = num_samples

# log properties of the reference model
# log properties of the reference Bambi model
self.response_name = model.response_component.term.name
self.ref_family = self.model.family.name
self.priors = self.model.constant_components

# build solver
self.solver = Solver(model=self.model, idata=self.idata)
self.observed_data = self.get_observed_data()
self.pps = self.get_pps()
self.base_terms = self.get_base_terms()

# log properties of the reference PyMC model
self.pymc_model = model.backend.model
self.all_terms = [fvar.name for fvar in self.pymc_model.free_RVs]
self.ref_var_info = get_model_information(self.pymc_model)

# log search path
self.path = path
Expand Down Expand Up @@ -108,10 +120,12 @@ def project(
raise UserWarning("Please pass either a list, tuple, or integer.")

def project_names(self, term_names: Sequence[str]) -> SubModel:
"""Primary projection method for GLM reference model.
"""Primary projection method for reference model.

The projection is defined as the values of the submodel parameters minimizing the
Kullback-Leibler divergence between the submodel and the reference model.
This is achieved by maximizing the log-likelihood of the submodel wrt the predictions of
the reference model.

Parameters:
----------
Expand All @@ -124,150 +138,74 @@ def project_names(self, term_names: Sequence[str]) -> SubModel:
kulprit.data.ModelData: Projected submodel ``ModelData`` object
"""

# copy term names to avoid mutating the input
term_names_ = copy.copy(term_names)

# build restricted bambi model
new_model = self._build_restricted_model(term_names=term_names_)

# extract the design matrix from the model
d_component = new_model.distributional_components[new_model.family.likelihood.parent]

if d_component:
X = d_component.design.common.design_matrix
slices = d_component.design.common.slices

# Add offset columns to their own design matrix
# Remove them from the common design matrix.
if hasattr(new_model, "offset_terms"): # pragma: no cover
for term in new_model.offset_terms:
term_slice = slices[term]
X = np.delete(X, term_slice, axis=1)

# build new term_names (add dispersion parameter if included)
term_names_, slices = self._extend_term_names(
new_model=new_model,
term_names=term_names_,
slices=slices,
term_names_ = self.base_terms + term_names
new_model = compute_new_model(
self.pymc_model, self.ref_var_info, self.all_terms, term_names_
)

# compute projected posterior
projected_posterior, loss = self.solver.solve(term_names=term_names_, X=X, slices=slices)

# add observed data component of projected idata
observed_data = {
self.response_name: self.idata.observed_data.get(self.response_name)
.to_dict()
.get("data")
}

# build idata object for the projected model
new_idata = az.InferenceData(
posterior=projected_posterior,
observed_data=az.convert_to_dataset(observed_data),
model_log_likelihood, old_y_value, obs_rvs = compile_mllk(new_model)
initial_guess = np.concatenate(
[np.ravel(value) for value in new_model.initial_point().values()]
)
var_info = get_model_information(new_model)

# compute the log-likelihood of the new submodel and add to idata
log_likelihood = self.compute_model_log_likelihood(model=new_model, idata=new_idata)
new_idata.add_groups(
log_likelihood={self.response_name: log_likelihood},
dims={self.response_name: [f"{self.response_name}_dim_0"]},
new_idata, loss = solve(
model_log_likelihood,
self.pps,
initial_guess,
var_info,
)
# restore obs_rvs value in the model
new_model.rvs_to_values[obs_rvs] = old_y_value

# Add observed data to the projected InferenceData object
new_idata.add_groups(observed_data=self.observed_data)
# Add log-likelihood to the projected InferenceData object
new_idata.add_groups(log_likelihood=compute_llk(new_idata, new_model))

# build SubModel object and return
sub_model = SubModel(
model=new_model,
idata=new_idata,
loss=loss,
size=len(new_model.components[new_model.family.likelihood.parent].common_terms),
size=len(new_model.free_RVs) - len(self.base_terms),
term_names=term_names,
has_intercept=self.has_intercept,
)
return sub_model

def compute_model_log_likelihood(self, model, idata):
# extract observed data
obs_array = self.idata.observed_data[model.response_component.term.name]
obs_array = obs_array.expand_dims(
chain=idata.posterior.dims["chain"],
draw=idata.posterior.dims["draw"],
)

preds = model.predict(idata, kind="response_params", inplace=False).posterior[
f"{model.family.likelihood.parent}"
]
if model.family.name == "gaussian":
# initialise probability distribution object
dist = XrContinuousRV(
stats.norm,
preds.values,
idata.posterior["sigma"],
)
elif model.family.name == "binomial":
# initialise probability distribution object
dist = XrDiscreteRV(
stats.binom,
n=model.response.data[:, 1],
p=preds.values,
)
elif model.family.name == "bernoulli":
# initialise probability distribution object
dist = XrDiscreteRV(
stats.bernoulli,
p=preds.values,
)
elif model.family.name == "poisson":
# initialise probability distribution object
dist = XrDiscreteRV(
stats.poisson,
mu=preds.values,
)
else:
raise NotImplementedError(f"The {model.family.name} family is not yet implemented.")

# compute log likelihood of model
if isinstance(dist, XrContinuousRV):
log_likelihood = dist.logpdf(obs_array).transpose( # pylint: disable=no-member
*("chain", "draw", ...)
)
else:
log_likelihood = dist.logpmf(obs_array).transpose( # pylint: disable=no-member
*("chain", "draw", ...)
)
return log_likelihood

def _build_restricted_formula(self, term_names: List[str]) -> str:
"""Build the formula for the restricted model."""

formula = (
f"{self.response_name} ~ " + " + ".join(term_names)
if len(term_names) > 0
else f"{self.response_name} ~ 1"
)
return formula

def _build_restricted_model(self, term_names: List[str]) -> bmb.Model:
"""Build the restricted model in Bambi."""

new_formula = self._build_restricted_formula(term_names=term_names)
new_model = bmb.Model(new_formula, self.model.data, family=self.ref_family)
return new_model

def _extend_term_names(
self,
new_model: bmb.Model,
term_names: List[str],
slices: Dict[str, slice],
) -> List[str]:
def get_observed_data(self):
"""Extract the observed data from the reference model."""
observed_data = {
self.response_name: self.idata.observed_data.get(self.response_name)
.to_dict()
.get("data")
}
return az.convert_to_dataset(observed_data)

def get_pps(self):
"""Extract the posterior predictive samples from the reference model."""
if "posterior_predictive" not in self.idata.groups():
self.model.predict(self.idata, kind="response", inplace=True)

pps = az.extract(
self.idata,
group="posterior_predictive",
var_names=[self.response_name],
num_samples=self.num_samples,
).values.T
return pps

def get_base_terms(self):
"""Extend the model term names to include dispersion terms."""

base_terms = []
# add intercept term if present
if bmb.formula.formula_has_intercept(new_model.formula.main):
term_names.insert(0, "Intercept")
if self.has_intercept:
base_terms.append("Intercept")

# add the auxiliary parameters
if self.priors:
aux_params = [f"{str(k)}" for k in self.priors]
term_names += aux_params
# TODO generalize #pylint: disable=fixme
slices[aux_params[0]] = slice(-1, None, None)
return term_names, slices
base_terms += aux_params
return base_terms
Loading
Loading