Skip to content

Commit

Permalink
chore: improve clone behavior in axum (#306)
Browse files Browse the repository at this point in the history
* Removed the need for clone each time on a axum request for sign node

* Removed the need for clone each time on a axum request for leader node

* FRP signatures are now passed by ref
  • Loading branch information
ChaoticTempest authored Sep 27, 2023
1 parent f405e55 commit f7600ec
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 66 deletions.
4 changes: 2 additions & 2 deletions mpc-recovery/src/key_recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ pub async fn get_user_recovery_pk(
client: &reqwest::Client,
sign_nodes: &[String],
oidc_token: &OidcToken,
frp_signature: Signature,
frp_signature: &Signature,
frp_public_key: &PublicKey,
) -> Result<PublicKey, LeaderNodeError> {
let request = PublicKeyNodeRequest {
oidc_token: oidc_token.clone(),
frp_signature,
frp_signature: *frp_signature,
frp_public_key: frp_public_key.clone(),
};
let res = call_all_nodes(client, sign_nodes, "public_key", request).await?;
Expand Down
32 changes: 17 additions & 15 deletions mpc-recovery/src/leader_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::transaction::{
};
use crate::utils::{check_digest_signature, user_credentials_request_digest};
use crate::{metrics, nar};

use anyhow::Context;
use axum::extract::MatchedPath;
use axum::middleware::{self, Next};
Expand All @@ -34,7 +35,9 @@ use near_primitives::transaction::{Action, DeleteKeyAction};
use near_primitives::types::AccountId;
use prometheus::{Encoder, TextEncoder};
use rand::{distributions::Alphanumeric, Rng};

use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;

pub struct Config {
Expand Down Expand Up @@ -87,7 +90,7 @@ pub async fn run<T: OAuthTokenVerifier + 'static>(config: Config) {
.unwrap();
}

let state = LeaderState {
let state = Arc::new(LeaderState {
env,
sign_nodes,
client,
Expand All @@ -96,7 +99,7 @@ pub async fn run<T: OAuthTokenVerifier + 'static>(config: Config) {
account_creator_id,
account_creator_sk,
partners,
};
});

// Get keys from all sign nodes, and broadcast them out as a set.
let pk_set = match gather_sign_node_pk_shares(&state).await {
Expand Down Expand Up @@ -206,7 +209,6 @@ async fn metrics() -> (StatusCode, String) {
}
}

#[derive(Clone)]
struct LeaderState {
env: String,
sign_nodes: Vec<String>,
Expand All @@ -220,7 +222,7 @@ struct LeaderState {
}

async fn mpc_public_key(
Extension(state): Extension<LeaderState>,
Extension(state): Extension<Arc<LeaderState>>,
WithRejection(Json(_), _): WithRejection<Json<MpcPkRequest>, MpcError>,
) -> (StatusCode, Json<MpcPkResponse>) {
// Getting MPC PK from sign nodes
Expand Down Expand Up @@ -253,7 +255,7 @@ async fn mpc_public_key(

#[tracing::instrument(level = "info", skip_all, fields(env = state.env))]
async fn claim_oidc(
Extension(state): Extension<LeaderState>,
Extension(state): Extension<Arc<LeaderState>>,
WithRejection(Json(claim_oidc_request), _): WithRejection<Json<ClaimOidcRequest>, MpcError>,
) -> (StatusCode, Json<ClaimOidcResponse>) {
tracing::info!(
Expand Down Expand Up @@ -287,7 +289,7 @@ async fn claim_oidc(

#[tracing::instrument(level = "info", skip_all, fields(env = state.env))]
async fn user_credentials<T: OAuthTokenVerifier>(
Extension(state): Extension<LeaderState>,
Extension(state): Extension<Arc<LeaderState>>,
WithRejection(Json(request), _): WithRejection<Json<UserCredentialsRequest>, MpcError>,
) -> (StatusCode, Json<UserCredentialsResponse>) {
tracing::info!(
Expand All @@ -313,7 +315,7 @@ async fn user_credentials<T: OAuthTokenVerifier>(
}

async fn process_user_credentials<T: OAuthTokenVerifier>(
state: LeaderState,
state: Arc<LeaderState>,
request: UserCredentialsRequest,
) -> Result<UserCredentialsResponse, LeaderNodeError> {
T::verify_token(&request.oidc_token, &state.partners.oidc_providers())
Expand All @@ -325,7 +327,7 @@ async fn process_user_credentials<T: OAuthTokenVerifier>(
&state.reqwest_client,
&state.sign_nodes,
&request.oidc_token,
request.frp_signature,
&request.frp_signature,
&request.frp_public_key,
)
.await?;
Expand All @@ -338,7 +340,7 @@ async fn process_user_credentials<T: OAuthTokenVerifier>(
}

async fn process_new_account<T: OAuthTokenVerifier>(
state: LeaderState,
state: Arc<LeaderState>,
request: NewAccountRequest,
) -> Result<NewAccountResponse, LeaderNodeError> {
// Create a transaction to create new NEAR account
Expand Down Expand Up @@ -373,7 +375,7 @@ async fn process_new_account<T: OAuthTokenVerifier>(
&state.reqwest_client,
&state.sign_nodes,
&request.oidc_token,
request.user_credentials_frp_signature,
&request.user_credentials_frp_signature,
&request.frp_public_key,
)
.await?;
Expand Down Expand Up @@ -454,7 +456,7 @@ async fn process_new_account<T: OAuthTokenVerifier>(

#[tracing::instrument(level = "info", skip_all, fields(env = state.env))]
async fn new_account<T: OAuthTokenVerifier>(
Extension(state): Extension<LeaderState>,
Extension(state): Extension<Arc<LeaderState>>,
WithRejection(Json(request), _): WithRejection<Json<NewAccountRequest>, MpcError>,
) -> (StatusCode, Json<NewAccountResponse>) {
tracing::info!(
Expand All @@ -477,7 +479,7 @@ async fn new_account<T: OAuthTokenVerifier>(
}

async fn process_sign<T: OAuthTokenVerifier>(
state: LeaderState,
state: Arc<LeaderState>,
request: SignRequest,
) -> Result<SignResponse, LeaderNodeError> {
// Deserialize the included delegate action via borsh
Expand Down Expand Up @@ -510,7 +512,7 @@ async fn process_sign<T: OAuthTokenVerifier>(
&state.reqwest_client,
&state.sign_nodes,
&request.oidc_token,
request.user_credentials_frp_signature,
&request.user_credentials_frp_signature,
&request.frp_public_key,
)
.await?;
Expand Down Expand Up @@ -543,7 +545,7 @@ async fn process_sign<T: OAuthTokenVerifier>(
&state.sign_nodes,
&request.oidc_token,
delegate_action.clone(),
request.frp_signature,
&request.frp_signature,
&request.frp_public_key,
)
.await?;
Expand All @@ -555,7 +557,7 @@ async fn process_sign<T: OAuthTokenVerifier>(

#[tracing::instrument(level = "info", skip_all, fields(env = state.env))]
async fn sign<T: OAuthTokenVerifier>(
Extension(state): Extension<LeaderState>,
Extension(state): Extension<Arc<LeaderState>>,
WithRejection(Json(request), _): WithRejection<Json<SignRequest>, MpcError>,
) -> (StatusCode, Json<SignResponse>) {
tracing::info!(
Expand Down
10 changes: 0 additions & 10 deletions mpc-recovery/src/relayer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,6 @@ pub struct NearRpcAndRelayerClient {
cached_nonces: CachedAccessKeyNonces,
}

impl Clone for NearRpcAndRelayerClient {
fn clone(&self) -> Self {
Self {
rpc_client: self.rpc_client.clone(),
// all the cached nonces will not get cloned, and instead get invalidated:
cached_nonces: Default::default(),
}
}
}

impl NearRpcAndRelayerClient {
pub fn connect(near_rpc: &str) -> Self {
Self {
Expand Down
38 changes: 18 additions & 20 deletions mpc-recovery/src/sign_node/aggregate_signer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;

use curv::arithmetic::Converter;
use curv::cryptographic_primitives::commitments::{
Expand Down Expand Up @@ -68,7 +67,7 @@ impl SigningState {

pub async fn get_reveal(
&self,
node_info: NodeInfo,
node_info: &NodeInfo,
recieved_commitments: Vec<SignedCommitment>,
) -> Result<Reveal, AggregateSigningError> {
// TODO Factor this out
Expand All @@ -89,15 +88,15 @@ impl SigningState {
AggregateSigningError::CommitmentNotFound(format!("{:?}", our_c.commitment))
})?;

let (reveal, state) = state.reveal(&node_info, recieved_commitments).await?;
let (reveal, state) = state.reveal(node_info, recieved_commitments).await?;
let reveal = Reveal(reveal);
self.revealed.write().await.insert(reveal.clone(), state);
Ok(reveal)
}

pub async fn get_signature_share(
&self,
node_info: NodeInfo,
node_info: &NodeInfo,
signature_parts: Vec<Reveal>,
) -> Result<protocols::Signature, AggregateSigningError> {
let i = node_info.our_index;
Expand All @@ -115,7 +114,7 @@ impl SigningState {

let signature_parts = signature_parts.into_iter().map(|s| s.0).collect();

state.combine(signature_parts, &node_info)
state.combine(signature_parts, node_info)
}
}

Expand Down Expand Up @@ -233,17 +232,16 @@ impl Revealed {
}

// Stores info about the other nodes we're interacting with
#[derive(Clone)]
pub struct NodeInfo {
pub nodes_public_keys: Arc<RwLock<Option<Vec<Point<Ed25519>>>>>,
pub nodes_public_keys: RwLock<Option<Vec<Point<Ed25519>>>>,
pub our_index: usize,
}

impl NodeInfo {
pub fn new(our_index: usize, nodes_public_keys: Option<Vec<Point<Ed25519>>>) -> Self {
Self {
our_index,
nodes_public_keys: Arc::new(RwLock::new(nodes_public_keys)),
nodes_public_keys: RwLock::new(nodes_public_keys),
}
}

Expand Down Expand Up @@ -441,19 +439,19 @@ mod tests {
commitments.push(create_rogue_commit(&message, &commitments));

let reveals = vec![
s1.get_reveal(ni(0), commitments.clone()).await.unwrap(),
s2.get_reveal(ni(1), commitments.clone()).await.unwrap(),
s3.get_reveal(ni(2), commitments.clone()).await.unwrap(),
s1.get_reveal(&ni(0), commitments.clone()).await.unwrap(),
s2.get_reveal(&ni(1), commitments.clone()).await.unwrap(),
s3.get_reveal(&ni(2), commitments.clone()).await.unwrap(),
];

let sig_shares = vec![
s1.get_signature_share(ni(0), reveals.clone())
s1.get_signature_share(&ni(0), reveals.clone())
.await
.unwrap(),
s2.get_signature_share(ni(1), reveals.clone())
s2.get_signature_share(&ni(1), reveals.clone())
.await
.unwrap(),
s3.get_signature_share(ni(2), reveals).await.unwrap(),
s3.get_signature_share(&ni(2), reveals).await.unwrap(),
];

let signing_keys: Vec<_> = commitments
Expand Down Expand Up @@ -500,19 +498,19 @@ mod tests {
];

let reveals = vec![
s1.get_reveal(ni(0), commitments.clone()).await.unwrap(),
s2.get_reveal(ni(1), commitments.clone()).await.unwrap(),
s3.get_reveal(ni(2), commitments.clone()).await.unwrap(),
s1.get_reveal(&ni(0), commitments.clone()).await.unwrap(),
s2.get_reveal(&ni(1), commitments.clone()).await.unwrap(),
s3.get_reveal(&ni(2), commitments.clone()).await.unwrap(),
];

let sig_shares = vec![
s1.get_signature_share(ni(0), reveals.clone())
s1.get_signature_share(&ni(0), reveals.clone())
.await
.unwrap(),
s2.get_signature_share(ni(1), reveals.clone())
s2.get_signature_share(&ni(1), reveals.clone())
.await
.unwrap(),
s3.get_signature_share(ni(2), reveals).await.unwrap(),
s3.get_signature_share(&ni(2), reveals).await.unwrap(),
];

let signing_keys: Vec<_> = commitments
Expand Down
Loading

0 comments on commit f7600ec

Please sign in to comment.