diff --git a/examples/simulate_planar_hsa.py b/examples/simulate_planar_hsa.py index 293ce84..72bc033 100644 --- a/examples/simulate_planar_hsa.py +++ b/examples/simulate_planar_hsa.py @@ -26,7 +26,7 @@ # activate all strains (i.e. bending, shear, and axial) strain_selector = jnp.ones((3 * num_segments,), dtype=bool) -consider_hysteresis = False +consider_hysteresis = True params = PARAMS_FPU_HYSTERESIS_CONTROL if consider_hysteresis else PARAMS_FPU_CONTROL @@ -249,7 +249,7 @@ def chi2u(chi: Array) -> Array: inverse_kinematics_end_effector_fn, dynamical_matrices_fn, sys_helpers, - ) = planar_hsa.factory(sym_exp_filepath, strain_selector) + ) = planar_hsa.factory(sym_exp_filepath, strain_selector, consider_hysteresis=consider_hysteresis) batched_forward_kinematics_virtual_backbone_fn = vmap( forward_kinematics_virtual_backbone_fn, in_axes=(None, None, 0), out_axes=-1 @@ -298,7 +298,7 @@ def chi2u(chi: Array) -> Array: x0 = jnp.zeros((2 * q0.shape[0],)) # initial condition x0 = x0.at[: q0.shape[0]].set(q0) # set initial configuration - ode_fn = planar_hsa.ode_factory(dynamical_matrices_fn, params) + ode_fn = planar_hsa.ode_factory(dynamical_matrices_fn, params, consider_hysteresis=consider_hysteresis) ode_term = ODETerm(ode_fn) sol = diffeqsolve( diff --git a/src/jsrm/systems/planar_hsa.py b/src/jsrm/systems/planar_hsa.py index e0713fe..88f7683 100644 --- a/src/jsrm/systems/planar_hsa.py +++ b/src/jsrm/systems/planar_hsa.py @@ -3,7 +3,7 @@ from jax import numpy as jnp import sympy as sp from pathlib import Path -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from .utils import ( concatenate_params_syms, @@ -559,7 +559,7 @@ def dynamical_matrices_fn( params: Dict[str, Array], q: Array, q_d: Array, - z: Array = None, + z: Optional[Array] = None, phi: Array = jnp.zeros((num_segments * num_rods_per_segment,)), eps: float = 1e4 * global_eps, ) -> Tuple[Array, Array, Array, Array, Array, Array]: @@ -750,8 +750,8 @@ def ode_fn(t: float, x: Array, u: Array) -> Array: n_q = (x.shape[0] - n_z) // 2 q, q_d, z = x[:n_q], x[n_q:2*n_q], x[2*n_q:] - z_d = (B_z.T @ x_d) * ( - hys_params["A"] - jnp.abs(z)**hys_params["n"] * (hys_params["gamma"] + hys_params["beta"] * jnp.sign((B_z.T @ x_d) * z)) + z_d = (B_z.T @ q_d) * ( + hys_params["A"] - jnp.abs(z)**hys_params["n"] * (hys_params["gamma"] + hys_params["beta"] * jnp.sign((B_z.T @ q_d) * z)) ) else: n_q = x.shape[0] // 2