Skip to content

Commit

Permalink
circuits: mpc-ciruits: settle: Add settlement ciruit
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Nov 24, 2023
1 parent 2cdbee6 commit 918ebcc
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 33 deletions.
6 changes: 6 additions & 0 deletions circuit-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ impl From<AuthenticatedScalar> for AuthenticatedBool {
}
}

impl From<AuthenticatedBool> for AuthenticatedScalar {
fn from(value: AuthenticatedBool) -> Self {
value.0
}
}

// -----------
// | Helpers |
// -----------
Expand Down
2 changes: 0 additions & 2 deletions circuits/src/mpc_circuits/match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ use crate::mpc_gadgets::{comparators::min, fixed_point::FixedPointMpcGadget};
/// Executes a match computation that returns matches from a given order
/// intersection
///
/// If no match is found, the values are opened to a zero'd list
///
/// We do not check whether the orders are valid and overlapping, this is left
/// to the `VALID MATCH MPC` circuit, which both parties verify before opening
/// the match result. So, if the match result is invalid, the orders don't
Expand Down
1 change: 1 addition & 0 deletions circuits/src/mpc_circuits/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Defines circuitry for specific multiparty computations performed by a
//! relayer
pub mod r#match;
pub mod settle;
214 changes: 214 additions & 0 deletions circuits/src/mpc_circuits/settle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
//! Settles a match into secret shared wallets
use circuit_types::{r#match::AuthenticatedMatchResult, wallet::AuthenticatedWalletShare};

use crate::{
mpc_gadgets::comparators::cond_select_vec,
zk_circuits::valid_commitments::OrderSettlementIndices,
};

/// Settles a match into two wallets and returns the updated wallet shares
///
/// We settle directly into the public shares both for efficiency and to avoid
/// the need to share private shares
pub fn settle_match<const MAX_BALANCES: usize, const MAX_ORDERS: usize, const MAX_FEES: usize>(
party0_settle_indices: OrderSettlementIndices,
party1_settle_indices: OrderSettlementIndices,
party0_public_share: &AuthenticatedWalletShare<MAX_BALANCES, MAX_ORDERS, MAX_FEES>,
party1_public_share: &AuthenticatedWalletShare<MAX_BALANCES, MAX_ORDERS, MAX_FEES>,
match_res: &AuthenticatedMatchResult,
) -> (
AuthenticatedWalletShare<MAX_BALANCES, MAX_ORDERS, MAX_FEES>,
AuthenticatedWalletShare<MAX_BALANCES, MAX_ORDERS, MAX_FEES>,
)
where
[(); MAX_BALANCES + MAX_ORDERS + MAX_FEES]: Sized,
{
let mut party0_new_shares = party0_public_share.clone();
let mut party1_new_shares = party1_public_share.clone();

// Subtract the amount of base token exchanged from each party's order
let party0_order = &mut party0_new_shares.orders[party0_settle_indices.order as usize];
let party1_order = &mut party1_new_shares.orders[party1_settle_indices.order as usize];
party0_order.amount = &party0_order.amount - &match_res.base_amount;
party1_order.amount = &party1_order.amount - &match_res.base_amount;

// Select the correct ordering of the two mints for buy and sell side
let mut amounts = cond_select_vec(
&match_res.direction,
&[
match_res.quote_amount.clone(),
match_res.base_amount.clone(),
],
&[
match_res.base_amount.clone(),
match_res.quote_amount.clone(),
],
);

let party0_buy = amounts.remove(0);
let party0_sell = amounts.remove(0);

// Update the balances of the two parties
let party0_buy_balance =
&mut party0_new_shares.balances[party0_settle_indices.balance_receive as usize];
party0_buy_balance.amount = &party0_buy_balance.amount + &party0_buy;

let party1_buy_balance =
&mut party1_new_shares.balances[party1_settle_indices.balance_receive as usize];
party1_buy_balance.amount = &party1_buy_balance.amount + &party0_sell;

let party0_sell_balance =
&mut party0_new_shares.balances[party0_settle_indices.balance_send as usize];
party0_sell_balance.amount = &party0_sell_balance.amount - &party0_sell;

let party1_sell_balance =
&mut party1_new_shares.balances[party1_settle_indices.balance_send as usize];
party1_sell_balance.amount = &party1_sell_balance.amount - &party0_buy;

(party0_new_shares, party1_new_shares)
}

#[cfg(test)]
mod test {
use std::iter;

use ark_mpc::{PARTY0, PARTY1};
use circuit_types::{
order::OrderSide,
r#match::MatchResult,
traits::{BaseType, MpcBaseType, MpcType},
SizedWalletShare,
};
use constants::{Scalar, MAX_BALANCES, MAX_ORDERS};
use rand::{thread_rng, Rng};
use renegade_crypto::fields::scalar_to_biguint;
use test_helpers::mpc_network::execute_mock_mpc;

use crate::{
mpc_circuits::settle::settle_match,
zk_circuits::{
valid_commitments::OrderSettlementIndices,
valid_match_settle::test_helpers::apply_match_to_shares,
},
};

/// The parameterization of a test
#[derive(Clone)]
struct SettlementTest {
/// The match result to settle
match_res: MatchResult,
/// The shares of the first party before settlement
party0_pre_shares: SizedWalletShare,
/// The indices of the first party's order and balances to settle
party0_indices: OrderSettlementIndices,
/// The shares of the second party before settlement
party1_pre_shares: SizedWalletShare,
/// The indices of the second party's order and balances to settle
party1_indices: OrderSettlementIndices,
/// The shares of the first party after settlement
party0_post_shares: SizedWalletShare,
/// The shares of the second party after settlement
party1_post_shares: SizedWalletShare,
}

/// Get a dummy set of inputs for a settlement circuit
fn generate_test_params() -> SettlementTest {
let mut rng = thread_rng();
let quote_mint = scalar_to_biguint(&Scalar::random(&mut rng));
let base_mint = scalar_to_biguint(&Scalar::random(&mut rng));

let match_res = MatchResult {
quote_mint: quote_mint.clone(),
base_mint: base_mint.clone(),
quote_amount: rng.gen(),
base_amount: rng.gen(),
direction: rng.gen_bool(0.5),
max_minus_min_amount: 0, // Unused
min_amount_order_index: false, // Unused
};

let party0_pre_shares = random_shares();
let party0_indices = random_indices();
let party0_side = OrderSide::from(match_res.direction as u64);
let party0_post_shares =
apply_match_to_shares(&party0_pre_shares, &party0_indices, &match_res, party0_side);

let party1_pre_shares = random_shares();
let party1_indices = random_indices();
let party1_side = party0_side.opposite();
let party1_post_shares =
apply_match_to_shares(&party1_pre_shares, &party1_indices, &match_res, party1_side);

SettlementTest {
match_res,
party0_pre_shares,
party0_indices,
party0_post_shares,
party1_pre_shares,
party1_indices,
party1_post_shares,
}
}

/// Generate a random set of wallet shares
fn random_shares() -> SizedWalletShare {
let mut rng = thread_rng();
SizedWalletShare::from_scalars(&mut iter::from_fn(|| Some(Scalar::random(&mut rng))))
}

/// Generate a random set of settlement indices
fn random_indices() -> OrderSettlementIndices {
let balance_send = random_index(MAX_BALANCES);
let mut balance_receive = random_index(MAX_BALANCES);

while balance_send == balance_receive {
balance_receive = random_index(MAX_BALANCES);
}

OrderSettlementIndices {
order: random_index(MAX_ORDERS),
balance_send,
balance_receive,
}
}

// Generate a random index bounded by a max
fn random_index(max: usize) -> u64 {
let mut rng = thread_rng();
rng.gen_range(0..max) as u64
}

/// Tests settlement of a match into two wallets
#[tokio::test]
async fn test_settle() {
// Generate a randomized test
let params = generate_test_params();

let (res, _) = execute_mock_mpc(move |fabric| {
let params = params.clone();

async move {
let party0_shares = params.party0_pre_shares.allocate(PARTY1, &fabric);
let party1_shares = params.party1_pre_shares.allocate(PARTY0, &fabric);
let match_res = params.match_res.allocate(PARTY0, &fabric);

let (party0_post_shares, party1_post_shares) = settle_match(
params.party0_indices,
params.party1_indices,
&party0_shares,
&party1_shares,
&match_res,
);

let party0_res = party0_post_shares.open_and_authenticate().await.unwrap();
let party1_res = party1_post_shares.open_and_authenticate().await.unwrap();

party0_res == params.party0_post_shares && party1_res == params.party1_post_shares
}
})
.await;

assert!(res);
}
}
56 changes: 30 additions & 26 deletions circuits/src/mpc_gadgets/comparators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
use std::iter;

use circuit_types::Fabric;
use circuit_types::{AuthenticatedBool, Fabric};
use constants::{AuthenticatedScalar, Scalar};
use itertools::Itertools;

use crate::SCALAR_BITS_MINUS_TWO;

Expand Down Expand Up @@ -201,19 +202,22 @@ pub fn min<const D: usize>(

/// Computes res = a if s else b
pub fn cond_select(
s: &AuthenticatedScalar,
s: &AuthenticatedBool,
a: &AuthenticatedScalar,
b: &AuthenticatedScalar,
) -> AuthenticatedScalar {
let selectors =
AuthenticatedScalar::batch_mul(&[a.clone(), b.clone()], &[s.clone(), Scalar::one() - s]);
let selector: AuthenticatedScalar = s.clone().into();
let terms = AuthenticatedScalar::batch_mul(
&[a.clone(), b.clone()],
&[selector.clone(), Scalar::one() - &selector],
);

&selectors[0] + &selectors[1]
&terms[0] + &terms[1]
}

/// Computes res = [a] if s else [b] where a and b are slices
pub fn cond_select_vec(
s: &AuthenticatedScalar,
s: &AuthenticatedBool,
a: &[AuthenticatedScalar],
b: &[AuthenticatedScalar],
) -> Vec<AuthenticatedScalar> {
Expand All @@ -222,24 +226,27 @@ pub fn cond_select_vec(
b.len(),
"cond_select_vec requires equal length vectors"
);

// Batch mul each a value with `s` and each `b` value with 1 - s
let selectors = AuthenticatedScalar::batch_mul(
let n = a.len();
let selector: AuthenticatedScalar = s.clone().into();
let terms = AuthenticatedScalar::batch_mul(
&a.iter()
.cloned()
.chain(b.iter().cloned())
.collect::<Vec<_>>(),
&iter::repeat(s.clone())
.take(a.len())
.chain(iter::repeat(Scalar::one() - s).take(b.len()))
.collect::<Vec<_>>(),
&iter::repeat(selector.clone())
.take(n)
.chain(iter::repeat(Scalar::one() - &selector).take(n))
.collect_vec(),
);

// Destruct the vector by zipping its first half with its second half
let mut result = Vec::with_capacity(a.len());
for (a_selected, b_selected) in selectors[..a.len()]
for (a_selected, b_selected) in terms[..a.len()]
.as_ref()
.iter()
.zip(selectors[a.len()..].iter())
.zip(terms[a.len()..].iter())
{
result.push(a_selected + b_selected)
}
Expand All @@ -252,10 +259,11 @@ mod test {
use std::ops::Neg;

use ark_mpc::PARTY0;
use circuit_types::traits::MpcBaseType;
use constants::Scalar;
use itertools::Itertools;
use num_bigint::RandBigInt;
use rand::{thread_rng, RngCore};
use rand::{thread_rng, Rng, RngCore};
use renegade_crypto::fields::biguint_to_scalar;
use test_helpers::mpc_network::execute_mock_mpc;

Expand Down Expand Up @@ -391,14 +399,14 @@ mod test {
let mut rng = thread_rng();
let a = Scalar::random(&mut rng);
let b = Scalar::random(&mut rng);
let s = Scalar::from(rng.next_u64() % 2);
let s = rng.gen_bool(0.5);

let expected = if s == Scalar::one() { a } else { b };
let expected = if s { a } else { b };

let (res, _) = execute_mock_mpc(move |fabric| async move {
let a = fabric.share_scalar(a, PARTY0);
let b = fabric.share_scalar(b, PARTY0);
let s = fabric.share_scalar(s, PARTY0);
let a = a.allocate(PARTY0, &fabric);
let b = b.allocate(PARTY0, &fabric);
let s = s.allocate(PARTY0, &fabric);

let res = cond_select(&s, &a, &b);

Expand All @@ -417,21 +425,17 @@ mod test {

let a = (0..N).map(|_| Scalar::random(&mut rng)).collect_vec();
let b = (0..N).map(|_| Scalar::random(&mut rng)).collect_vec();
let s = Scalar::from(rng.next_u64() % 2);
let s = rng.gen_bool(0.5);

let expected = if s == Scalar::one() {
a.clone()
} else {
b.clone()
};
let expected = if s { a.clone() } else { b.clone() };

let (res, _) = execute_mock_mpc(move |fabric| {
let a = a.clone();
let b = b.clone();
async move {
let a = fabric.batch_share_scalar(a, PARTY0);
let b = fabric.batch_share_scalar(b, PARTY0);
let s = fabric.share_scalar(s, PARTY0);
let s = s.allocate(PARTY0, &fabric);

let res = cond_select_vec(&s, &a, &b);

Expand Down
Loading

0 comments on commit 918ebcc

Please sign in to comment.