From 6dee6adafb352a9155c0e64b58d3c2eb62f91973 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Mon, 16 Dec 2024 11:49:24 -0300 Subject: [PATCH] Give users more control over search routine (#68) * add more control for users * lint * change tol value --- kulprit/projection/arviz_io.py | 8 ++++-- kulprit/projection/search_strategies.py | 16 ++++++++++- kulprit/projection/solver.py | 11 ++++++-- kulprit/projector.py | 36 ++++++++++++++++++++----- 4 files changed, 60 insertions(+), 11 deletions(-) diff --git a/kulprit/projection/arviz_io.py b/kulprit/projection/arviz_io.py index e498d99..5d58a00 100644 --- a/kulprit/projection/arviz_io.py +++ b/kulprit/projection/arviz_io.py @@ -9,10 +9,14 @@ def get_observed_data(idata, response_name): return convert_to_dataset(observed_data), idata.observed_data.get(response_name).values -def get_pps(model, idata, response_name, num_samples): - """Extract the posterior predictive samples from the reference model.""" +def compute_pps(model, idata): + """Compute posterior predictive samples from the reference model.""" if "posterior_predictive" not in idata.groups(): model.predict(idata, kind="response", inplace=True) + + +def get_pps(idata, response_name, num_samples): + """Extract samples posterior predictive samples from the reference model.""" pps = extract( idata, group="posterior_predictive", diff --git a/kulprit/projection/search_strategies.py b/kulprit/projection/search_strategies.py index 5108a3f..5b01bce 100644 --- a/kulprit/projection/search_strategies.py +++ b/kulprit/projection/search_strategies.py @@ -12,7 +12,7 @@ def user_path(_project, path): return submodels -def forward_search(_project, ref_terms, max_terms): +def forward_search(_project, ref_terms, max_terms, elpd_ref, early_stop): # initial intercept-only subset submodel_size = 0 term_names = [] @@ -35,6 +35,10 @@ def forward_search(_project, ref_terms, max_terms): # compute loo for the best candidate and update inplace compute_loo(submodel=submodel) + if early_stopping(submodel, elpd_ref, early_stop): + submodels.append(submodel) + break + # add best candidate to the list of selected submodels submodels.append(submodel) @@ -58,3 +62,13 @@ def get_candidates(prev_subset, ref_terms): candidate_additions = list(set(ref_terms).difference(prev_subset)) candidates = [prev_subset + [addition] for addition in candidate_additions] return candidates + + +def early_stopping(submodel, elpd_ref, early_stop): + if early_stop == "mean": + if elpd_ref.elpd_loo - submodel.elpd_loo <= 4: + return True + elif early_stop == "se": + if submodel.elpd_loo + submodel.elpd_se >= elpd_ref.elpd_loo: + return True + return False diff --git a/kulprit/projection/solver.py b/kulprit/projection/solver.py index 17a4b32..ecb526d 100644 --- a/kulprit/projection/solver.py +++ b/kulprit/projection/solver.py @@ -4,7 +4,7 @@ from scipy.optimize import minimize -def solve(model_log_likelihood, pps, initial_guess, var_info): +def solve(model_log_likelihood, pps, initial_guess, var_info, tolerance): """The primary projection method in the procedure. Parameters: @@ -28,12 +28,19 @@ def solve(model_log_likelihood, pps, initial_guess, var_info): posterior_dict = {} objectives = [] + opt = minimize( + model_log_likelihood, + args=(pps[0]), + x0=initial_guess, + ) + initial_guess = opt.x + for idx, obs in enumerate(pps): opt = minimize( model_log_likelihood, args=(obs), x0=initial_guess, - tol=0.001, + tol=tolerance, method="SLSQP", ) diff --git a/kulprit/projector.py b/kulprit/projector.py index 25452d1..2be1b3d 100644 --- a/kulprit/projector.py +++ b/kulprit/projector.py @@ -7,7 +7,7 @@ from bambi import formula from kulprit.plots.plots import plot_compare, plot_densities -from kulprit.projection.arviz_io import compute_loo, get_observed_data, get_pps +from kulprit.projection.arviz_io import compute_loo, get_observed_data, compute_pps, get_pps from kulprit.projection.pymc_io import ( compile_mllk, compute_llk, @@ -23,7 +23,7 @@ class ProjectionPredictive: Projection Predictive class from which we perform the model selection procedure. """ - def __init__(self, model, idata=None, num_samples=100): + def __init__(self, model, idata=None): """Builder for projection predictive model selection. This object initializes the reference model and handles the core projection, variable search @@ -60,10 +60,13 @@ def __init__(self, model, idata=None, num_samples=100): self.observed_dataset, self.observed_array = get_observed_data( self.idata, self.response_name ) - self.num_samples = num_samples - self.pps = get_pps(self.model, self.idata, self.response_name, self.num_samples) + self.num_samples = None + compute_pps(self.model, self.idata) self.elpd_ref = compute_loo(idata=self.idata) + self.tolerance = None + self.early_stop = None + self.pps = None self.list_of_submodels = [] def __repr__(self) -> str: @@ -79,7 +82,9 @@ def __repr__(self) -> str: ) return str_of_submodels - def project(self, max_terms=None, path="forward"): + def project( + self, max_terms=None, path="forward", num_samples=100, tolerance=0.01, early_stop=False + ): """Perform model projection. If ``max_terms`` is not provided, then the search path runs from the intercept-only model @@ -94,7 +99,23 @@ def project(self, max_terms=None, path="forward"): The search method to employ, either "forward" for a forward search, or "l1" for a L1-regularized search path. If a nested list of terms is provided, model with those terms will be projected directly. + num_samples : int + The number of samples to draw from the posterior predictive distribution for the + projection procedure. Defaults to 100. + tolerance : float + The tolerance for the optimization procedure. Defaults to 0.01 + early_stop : bool or str + Whether to stop the search when the difference in ELPD between the submodel and the + reference model is small. There are two criteria, "mean" and "se". The "mean" criterion + stops the search when the difference in ELPD is smaller than 4. The "se" criterion stops + the search when the ELPD of the submodel is within one standard error of the reference + model. Defaults to False. """ + self.num_samples = num_samples + self.tolerance = tolerance + self.early_stop = early_stop + self.pps = get_pps(self.idata, self.response_name, self.num_samples) + # test if path is a list of terms if isinstance(path, list): # check if the length of the path always increase @@ -124,7 +145,9 @@ def project(self, max_terms=None, path="forward"): max_terms = n_terms if path == "forward": - self.list_of_submodels = forward_search(self._project, self.ref_terms, max_terms) + self.list_of_submodels = forward_search( + self._project, self.ref_terms, max_terms, self.elpd_ref, self.early_stop + ) # else: # self.searcher_path = L1SearchPath(self.projector) @@ -146,6 +169,7 @@ def _project(self, term_names): self.pps, initial_guess, var_info, + self.tolerance, ) # restore obs_rvs value in the model new_model.rvs_to_values[obs_rvs] = old_y_value