Skip to content

Commit

Permalink
Fix a few bugs in hysteresis implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mstoelzle committed May 2, 2024
1 parent 75a0460 commit d50f67d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions examples/simulate_planar_hsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# activate all strains (i.e. bending, shear, and axial)
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
consider_hysteresis = False
consider_hysteresis = True

params = PARAMS_FPU_HYSTERESIS_CONTROL if consider_hysteresis else PARAMS_FPU_CONTROL

Expand Down Expand Up @@ -249,7 +249,7 @@ def chi2u(chi: Array) -> Array:
inverse_kinematics_end_effector_fn,
dynamical_matrices_fn,
sys_helpers,
) = planar_hsa.factory(sym_exp_filepath, strain_selector)
) = planar_hsa.factory(sym_exp_filepath, strain_selector, consider_hysteresis=consider_hysteresis)

batched_forward_kinematics_virtual_backbone_fn = vmap(
forward_kinematics_virtual_backbone_fn, in_axes=(None, None, 0), out_axes=-1
Expand Down Expand Up @@ -298,7 +298,7 @@ def chi2u(chi: Array) -> Array:
x0 = jnp.zeros((2 * q0.shape[0],)) # initial condition
x0 = x0.at[: q0.shape[0]].set(q0) # set initial configuration

ode_fn = planar_hsa.ode_factory(dynamical_matrices_fn, params)
ode_fn = planar_hsa.ode_factory(dynamical_matrices_fn, params, consider_hysteresis=consider_hysteresis)
ode_term = ODETerm(ode_fn)

sol = diffeqsolve(
Expand Down
8 changes: 4 additions & 4 deletions src/jsrm/systems/planar_hsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import numpy as jnp
import sympy as sp
from pathlib import Path
from typing import Callable, Dict, List, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

from .utils import (
concatenate_params_syms,
Expand Down Expand Up @@ -559,7 +559,7 @@ def dynamical_matrices_fn(
params: Dict[str, Array],
q: Array,
q_d: Array,
z: Array = None,
z: Optional[Array] = None,
phi: Array = jnp.zeros((num_segments * num_rods_per_segment,)),
eps: float = 1e4 * global_eps,
) -> Tuple[Array, Array, Array, Array, Array, Array]:
Expand Down Expand Up @@ -750,8 +750,8 @@ def ode_fn(t: float, x: Array, u: Array) -> Array:
n_q = (x.shape[0] - n_z) // 2
q, q_d, z = x[:n_q], x[n_q:2*n_q], x[2*n_q:]

z_d = (B_z.T @ x_d) * (
hys_params["A"] - jnp.abs(z)**hys_params["n"] * (hys_params["gamma"] + hys_params["beta"] * jnp.sign((B_z.T @ x_d) * z))
z_d = (B_z.T @ q_d) * (
hys_params["A"] - jnp.abs(z)**hys_params["n"] * (hys_params["gamma"] + hys_params["beta"] * jnp.sign((B_z.T @ q_d) * z))
)
else:
n_q = x.shape[0] // 2
Expand Down

0 comments on commit d50f67d

Please sign in to comment.