diff --git a/Cargo.lock b/Cargo.lock index 63edf1966..98f9886f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1073,15 +1073,13 @@ dependencies = [ [[package]] name = "bollard" -version = "0.11.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c92fed694fd5a7468c971538351c61b9c115f1ae6ed411cd2800f0f299403a4b" +checksum = "d82e7850583ead5f8bbef247e2a3c37a19bd576e8420cd262a6711921827e1e5" dependencies = [ "base64 0.13.1", "bollard-stubs", "bytes", - "chrono", - "dirs-next", "futures-core", "futures-util", "hex 0.4.3", @@ -1089,25 +1087,24 @@ dependencies = [ "hyper", "hyperlocal", "log", - "pin-project", + "pin-project-lite", "serde", "serde_derive", "serde_json", "serde_urlencoded", "thiserror", "tokio", - "tokio-util 0.6.10", + "tokio-util 0.7.9", "url 2.4.1", "winapi", ] [[package]] name = "bollard-stubs" -version = "1.41.0" +version = "1.42.0-rc.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed2f2e73fffe9455141e170fb9c1feb0ac521ec7e7dcd47a7cab72a658490fb8" +checksum = "ed59b5c00048f48d7af971b71f800fdf23e858844a6f9e4d32ca72e9399e7864" dependencies = [ - "chrono", "serde", "serde_with 1.14.0", ] @@ -3627,6 +3624,7 @@ dependencies = [ "hyper", "mpc-contract", "mpc-recovery", + "mpc-recovery-node", "multi-party-eddsa", "near-crypto 0.17.0", "near-fetch", @@ -6708,9 +6706,9 @@ dependencies = [ [[package]] name = "testcontainers" -version = "0.14.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e2b1567ca8a2b819ea7b28c92be35d9f76fb9edb214321dcc86eb96023d1f87" +checksum = "f83d2931d7f521af5bae989f716c3fa43a6af9af7ec7a5e21b59ae40878cec00" dependencies = [ "async-trait", "bollard", diff --git a/integration-tests/Cargo.toml b/integration-tests/Cargo.toml index 0ca2865d5..270ad9d17 100644 --- a/integration-tests/Cargo.toml +++ b/integration-tests/Cargo.toml @@ -8,14 +8,16 @@ publish = false aes-gcm = "0.10" anyhow = { version = "1.0", features = ["backtrace"] } async-process = "1" -bollard = "0.11" +bollard = "0.13" clap = { version = "4.2", features = ["derive", "env"] } curv = { package = "curv-kzen", version = "0.9", default-features = false } ed25519-dalek = { version = "1.0.1", features = ["serde"] } futures = "0.3" hex = "0.4.3" hyper = { version = "0.14", features = ["full"] } +mpc-contract = { path = "../contract" } mpc-recovery = { path = "../mpc-recovery" } +mpc-recovery-node = { path = "../node" } multi-party-eddsa = { git = "https://github.com/DavidM-D/multi-party-eddsa.git", rev = "25ae4fdc5ff7819ae70e73ab4afacf1c24fc4da1" } tracing = "0.1" near-crypto = "0.17" @@ -26,7 +28,7 @@ near-units = "0.2.0" once_cell = "1" serde = "1" serde_json = "1" -testcontainers = { version = "0.14", features = ["experimental"] } +testcontainers = { version = "0.15", features = ["experimental"] } tokio = { version = "1.28", features = ["full"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } near-workspaces = "0.8.0" diff --git a/integration-tests/src/env/containers.rs b/integration-tests/src/env/containers.rs index f4ebb0b2d..071948a04 100644 --- a/integration-tests/src/env/containers.rs +++ b/integration-tests/src/env/containers.rs @@ -34,8 +34,7 @@ use once_cell::sync::Lazy; use testcontainers::{ clients::Cli, core::{ExecCommand, WaitFor}, - images::generic::GenericImage, - Container, Image, RunnableImage, + Container, GenericImage, Image, RunnableImage, }; use tokio::io::AsyncWriteExt; use tracing; @@ -972,78 +971,3 @@ impl LeaderNodeApi { .await } } - -pub struct Node<'a> { - pub container: Container<'a, GenericImage>, - pub address: String, - pub local_address: String, -} - -pub struct NodeApi { - pub address: String, - pub node_id: usize, - pub sk_share: ExpandedKeyPair, - pub cipher_key: GenericArray, - pub gcp_project_id: String, - pub gcp_datastore_local_url: String, -} - -impl<'a> Node<'a> { - // Container port used for the docker network, does not have to be unique - const CONTAINER_PORT: u16 = 3000; - - pub async fn run( - docker_client: &'a DockerClient, - network: &str, - node_id: u64, - near_rpc: &str, - signer_account: &AccountId, - account: &AccountId, - account_sk: &near_workspaces::types::SecretKey, - ) -> anyhow::Result> { - tracing::info!(node_id, "running node container"); - let image: GenericImage = GenericImage::new("near/mpc-recovery-node", "latest") - .with_wait_for(WaitFor::Nothing) - .with_exposed_port(Self::CONTAINER_PORT) - .with_env_var("RUST_LOG", "mpc_recovery_node=DEBUG") - .with_env_var("RUST_BACKTRACE", "1"); - let image: RunnableImage = ( - image, - vec![ - "start".to_string(), - "--node-id".to_string(), - node_id.to_string(), - "--near-rpc".to_string(), - near_rpc.to_string(), - "--mpc-contract-id".to_string(), - signer_account.to_string(), - "--account".to_string(), - account.to_string(), - "--account-sk".to_string(), - account_sk.to_string(), - "--web-port".to_string(), - Self::CONTAINER_PORT.to_string(), - ], - ) - .into(); - let image = image.with_network(network); - let container = docker_client.cli.run(image); - let ip_address = docker_client - .get_network_ip_address(&container, network) - .await?; - let host_port = container.get_host_port_ipv4(Self::CONTAINER_PORT); - - container.exec(ExecCommand { - cmd: format!("bash -c 'while [[ \"$(curl -s -o /dev/null -w ''%{{http_code}}'' localhost:{})\" != \"200\" ]]; do sleep 1; done'", Self::CONTAINER_PORT), - ready_conditions: vec![WaitFor::message_on_stdout("node is ready to accept connections")] - }); - - let full_address = format!("http://{ip_address}:{}", Self::CONTAINER_PORT); - tracing::info!(node_id, full_address, "node container is running"); - Ok(Node { - container, - address: full_address, - local_address: format!("http://localhost:{host_port}"), - }) - } -} diff --git a/integration-tests/src/lib.rs b/integration-tests/src/lib.rs index fd05925dc..443e3c0ee 100644 --- a/integration-tests/src/lib.rs +++ b/integration-tests/src/lib.rs @@ -10,6 +10,7 @@ use near_workspaces::{ use crate::env::containers; pub mod env; +pub mod multichain; pub mod sandbox; pub mod util; diff --git a/integration-tests/src/multichain/containers.rs b/integration-tests/src/multichain/containers.rs new file mode 100644 index 000000000..fbc9ebba2 --- /dev/null +++ b/integration-tests/src/multichain/containers.rs @@ -0,0 +1,72 @@ +use ed25519_dalek::ed25519::signature::digest::{consts::U32, generic_array::GenericArray}; +use multi_party_eddsa::protocols::ExpandedKeyPair; +use near_workspaces::AccountId; +use testcontainers::{ + core::{ExecCommand, WaitFor}, + Container, GenericImage, RunnableImage, +}; +use tracing; + +pub struct Node<'a> { + pub container: Container<'a, GenericImage>, + pub address: String, + pub local_address: String, +} + +pub struct NodeApi { + pub address: String, + pub node_id: usize, + pub sk_share: ExpandedKeyPair, + pub cipher_key: GenericArray, + pub gcp_project_id: String, + pub gcp_datastore_local_url: String, +} + +impl<'a> Node<'a> { + // Container port used for the docker network, does not have to be unique + const CONTAINER_PORT: u16 = 3000; + + pub async fn run( + ctx: &super::Context<'a>, + node_id: u32, + account: &AccountId, + account_sk: &near_workspaces::types::SecretKey, + ) -> anyhow::Result> { + tracing::info!(node_id, "running node container"); + let args = mpc_recovery_node::cli::Cli::Start { + node_id: node_id.into(), + near_rpc: ctx.sandbox.local_address.clone(), + mpc_contract_id: ctx.mpc_contract.id().clone(), + account: account.clone(), + account_sk: account_sk.to_string().parse()?, + web_port: Self::CONTAINER_PORT, + } + .into_str_args(); + let image: GenericImage = GenericImage::new("near/mpc-recovery-node", "latest") + .with_wait_for(WaitFor::Nothing) + .with_exposed_port(Self::CONTAINER_PORT) + .with_env_var("RUST_LOG", "mpc_recovery_node=DEBUG") + .with_env_var("RUST_BACKTRACE", "1"); + let image: RunnableImage = (image, args).into(); + let image = image.with_network(&ctx.docker_network); + let container = ctx.docker_client.cli.run(image); + let ip_address = ctx + .docker_client + .get_network_ip_address(&container, &ctx.docker_network) + .await?; + let host_port = container.get_host_port_ipv4(Self::CONTAINER_PORT); + + container.exec(ExecCommand { + cmd: format!("bash -c 'while [[ \"$(curl -s -o /dev/null -w ''%{{http_code}}'' localhost:{})\" != \"200\" ]]; do sleep 1; done'", Self::CONTAINER_PORT), + ready_conditions: vec![WaitFor::message_on_stdout("node is ready to accept connections")] + }); + + let full_address = format!("http://{ip_address}:{}", Self::CONTAINER_PORT); + tracing::info!(node_id, full_address, "node container is running"); + Ok(Node { + container, + address: full_address, + local_address: format!("http://localhost:{host_port}"), + }) + } +} diff --git a/integration-tests/src/multichain/local.rs b/integration-tests/src/multichain/local.rs new file mode 100644 index 000000000..999969c35 --- /dev/null +++ b/integration-tests/src/multichain/local.rs @@ -0,0 +1,50 @@ +use crate::util; +use async_process::Child; +use near_workspaces::AccountId; + +#[allow(dead_code)] +pub struct Node { + pub address: String, + node_id: usize, + account: AccountId, + account_sk: near_workspaces::types::SecretKey, + + // process held so it's not dropped. Once dropped, process will be killed. + #[allow(unused)] + process: Child, +} + +impl Node { + pub async fn run( + ctx: &super::Context<'_>, + node_id: u32, + account: &AccountId, + account_sk: &near_workspaces::types::SecretKey, + ) -> anyhow::Result { + let web_port = util::pick_unused_port().await?; + let args = mpc_recovery_node::cli::Cli::Start { + node_id: node_id.into(), + near_rpc: ctx.sandbox.local_address.clone(), + mpc_contract_id: ctx.mpc_contract.id().clone(), + account: account.clone(), + account_sk: account_sk.to_string().parse()?, + web_port, + } + .into_str_args(); + + let mpc_node_id = format!("multichain/{node_id}"); + let process = util::spawn_mpc_multichain(ctx.release, &mpc_node_id, &args)?; + let address = format!("http://127.0.0.1:{web_port}"); + tracing::info!("node is starting at {}", address); + util::ping_until_ok(&address, 60).await?; + tracing::info!("node started [node_id={node_id}, {address}]"); + + Ok(Self { + address, + node_id: node_id as usize, + account: account.clone(), + account_sk: account_sk.clone(), + process, + }) + } +} diff --git a/integration-tests/src/multichain/mod.rs b/integration-tests/src/multichain/mod.rs new file mode 100644 index 000000000..bf65ae189 --- /dev/null +++ b/integration-tests/src/multichain/mod.rs @@ -0,0 +1,186 @@ +pub mod containers; +pub mod local; + +use crate::env::containers::DockerClient; +use crate::{initialize_sandbox, SandboxCtx}; +use mpc_contract::ParticipantInfo; +use near_workspaces::network::Sandbox; +use near_workspaces::{AccountId, Contract, Worker}; +use serde_json::json; +use std::collections::HashMap; + +const NETWORK: &str = "mpc_it_network"; + +pub enum Nodes<'a> { + Local { + ctx: Context<'a>, + nodes: Vec, + }, + Docker { + ctx: Context<'a>, + nodes: Vec>, + }, +} + +impl Nodes<'_> { + pub fn ctx(&self) -> &Context { + match self { + Nodes::Local { ctx, .. } => ctx, + Nodes::Docker { ctx, .. } => ctx, + } + } + + pub async fn add_node( + &mut self, + node_id: u32, + account: &AccountId, + account_sk: &near_workspaces::types::SecretKey, + ) -> anyhow::Result<()> { + tracing::info!(%account, "adding one more node"); + match self { + Nodes::Local { ctx, nodes } => { + nodes.push(local::Node::run(&ctx, node_id, account, account_sk).await?) + } + Nodes::Docker { ctx, nodes } => { + nodes.push(containers::Node::run(&ctx, node_id, account, account_sk).await?) + } + } + + Ok(()) + } +} + +pub struct Context<'a> { + pub docker_client: &'a DockerClient, + pub docker_network: String, + pub release: bool, + + pub sandbox: crate::env::containers::Sandbox<'a>, + pub worker: Worker, + pub mpc_contract: Contract, +} + +pub async fn setup(docker_client: &DockerClient) -> anyhow::Result> { + let docker_network = NETWORK; + docker_client.create_network(docker_network).await?; + + let SandboxCtx { sandbox, worker } = initialize_sandbox(&docker_client, NETWORK).await?; + + let mpc_contract = worker + .dev_deploy(include_bytes!( + "../../../target/wasm32-unknown-unknown/release/mpc_contract.wasm" + )) + .await?; + tracing::info!(contract_id = %mpc_contract.id(), "deployed mpc contract"); + + Ok(Context { + docker_client, + docker_network: docker_network.to_string(), + release: true, + sandbox, + worker, + mpc_contract, + }) +} + +pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Result { + let ctx = setup(docker_client).await?; + + let accounts = futures::future::join_all((0..nodes).map(|_| ctx.worker.dev_create_account())) + .await + .into_iter() + .collect::, _>>()?; + let mut node_futures = Vec::new(); + for (i, account) in accounts.iter().enumerate() { + let node = containers::Node::run(&ctx, i as u32, account.id(), account.secret_key()); + node_futures.push(node); + } + let nodes = futures::future::join_all(node_futures) + .await + .into_iter() + .collect::, _>>()?; + let participants: HashMap = accounts + .iter() + .cloned() + .enumerate() + .zip(&nodes) + .map(|((i, account), node)| { + ( + account.id().clone(), + ParticipantInfo { + id: i as u32, + account_id: account.id().to_string().parse().unwrap(), + url: node.address.clone(), + }, + ) + }) + .collect(); + ctx.mpc_contract + .call("init") + .args_json(json!({ + "threshold": 2, + "participants": participants + })) + .transact() + .await? + .into_result()?; + + Ok(Nodes::Docker { ctx, nodes }) +} + +pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result { + let ctx = setup(docker_client).await?; + + let accounts = futures::future::join_all((0..nodes).map(|_| ctx.worker.dev_create_account())) + .await + .into_iter() + .collect::, _>>()?; + let mut node_futures = Vec::with_capacity(nodes); + for (i, account) in accounts.iter().enumerate().take(nodes) { + node_futures.push(local::Node::run( + &ctx, + i as u32, + account.id(), + account.secret_key(), + )); + } + let nodes = futures::future::join_all(node_futures) + .await + .into_iter() + .collect::, _>>()?; + let participants: HashMap = accounts + .iter() + .cloned() + .enumerate() + .zip(&nodes) + .map(|((i, account), node)| { + ( + account.id().clone(), + ParticipantInfo { + id: i as u32, + account_id: account.id().to_string().parse().unwrap(), + url: node.address.clone(), + }, + ) + }) + .collect(); + ctx.mpc_contract + .call("init") + .args_json(json!({ + "threshold": 2, + "participants": participants + })) + .transact() + .await? + .into_result()?; + + Ok(Nodes::Local { ctx, nodes }) +} + +pub async fn run(nodes: usize, docker_client: &DockerClient) -> anyhow::Result { + #[cfg(feature = "docker-test")] + return docker(nodes, docker_client).await; + + #[cfg(not(feature = "docker-test"))] + return host(nodes, docker_client).await; +} diff --git a/integration-tests/src/util.rs b/integration-tests/src/util.rs index a8eda42e8..3ddd84685 100644 --- a/integration-tests/src/util.rs +++ b/integration-tests/src/util.rs @@ -12,6 +12,7 @@ use std::{ use toml::Value; const EXECUTABLE: &str = "mpc-recovery"; +const EXECUTABLE_MULTICHAIN: &str = "mpc-recovery-node"; pub async fn post( uri: U, @@ -235,15 +236,15 @@ pub fn target_dir() -> Option { } } -pub fn executable(release: bool) -> Option { +pub fn executable(release: bool, executable: &str) -> Option { let executable = target_dir()? .join(if release { "release" } else { "debug" }) - .join(EXECUTABLE); + .join(executable); Some(executable) } pub fn spawn_mpc(release: bool, node: &str, args: &[String]) -> anyhow::Result { - let executable = executable(release) + let executable = executable(release, EXECUTABLE) .with_context(|| format!("could not find target dir while starting {node} node"))?; Command::new(&executable) @@ -256,3 +257,18 @@ pub fn spawn_mpc(release: bool, node: &str, args: &[String]) -> anyhow::Result anyhow::Result { + let executable = executable(release, EXECUTABLE_MULTICHAIN) + .with_context(|| format!("could not find target dir while starting {node} node"))?; + + Command::new(&executable) + .args(args) + .env("RUST_LOG", "mpc_recovery_node=INFO") + .envs(std::env::vars()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .kill_on_drop(true) + .spawn() + .with_context(|| format!("failed to run {node} node: {}", executable.display())) +} diff --git a/integration-tests/tests/lib.rs b/integration-tests/tests/lib.rs index dc8f84e8c..e74840518 100644 --- a/integration-tests/tests/lib.rs +++ b/integration-tests/tests/lib.rs @@ -2,6 +2,7 @@ mod mpc; mod multichain; use curv::elliptic::curves::{Ed25519, Point}; +use futures::future::BoxFuture; use hyper::StatusCode; use mpc_recovery::{ gcp::GcpService, @@ -58,95 +59,20 @@ where Ok(()) } -pub struct MultichainTestContext { - worker: Worker, +pub struct MultichainTestContext<'a> { + nodes: mpc_recovery_integration_tests::multichain::Nodes<'a>, rpc_client: near_fetch::Client, - mpc_contract: Contract, - near_rpc: String, -} - -#[derive(Serialize, Deserialize)] -pub struct Participant { - id: u32, - account_id: AccountId, - url: String, } -async fn with_multichain_nodes(nodes: usize, f: F) -> anyhow::Result<()> +async fn with_multichain_nodes(nodes: usize, f: F) -> anyhow::Result<()> where - F: FnOnce(MultichainTestContext) -> Fut, - Fut: core::future::Future>, + F: for<'a> FnOnce(MultichainTestContext<'a>) -> BoxFuture<'a, anyhow::Result<()>>, { - let docker_client = containers::DockerClient::default(); - docker_client.create_network(NETWORK).await?; - - let SandboxCtx { sandbox, worker } = - mpc_recovery_integration_tests::initialize_sandbox(&docker_client, NETWORK).await?; - - tracing::info!("deploying mpc contract"); - let mpc_contract = worker - .dev_deploy(include_bytes!( - "../../target/wasm32-unknown-unknown/release/mpc_contract.wasm" - )) - .await?; - tracing::info!("deployed mpc contract"); + let docker_client = DockerClient::default(); + let nodes = mpc_recovery_integration_tests::multichain::run(nodes, &docker_client).await?; - let accounts = futures::future::join_all((0..nodes).map(|_| worker.dev_create_account())) - .await - .into_iter() - .collect::, _>>()?; - let mut node_futures = Vec::new(); - for (i, account) in accounts.iter().enumerate() { - let node = containers::Node::run( - &docker_client, - NETWORK, - i as u64, - &sandbox.address, - mpc_contract.id(), - account.id(), - account.secret_key(), - ); - node_futures.push(node); - } - let nodes = futures::future::join_all(node_futures) - .await - .into_iter() - .collect::, _>>()?; - let participants: HashMap = accounts - .iter() - .cloned() - .enumerate() - .zip(&nodes) - .map(|((i, account), node)| { - ( - account.id().clone(), - Participant { - id: i as u32, - account_id: account.id().clone(), - url: node.address.clone(), - }, - ) - }) - .collect(); - mpc_contract - .call("init") - .args_json(json!({ - "threshold": 2, - "participants": participants - })) - .transact() - .await? - .into_result()?; - - let rpc_client = near_fetch::Client::new(&sandbox.local_address); - - f(MultichainTestContext { - worker, - rpc_client, - mpc_contract, - near_rpc: sandbox.address, - }) - .await?; + let rpc_client = near_fetch::Client::new(&nodes.ctx().sandbox.local_address); + f(MultichainTestContext { nodes, rpc_client }).await?; Ok(()) } @@ -259,14 +185,14 @@ mod wait_for { use mpc_contract::ProtocolContractState; use mpc_contract::RunningContractState; - pub async fn running_mpc( - ctx: &MultichainTestContext, + pub async fn running_mpc<'a>( + ctx: &MultichainTestContext<'a>, epoch: u64, ) -> anyhow::Result { let is_running = || async { let state: ProtocolContractState = ctx .rpc_client - .view(ctx.mpc_contract.id(), "state", ()) + .view(ctx.nodes.ctx().mpc_contract.id(), "state", ()) .await?; match state { diff --git a/integration-tests/tests/mpc/positive.rs b/integration-tests/tests/mpc/positive.rs index 247cb187c..c199ff18a 100644 --- a/integration-tests/tests/mpc/positive.rs +++ b/integration-tests/tests/mpc/positive.rs @@ -1,8 +1,6 @@ use crate::mpc::{add_pk_and_check_validity, fetch_recovery_pk, new_random_account}; use crate::{account, key, with_nodes, MpcCheck}; use hyper::StatusCode; -use near_workspaces::types::AccessKeyPermission; - use mpc_recovery::{ gcp::value::{FromValue, IntoValue}, sign_node::user_credentials::EncryptedUserCredentials, diff --git a/integration-tests/tests/multichain/mod.rs b/integration-tests/tests/multichain/mod.rs index fb8c5ccb4..fd78570a3 100644 --- a/integration-tests/tests/multichain/mod.rs +++ b/integration-tests/tests/multichain/mod.rs @@ -1,39 +1,30 @@ use crate::{wait_for, with_multichain_nodes}; -use mpc_recovery_integration_tests::containers; use test_log::test; #[test(tokio::test)] async fn test_multichain_reshare() -> anyhow::Result<()> { - with_multichain_nodes(3, |ctx| async move { - // Wait for network to complete key generation - let state_0 = wait_for::running_mpc(&ctx, 0).await?; - assert_eq!(state_0.participants.len(), 3); + with_multichain_nodes(3, |mut ctx| { + Box::pin(async move { + // Wait for network to complete key generation + let state_0 = wait_for::running_mpc(&ctx, 0).await?; + assert_eq!(state_0.participants.len(), 3); - let docker_client = containers::DockerClient::default(); - let account = ctx.worker.dev_create_account().await?; - let node = containers::Node::run( - &docker_client, - crate::NETWORK, - 3, - &ctx.near_rpc, - ctx.mpc_contract.id(), - account.id(), - account.secret_key(), - ) - .await?; + let account = ctx.nodes.ctx().worker.dev_create_account().await?; + ctx.nodes + .add_node(3, account.id(), account.secret_key()) + .await?; - // Wait for network to complete key reshare - let state_1 = wait_for::running_mpc(&ctx, 1).await?; - assert_eq!(state_1.participants.len(), 4); + // Wait for network to complete key reshare + let state_1 = wait_for::running_mpc(&ctx, 1).await?; + assert_eq!(state_1.participants.len(), 4); - assert_eq!( - state_0.public_key, state_1.public_key, - "public key must stay the same" - ); + assert_eq!( + state_0.public_key, state_1.public_key, + "public key must stay the same" + ); - drop(node); - - Ok(()) + Ok(()) + }) }) .await } diff --git a/node/src/cli.rs b/node/src/cli.rs new file mode 100644 index 000000000..7863417c0 --- /dev/null +++ b/node/src/cli.rs @@ -0,0 +1,150 @@ +use crate::protocol::MpcSignProtocol; +use crate::{indexer, web}; +use cait_sith::protocol::Participant; +use clap::Parser; +use local_ip_address::local_ip; +use near_crypto::{InMemorySigner, SecretKey}; +use near_primitives::types::AccountId; +use tokio::sync::mpsc; +use tracing_subscriber::EnvFilter; +use url::Url; + +#[derive(Parser, Debug)] +pub enum Cli { + Start { + /// Node ID + #[arg(long, value_parser = parse_participant, env("MPC_RECOVERY_NODE_ID"))] + node_id: Participant, + /// NEAR RPC address + #[arg( + long, + env("MPC_RECOVERY_NEAR_RPC"), + default_value("https://rpc.testnet.near.org") + )] + near_rpc: String, + /// MPC contract id + #[arg(long, env("MPC_RECOVERY_CONTRACT_ID"))] + mpc_contract_id: AccountId, + /// This node's account id + #[arg(long, env("MPC_RECOVERY_ACCOUNT"))] + account: AccountId, + /// This node's account ed25519 secret key + #[arg(long, env("MPC_RECOVERY_ACCOUNT_SK"))] + account_sk: SecretKey, + /// The web port for this server + #[arg(long, env("MPC_RECOVERY_WEB_PORT"))] + web_port: u16, + }, +} + +fn parse_participant(arg: &str) -> Result { + let participant_id: u32 = arg.parse()?; + Ok(participant_id.into()) +} + +impl Cli { + pub fn into_str_args(self) -> Vec { + match self { + Cli::Start { + node_id, + near_rpc, + mpc_contract_id, + account, + account_sk, + web_port, + } => { + vec![ + "start".to_string(), + "--node-id".to_string(), + u32::from(node_id).to_string(), + "--near-rpc".to_string(), + near_rpc, + "--mpc-contract-id".to_string(), + mpc_contract_id.to_string(), + "--account".to_string(), + account.to_string(), + "--account-sk".to_string(), + account_sk.to_string(), + "--web-port".to_string(), + web_port.to_string(), + ] + } + } + } +} + +pub fn run(cmd: Cli) -> anyhow::Result<()> { + // Install global collector configured based on RUST_LOG env var. + let mut subscriber = tracing_subscriber::fmt() + .with_thread_ids(true) + .with_env_filter(EnvFilter::from_default_env()); + // Check if running in Google Cloud Run: https://cloud.google.com/run/docs/container-contract#services-env-vars + if std::env::var("K_SERVICE").is_ok() { + // Disable colored logging as it messes up GCP's log formatting + subscriber = subscriber.with_ansi(false); + } + subscriber.init(); + let _span = tracing::trace_span!("cli").entered(); + + match cmd { + Cli::Start { + node_id, + near_rpc, + web_port, + mpc_contract_id, + account, + account_sk, + } => { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + let (sender, receiver) = mpsc::channel(16384); + + let my_ip = local_ip()?; + let my_address = Url::parse(&format!("http://{my_ip}:{web_port}"))?; + tracing::info!(%my_address, "address detected"); + let rpc_client = near_fetch::Client::new(&near_rpc); + tracing::debug!(rpc_addr = rpc_client.rpc_addr(), "rpc client initialized"); + let signer = InMemorySigner::from_secret_key(account, account_sk); + let (protocol, protocol_state) = MpcSignProtocol::init( + node_id, + my_address, + mpc_contract_id.clone(), + rpc_client.clone(), + signer.clone(), + receiver, + ); + tracing::debug!("protocol initialized"); + let protocol_handle = tokio::spawn(async move { + protocol.run().await.unwrap(); + }); + tracing::debug!("protocol thread spawned"); + let mpc_contract_id_cloned = mpc_contract_id.clone(); + let web_handle = tokio::spawn(async move { + web::run( + web_port, + mpc_contract_id_cloned, + rpc_client, + signer, + sender, + protocol_state, + ) + .await + .unwrap(); + }); + tracing::debug!("protocol http server spawned"); + + protocol_handle.await?; + web_handle.await?; + tracing::debug!("spinning down"); + + anyhow::Ok(()) + })?; + indexer::run(&near_rpc, mpc_contract_id)?; + } + } + + Ok(()) +} diff --git a/node/src/lib.rs b/node/src/lib.rs index fffa16497..d59a02786 100644 --- a/node/src/lib.rs +++ b/node/src/lib.rs @@ -1,3 +1,4 @@ +pub mod cli; pub mod http_client; pub mod indexer; pub mod protocol; diff --git a/node/src/main.rs b/node/src/main.rs index 08aa533b0..8eaf8249f 100644 --- a/node/src/main.rs +++ b/node/src/main.rs @@ -1,118 +1,6 @@ -use cait_sith::protocol::Participant; use clap::Parser; -use local_ip_address::local_ip; -use mpc_recovery_node::protocol::MpcSignProtocol; -use near_crypto::{InMemorySigner, SecretKey}; -use near_primitives::types::AccountId; -use tokio::sync::mpsc; -use tracing_subscriber::EnvFilter; -use url::Url; - -#[derive(Parser, Debug)] -enum Cli { - Start { - /// Node ID - #[arg(long, value_parser = parse_participant, env("MPC_RECOVERY_NODE_ID"))] - node_id: Participant, - /// NEAR RPC address - #[arg( - long, - env("MPC_RECOVERY_NEAR_RPC"), - default_value("https://rpc.testnet.near.org") - )] - near_rpc: String, - /// MPC contract id - #[arg(long, env("MPC_RECOVERY_CONTRACT_ID"))] - mpc_contract_id: AccountId, - /// This node's account id - #[arg(long, env("MPC_RECOVERY_ACCOUNT"))] - account: AccountId, - /// This node's account ed25519 secret key - #[arg(long, env("MPC_RECOVERY_ACCOUNT_SK"))] - account_sk: SecretKey, - /// The web port for this server - #[arg(long, env("MPC_RECOVERY_WEB_PORT"))] - web_port: u16, - }, -} - -fn parse_participant(arg: &str) -> Result { - let participant_id: u32 = arg.parse()?; - Ok(participant_id.into()) -} +use mpc_recovery_node::cli::Cli; fn main() -> anyhow::Result<()> { - // Install global collector configured based on RUST_LOG env var. - let mut subscriber = tracing_subscriber::fmt() - .with_thread_ids(true) - .with_env_filter(EnvFilter::from_default_env()); - // Check if running in Google Cloud Run: https://cloud.google.com/run/docs/container-contract#services-env-vars - if std::env::var("K_SERVICE").is_ok() { - // Disable colored logging as it messes up GCP's log formatting - subscriber = subscriber.with_ansi(false); - } - subscriber.init(); - let _span = tracing::trace_span!("cli").entered(); - - match Cli::parse() { - Cli::Start { - node_id, - near_rpc, - web_port, - mpc_contract_id, - account, - account_sk, - } => { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() - .block_on(async { - let (sender, receiver) = mpsc::channel(16384); - - let my_ip = local_ip()?; - let my_address = Url::parse(&format!("http://{my_ip}:{web_port}"))?; - tracing::info!(%my_address, "address detected"); - let rpc_client = near_fetch::Client::new(&near_rpc); - tracing::debug!(rpc_addr = rpc_client.rpc_addr(), "rpc client initialized"); - let signer = InMemorySigner::from_secret_key(account, account_sk); - let (protocol, protocol_state) = MpcSignProtocol::init( - node_id, - my_address, - mpc_contract_id.clone(), - rpc_client.clone(), - signer.clone(), - receiver, - ); - tracing::debug!("protocol initialized"); - let protocol_handle = tokio::spawn(async move { - protocol.run().await.unwrap(); - }); - tracing::debug!("protocol thread spawned"); - let mpc_contract_id_cloned = mpc_contract_id.clone(); - let web_handle = tokio::spawn(async move { - mpc_recovery_node::web::run( - web_port, - mpc_contract_id_cloned, - rpc_client, - signer, - sender, - protocol_state, - ) - .await - .unwrap(); - }); - tracing::debug!("protocol http server spawned"); - - protocol_handle.await?; - web_handle.await?; - tracing::debug!("spinning down"); - - anyhow::Ok(()) - })?; - mpc_recovery_node::indexer::run(&near_rpc, mpc_contract_id)?; - } - } - - Ok(()) + mpc_recovery_node::cli::run(Cli::parse()) }