Skip to content

Commit

Permalink
Give users more control over search routine (#68)
Browse files Browse the repository at this point in the history
* add more control for users

* lint

* change tol value
  • Loading branch information
aloctavodia authored Dec 16, 2024
1 parent 127bd6a commit 6dee6ad
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 11 deletions.
8 changes: 6 additions & 2 deletions kulprit/projection/arviz_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ def get_observed_data(idata, response_name):
return convert_to_dataset(observed_data), idata.observed_data.get(response_name).values


def get_pps(model, idata, response_name, num_samples):
"""Extract the posterior predictive samples from the reference model."""
def compute_pps(model, idata):
"""Compute posterior predictive samples from the reference model."""
if "posterior_predictive" not in idata.groups():
model.predict(idata, kind="response", inplace=True)


def get_pps(idata, response_name, num_samples):
"""Extract samples posterior predictive samples from the reference model."""
pps = extract(
idata,
group="posterior_predictive",
Expand Down
16 changes: 15 additions & 1 deletion kulprit/projection/search_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def user_path(_project, path):
return submodels


def forward_search(_project, ref_terms, max_terms):
def forward_search(_project, ref_terms, max_terms, elpd_ref, early_stop):
# initial intercept-only subset
submodel_size = 0
term_names = []
Expand All @@ -35,6 +35,10 @@ def forward_search(_project, ref_terms, max_terms):
# compute loo for the best candidate and update inplace
compute_loo(submodel=submodel)

if early_stopping(submodel, elpd_ref, early_stop):
submodels.append(submodel)
break

# add best candidate to the list of selected submodels
submodels.append(submodel)

Expand All @@ -58,3 +62,13 @@ def get_candidates(prev_subset, ref_terms):
candidate_additions = list(set(ref_terms).difference(prev_subset))
candidates = [prev_subset + [addition] for addition in candidate_additions]
return candidates


def early_stopping(submodel, elpd_ref, early_stop):
if early_stop == "mean":
if elpd_ref.elpd_loo - submodel.elpd_loo <= 4:
return True
elif early_stop == "se":
if submodel.elpd_loo + submodel.elpd_se >= elpd_ref.elpd_loo:
return True
return False
11 changes: 9 additions & 2 deletions kulprit/projection/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scipy.optimize import minimize


def solve(model_log_likelihood, pps, initial_guess, var_info):
def solve(model_log_likelihood, pps, initial_guess, var_info, tolerance):
"""The primary projection method in the procedure.
Parameters:
Expand All @@ -28,12 +28,19 @@ def solve(model_log_likelihood, pps, initial_guess, var_info):
posterior_dict = {}
objectives = []

opt = minimize(
model_log_likelihood,
args=(pps[0]),
x0=initial_guess,
)
initial_guess = opt.x

for idx, obs in enumerate(pps):
opt = minimize(
model_log_likelihood,
args=(obs),
x0=initial_guess,
tol=0.001,
tol=tolerance,
method="SLSQP",
)

Expand Down
36 changes: 30 additions & 6 deletions kulprit/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bambi import formula
from kulprit.plots.plots import plot_compare, plot_densities

from kulprit.projection.arviz_io import compute_loo, get_observed_data, get_pps
from kulprit.projection.arviz_io import compute_loo, get_observed_data, compute_pps, get_pps
from kulprit.projection.pymc_io import (
compile_mllk,
compute_llk,
Expand All @@ -23,7 +23,7 @@ class ProjectionPredictive:
Projection Predictive class from which we perform the model selection procedure.
"""

def __init__(self, model, idata=None, num_samples=100):
def __init__(self, model, idata=None):
"""Builder for projection predictive model selection.
This object initializes the reference model and handles the core projection, variable search
Expand Down Expand Up @@ -60,10 +60,13 @@ def __init__(self, model, idata=None, num_samples=100):
self.observed_dataset, self.observed_array = get_observed_data(
self.idata, self.response_name
)
self.num_samples = num_samples
self.pps = get_pps(self.model, self.idata, self.response_name, self.num_samples)
self.num_samples = None
compute_pps(self.model, self.idata)
self.elpd_ref = compute_loo(idata=self.idata)

self.tolerance = None
self.early_stop = None
self.pps = None
self.list_of_submodels = []

def __repr__(self) -> str:
Expand All @@ -79,7 +82,9 @@ def __repr__(self) -> str:
)
return str_of_submodels

def project(self, max_terms=None, path="forward"):
def project(
self, max_terms=None, path="forward", num_samples=100, tolerance=0.01, early_stop=False
):
"""Perform model projection.
If ``max_terms`` is not provided, then the search path runs from the intercept-only model
Expand All @@ -94,7 +99,23 @@ def project(self, max_terms=None, path="forward"):
The search method to employ, either "forward" for a forward search, or "l1" for
a L1-regularized search path. If a nested list of terms is provided, model with
those terms will be projected directly.
num_samples : int
The number of samples to draw from the posterior predictive distribution for the
projection procedure. Defaults to 100.
tolerance : float
The tolerance for the optimization procedure. Defaults to 0.01
early_stop : bool or str
Whether to stop the search when the difference in ELPD between the submodel and the
reference model is small. There are two criteria, "mean" and "se". The "mean" criterion
stops the search when the difference in ELPD is smaller than 4. The "se" criterion stops
the search when the ELPD of the submodel is within one standard error of the reference
model. Defaults to False.
"""
self.num_samples = num_samples
self.tolerance = tolerance
self.early_stop = early_stop
self.pps = get_pps(self.idata, self.response_name, self.num_samples)

# test if path is a list of terms
if isinstance(path, list):
# check if the length of the path always increase
Expand Down Expand Up @@ -124,7 +145,9 @@ def project(self, max_terms=None, path="forward"):
max_terms = n_terms

if path == "forward":
self.list_of_submodels = forward_search(self._project, self.ref_terms, max_terms)
self.list_of_submodels = forward_search(
self._project, self.ref_terms, max_terms, self.elpd_ref, self.early_stop
)
# else:
# self.searcher_path = L1SearchPath(self.projector)

Expand All @@ -146,6 +169,7 @@ def _project(self, term_names):
self.pps,
initial_guess,
var_info,
self.tolerance,
)
# restore obs_rvs value in the model
new_model.rvs_to_values[obs_rvs] = old_y_value
Expand Down

0 comments on commit 6dee6ad

Please sign in to comment.