Skip to content

Commit

Permalink
circuits: integration: Simplify integration tests
Browse files Browse the repository at this point in the history
This removes many unnecessary integration tests in light of the mocking
framework for MPC gadgets -- which has allowed most tests to be moved
into unit tests where they belond.

Integration tests will be reserved for connecting circuits and end-to
-end functionality.
  • Loading branch information
joeykraut committed Nov 23, 2023
1 parent 1e36d44 commit 2cdbee6
Show file tree
Hide file tree
Showing 21 changed files with 135 additions and 1,434 deletions.
176 changes: 73 additions & 103 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion circuit-types/src/order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl BaseType for OrderSide {
fn from_scalars<I: Iterator<Item = Scalar>>(i: &mut I) -> Self {
match scalar_to_u64(&i.next().unwrap()) {
val @ 0..=1 => OrderSide::from(val),
_ => panic!("invalid value for OrderSide"),
x => panic!("invalid value for OrderSide({x})"),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion circuit-types/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub trait BaseType: Clone {
/// This method is added to the `BaseType` trait for maximum flexibility, so
/// that types may be shared without requiring them to implement the
/// full `MpcBaseType` trait
async fn share_public(&self, owning_party: PartyId, fabric: Fabric) -> Self {
async fn share_public(&self, owning_party: PartyId, fabric: &Fabric) -> Self {
let self_scalars = self.to_scalars();
let res_scalars = fabric
.batch_share_plaintext(self_scalars, owning_party)
Expand Down
1 change: 1 addition & 0 deletions circuits/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ stats = ["ark-mpc/stats"]
name = "integration"
path = "integration/main.rs"
harness = false
required-features = ["test_helpers"]

[[bench]]
name = "valid_wallet_create"
Expand Down
4 changes: 2 additions & 2 deletions circuits/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ RUN sed -i 's/main.rs/dummy-main.rs/g' Cargo.toml
ENV RUSTFLAGS=-Awarnings
ENV RUST_BACKTRACE=1

RUN cargo build --quiet --test integration
RUN cargo build --quiet --test integration --all-features

# Edit the Cargo.toml back to the original, build the full executable
RUN sed -i 's/dummy-lib.rs/lib.rs/g' Cargo.toml
Expand All @@ -58,6 +58,6 @@ RUN sed -i 's/dummy-main.rs/main.rs/g' Cargo.toml
COPY circuits/src ./src
COPY circuits/integration ./integration

RUN cargo build --quiet --test integration
RUN cargo build --quiet --test integration --all-features

CMD [ "cargo", "test" ]
4 changes: 2 additions & 2 deletions circuits/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ services:
ports:
- "8000:8000"
command: >
cargo test --test integration --
cargo test --test integration --all-features --
--party 0
--port1 8000
--port2 9000
Expand All @@ -22,7 +22,7 @@ services:
ports:
- "9000:9000"
command: >
cargo test --test integration --
cargo test --test integration --all-features --
--party 1
--port1 9000
--port2 8000
Expand Down
8 changes: 3 additions & 5 deletions circuits/integration/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
#![feature(inherent_associated_types)]

mod mpc_circuits;
mod mpc_gadgets;
mod types;
mod zk_circuits;
mod zk_gadgets;
// mod zk_circuits;

use circuit_types::Fabric;
use clap::Parser;
use mpc_stark::MpcFabric;
use test_helpers::{integration_test_main, mpc_network::setup_mpc_fabric};
use util::logging::LevelFilter;

Expand Down Expand Up @@ -44,7 +42,7 @@ struct CliArgs {
#[derive(Debug, Clone)]
struct IntegrationTestArgs {
/// The MPC fabric to use during the course of the integration test
pub(crate) mpc_fabric: MpcFabric,
pub(crate) mpc_fabric: Fabric,
}

impl From<CliArgs> for IntegrationTestArgs {
Expand Down
284 changes: 27 additions & 257 deletions circuits/integration/mpc_circuits/match.rs
Original file line number Diff line number Diff line change
@@ -1,272 +1,42 @@
//! Groups integration tests for the match circuitry
use circuit_types::{
balance::Balance,
fixed_point::FixedPoint,
order::{Order, OrderSide},
r#match::MatchResult,
traits::{LinkableBaseType, MpcBaseType, MpcType, MultiproverCircuitBaseType},
};
use circuits::{
mpc_circuits::r#match::compute_match,
zk_circuits::valid_match_mpc::{AuthenticatedValidMatchMpcWitness, ValidMatchMpcCircuit},
};
use eyre::{eyre, Result};
use merlin::HashChainTranscript as Transcript;
use mpc_bulletproof::{r1cs_mpc::MpcProver, PedersenGens};
use mpc_stark::{PARTY0, PARTY1};
use num_bigint::BigUint;
use rand::thread_rng;
use test_helpers::integration_test_async;
use ark_mpc::PARTY0;
use circuit_types::traits::{BaseType, MpcBaseType, MpcType};
use circuits::{mpc_circuits::r#match::compute_match, test_helpers::random_orders_and_match};
use eyre::Result;
use test_helpers::{assert_eq_result, integration_test_async};

use crate::IntegrationTestArgs;

// --------------
// | Test Cases |
// --------------

/// Tests the match function with non overlapping orders for a variety of
/// failure cases
async fn test_match_no_match(test_args: IntegrationTestArgs) -> Result<()> {
// Convenience selector for brevity
let fabric = &test_args.mpc_fabric;
let mut rng = thread_rng();
let party_id = fabric.party_id();

/// Convenience selector between two party's values
macro_rules! sel {
($a:expr, $b:expr) => {
if party_id == 0 {
$a
} else {
$b
}
};
}

// Give a balance to each party and allocate it in the network
let my_balance = sel!(
Balance {
mint: BigUint::from(1u8),
amount: 200
},
Balance {
mint: BigUint::from(2u8),
amount: 200
}
)
.to_linkable();

let balance1 = my_balance.allocate(PARTY0, fabric);
let balance2 = my_balance.allocate(PARTY1, fabric);

// Build the test cases for different invalid match pairs
let test_cases: Vec<(Order, u64)> = vec![
// Quote mints different
(
Order {
quote_mint: sel!(0u8, 1u8).into(),
base_mint: 2u8.into(),
side: sel!(OrderSide::Buy, OrderSide::Sell),
worst_case_price: FixedPoint::from_integer(sel!(15, 5)),
amount: sel!(20, 30),
timestamp: 0, // unused
},
10, // execution_price
),
// Base mints different
(
Order {
quote_mint: 1u8.into(),
base_mint: sel!(0u8, 1u8).into(),
side: sel!(OrderSide::Buy, OrderSide::Sell),
worst_case_price: FixedPoint::from_integer(sel!(15, 5)),
amount: sel!(20, 30),
timestamp: 0, // unused
},
10, // execution_price
),
// Orders on the same side (buy side)
(
Order {
quote_mint: 1u8.into(),
base_mint: 2u8.into(),
side: OrderSide::Buy,
worst_case_price: FixedPoint::from_integer(15),
amount: 20,
timestamp: 0, // unused
},
10, // execution_price
),
// Prices differ between orders
(
Order {
quote_mint: 1u8.into(),
base_mint: 2u8.into(),
side: sel!(OrderSide::Buy, OrderSide::Sell),
worst_case_price: FixedPoint::from_integer(sel!(15, 5)),
amount: 30,
timestamp: 0, // unused
},
sel!(10, 11), // execution_price
),
];

for (my_order, my_price) in test_cases.into_iter() {
// Allocate the orders in the network
let order1 = my_order.to_linkable().allocate(PARTY0, fabric);
let order2 = my_order.to_linkable().allocate(PARTY1, fabric);

// Allocate the price in the network
let price1 = FixedPoint::from_integer(my_price).allocate(PARTY0, fabric);
let price2 = FixedPoint::from_integer(my_price).allocate(PARTY1, fabric);

// Compute matches
let res = compute_match(
&order1,
&order2,
order1.amount.value(),
order2.amount.value(),
&price1, // Use the first party's price
fabric,
);

// Assert that match verification fails
let pc_gens = PedersenGens::default();
let transcript = Transcript::new(b"test");
let mut dummy_prover =
MpcProver::new_with_fabric(test_args.mpc_fabric.clone(), transcript, pc_gens);

let witness = AuthenticatedValidMatchMpcWitness {
order1: order1.clone(),
amount1: order1.amount.value().clone(),
price1: price1.clone(),
order2: order2.clone(),
amount2: order2.amount.value().clone(),
price2: price2.clone(),
balance1: balance1.clone(),
balance2: balance2.clone(),
match_res: res.link_commitments(fabric),
};
let (witness_var, _) = witness.commit_shared(&mut rng, &mut dummy_prover).unwrap();

ValidMatchMpcCircuit::matching_engine_check(
witness_var,
test_args.mpc_fabric.clone(),
&mut dummy_prover,
)
.unwrap();

if dummy_prover.constraints_satisfied().await {
return Err(eyre!("Constraints satisfied"));
}
}

Ok(())
}

/// Tests that a valid match is found when one exists
async fn test_match_valid_match(test_args: IntegrationTestArgs) -> Result<()> {
// Convenience selector for brevity, simpler to redefine per test than to
// pass in party_id from the environment
async fn test_match(test_args: IntegrationTestArgs) -> Result<()> {
let fabric = &test_args.mpc_fabric;
let party_id = fabric.party_id();

/// Convenience selector for values that differ between parties
macro_rules! sel {
($a:expr, $b:expr) => {
if party_id == 0 {
$a
} else {
$b
}
};
}

let test_cases: Vec<(Order, u64)> = vec![
// Different amounts
(
Order {
quote_mint: 1u8.into(),
base_mint: 2u8.into(),
side: sel!(OrderSide::Buy, OrderSide::Sell),
worst_case_price: FixedPoint::from_integer(sel!(15, 5)),
amount: sel!(20, 30),
timestamp: 0, // unused
},
10, // execution_price
),
// Same amount
(
Order {
quote_mint: 1u8.into(),
base_mint: 2u8.into(),
side: sel!(OrderSide::Sell, OrderSide::Buy),
worst_case_price: FixedPoint::from_integer(sel!(5, 15)),
amount: 15,
timestamp: 0, // unused
},
10, // execution_price
),
];

// Stores the expected result for each test case as a vector
// [party1_buy_mint, party1_buy_amount, party2_buy_mint, party2_buy_amount]
let expected_results = vec![
MatchResult {
quote_mint: BigUint::from(1u8),
base_mint: BigUint::from(2u8),
quote_amount: 200,
base_amount: 20,
direction: 0,
max_minus_min_amount: 10,
min_amount_order_index: 0,
},
MatchResult {
quote_mint: BigUint::from(1u8),
base_mint: BigUint::from(2u8),
quote_amount: 150,
base_amount: 15,
direction: 1,
max_minus_min_amount: 0,
min_amount_order_index: 1,
},
];

for ((my_order, my_price), expected_res) in
test_cases.into_iter().zip(expected_results.into_iter())
{
// Allocate the prices in the network
let price1 = FixedPoint::from_integer(my_price).allocate(PARTY0, fabric);

// Allocate the orders in the network
let order1 = my_order.to_linkable().allocate(PARTY0, fabric);
let order2 = my_order.to_linkable().allocate(PARTY1, fabric);

// Compute matches
let res = compute_match(
&order1,
&order2,
order1.amount.value(),
order2.amount.value(),
&price1,
fabric,
)
.open_and_authenticate()
.await
.map_err(|e| eyre!("Error computing match: {e:?}"))?;

// Assert that no match occurred
if res != expected_res.clone() {
return Err(eyre!(
"Match result {res:?} does not match expected result {expected_res:?}",
));
}
}
// Compute a match on two random orders using the internal engine
let (o1, o2, price, expected) = random_orders_and_match();

// Compute a match in a circuit
let order1 = o1.allocate(PARTY0, fabric);
let order2 = o2.allocate(PARTY0, fabric);
let price_shared = price.allocate(PARTY0, fabric);
let res = compute_match(
&order1,
&order2,
&order1.amount,
&order2.amount,
&price_shared,
fabric,
)
.open_and_authenticate()
.await?;

Ok(())
// Party 0 shares their expected match
let expected = expected.share_public(PARTY0, fabric).await;
assert_eq_result!(res, expected)
}

// Take inventory
integration_test_async!(test_match_no_match);
integration_test_async!(test_match_valid_match);
integration_test_async!(test_match);
Loading

0 comments on commit 2cdbee6

Please sign in to comment.