From 839bf50cef5b95d804ea6dc643a1e0c100574252 Mon Sep 17 00:00:00 2001 From: Maximilian Stolzle Date: Wed, 20 Nov 2024 15:38:06 -0500 Subject: [PATCH] Fully implement pneumatic actuation model --- examples/simulate_pneumatic_planar_pcs.py | 73 ++++++++++++++++++++--- src/jsrm/systems/planar_pcs.py | 6 +- src/jsrm/systems/pneumatic_planar_pcs.py | 63 +++++++++++++------ 3 files changed, 113 insertions(+), 29 deletions(-) diff --git a/examples/simulate_pneumatic_planar_pcs.py b/examples/simulate_pneumatic_planar_pcs.py index b9825d1..98f1b9c 100644 --- a/examples/simulate_pneumatic_planar_pcs.py +++ b/examples/simulate_pneumatic_planar_pcs.py @@ -37,7 +37,7 @@ "r_cham_out": 2e-2 - 2e-3 * jnp.ones((num_segments,)), "varphi_cham": jnp.pi/2 * jnp.ones((num_segments,)), } -params["D"] = 1e-3 * jnp.diag( +params["D"] = 5e-4 * jnp.diag( (jnp.repeat( jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0 ) * params["l"][:, None]).flatten() @@ -52,29 +52,88 @@ ) # jit the functions dynamical_matrices_fn = jax.jit(dynamical_matrices_fn) -actuation_mapping_fn = auxiliary_fns["actuation_mapping_fn"] +actuation_mapping_fn = partial( + auxiliary_fns["actuation_mapping_fn"], + forward_kinematics_fn, + auxiliary_fns["jacobian_fn"], +) def sweep_actuation_mapping(): + # evaluate the actuation matrix for a straight backbone q = jnp.zeros((2 * num_segments,)) A = actuation_mapping_fn(params, B_xi, q) - print("A =\n", A) + print("Evaluating actuation matrix for straight backbone: A =\n", A) + + kappa_be_pts = jnp.linspace(-jnp.pi, jnp.pi, 500) + sigma_ax_pts = jnp.zeros_like(kappa_be_pts) + q_pts = jnp.stack([kappa_be_pts, sigma_ax_pts], axis=-1) + A_pts = vmap(actuation_mapping_fn, in_axes=(None, None, 0))(params, B_xi, q_pts) + # plot the mapping on the bending strain for various bending strains + fig, ax = plt.subplots(num="pneumatic_planar_pcs_actuation_mapping_bending_torque_vs_bending_strain") + plt.title(r"Actuation mapping from $u_1$ to $\tau_\mathrm{be}$") + ax.plot(kappa_be_pts, A_pts[:, 0, 0], linewidth=2) + # shade the region where the actuation mapping is negative as we are not able to bend the robot further + ax.axhspan(A_pts[:, 0, 0].min(), 0.0, facecolor='red', alpha=0.5) + ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]") + ax.set_ylabel(r"$\frac{\partial \tau_\mathrm{be}}{\partial u_1}$") + plt.grid(True) + plt.tight_layout() + plt.show() + + # create grid for bending and axial strains + kappa_be_grid, sigma_ax_grid = jnp.meshgrid( + jnp.linspace(-jnp.pi, jnp.pi, 20), + jnp.linspace(-0.2, 0.2, 20), + ) + q_pts = jnp.stack([kappa_be_grid.flatten(), sigma_ax_grid.flatten()], axis=-1) + + # evaluate the actuation mapping on the grid + A_pts = vmap(actuation_mapping_fn, in_axes=(None, None, 0))(params, B_xi, q_pts) + # reshape A_pts to match the grid shape + A_grid = A_pts.reshape(kappa_be_grid.shape[:2] + A_pts.shape[-2:]) + + # plot the mapping on the bending strain + fig, ax = plt.subplots(num="pneumatic_planar_pcs_actuation_mapping_bending_torque_vs_axial_vs_bending_strain") + plt.title(r"Actuation mapping from $u_1$ to $\tau_\mathrm{be}$") + # contourf plot + c = ax.contourf(kappa_be_grid, sigma_ax_grid, A_grid[..., 0, 0], levels=100) + fig.colorbar(c, ax=ax, label=r"$\frac{\partial \tau_\mathrm{be}}{\partial u_1}$") + # contour plot + ax.contour(kappa_be_grid, sigma_ax_grid, A_grid[..., 0, 0], levels=20, colors="k", linewidths=0.5) + ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]") + ax.set_ylabel(r"$\sigma_\mathrm{ax}$ [-]") + plt.tight_layout() + plt.show() + + # plot the mapping on the axial strain + fig, ax = plt.subplots(num="pneumatic_planar_pcs_actuation_mapping_axial_torque_vs_axial_vs_bending_strain") + plt.title(r"Actuation mapping from $u_1$ to $\tau_\mathrm{ax}$") + # contourf plot + c = ax.contourf(kappa_be_grid, sigma_ax_grid, A_grid[..., 1, 0], levels=100) + fig.colorbar(c, ax=ax, label=r"$\frac{\partial \tau_\mathrm{ax}}{\partial u_1}$") + # contour plot + ax.contour(kappa_be_grid, sigma_ax_grid, A_grid[..., 1, 0], levels=20, colors="k", linewidths=0.5) + ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]") + ax.set_ylabel(r"$\sigma_\mathrm{ax}$ [-]") + plt.tight_layout() + plt.show() def simulate_robot(): # define initial configuration - q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.2])[None, :], num_segments, axis=0).flatten() + q0 = jnp.repeat(jnp.array([-5.0 * jnp.pi, -0.2])[None, :], num_segments, axis=0).flatten() # number of generalized coordinates n_q = q0.shape[0] # set simulation parameters dt = 1e-3 # time step sim_dt = 5e-5 # simulation time step - ts = jnp.arange(0.0, 2, dt) # time steps + ts = jnp.arange(0.0, 7.0, dt) # time steps x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition - tau = jnp.zeros_like(q0) # torques + u = jnp.array([1.2e3, 0e0]) # control inputs (pressures in the right and left chambers) - ode_fn = ode_factory(dynamical_matrices_fn, params, tau) + ode_fn = ode_factory(dynamical_matrices_fn, params, u) term = ODETerm(ode_fn) sol = diffeqsolve( diff --git a/src/jsrm/systems/planar_pcs.py b/src/jsrm/systems/planar_pcs.py index 7cfdbd7..54d84f5 100644 --- a/src/jsrm/systems/planar_pcs.py +++ b/src/jsrm/systems/planar_pcs.py @@ -256,7 +256,7 @@ def actuation_mapping_fn( Returns: A: actuation matrix of shape (n_xi, n_xi) where n_xi is the number of strains. """ - A = jnp.identity(n_xi) @ B_xi + A = B_xi.T @ jnp.identity(n_xi) @ B_xi return A @@ -359,7 +359,7 @@ def dynamical_matrices_fn( # compute the stiffness matrix K = stiffness_fn(params, B_xi, formulate_in_strain_space=True) # compute the actuation matrix - A = actuation_mapping_fn(forward_kinematics_fn, actuation_mapping_fn, params, B_xi, q) + A = actuation_mapping_fn(forward_kinematics_fn, jacobian_fn, params, B_xi, q) # dissipative matrix from the parameters D = params.get("D", jnp.zeros((n_xi, n_xi))) @@ -376,7 +376,7 @@ def dynamical_matrices_fn( D = B_xi.T @ D @ B_xi # apply the strain basis to the actuation matrix - alpha = B_xi.T @ A + alpha = A return B, C, G, K, D, alpha diff --git a/src/jsrm/systems/pneumatic_planar_pcs.py b/src/jsrm/systems/pneumatic_planar_pcs.py index 26302d5..387d71a 100644 --- a/src/jsrm/systems/pneumatic_planar_pcs.py +++ b/src/jsrm/systems/pneumatic_planar_pcs.py @@ -55,15 +55,8 @@ def actuation_mapping_fn( A: actuation matrix of shape (n_xi, n_act) where n_xi is the number of strains and n_act is the number of actuators """ - # map the configurations to strains - xi = B_xi @ q - - # number of strains - n_xi = xi.shape[0] - # all segment bases and tips sms = jnp.concat([jnp.zeros((1,)), jnp.cumsum(params["l"])], axis=0) - print("sms =\n", sms) # compute the poses of all segment tips chi_sms = vmap(forward_kinematics_fn, in_axes=(None, None, 0))(params, q, sms) @@ -74,10 +67,20 @@ def actuation_mapping_fn( def compute_actuation_matrix_for_segment( r_cham_in: Array, r_cham_out: Array, varphi_cham: Array, chi_pe: Array, chi_de: Array, - J_pe: Array, J_de: Array, xi: Array + J_pe: Array, J_de: Array, ) -> Array: """ Compute the actuation matrix for a single segment. + We assume that each segment contains four identical and symmetric pneumatic chambers with pressures + p1, p2, p3, and p4, where p1 and p3 are the right and left chamber pressures respectively, and + p2 and p4 are the back and front chamber pressures respectively. The front and back chambers + do not exert a level arm (i.e., a bending moment) on the segment. + We map the control inputs u1 and u2 as follows to the pressures: + p1 = u1 (right chamber) + p2 = (u1 + u2) / 2 + p3 = u2 (left chamber) + p4 = (u1 + u2) / 2 + Args: r_cham_in: inner radius of each segment chamber r_cham_out: outer radius of each segment chamber @@ -86,23 +89,45 @@ def compute_actuation_matrix_for_segment( chi_de: pose of the distal end (i.e., the tip) of the segment as array of shape (3,) J_pe: Jacobian of the proximal end of the segment as array of shape (3, n_q) J_de: Jacobian of the distal end of the segment as array of shape (3, n_q) - xi: strains of the segment Returns: A_sm: actuation matrix of shape (n_xi, 2) """ - # rotation matrix from the robot base to the segment base - R_pe = jnp.array([[jnp.cos(chi_pe[2]), -jnp.sin(chi_pe[2])], [jnp.sin(chi_pe[2]), jnp.cos(chi_pe[2])]]) - # rotation matrix from the robot base to the segment tip - R_de = jnp.array([[jnp.cos(chi_de[2]), -jnp.sin(chi_de[2])], [jnp.sin(chi_de[2]), jnp.cos(chi_de[2])]]) + # orientation of the proximal and distal ends of the segment + th_pe, th_de = chi_pe[2], chi_de[2] + + # compute the area of each pneumatic chamber (we assume identical chambers within a segment) + A_cham = 0.5 * varphi_cham * (r_cham_out ** 2 - r_cham_in ** 2) + # compute the center of pressure of the pneumatic chamber + r_cop = ( + 2 / 3 * jnp.sinc(0.5 * varphi_cham) * (r_cham_out ** 3 - r_cham_in ** 3) / (r_cham_out ** 2 - r_cham_in ** 2) + ) + + # compute the actuation matrix that collects the contributions of the pneumatic chambers in the given segment + # first we consider the contribution of the distal end + A_sm_de = J_de.T @ jnp.array([ + [-2 * A_cham * jnp.sin(th_de), -2 * A_cham * jnp.sin(th_de)], + [2 * A_cham * jnp.cos(th_de), 2 * A_cham * jnp.cos(th_de)], + [A_cham * r_cop, -A_cham * r_cop] + ]) + # then, we consider the contribution of the proximal end + A_sm_pe = J_pe.T @ jnp.array([ + [2 * A_cham * jnp.sin(th_pe), 2 * A_cham * jnp.sin(th_pe)], + [-2 * A_cham * jnp.cos(th_pe), -2 * A_cham * jnp.cos(th_pe)], + [-A_cham * r_cop, A_cham * r_cop] + ]) + + # sum the contributions of the distal and proximal ends + A_sm = A_sm_de + A_sm_pe - - # compute the actuation matrix for a single segment - A_sm = jnp.zeros((n_xi, 2)) return A_sm - A_sms = vmap(compute_actuation_matrix_for_segment)(chi_sms, J_sms, xi) - - A = jnp.zeros((n_xi, 2 * num_segments)) + A_sms = vmap(compute_actuation_matrix_for_segment)( + params["r_cham_in"], params["r_cham_out"], params["varphi_cham"], + chi_pe=chi_sms[:-1], chi_de=chi_sms[1:], + J_pe=J_sms[:-1], J_de=J_sms[1:], + ) + # we need to sum the contributions of the actuation of each segment + A = jnp.sum(A_sms, axis=0) # apply the actuation_basis A = A @ actuation_basis