Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

circuits: integration: Simplify integration tests #267

Merged
merged 2 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 73 additions & 103 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion circuit-types/src/order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl BaseType for OrderSide {
fn from_scalars<I: Iterator<Item = Scalar>>(i: &mut I) -> Self {
match scalar_to_u64(&i.next().unwrap()) {
val @ 0..=1 => OrderSide::from(val),
_ => panic!("invalid value for OrderSide"),
x => panic!("invalid value for OrderSide({x})"),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion circuit-types/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub trait BaseType: Clone {
/// This method is added to the `BaseType` trait for maximum flexibility, so
/// that types may be shared without requiring them to implement the
/// full `MpcBaseType` trait
async fn share_public(&self, owning_party: PartyId, fabric: Fabric) -> Self {
async fn share_public(&self, owning_party: PartyId, fabric: &Fabric) -> Self {
let self_scalars = self.to_scalars();
let res_scalars = fabric
.batch_share_plaintext(self_scalars, owning_party)
Expand Down
1 change: 1 addition & 0 deletions circuits/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ stats = ["ark-mpc/stats"]
name = "integration"
path = "integration/main.rs"
harness = false
required-features = ["test_helpers"]

[[bench]]
name = "valid_wallet_create"
Expand Down
4 changes: 2 additions & 2 deletions circuits/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ RUN sed -i 's/main.rs/dummy-main.rs/g' Cargo.toml
ENV RUSTFLAGS=-Awarnings
ENV RUST_BACKTRACE=1

RUN cargo build --quiet --test integration
RUN cargo build --quiet --test integration --all-features

# Edit the Cargo.toml back to the original, build the full executable
RUN sed -i 's/dummy-lib.rs/lib.rs/g' Cargo.toml
Expand All @@ -58,6 +58,6 @@ RUN sed -i 's/dummy-main.rs/main.rs/g' Cargo.toml
COPY circuits/src ./src
COPY circuits/integration ./integration

RUN cargo build --quiet --test integration
RUN cargo build --quiet --test integration --all-features

CMD [ "cargo", "test" ]
4 changes: 2 additions & 2 deletions circuits/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ services:
ports:
- "8000:8000"
command: >
cargo test --test integration --
cargo test --test integration --all-features --
--party 0
--port1 8000
--port2 9000
Expand All @@ -22,7 +22,7 @@ services:
ports:
- "9000:9000"
command: >
cargo test --test integration --
cargo test --test integration --all-features --
--party 1
--port1 9000
--port2 8000
Expand Down
8 changes: 3 additions & 5 deletions circuits/integration/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
#![feature(inherent_associated_types)]

mod mpc_circuits;
mod mpc_gadgets;
mod types;
mod zk_circuits;
mod zk_gadgets;
// mod zk_circuits;

use circuit_types::Fabric;
use clap::Parser;
use mpc_stark::MpcFabric;
use test_helpers::{integration_test_main, mpc_network::setup_mpc_fabric};
use util::logging::LevelFilter;

Expand Down Expand Up @@ -44,7 +42,7 @@ struct CliArgs {
#[derive(Debug, Clone)]
struct IntegrationTestArgs {
/// The MPC fabric to use during the course of the integration test
pub(crate) mpc_fabric: MpcFabric,
pub(crate) mpc_fabric: Fabric,
}

impl From<CliArgs> for IntegrationTestArgs {
Expand Down
284 changes: 27 additions & 257 deletions circuits/integration/mpc_circuits/match.rs
Original file line number Diff line number Diff line change
@@ -1,272 +1,42 @@
//! Groups integration tests for the match circuitry

use circuit_types::{
balance::Balance,
fixed_point::FixedPoint,
order::{Order, OrderSide},
r#match::MatchResult,
traits::{LinkableBaseType, MpcBaseType, MpcType, MultiproverCircuitBaseType},
};
use circuits::{
mpc_circuits::r#match::compute_match,
zk_circuits::valid_match_mpc::{AuthenticatedValidMatchMpcWitness, ValidMatchMpcCircuit},
};
use eyre::{eyre, Result};
use merlin::HashChainTranscript as Transcript;
use mpc_bulletproof::{r1cs_mpc::MpcProver, PedersenGens};
use mpc_stark::{PARTY0, PARTY1};
use num_bigint::BigUint;
use rand::thread_rng;
use test_helpers::integration_test_async;
use ark_mpc::PARTY0;
use circuit_types::traits::{BaseType, MpcBaseType, MpcType};
use circuits::{mpc_circuits::r#match::compute_match, test_helpers::random_orders_and_match};
use eyre::Result;
use test_helpers::{assert_eq_result, integration_test_async};

use crate::IntegrationTestArgs;

// --------------
// | Test Cases |
// --------------

/// Tests the match function with non overlapping orders for a variety of
/// failure cases
async fn test_match_no_match(test_args: IntegrationTestArgs) -> Result<()> {
// Convenience selector for brevity
let fabric = &test_args.mpc_fabric;
let mut rng = thread_rng();
let party_id = fabric.party_id();

/// Convenience selector between two party's values
macro_rules! sel {
($a:expr, $b:expr) => {
if party_id == 0 {
$a
} else {
$b
}
};
}

// Give a balance to each party and allocate it in the network
let my_balance = sel!(
Balance {
mint: BigUint::from(1u8),
amount: 200
},
Balance {
mint: BigUint::from(2u8),
amount: 200
}
)
.to_linkable();

let balance1 = my_balance.allocate(PARTY0, fabric);
let balance2 = my_balance.allocate(PARTY1, fabric);

// Build the test cases for different invalid match pairs
let test_cases: Vec<(Order, u64)> = vec![
// Quote mints different
(
Order {
quote_mint: sel!(0u8, 1u8).into(),
base_mint: 2u8.into(),
side: sel!(OrderSide::Buy, OrderSide::Sell),
worst_case_price: FixedPoint::from_integer(sel!(15, 5)),
amount: sel!(20, 30),
timestamp: 0, // unused
},
10, // execution_price
),
// Base mints different
(
Order {
quote_mint: 1u8.into(),
base_mint: sel!(0u8, 1u8).into(),
side: sel!(OrderSide::Buy, OrderSide::Sell),
worst_case_price: FixedPoint::from_integer(sel!(15, 5)),
amount: sel!(20, 30),
timestamp: 0, // unused
},
10, // execution_price
),
// Orders on the same side (buy side)
(
Order {
quote_mint: 1u8.into(),
base_mint: 2u8.into(),
side: OrderSide::Buy,
worst_case_price: FixedPoint::from_integer(15),
amount: 20,
timestamp: 0, // unused
},
10, // execution_price
),
// Prices differ between orders
(
Order {
quote_mint: 1u8.into(),
base_mint: 2u8.into(),
side: sel!(OrderSide::Buy, OrderSide::Sell),
worst_case_price: FixedPoint::from_integer(sel!(15, 5)),
amount: 30,
timestamp: 0, // unused
},
sel!(10, 11), // execution_price
),
];

for (my_order, my_price) in test_cases.into_iter() {
// Allocate the orders in the network
let order1 = my_order.to_linkable().allocate(PARTY0, fabric);
let order2 = my_order.to_linkable().allocate(PARTY1, fabric);

// Allocate the price in the network
let price1 = FixedPoint::from_integer(my_price).allocate(PARTY0, fabric);
let price2 = FixedPoint::from_integer(my_price).allocate(PARTY1, fabric);

// Compute matches
let res = compute_match(
&order1,
&order2,
order1.amount.value(),
order2.amount.value(),
&price1, // Use the first party's price
fabric,
);

// Assert that match verification fails
let pc_gens = PedersenGens::default();
let transcript = Transcript::new(b"test");
let mut dummy_prover =
MpcProver::new_with_fabric(test_args.mpc_fabric.clone(), transcript, pc_gens);

let witness = AuthenticatedValidMatchMpcWitness {
order1: order1.clone(),
amount1: order1.amount.value().clone(),
price1: price1.clone(),
order2: order2.clone(),
amount2: order2.amount.value().clone(),
price2: price2.clone(),
balance1: balance1.clone(),
balance2: balance2.clone(),
match_res: res.link_commitments(fabric),
};
let (witness_var, _) = witness.commit_shared(&mut rng, &mut dummy_prover).unwrap();

ValidMatchMpcCircuit::matching_engine_check(
witness_var,
test_args.mpc_fabric.clone(),
&mut dummy_prover,
)
.unwrap();

if dummy_prover.constraints_satisfied().await {
return Err(eyre!("Constraints satisfied"));
}
}

Ok(())
}

/// Tests that a valid match is found when one exists
async fn test_match_valid_match(test_args: IntegrationTestArgs) -> Result<()> {
// Convenience selector for brevity, simpler to redefine per test than to
// pass in party_id from the environment
async fn test_match(test_args: IntegrationTestArgs) -> Result<()> {
let fabric = &test_args.mpc_fabric;
let party_id = fabric.party_id();

/// Convenience selector for values that differ between parties
macro_rules! sel {
($a:expr, $b:expr) => {
if party_id == 0 {
$a
} else {
$b
}
};
}

let test_cases: Vec<(Order, u64)> = vec![
// Different amounts
(
Order {
quote_mint: 1u8.into(),
base_mint: 2u8.into(),
side: sel!(OrderSide::Buy, OrderSide::Sell),
worst_case_price: FixedPoint::from_integer(sel!(15, 5)),
amount: sel!(20, 30),
timestamp: 0, // unused
},
10, // execution_price
),
// Same amount
(
Order {
quote_mint: 1u8.into(),
base_mint: 2u8.into(),
side: sel!(OrderSide::Sell, OrderSide::Buy),
worst_case_price: FixedPoint::from_integer(sel!(5, 15)),
amount: 15,
timestamp: 0, // unused
},
10, // execution_price
),
];

// Stores the expected result for each test case as a vector
// [party1_buy_mint, party1_buy_amount, party2_buy_mint, party2_buy_amount]
let expected_results = vec![
MatchResult {
quote_mint: BigUint::from(1u8),
base_mint: BigUint::from(2u8),
quote_amount: 200,
base_amount: 20,
direction: 0,
max_minus_min_amount: 10,
min_amount_order_index: 0,
},
MatchResult {
quote_mint: BigUint::from(1u8),
base_mint: BigUint::from(2u8),
quote_amount: 150,
base_amount: 15,
direction: 1,
max_minus_min_amount: 0,
min_amount_order_index: 1,
},
];

for ((my_order, my_price), expected_res) in
test_cases.into_iter().zip(expected_results.into_iter())
{
// Allocate the prices in the network
let price1 = FixedPoint::from_integer(my_price).allocate(PARTY0, fabric);

// Allocate the orders in the network
let order1 = my_order.to_linkable().allocate(PARTY0, fabric);
let order2 = my_order.to_linkable().allocate(PARTY1, fabric);

// Compute matches
let res = compute_match(
&order1,
&order2,
order1.amount.value(),
order2.amount.value(),
&price1,
fabric,
)
.open_and_authenticate()
.await
.map_err(|e| eyre!("Error computing match: {e:?}"))?;

// Assert that no match occurred
if res != expected_res.clone() {
return Err(eyre!(
"Match result {res:?} does not match expected result {expected_res:?}",
));
}
}
// Compute a match on two random orders using the internal engine
let (o1, o2, price, expected) = random_orders_and_match();

// Compute a match in a circuit
let order1 = o1.allocate(PARTY0, fabric);
let order2 = o2.allocate(PARTY0, fabric);
let price_shared = price.allocate(PARTY0, fabric);
let res = compute_match(
&order1,
&order2,
&order1.amount,
&order2.amount,
&price_shared,
fabric,
)
.open_and_authenticate()
.await?;

Ok(())
// Party 0 shares their expected match
let expected = expected.share_public(PARTY0, fabric).await;
assert_eq_result!(res, expected)
}

// Take inventory
integration_test_async!(test_match_no_match);
integration_test_async!(test_match_valid_match);
integration_test_async!(test_match);
Loading
Loading