Skip to content

Commit

Permalink
Handle round completion on sends
Browse files Browse the repository at this point in the history
Also increase the channel buffer to 2 to allow for round completion
on both sends and receives
  • Loading branch information
pool2win committed Nov 29, 2024
1 parent 46dd5cc commit 316dc15
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 18 deletions.
13 changes: 8 additions & 5 deletions src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,15 @@ impl Node {
) {
log::debug!("Starting... {}", self.bind_address);
let node_id = self.get_node_id().clone();
let state = self.state.clone();
let (round_one_tx, round_one_rx) = mpsc::channel::<()>(1);
self.state.round_one_tx = Some(round_one_tx.clone());
let (round_two_tx, round_two_rx) = mpsc::channel::<()>(1);
self.state.round_two_tx = Some(round_two_tx.clone());
let echo_broadcast_handle = self.echo_broadcast_handle.clone();

// We can send message on the channel from both sending and receiving tasks
let (round_one_tx, round_one_rx) = mpsc::channel::<()>(2);
self.state.round_one_tx = Some(round_one_tx);
let (round_two_tx, round_two_rx) = mpsc::channel::<()>(2);
self.state.round_two_tx = Some(round_two_tx);

let state = self.state.clone();
tokio::spawn(async move {
dkg::trigger::run_dkg_trigger(
15000,
Expand Down
19 changes: 15 additions & 4 deletions src/node/protocol/dkg/round_one.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl PackageMessage {
/// Builds a round one package using the frost-secp256k1 crate
async fn build_round1_package(
sender_id: String,
state: crate::node::state::State,
state: &crate::node::state::State,
) -> Result<Message, frost::Error> {
let (max_signers, min_signers) = get_max_min_signers(&state).await;

Expand Down Expand Up @@ -123,8 +123,19 @@ impl Service<Message> for Package {
_message_id,
) => {
log::debug!("Build round one package");
let response = build_round1_package(this_sender_id, state).await?;
let response = build_round1_package(this_sender_id, &state).await?;
log::info!("Sending round one package {:?}", response);
let finished = state
.dkg_state
.get_received_round1_packages()
.await
.unwrap()
.len()
== state.dkg_state.get_expected_members().await.unwrap();
if finished {
log::debug!("Round one finished, sending signal");
let _ = state.round_one_tx.unwrap().send(()).await;
}
Ok(Some(response))
}
Message::Broadcast(
Expand Down Expand Up @@ -207,7 +218,7 @@ mod round_one_package_tests {
let membership_handle = build_membership(3).await;
let state = State::new(membership_handle, message_id_generator).await;

let round1_package = build_round1_package("local".into(), state).await.unwrap();
let round1_package = build_round1_package("local".into(), &state).await.unwrap();

// Extract the public key package from the NetworkMessage
if let Message::Broadcast(BroadcastProtocol::DKGRoundOnePackage(pkg_msg), _message_id) =
Expand All @@ -232,7 +243,7 @@ mod round_one_package_tests {
let state_clone = state.clone();

// First create a round1 package that we'll pretend came from another node
let round1_package = build_round1_package("remote".into(), state).await.unwrap();
let round1_package = build_round1_package("remote".into(), &state).await.unwrap();

// Create our local package service
let mut pkg = Package::new("local".into(), state_clone);
Expand Down
20 changes: 18 additions & 2 deletions src/node/protocol/dkg/round_two.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ impl Service<Message> for Package {
log::error!("Failed to send round2 packages: {:?}", e);
return Err(e.into());
}
let finished = state
.dkg_state
.get_received_round2_packages()
.await
.unwrap()
.len()
== state.dkg_state.get_expected_members().await.unwrap();
if finished {
log::debug!("Round two finished on send, sending signal");
state.round_two_tx.unwrap().send(()).await?;
}
log::debug!("Sent round2 packages");
Ok(None)
}
Expand All @@ -155,15 +166,20 @@ impl Service<Message> for Package {
message: Some(message), // received a message
})) => {
// Received round2 message and save it in state
log::debug!(
"Received round two message from {} \n {:?}",
from_sender_id,
message
);
let identifier = frost::Identifier::derive(from_sender_id.as_bytes()).unwrap();
let finished = state
.dkg_state
.add_round2_package(identifier, message)
.await
.unwrap();
if finished {
log::debug!("Round two finished, sending signal");
let _ = state.round_two_tx.unwrap().send(()).await;
log::debug!("Round two finished on receive, sending signal");
state.round_two_tx.unwrap().send(()).await?;
}
Ok(None)
}
Expand Down
14 changes: 7 additions & 7 deletions src/node/protocol/dkg/trigger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ pub(crate) async fn trigger_dkg(
// TODO Improve this to allow round1 to finish as soon as all other parties have sent their round1 message
// This will mean moving the timeout into round1 service

// Wait for round1 to finish, give it 5 seconds
if round1_future.await.is_err() {
log::error!("Error running round 1");
return Err("Error running round 1".into());
// Start round1
if let Err(e) = round1_future.await {
log::error!("Error running round 1: {:?}", e);
return Err("Error running round 1: failed with error".into());
}
round_one_rx.recv().await.unwrap();
log::info!("Round 1 finished");
Expand All @@ -149,9 +149,9 @@ pub(crate) async fn trigger_dkg(
);

// start round2
if round2_future.await.is_err() {
log::error!("Error running round 2");
return Err("Error running round 2".into());
if let Err(e) = round2_future.await {
log::error!("Error running round 2: {:?}", e);
return Err("Error running round 2: failed with error".into());
}
round_two_rx.recv().await.unwrap();
log::info!("Round 2 finished");
Expand Down

0 comments on commit 316dc15

Please sign in to comment.