Skip to content

Commit

Permalink
STY: more typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Sep 27, 2023
1 parent e4ff4c4 commit c2c4fd3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
7 changes: 5 additions & 2 deletions src/sparselm/model/_adaptive_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,16 @@ def _validate_params(self, X: NDArray, y: NDArray) -> None:
def _set_param_values(self) -> None:
"""Set parameter values."""
super()._set_param_values()
assert self.canonicals_.parameters is not None
length = len(self.canonicals_.parameters.adaptive_weights.value)
self.canonicals_.parameters.adaptive_weights.value = self.alpha * np.ones(
length
)

def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace | None:
def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace:
"""Generate parameters for the problem."""
parameters = super()._generate_params(X, y)
assert parameters is not None
parameters.adaptive_weights = cp.Parameter(
shape=X.shape[1], nonneg=True, value=self.alpha * np.ones(X.shape[1])
)
Expand Down Expand Up @@ -335,7 +337,7 @@ def __init__(
**kwargs,
)

def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace | None:
def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace:
# skip AdaptiveLasso in super
parameters = super(AdaptiveLasso, self)._generate_params(X, y)
n_groups = X.shape[1] if self.groups is None else len(np.unique(self.groups))
Expand All @@ -353,6 +355,7 @@ def _generate_regularization(
parameters: SimpleNamespace,
auxiliaries: SimpleNamespace | None = None,
) -> cp.Expression:
assert auxiliaries is not None
return parameters.adaptive_weights @ auxiliaries.group_norms

def _iterative_update(
Expand Down
2 changes: 1 addition & 1 deletion src/sparselm/model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _set_param_values(self) -> None:
value = np.asarray(value)
cvx_parameter.value = value

def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace | None:
def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace:
"""Return the named tuple of cvxpy parameters for optimization problem.
The cvxpy Parameters must be given values when generating.
Expand Down
15 changes: 9 additions & 6 deletions src/sparselm/model/_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _set_param_values(self) -> None:
if self.group_weights is not None:
self.canonicals_.parameters.group_weights = self.group_weights # type: ignore

def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace | None:
def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace:
parameters = super()._generate_params(X, y)
n_groups = X.shape[1] if self.groups is None else len(np.unique(self.groups))
group_weights = (
Expand Down Expand Up @@ -387,7 +387,7 @@ def _validate_params(self, X: NDArray, y: NDArray) -> None:

_check_group_weights(self.group_weights, n_groups)

def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace | None:
def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace:
parameters = super()._generate_params(X, y)

if self.group_list is None:
Expand Down Expand Up @@ -608,13 +608,13 @@ def _set_param_values(self) -> None:
self.canonicals_.parameters.lambda1.value = self.l1_ratio * self.alpha # type: ignore
self.canonicals_.parameters.lambda2.value = (1 - self.l1_ratio) * self.alpha # type: ignore

def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace | None:
def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace:
"""Generate parameters."""
parameters = super()._generate_params(X, y)
# save for information purposes
parameters.l1_ratio = self.l1_ratio # type: ignore
parameters.lambda1 = cp.Parameter(nonneg=True, value=self.l1_ratio * self.alpha)
parameters.lambda2 = cp.Parameter(
parameters.lambda1 = cp.Parameter(nonneg=True, value=self.l1_ratio * self.alpha) # type: ignore
parameters.lambda2 = cp.Parameter( # type: ignore
nonneg=True, value=(1 - self.l1_ratio) * self.alpha
)
return parameters
Expand All @@ -626,6 +626,7 @@ def _generate_regularization(
parameters: SimpleNamespace,
auxiliaries: SimpleNamespace | None = None,
) -> cp.Expression:
assert auxiliaries is not None
group_regularization = parameters.lambda2 * (
parameters.group_weights @ auxiliaries.group_norms
)
Expand Down Expand Up @@ -746,7 +747,7 @@ def _validate_params(self, X: NDArray, y: NDArray) -> None:
f"delta must be an array of length 1 or equal to the number of groups {n_groups}."
)

def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace | None:
def _generate_params(self, X: NDArray, y: NDArray) -> SimpleNamespace:
"""Generate parameters."""
parameters = super()._generate_params(X, y)
# force cvxpy delta to be an array of n_groups!
Expand All @@ -768,6 +769,7 @@ def _generate_group_norms(
) -> cp.Expression:
group_masks = [groups == i for i in np.sort(np.unique(groups))]
if standardize:
assert parameters is not None
group_norms = cp.hstack(
[
cp.norm2(
Expand All @@ -793,6 +795,7 @@ def _generate_regularization(
auxiliaries: SimpleNamespace | None = None,
) -> cp.Expression:
# repetitive code...
assert auxiliaries is not None
groups = np.arange(X.shape[1]) if self.groups is None else self.groups
group_masks = [groups == i for i in np.sort(np.unique(groups))]
ridge = cp.hstack([cp.sum_squares(beta[mask]) for mask in group_masks])
Expand Down
1 change: 1 addition & 0 deletions src/sparselm/model/_miqp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _generate_constraints(
auxiliaries: SimpleNamespace | None = None,
) -> list[cp.Constraint]:
"""Generate the constraints used to solve l0 regularization."""
assert auxiliaries is not None and parameters is not None
groups = np.arange(X.shape[1]) if self.groups is None else self.groups
group_masks = [groups == i for i in np.sort(np.unique(groups))]
constraints = []
Expand Down

0 comments on commit c2c4fd3

Please sign in to comment.