Skip to content

Commit

Permalink
pass error context in the usual way
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Oct 23, 2024
1 parent 528d383 commit d51dd69
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 41 deletions.
38 changes: 18 additions & 20 deletions R/add_candidates.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ add_candidates.workflow_set <- function(data_stack, candidates,
cli_abort(
"The supplied workflow set must be fitted to resamples with
{.help [`workflow_map()`](workflowsets::workflow_map)} before being added to a data stack.",
call = caller_env(0),
class = "wf_set_unfitted"
)
}
Expand Down Expand Up @@ -161,18 +160,17 @@ add_candidates.default <- function(data_stack, candidates, name, ...) {
"The second argument to {.help [`add_candidates()`](stacks::add_candidates)} should inherit from one of
{.help [`tune_results`](tune::tune_grid)} or
{.help [`workflow_set`](workflowsets::workflow_set)}, but its class
is {.var {class(candidates)}}.",
call = caller_env(0)
is {.var {class(candidates)}}."
)
}

.set_outcome <- function(stack, candidates) {
.set_outcome <- function(stack, candidates, call = caller_env()) {
if (!.get_outcome(stack) %in% c("init_", tune::.get_tune_outcome_names(candidates))) {
cli_abort(
"The model definition you've tried to add to the stack has
outcome variable {.var {tune::.get_tune_outcome_names(candidates)}},
while the stack's outcome variable is {.var {.get_outcome(stack)}}.",
call = caller_env(1)
call = call
)
}

Expand All @@ -183,7 +181,7 @@ add_candidates.default <- function(data_stack, candidates, name, ...) {

# checks that the hash for the resampling object
# is appropriate and then sets it
.set_rs_hash <- function(stack, candidates, name) {
.set_rs_hash <- function(stack, candidates, name, call = caller_env()) {
new_hash <- tune::.get_fingerprint(candidates)

hash_matches <- .get_rs_hash(stack) %in% c("init_", new_hash)
Expand All @@ -192,7 +190,7 @@ add_candidates.default <- function(data_stack, candidates, name, ...) {
cli_abort(
"It seems like the new candidate member '{name}' doesn't make use
of the same resampling object as the existing candidates.",
call = caller_env()
call = call
)
}

Expand All @@ -214,7 +212,7 @@ add_candidates.default <- function(data_stack, candidates, name, ...) {
}

# note whether classification or regression
.set_mode_ <- function(stack, candidates, name) {
.set_mode_ <- function(stack, candidates, name, call = caller_env()) {
wf_spec <-
attr(candidates, "workflow") %>%
workflows::extract_spec_parsnip()
Expand All @@ -226,7 +224,7 @@ add_candidates.default <- function(data_stack, candidates, name, ...) {
cli_abort(
"The {.pkg stacks} package does not support stacking models with mode
{.val {new_mode}}.",
call = NULL
call = call
)
}

Expand All @@ -238,12 +236,12 @@ add_candidates.default <- function(data_stack, candidates, name, ...) {
# check to make sure that the supplied model def name
# doesn't have the same name or hash as an existing model def
# and then appends the model definition, hash, and metrics
.set_model_defs_candidates <- function(stack, candidates, name) {
.set_model_defs_candidates <- function(stack, candidates, name, call = caller_env()) {
if (name %in% .get_model_def_names(stack)) {
cli_abort(
"The new model definition has the
same name '{name}' as an existing model definition.",
call = caller_env(1)
call = call
)
}

Expand All @@ -263,7 +261,7 @@ add_candidates.default <- function(data_stack, candidates, name, ...) {
"The supplied candidates were tuned/fitted using only metrics that
rely on hard class predictions. Please tune/fit with at least one
class probability-based metric, such as {.help [`roc_auc`](yardstick::roc_auc)}.",
call = caller_env(1)
call = call
)
}
}
Expand All @@ -287,7 +285,7 @@ class_1 <- function(.x) {
# checks that the training data in a newly added candidate
# is the same is that from existing candidates, and sets the
# training data if the new candidate is the first in the stack
.set_training_data <- function(stack, candidates, name) {
.set_training_data <- function(stack, candidates, name, call = caller_env()) {
training_data <- attr(stack, "train")
new_data <- tibble::as_tibble(candidates[["splits"]][[1]][["data"]])

Expand All @@ -296,7 +294,7 @@ class_1 <- function(.x) {
cli_abort(
"The newly added candidate member, `{name}`,
uses different training data than the existing candidates.",
call = caller_env(1)
call = call
)
}

Expand Down Expand Up @@ -402,7 +400,7 @@ update_stack_data <- function(stack, new_data) {
)
}

check_add_data_stack <- function(data_stack) {
check_add_data_stack <- function(data_stack, call = caller_env()) {
if (rlang::inherits_any(
data_stack,
c("tune_results", "tune_bayes", "resample_results")
Expand All @@ -414,14 +412,14 @@ check_add_data_stack <- function(data_stack) {
If so, please supply the output of {.help [`stacks()`](stacks::stacks)} or another
{.help [`add_candidates()`](stacks::add_candidates)} call as
the argument to {.arg data_stack}.",
call = caller_env()
call = call
)
} else {
check_inherits(data_stack, "data_stack", call = caller_env())
}
}

check_candidates <- function(candidates, name) {
check_candidates <- function(candidates, name, call = caller_env()) {
if (nrow(tune::collect_notes(candidates)) != 0) {
cli_warn(
"The inputted {.arg candidates} argument {.var {name}} generated notes during
Expand All @@ -435,12 +433,12 @@ check_candidates <- function(candidates, name) {
cli_abort(
"The inputted {.arg candidates} argument was not generated with the
appropriate control settings. Please see {.help [`control_stack()`](stacks::control_stack)}.",
call = caller_env()
call = call
)
}
}

check_name <- function(name) {
check_name <- function(name, call = caller_env()) {
if (rlang::inherits_any(
name,
c("tune_results", "tune_bayes", "resample_results")
Expand All @@ -449,7 +447,7 @@ check_name <- function(name) {
"The inputted {.arg name} argument looks like a tuning/fitting results object
that might be supplied as a {.arg candidates} argument. Did you try to add
more than one set of candidates in one {.help [`add_candidates()`](stacks::add_candidates)} call?",
call = caller_env()
call = call
)
} else {
check_inherits(name, "character", call = caller_env())
Expand Down
22 changes: 11 additions & 11 deletions R/blend_predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -239,31 +239,31 @@ blend_predictions <- function(data_stack,
if (model_stack_constr(model_stack)) {model_stack}
}

check_regularization <- function(x, arg) {
check_regularization <- function(x, arg, call = caller_env()) {
if (!is.numeric(x)) {
cli_abort(
"The argument to '{arg}' must be a numeric, but the supplied {arg}'s
class is {.var {class(x)}}.",
call = caller_env()
call = call
)
}

if (length(x) == 0) {
cli_abort("Please supply one or more {arg} values.",
call = caller_env())
call = call)
}

if (arg == "penalty") {
if (any(x < 0)) {
cli_abort("Please supply only nonnegative values to the {arg} argument.",
call = caller_env())
call = call)
}
}

if (arg == "mixture") {
if (any(x < 0) || any(x > 1)) {
cli_abort("Please supply only values in [0, 1] to the {arg} argument.",
call = caller_env())
call = call)
}
}
}
Expand Down Expand Up @@ -329,33 +329,33 @@ safe_attr <- function(x, new_attr) {
res
}

check_blend_data_stack <- function(data_stack) {
check_blend_data_stack <- function(data_stack, call = caller_env()) {
# many possible checks we could do here are redundant with those we
# carry out in fit_members() -- just check for bare stacks, 1-candidate
# stacks, and non-stack objects
if (!inherits(data_stack, "data_stack")) {
check_inherits(data_stack, "data_stack", call = caller_env())
check_inherits(data_stack, "data_stack", call = call)
} else if (ncol(data_stack) == 0) {
cli_abort(
"The data stack supplied as the argument to `data_stack` has no
candidate members. Please first add candidates with
the {.help [`add_candidates()`](stacks::add_candidates)} function.",
call = caller_env()
call = call
)
} else if ((ncol(data_stack) == 2 && attr(data_stack, "mode") == "regression") ||
ncol(data_stack) == length(levels(data_stack[[1]])) + 1) {
cli_abort(
"The supplied data stack only contains one candidate member. Please
add more candidate members using
{.help [`add_candidates()`](stacks::add_candidates)} before blending.",
call = caller_env()
call = call
)
}

invisible(NULL)
}

process_data_stack <- function(data_stack) {
process_data_stack <- function(data_stack, call = caller_env()) {
dat <- tibble::as_tibble(data_stack) %>% na.omit()

# retain only the tbl_df attributes (#214)
Expand All @@ -367,7 +367,7 @@ process_data_stack <- function(data_stack) {
cli_abort(
"All rows in the data stack have at least one missing value.
Please ensure that all candidates supply predictions.",
call = caller_env()
call = call
)
}

Expand Down
7 changes: 3 additions & 4 deletions R/collect_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ collect_parameters <- function(stack, candidates, ...) {
collect_parameters.default <- function(stack, candidates, ...) {
cli_abort(
"There is no `collect_parameters()` method currently implemented
for {.var {class(stack)}} objects.",
call = caller_env(0)
for {.var {class(stack)}} objects."
)
}

Expand Down Expand Up @@ -155,13 +154,13 @@ collect_params <- function(cols_map, model_metrics, candidates, workflows, blend
res
}

check_for_candidates <- function(model_metrics, candidates) {
check_for_candidates <- function(model_metrics, candidates, call = caller_env()) {
if ((!inherits(candidates, "character")) ||
(!candidates %in% names(model_metrics))) {
cli_abort(
"The `candidates` argument to `collect_parameters()` must be the name
given to a set of candidates added with `add_candidates()`.",
call = caller_env()
call = call
)
}
}
Expand Down
4 changes: 2 additions & 2 deletions R/fit_members.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ sanitize_classification_names <- function(model_stack, member_names) {
}


check_model_stack <- function(model_stack) {
check_model_stack <- function(model_stack, call = caller_env()) {
if (inherits(model_stack, "model_stack")) {
if (!is.null(model_stack[["member_fits"]])) {
cli_warn(
Expand All @@ -221,7 +221,7 @@ check_model_stack <- function(model_stack) {
a model stack. Did you forget to first evaluate the ensemble's
stacking coefficients with
{.help [`blend_predictions()`](stacks::blend_predictions)}?",
call = caller_env()
call = call
)
} else {
check_inherits(model_stack, "model_stack", call = caller_env())
Expand Down
7 changes: 3 additions & 4 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ predict.data_stack <- function(object, ...) {
cli_abort(
"To predict with a stacked ensemble, the supplied data stack must be
evaluated with `blend_predictions()` and its member models fitted with
`fit_members()` to predict on new data.",
call = caller_env(0)
`fit_members()` to predict on new data."
)
}

Expand Down Expand Up @@ -190,12 +189,12 @@ parse_member_probs <- function(member_name, member_probs, levels) {
)
}

check_fitted <- function(model_stack) {
check_fitted <- function(model_stack, call = caller_env()) {
if (is.null(model_stack[["member_fits"]])) {
cli_abort(
"The supplied model stack hasn't been fitted yet.
Please fit the necessary members with fit_members() to predict on new data.",
call = caller_env()
call = call
)
}
}
Expand Down

0 comments on commit d51dd69

Please sign in to comment.