Skip to content

Commit

Permalink
fix: adding loopsolver optional depencency + admm solver option
Browse files Browse the repository at this point in the history
  • Loading branch information
lachlangrose committed Jul 3, 2024
1 parent ab1fa90 commit 26edd3f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
34 changes: 33 additions & 1 deletion LoopStructural/interpolators/_discrete_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,8 @@ def build_inequality_matrix(self):
for c in self.ineq_constraints.values():
mats.append(c['matrix'])
bounds.append(c['bounds'])
if len(mats) == 0:
return None, None
Q = sparse.vstack(mats)
bounds = np.hstack(bounds)
return Q, bounds
Expand Down Expand Up @@ -514,6 +516,7 @@ def solve_system(
self.c = np.zeros(self.support.n_nodes)
self.c[:] = np.nan
A, b = self.build_matrix()
Q, bounds = self.build_inequality_matrix()
if callable(solver):
logger.warning('Using custom solver')
self.c = solver(A.tocsr(), b)
Expand All @@ -522,7 +525,7 @@ def solve_system(
return True
## solve with lsmr
if isinstance(solver, str):
if solver not in ['cg', 'lsmr']:
if solver not in ['cg', 'lsmr', 'admm']:
logger.warning(
f'Unknown solver {solver} using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
)
Expand Down Expand Up @@ -557,6 +560,35 @@ def solve_system(
self.up_to_date = True
logger.info("Interpolation took %f seconds" % (time() - starttime))
return True
elif solver == 'admm':
logger.info("Solving using admm")
if Q is None:
logger.warning("No inequality constraints, using lsmr")
return self.solve_system('lsmr', solver_kwargs)

try:
from loopsolver import admm_solve
except ImportError:
logger.warning(
"Cannot import admm solver. Please install loopsolver or use lsmr or cg"
)
return False
try:
res = admm_solve(
A,
b,
Q,
bounds,
x0=solver_kwargs.pop('x0', np.zeros(A.shape[1])),
admm_weight=solver_kwargs.pop('admm_weight', 0.01),
nmajor=solver_kwargs.pop('nmajor', 200),
linsys_solver_kwargs=solver_kwargs,
)
self.c = res
self.up_to_date = True
except ValueError as e:
logger.error(f"ADMM solver failed: {e}")
return False
return False

def update(self) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
dynamic = ['version']

[project.optional-dependencies]
all = ['loopstructural[visualisation,export]']
all = ['loopstructural[visualisation,inequalities,export]']
visualisation = [
"matplotlib",
"pyvista",
Expand All @@ -56,6 +56,8 @@ jupyter = [
"pyvista[all]",
"tqdm"
]
inequalities = [
"loopsolver"]
[project.urls]
Documentation = 'https://Loop3d.org/LoopStructural/'
"Bug Tracker" = 'https://github.com/loop3d/loopstructural/issues'
Expand Down

0 comments on commit 26edd3f

Please sign in to comment.