Skip to content

Commit

Permalink
Add option to use simplified actuation matrix for pneumatic actuatino
Browse files Browse the repository at this point in the history
  • Loading branch information
mstoelzle committed Nov 28, 2024
1 parent 96d4053 commit 2daa3b1
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 29 deletions.
1 change: 0 additions & 1 deletion examples/demo_planar_hsa_motor2ee_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
jax.config.update("jax_enable_x64", True) # double precision
from jax import Array, jacfwd, jacrev, jit, random, vmap
from jax import numpy as jnp
from jaxopt import GaussNewton, LevenbergMarquardt
from functools import partial
import numpy as onp
from pathlib import Path
Expand Down
5 changes: 4 additions & 1 deletion examples/simulate_pneumatic_planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
strain_selector = jnp.array([True, False, True])[None, :].repeat(num_segments, axis=0).flatten()

B_xi, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
pneumatic_planar_pcs.factory(num_segments, sym_exp_filepath, strain_selector)
pneumatic_planar_pcs.factory(
num_segments, sym_exp_filepath, strain_selector, # simplified_actuation_mapping=True
)
)
# jit the functions
dynamical_matrices_fn = jax.jit(dynamical_matrices_fn)
Expand All @@ -57,6 +59,7 @@
forward_kinematics_fn,
auxiliary_fns["jacobian_fn"],
)
print("A=", actuation_mapping_fn(params, B_xi, jnp.zeros((2 * num_segments,))))


def sweep_actuation_mapping():
Expand Down
10 changes: 5 additions & 5 deletions src/jsrm/systems/planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def stiffness_fn(
B_xi: Strain basis matrix
formulate_in_strain_space: whether to formulate the elastic matrix in the strain space
Returns:
K: elastic matrix of shape (n_q, n_q) if formulate_in_strain_space is False or (n_xi, n_xi) otherwise
S: elastic matrix of shape (n_q, n_q) if formulate_in_strain_space is False or (n_xi, n_xi) otherwise
"""
# length of the segments
l = params["l"]
Expand All @@ -227,14 +227,14 @@ def stiffness_fn(
# elastic and shear modulus
E, G = params["E"], params["G"]
# stiffness matrix of shape (num_segments, 3, 3)
S = compute_stiffness_matrix_for_all_segments_fn(l, A, Ib, E, G)
S_sms = compute_stiffness_matrix_for_all_segments_fn(l, A, Ib, E, G)
# we define the elastic matrix of shape (n_xi, n_xi) as K(xi) = K @ xi where K is equal to
K = blk_diag(S)
S = blk_diag(S_sms)

if not formulate_in_strain_space:
K = B_xi.T @ K @ B_xi
S = B_xi.T @ S @ B_xi

return K
return S

if actuation_mapping_fn is None:
def actuation_mapping_fn(
Expand Down
53 changes: 31 additions & 22 deletions src/jsrm/systems/pneumatic_planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def factory(
num_segments: int,
*args,
segment_actuation_selector: Optional[Array] = None,
simplified_actuation_mapping: bool = False,
**kwargs
):
"""
Expand All @@ -19,6 +20,7 @@ def factory(
num_segments: number of segments
segment_actuation_selector: actuation selector for the segments as boolean array of shape (num_segments,)
True entries signify that the segment is actuated, False entries signify that the segment is passive
simplified_actuation_mapping: flag to use a simplified actuation mapping (i.e., a constant actuation matrix)
Returns:
"""
if segment_actuation_selector is None:
Expand Down Expand Up @@ -102,22 +104,29 @@ def compute_actuation_matrix_for_segment(
2 / 3 * jnp.sinc(0.5 * varphi_cham) * (r_cham_out ** 3 - r_cham_in ** 3) / (r_cham_out ** 2 - r_cham_in ** 2)
)

# compute the actuation matrix that collects the contributions of the pneumatic chambers in the given segment
# first we consider the contribution of the distal end
A_sm_de = J_de.T @ jnp.array([
[-2 * A_cham * jnp.sin(th_de), -2 * A_cham * jnp.sin(th_de)],
[2 * A_cham * jnp.cos(th_de), 2 * A_cham * jnp.cos(th_de)],
[A_cham * r_cop, -A_cham * r_cop]
])
# then, we consider the contribution of the proximal end
A_sm_pe = J_pe.T @ jnp.array([
[2 * A_cham * jnp.sin(th_pe), 2 * A_cham * jnp.sin(th_pe)],
[-2 * A_cham * jnp.cos(th_pe), -2 * A_cham * jnp.cos(th_pe)],
[-A_cham * r_cop, A_cham * r_cop]
])

# sum the contributions of the distal and proximal ends
A_sm = A_sm_de + A_sm_pe
if simplified_actuation_mapping:
A_sm = B_xi.T @ jnp.array([
[A_cham * r_cop, -A_cham * r_cop],
[0.0, 0.0],
[2 * A_cham, 2 * A_cham],
])
else:
# compute the actuation matrix that collects the contributions of the pneumatic chambers in the given segment
# first we consider the contribution of the distal end
A_sm_de = J_de.T @ jnp.array([
[-2 * A_cham * jnp.sin(th_de), -2 * A_cham * jnp.sin(th_de)],
[2 * A_cham * jnp.cos(th_de), 2 * A_cham * jnp.cos(th_de)],
[A_cham * r_cop, -A_cham * r_cop]
])
# then, we consider the contribution of the proximal end
A_sm_pe = J_pe.T @ jnp.array([
[2 * A_cham * jnp.sin(th_pe), 2 * A_cham * jnp.sin(th_pe)],
[-2 * A_cham * jnp.cos(th_pe), -2 * A_cham * jnp.cos(th_pe)],
[-A_cham * r_cop, A_cham * r_cop]
])

# sum the contributions of the distal and proximal ends
A_sm = A_sm_de + A_sm_pe

return A_sm

Expand Down Expand Up @@ -173,18 +182,18 @@ def stiffness_fn(
B_xi: Strain basis matrix
formulate_in_strain_space: whether to formulate the elastic matrix in the strain space
Returns:
K: elastic matrix of shape (n_q, n_q) if formulate_in_strain_space is False or (n_xi, n_xi) otherwise
S: elastic matrix of shape (n_q, n_q) if formulate_in_strain_space is False or (n_xi, n_xi) otherwise
"""
# stiffness matrix of shape (num_segments, 3, 3)
S = vmap(
S_sms = vmap(
_compute_stiffness_matrix_for_segment
)(
params["l"], params["r"], params["r_cham_in"], params["r_cham_out"], params["varphi_cham"], params["E"]
)
# we define the elastic matrix of shape (n_xi, n_xi) as K(xi) = K @ xi where K is equal to
K = blk_diag(S)
# we define the elastic matrix of shape (n_xi, n_xi) as K(xi) = S @ xi where K is equal to
S = blk_diag(S_sms)

if not formulate_in_strain_space:
K = B_xi.T @ K @ B_xi
S = B_xi.T @ S @ B_xi

return K
return S

0 comments on commit 2daa3b1

Please sign in to comment.