diff --git a/circuit-types/src/lib.rs b/circuit-types/src/lib.rs index d4aa624a2..e14bdac63 100644 --- a/circuit-types/src/lib.rs +++ b/circuit-types/src/lib.rs @@ -91,6 +91,12 @@ impl From for AuthenticatedBool { } } +impl From for AuthenticatedScalar { + fn from(value: AuthenticatedBool) -> Self { + value.0 + } +} + // ----------- // | Helpers | // ----------- diff --git a/circuits/src/mpc_circuits/match.rs b/circuits/src/mpc_circuits/match.rs index f3cb23f2e..4b3cad43c 100644 --- a/circuits/src/mpc_circuits/match.rs +++ b/circuits/src/mpc_circuits/match.rs @@ -11,8 +11,6 @@ use crate::mpc_gadgets::{comparators::min, fixed_point::FixedPointMpcGadget}; /// Executes a match computation that returns matches from a given order /// intersection /// -/// If no match is found, the values are opened to a zero'd list -/// /// We do not check whether the orders are valid and overlapping, this is left /// to the `VALID MATCH MPC` circuit, which both parties verify before opening /// the match result. So, if the match result is invalid, the orders don't diff --git a/circuits/src/mpc_circuits/mod.rs b/circuits/src/mpc_circuits/mod.rs index 862dbbb34..6672fb174 100644 --- a/circuits/src/mpc_circuits/mod.rs +++ b/circuits/src/mpc_circuits/mod.rs @@ -1,3 +1,4 @@ //! Defines circuitry for specific multiparty computations performed by a //! relayer pub mod r#match; +pub mod settle; diff --git a/circuits/src/mpc_circuits/settle.rs b/circuits/src/mpc_circuits/settle.rs new file mode 100644 index 000000000..cfbbe9349 --- /dev/null +++ b/circuits/src/mpc_circuits/settle.rs @@ -0,0 +1,214 @@ +//! Settles a match into secret shared wallets + +use circuit_types::{r#match::AuthenticatedMatchResult, wallet::AuthenticatedWalletShare}; + +use crate::{ + mpc_gadgets::comparators::cond_select_vec, + zk_circuits::valid_commitments::OrderSettlementIndices, +}; + +/// Settles a match into two wallets and returns the updated wallet shares +/// +/// We settle directly into the public shares both for efficiency and to avoid +/// the need to share private shares +pub fn settle_match( + party0_settle_indices: OrderSettlementIndices, + party1_settle_indices: OrderSettlementIndices, + party0_public_share: &AuthenticatedWalletShare, + party1_public_share: &AuthenticatedWalletShare, + match_res: &AuthenticatedMatchResult, +) -> ( + AuthenticatedWalletShare, + AuthenticatedWalletShare, +) +where + [(); MAX_BALANCES + MAX_ORDERS + MAX_FEES]: Sized, +{ + let mut party0_new_shares = party0_public_share.clone(); + let mut party1_new_shares = party1_public_share.clone(); + + // Subtract the amount of base token exchanged from each party's order + let party0_order = &mut party0_new_shares.orders[party0_settle_indices.order as usize]; + let party1_order = &mut party1_new_shares.orders[party1_settle_indices.order as usize]; + party0_order.amount = &party0_order.amount - &match_res.base_amount; + party1_order.amount = &party1_order.amount - &match_res.base_amount; + + // Select the correct ordering of the two mints for buy and sell side + let mut amounts = cond_select_vec( + &match_res.direction, + &[ + match_res.quote_amount.clone(), + match_res.base_amount.clone(), + ], + &[ + match_res.base_amount.clone(), + match_res.quote_amount.clone(), + ], + ); + + let party0_buy = amounts.remove(0); + let party0_sell = amounts.remove(0); + + // Update the balances of the two parties + let party0_buy_balance = + &mut party0_new_shares.balances[party0_settle_indices.balance_receive as usize]; + party0_buy_balance.amount = &party0_buy_balance.amount + &party0_buy; + + let party1_buy_balance = + &mut party1_new_shares.balances[party1_settle_indices.balance_receive as usize]; + party1_buy_balance.amount = &party1_buy_balance.amount + &party0_sell; + + let party0_sell_balance = + &mut party0_new_shares.balances[party0_settle_indices.balance_send as usize]; + party0_sell_balance.amount = &party0_sell_balance.amount - &party0_sell; + + let party1_sell_balance = + &mut party1_new_shares.balances[party1_settle_indices.balance_send as usize]; + party1_sell_balance.amount = &party1_sell_balance.amount - &party0_buy; + + (party0_new_shares, party1_new_shares) +} + +#[cfg(test)] +mod test { + use std::iter; + + use ark_mpc::{PARTY0, PARTY1}; + use circuit_types::{ + order::OrderSide, + r#match::MatchResult, + traits::{BaseType, MpcBaseType, MpcType}, + SizedWalletShare, + }; + use constants::{Scalar, MAX_BALANCES, MAX_ORDERS}; + use rand::{thread_rng, Rng}; + use renegade_crypto::fields::scalar_to_biguint; + use test_helpers::mpc_network::execute_mock_mpc; + + use crate::{ + mpc_circuits::settle::settle_match, + zk_circuits::{ + valid_commitments::OrderSettlementIndices, + valid_match_settle::test_helpers::apply_match_to_shares, + }, + }; + + /// The parameterization of a test + #[derive(Clone)] + struct SettlementTest { + /// The match result to settle + match_res: MatchResult, + /// The shares of the first party before settlement + party0_pre_shares: SizedWalletShare, + /// The indices of the first party's order and balances to settle + party0_indices: OrderSettlementIndices, + /// The shares of the second party before settlement + party1_pre_shares: SizedWalletShare, + /// The indices of the second party's order and balances to settle + party1_indices: OrderSettlementIndices, + /// The shares of the first party after settlement + party0_post_shares: SizedWalletShare, + /// The shares of the second party after settlement + party1_post_shares: SizedWalletShare, + } + + /// Get a dummy set of inputs for a settlement circuit + fn generate_test_params() -> SettlementTest { + let mut rng = thread_rng(); + let quote_mint = scalar_to_biguint(&Scalar::random(&mut rng)); + let base_mint = scalar_to_biguint(&Scalar::random(&mut rng)); + + let match_res = MatchResult { + quote_mint: quote_mint.clone(), + base_mint: base_mint.clone(), + quote_amount: rng.gen(), + base_amount: rng.gen(), + direction: rng.gen_bool(0.5), + max_minus_min_amount: 0, // Unused + min_amount_order_index: false, // Unused + }; + + let party0_pre_shares = random_shares(); + let party0_indices = random_indices(); + let party0_side = OrderSide::from(match_res.direction as u64); + let party0_post_shares = + apply_match_to_shares(&party0_pre_shares, &party0_indices, &match_res, party0_side); + + let party1_pre_shares = random_shares(); + let party1_indices = random_indices(); + let party1_side = party0_side.opposite(); + let party1_post_shares = + apply_match_to_shares(&party1_pre_shares, &party1_indices, &match_res, party1_side); + + SettlementTest { + match_res, + party0_pre_shares, + party0_indices, + party0_post_shares, + party1_pre_shares, + party1_indices, + party1_post_shares, + } + } + + /// Generate a random set of wallet shares + fn random_shares() -> SizedWalletShare { + let mut rng = thread_rng(); + SizedWalletShare::from_scalars(&mut iter::from_fn(|| Some(Scalar::random(&mut rng)))) + } + + /// Generate a random set of settlement indices + fn random_indices() -> OrderSettlementIndices { + let balance_send = random_index(MAX_BALANCES); + let mut balance_receive = random_index(MAX_BALANCES); + + while balance_send == balance_receive { + balance_receive = random_index(MAX_BALANCES); + } + + OrderSettlementIndices { + order: random_index(MAX_ORDERS), + balance_send, + balance_receive, + } + } + + // Generate a random index bounded by a max + fn random_index(max: usize) -> u64 { + let mut rng = thread_rng(); + rng.gen_range(0..max) as u64 + } + + /// Tests settlement of a match into two wallets + #[tokio::test] + async fn test_settle() { + // Generate a randomized test + let params = generate_test_params(); + + let (res, _) = execute_mock_mpc(move |fabric| { + let params = params.clone(); + + async move { + let party0_shares = params.party0_pre_shares.allocate(PARTY1, &fabric); + let party1_shares = params.party1_pre_shares.allocate(PARTY0, &fabric); + let match_res = params.match_res.allocate(PARTY0, &fabric); + + let (party0_post_shares, party1_post_shares) = settle_match( + params.party0_indices, + params.party1_indices, + &party0_shares, + &party1_shares, + &match_res, + ); + + let party0_res = party0_post_shares.open_and_authenticate().await.unwrap(); + let party1_res = party1_post_shares.open_and_authenticate().await.unwrap(); + + party0_res == params.party0_post_shares && party1_res == params.party1_post_shares + } + }) + .await; + + assert!(res); + } +} diff --git a/circuits/src/mpc_gadgets/comparators.rs b/circuits/src/mpc_gadgets/comparators.rs index ac682bc51..2dcea7821 100644 --- a/circuits/src/mpc_gadgets/comparators.rs +++ b/circuits/src/mpc_gadgets/comparators.rs @@ -2,8 +2,9 @@ use std::iter; -use circuit_types::Fabric; +use circuit_types::{AuthenticatedBool, Fabric}; use constants::{AuthenticatedScalar, Scalar}; +use itertools::Itertools; use crate::SCALAR_BITS_MINUS_TWO; @@ -201,19 +202,22 @@ pub fn min( /// Computes res = a if s else b pub fn cond_select( - s: &AuthenticatedScalar, + s: &AuthenticatedBool, a: &AuthenticatedScalar, b: &AuthenticatedScalar, ) -> AuthenticatedScalar { - let selectors = - AuthenticatedScalar::batch_mul(&[a.clone(), b.clone()], &[s.clone(), Scalar::one() - s]); + let selector: AuthenticatedScalar = s.clone().into(); + let terms = AuthenticatedScalar::batch_mul( + &[a.clone(), b.clone()], + &[selector.clone(), Scalar::one() - &selector], + ); - &selectors[0] + &selectors[1] + &terms[0] + &terms[1] } /// Computes res = [a] if s else [b] where a and b are slices pub fn cond_select_vec( - s: &AuthenticatedScalar, + s: &AuthenticatedBool, a: &[AuthenticatedScalar], b: &[AuthenticatedScalar], ) -> Vec { @@ -222,24 +226,27 @@ pub fn cond_select_vec( b.len(), "cond_select_vec requires equal length vectors" ); + // Batch mul each a value with `s` and each `b` value with 1 - s - let selectors = AuthenticatedScalar::batch_mul( + let n = a.len(); + let selector: AuthenticatedScalar = s.clone().into(); + let terms = AuthenticatedScalar::batch_mul( &a.iter() .cloned() .chain(b.iter().cloned()) .collect::>(), - &iter::repeat(s.clone()) - .take(a.len()) - .chain(iter::repeat(Scalar::one() - s).take(b.len())) - .collect::>(), + &iter::repeat(selector.clone()) + .take(n) + .chain(iter::repeat(Scalar::one() - &selector).take(n)) + .collect_vec(), ); // Destruct the vector by zipping its first half with its second half let mut result = Vec::with_capacity(a.len()); - for (a_selected, b_selected) in selectors[..a.len()] + for (a_selected, b_selected) in terms[..a.len()] .as_ref() .iter() - .zip(selectors[a.len()..].iter()) + .zip(terms[a.len()..].iter()) { result.push(a_selected + b_selected) } @@ -252,10 +259,11 @@ mod test { use std::ops::Neg; use ark_mpc::PARTY0; + use circuit_types::traits::MpcBaseType; use constants::Scalar; use itertools::Itertools; use num_bigint::RandBigInt; - use rand::{thread_rng, RngCore}; + use rand::{thread_rng, Rng, RngCore}; use renegade_crypto::fields::biguint_to_scalar; use test_helpers::mpc_network::execute_mock_mpc; @@ -391,14 +399,14 @@ mod test { let mut rng = thread_rng(); let a = Scalar::random(&mut rng); let b = Scalar::random(&mut rng); - let s = Scalar::from(rng.next_u64() % 2); + let s = rng.gen_bool(0.5); - let expected = if s == Scalar::one() { a } else { b }; + let expected = if s { a } else { b }; let (res, _) = execute_mock_mpc(move |fabric| async move { - let a = fabric.share_scalar(a, PARTY0); - let b = fabric.share_scalar(b, PARTY0); - let s = fabric.share_scalar(s, PARTY0); + let a = a.allocate(PARTY0, &fabric); + let b = b.allocate(PARTY0, &fabric); + let s = s.allocate(PARTY0, &fabric); let res = cond_select(&s, &a, &b); @@ -417,13 +425,9 @@ mod test { let a = (0..N).map(|_| Scalar::random(&mut rng)).collect_vec(); let b = (0..N).map(|_| Scalar::random(&mut rng)).collect_vec(); - let s = Scalar::from(rng.next_u64() % 2); + let s = rng.gen_bool(0.5); - let expected = if s == Scalar::one() { - a.clone() - } else { - b.clone() - }; + let expected = if s { a.clone() } else { b.clone() }; let (res, _) = execute_mock_mpc(move |fabric| { let a = a.clone(); @@ -431,7 +435,7 @@ mod test { async move { let a = fabric.batch_share_scalar(a, PARTY0); let b = fabric.batch_share_scalar(b, PARTY0); - let s = fabric.share_scalar(s, PARTY0); + let s = s.allocate(PARTY0, &fabric); let res = cond_select_vec(&s, &a, &b); diff --git a/circuits/src/zk_circuits/valid_match_settle/mod.rs b/circuits/src/zk_circuits/valid_match_settle/mod.rs index 331bc48e4..5c6846609 100644 --- a/circuits/src/zk_circuits/valid_match_settle/mod.rs +++ b/circuits/src/zk_circuits/valid_match_settle/mod.rs @@ -199,6 +199,7 @@ pub mod test_helpers { balance::Balance, order::{Order, OrderSide}, r#match::MatchResult, + wallet::WalletShare, }; use constants::Scalar; use rand::{distributions::uniform::SampleRange, thread_rng, RngCore}; @@ -207,8 +208,8 @@ pub mod test_helpers { test_helpers::random_orders_and_match, zk_circuits::{ test_helpers::{ - create_wallet_shares, SizedWallet, SizedWalletShare, INITIAL_WALLET, MAX_BALANCES, - MAX_FEES, MAX_ORDERS, + create_wallet_shares, SizedWallet, INITIAL_WALLET, MAX_BALANCES, MAX_FEES, + MAX_ORDERS, }, valid_commitments::OrderSettlementIndices, }, @@ -332,12 +333,19 @@ pub mod test_helpers { /// Applies a match to the shares of a wallet /// /// Returns a new wallet share with the match applied - fn apply_match_to_shares( - shares: &SizedWalletShare, + pub(crate) fn apply_match_to_shares< + const MAX_BALANCES: usize, + const MAX_ORDERS: usize, + const MAX_FEES: usize, + >( + shares: &WalletShare, indices: &OrderSettlementIndices, match_res: &MatchResult, side: OrderSide, - ) -> SizedWalletShare { + ) -> WalletShare + where + [(); MAX_BALANCES + MAX_ORDERS + MAX_FEES]: Sized, + { let (send_amt, recv_amt) = match side { // Buy side; send quote, receive base OrderSide::Buy => (match_res.quote_amount, match_res.base_amount),