diff --git a/Cargo.toml b/Cargo.toml index 54651e7..3ea268d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,3 +65,4 @@ autosurgeon = "0.8.0" bolero = { version = "0.10.0", features = ["arbitrary"] } arbitrary = { version = "1.3.1", features = ["derive"] } bolero-generator = { version = "0.10.0", features = ["arbitrary"] } +rand = "0.8.5" diff --git a/src/network_connect.rs b/src/network_connect.rs index 3ed2b16..06ef7b1 100644 --- a/src/network_connect.rs +++ b/src/network_connect.rs @@ -22,17 +22,17 @@ impl RepoHandle { Str: Stream> + Send + 'static + Unpin, { let other_id = self.handshake(&mut stream, &mut sink, direction).await?; - tracing::trace!(?other_id, repo_id=?self.get_repo_id(), "Handshake complete"); + tracing::trace!(?other_id, repo_id=?self.get_repo_id(), "handshake complete"); let stream = stream.map({ let repo_id = self.get_repo_id().clone(); move |msg| match msg { Ok(Message::Repo(repo_msg)) => { - tracing::trace!(?repo_msg, repo_id=?repo_id, "Received repo message"); + tracing::trace!(?repo_msg, repo_id=?repo_id, "received repo message"); Ok(repo_msg) } Ok(m) => { - tracing::warn!(?m, repo_id=?repo_id, "Received non-repo message"); + tracing::warn!(?m, repo_id=?repo_id, "received non-repo message"); Err(NetworkError::Error( "unexpected non-repo message".to_string(), )) @@ -48,12 +48,9 @@ impl RepoHandle { }); let sink = sink - .with_flat_map::(|msg| match msg { - RepoMessage::Sync { .. } => futures::stream::iter(vec![Ok(Message::Repo(msg))]), - _ => futures::stream::iter(vec![]), - }) + .with::<_, _, _, SendErr>(move |msg| futures::future::ready(Ok(Message::Repo(msg)))) .sink_map_err(|e| { - tracing::error!(?e, "Error sending repo message"); + tracing::error!(?e, "error sending repo message"); NetworkError::Error(format!("error sending repo message: {}", e)) }); diff --git a/src/repo.rs b/src/repo.rs index af30d91..f40e291 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -4,7 +4,7 @@ use crate::interfaces::{NetworkError, RepoMessage, Storage, StorageError}; use crate::share_policy::ShareDecision; use crate::{share_policy, SharePolicy, SharePolicyError}; use automerge::sync::{Message as SyncMessage, State as SyncState, SyncDoc}; -use automerge::{Automerge, ChangeHash}; +use automerge::{Automerge, ChangeHash, ReadDoc}; use core::pin::Pin; use crossbeam_channel::{select, unbounded, Receiver, Sender}; use futures::future::{BoxFuture, Future}; @@ -23,6 +23,8 @@ use std::sync::Arc; use std::thread::{self, JoinHandle}; use uuid::Uuid; +mod request; + /// Front-end of the repo. #[derive(Debug, Clone)] pub struct RepoHandle { @@ -59,6 +61,18 @@ enum NetworkEvent { document_id: DocumentId, message: SyncMessage, }, + /// A repo requested a document + Request { + from_repo_id: RepoId, + to_repo_id: RepoId, + document_id: DocumentId, + message: SyncMessage, + }, + Unavailable { + from_repo_id: RepoId, + to_repo_id: RepoId, + document_id: DocumentId, + }, } impl std::fmt::Debug for NetworkEvent { @@ -75,6 +89,27 @@ impl std::fmt::Debug for NetworkEvent { .field("to_repo_id", to_repo_id) .field("document_id", document_id) .finish(), + NetworkEvent::Request { + from_repo_id, + to_repo_id, + document_id, + message: _, + } => f + .debug_struct("NetworkEvent::Request") + .field("from_repo_id", from_repo_id) + .field("to_repo_id", to_repo_id) + .field("document_id", document_id) + .finish(), + NetworkEvent::Unavailable { + from_repo_id, + to_repo_id, + document_id, + } => f + .debug_struct("NetworkEvent::Unavailable") + .field("from_repo_id", from_repo_id) + .field("to_repo_id", to_repo_id) + .field("document_id", document_id) + .finish(), } } } @@ -90,6 +125,55 @@ enum NetworkMessage { document_id: DocumentId, message: SyncMessage, }, + Request { + from_repo_id: RepoId, + to_repo_id: RepoId, + document_id: DocumentId, + message: SyncMessage, + }, + Unavailable { + from_repo_id: RepoId, + to_repo_id: RepoId, + document_id: DocumentId, + }, +} + +impl From for RepoMessage { + fn from(msg: NetworkMessage) -> Self { + match msg { + NetworkMessage::Sync { + from_repo_id, + to_repo_id, + document_id, + message, + } => RepoMessage::Sync { + from_repo_id, + to_repo_id, + document_id, + message: message.encode(), + }, + NetworkMessage::Request { + from_repo_id, + to_repo_id, + document_id, + message, + } => RepoMessage::Request { + sender_id: from_repo_id, + target_id: to_repo_id, + document_id, + sync_message: message.encode(), + }, + NetworkMessage::Unavailable { + from_repo_id, + to_repo_id, + document_id, + } => RepoMessage::Unavailable { + document_id, + sender_id: from_repo_id, + target_id: to_repo_id, + }, + } + } } /// Create a pair of repo future and resolver. @@ -529,7 +613,10 @@ enum PeerConnection { /// we've accepted the peer and are syncing with them Accepted(SyncState), /// We're waiting for a response from the share policy - PendingAuth { received_messages: Vec }, + PendingAuth { + /// Messages received while we were waiting for a response from the share policy + received_messages: Vec, + }, } impl PeerConnection { @@ -539,32 +626,34 @@ impl PeerConnection { } } - fn receive_sync_message( - &mut self, - doc: &mut Automerge, - msg: SyncMessage, - ) -> Result<(), automerge::AutomergeError> { - match self { - PeerConnection::Accepted(sync_state) => doc.receive_sync_message(sync_state, msg), - PeerConnection::PendingAuth { received_messages } => { - received_messages.push(msg); - Ok(()) - } - } - } - - fn generate_sync_message(&mut self, doc: &Automerge) -> Option { - match self { - Self::Accepted(sync_state) => doc.generate_sync_message(sync_state), - Self::PendingAuth { .. } => None, + fn up_to_date(&self, doc: &Automerge) -> bool { + if let Self::Accepted(SyncState { + their_heads: Some(their_heads), + .. + }) = self + { + their_heads + .iter() + .all(|h| doc.get_change_by_hash(h).is_some()) + } else { + false } } } /// A change requested by a peer connection +#[derive(Debug)] enum PeerConnCommand { /// Request authorization from the share policy - RequestAuth(RepoId), + RequestAuth(RepoId, ShareType), + SendRequest { + message: SyncMessage, + to: RepoId, + }, + SendSyncMessage { + message: SyncMessage, + to: RepoId, + }, } impl DocumentInfo { @@ -768,98 +857,156 @@ impl DocumentInfo { /// /// # Returns /// - /// A tuple of `(has_changes, commands)` where `has_changes` is true if the document changed as - /// a result of applying the sync message and `commands` is a list of changes requested by the - /// peer connections for this document (e.g. requesting authorization from the share policy). - fn receive_sync_message( - &mut self, - per_remote: HashMap>, - ) -> (bool, Vec) { + /// A `Vec` which is a list of changes requested by the peer connections for + /// this document (e.g. requesting authorization from the share policy). + fn receive_sync_message(&mut self, per_remote: P) -> Vec + where + P: IntoIterator, + I: IntoIterator, + { let mut commands = Vec::new(); - let (start_heads, new_heads) = { - let mut document = self.document.write(); - let start_heads = document.automerge.get_heads(); - for (repo_id, messages) in per_remote { - let conn = match self.peer_connections.entry(repo_id.clone()) { - Entry::Vacant(entry) => { - // if this is a new peer, request authorization - commands.push(PeerConnCommand::RequestAuth(repo_id.clone())); - entry.insert(PeerConnection::pending()) + let mut document = self.document.write(); + for (repo_id, messages) in per_remote { + let conn = match self.peer_connections.entry(repo_id.clone()) { + Entry::Vacant(entry) => { + // if this is a new peer, request authorization + commands.push(PeerConnCommand::RequestAuth( + repo_id.clone(), + ShareType::Synchronize, + )); + entry.insert(PeerConnection::pending()) + } + Entry::Occupied(entry) => entry.into_mut(), + }; + match conn { + PeerConnection::PendingAuth { + ref mut received_messages, + } => { + received_messages.extend(messages); + } + PeerConnection::Accepted(ref mut sync_state) => { + for message in messages { + document + .automerge + .receive_sync_message(sync_state, message) + .expect("Failed to receive sync message."); + } + if let Some(msg) = document.automerge.generate_sync_message(sync_state) { + commands.push(PeerConnCommand::SendSyncMessage { + message: msg, + to: repo_id.clone(), + }); } - Entry::Occupied(entry) => entry.into_mut(), - }; - for message in messages { - conn.receive_sync_message(&mut document.automerge, message) - .expect("Failed to receive sync message."); } } - let new_heads = document.automerge.get_heads(); - (start_heads, new_heads) - }; - (start_heads != new_heads, commands) + } + commands } - /// Promote a peer awaiting authorization to a full peer - /// - /// Returns any messages which the peer sent while we were waiting for authorization - fn promote_pending_peer(&mut self, repo_id: &RepoId) -> Option> { - if let Some(PeerConnection::PendingAuth { received_messages }) = - self.peer_connections.remove(repo_id) - { - self.peer_connections - .insert(repo_id.clone(), PeerConnection::Accepted(SyncState::new())); - Some(received_messages) - } else { - tracing::warn!(remote=%repo_id, "Tried to promote a peer which was not pending authorization"); - None - } + /// Generate outgoing sync message for all repos we are syncing with. + fn generate_sync_messages(&mut self) -> Vec<(RepoId, SyncMessage)> { + let document = self.document.read(); + self.peer_connections + .iter_mut() + .filter_map(|(repo_id, conn)| { + if let PeerConnection::Accepted(ref mut sync_state) = conn { + let message = document.automerge.generate_sync_message(sync_state); + message.map(|msg| (repo_id.clone(), msg)) + } else { + None + } + }) + .collect() } - /// Potentially generate an outgoing sync message. - fn generate_first_sync_message(&mut self, repo_id: RepoId) -> Option { - match self.peer_connections.entry(repo_id) { + fn begin_request(&mut self, remote: &RepoId) -> BeginRequest { + match self.peer_connections.entry(remote.clone()) { Entry::Vacant(entry) => { - let mut sync_state = SyncState::new(); - let document = self.document.read(); - let message = document.automerge.generate_sync_message(&mut sync_state); - entry.insert(PeerConnection::Accepted(sync_state)); - message + entry.insert(PeerConnection::pending()); + BeginRequest::RequiresAuth } Entry::Occupied(mut entry) => match entry.get_mut() { - PeerConnection::PendingAuth { received_messages } => { - let mut document = self.document.write(); - let mut sync_state = SyncState::new(); - for msg in received_messages.drain(..) { - document - .automerge - .receive_sync_message(&mut sync_state, msg) - .expect("Failed to receive sync message."); - } - let message = document.automerge.generate_sync_message(&mut sync_state); - entry.insert(PeerConnection::Accepted(sync_state)); - message - } + PeerConnection::PendingAuth { .. } => BeginRequest::AwaitingAuth, PeerConnection::Accepted(ref mut sync_state) => { + if sync_state.in_flight || sync_state.have_responded { + return BeginRequest::AlreadySyncing; + } let document = self.document.read(); - document.automerge.generate_sync_message(sync_state) + let message = document.automerge.generate_sync_message(sync_state); + if let Some(msg) = message { + BeginRequest::Request(msg) + } else { + BeginRequest::AlreadySyncing + } } }, } } - /// Generate outgoing sync message for all repos we are syncing with. - fn generate_sync_messages(&mut self) -> Vec<(RepoId, SyncMessage)> { - let document = self.document.read(); + fn begin_requests<'a, I: Iterator + 'a>( + &'a mut self, + to_peers: I, + ) -> impl Iterator + 'a { + to_peers.filter_map(|peer| match self.begin_request(peer) { + BeginRequest::AlreadySyncing => { + tracing::debug!(remote=%peer, "not sending request as we are already syncing"); + None + } + BeginRequest::Request(message) => Some(PeerConnCommand::SendRequest { + message, + to: peer.clone(), + }), + BeginRequest::AwaitingAuth => None, + BeginRequest::RequiresAuth => Some(PeerConnCommand::RequestAuth( + peer.clone(), + ShareType::Request, + )), + }) + } + + fn authorize_peer(&mut self, remote: &RepoId) -> Option { + if let Some(PeerConnection::PendingAuth { received_messages }) = + self.peer_connections.remove(remote) + { + let mut doc = self.document.write(); + let mut sync_state = SyncState::new(); + for msg in received_messages { + doc.automerge + .receive_sync_message(&mut sync_state, msg) + .expect("Failed to receive sync message."); + } + let msg = doc.automerge.generate_sync_message(&mut sync_state); + self.peer_connections + .insert(remote.clone(), PeerConnection::Accepted(sync_state)); + msg + } else if !self.peer_connections.contains_key(remote) { + let mut sync_state = SyncState::new(); + let doc = self.document.write(); + let msg = doc.automerge.generate_sync_message(&mut sync_state); + self.peer_connections + .insert(remote.clone(), PeerConnection::Accepted(sync_state)); + msg + } else { + tracing::warn!(remote=%remote, "tried to authorize a peer which was not pending authorization"); + None + } + } + + fn has_up_to_date_peer(&self) -> bool { + let doc = self.document.read(); self.peer_connections - .iter_mut() - .filter_map(|(repo_id, conn)| { - let message = conn.generate_sync_message(&document.automerge); - message.map(|msg| (repo_id.clone(), msg)) - }) - .collect() + .iter() + .any(|(_, conn)| conn.up_to_date(&doc.automerge)) } } +enum BeginRequest { + AlreadySyncing, + Request(SyncMessage), + RequiresAuth, + AwaitingAuth, +} + /// Signal that the stream or sink on the network adapter is ready to be polled. #[derive(Debug)] enum WakeSignal { @@ -978,6 +1125,9 @@ pub struct Repo { /// Pending share policy futures pending_share_decisions: HashMap>, + + /// Outstanding requests + requests: HashMap, } impl Repo { @@ -1006,6 +1156,7 @@ impl Repo { share_policy, pending_share_decisions: HashMap::new(), share_decisions_to_poll: HashSet::new(), + requests: HashMap::new(), } } @@ -1021,12 +1172,45 @@ impl Repo { /// Save documents that have changed to storage, /// resolve change observers. fn process_changed_document(&mut self) { + let mut commands_by_doc = HashMap::new(); for doc_id in mem::take(&mut self.documents_with_changes) { - if let Some(info) = self.documents.get_mut(&doc_id) { + let Some(info) = self.documents.get_mut(&doc_id) else { + continue; + }; + + if info.has_up_to_date_peer() && info.state.is_bootstrapping() { + tracing::trace!(%doc_id, "bootstrapping complete"); + info.handle_count.fetch_add(1, Ordering::SeqCst); + let handle = DocHandle::new( + self.repo_sender.clone(), + doc_id.clone(), + info.document.clone(), + info.handle_count.clone(), + self.repo_id.clone(), + ); + info.state.resolve_bootstrap_fut(Ok(handle)); + info.state = DocState::Sync(vec![]); + + if let Some(req) = self.requests.remove(&doc_id) { + tracing::trace!(%doc_id, "resolving request"); + let awaiting_response = req.fulfilled(); + let commands = info.receive_sync_message( + awaiting_response + .into_iter() + .map(|(repo, msg)| (repo, std::iter::once(msg))), + ); + commands_by_doc.insert(doc_id.clone(), commands); + } + } + + if info.note_changes() { info.resolve_change_observers(Ok(())); - info.save_document(doc_id, self.storage.as_ref(), &self.wake_sender); + info.save_document(doc_id.clone(), self.storage.as_ref(), &self.wake_sender); } } + for (doc_id, commands) in commands_by_doc { + self.dispatch_peer_conn_commands(&doc_id, commands) + } } /// Remove sync states for repos for which we do not have an adapter anymore. @@ -1122,16 +1306,37 @@ impl Repo { } }, Ok(RepoMessage::Request { - sender_id: _, - target_id: _, - document_id: _, - sync_message: _, - }) => {} + sender_id, + target_id, + document_id, + sync_message, + }) => match SyncMessage::decode(&sync_message) { + Ok(message) => { + let event = NetworkEvent::Request { + from_repo_id: sender_id, + to_repo_id: target_id, + document_id, + message, + }; + new_messages.push(event); + } + Err(e) => { + tracing::error!(error = ?e, "error decoding sync message"); + break true; + } + }, Ok(RepoMessage::Unavailable { - document_id: _, - sender_id: _, - target_id: _, - }) => {} + document_id, + sender_id, + target_id, + }) => { + let event = NetworkEvent::Unavailable { + document_id, + from_repo_id: sender_id, + to_repo_id: target_id, + }; + new_messages.push(event); + } Ok(RepoMessage::Ephemeral { from_repo_id: _, to_repo_id: _, @@ -1213,20 +1418,11 @@ impl Repo { Poll::Pending => break, Poll::Ready(Ok(())) => { let pinned_sink = Pin::new(&mut remote_repo.sink); - let NetworkMessage::Sync { - from_repo_id, - to_repo_id, - document_id, - message, - } = pending_messages + let msg = pending_messages .pop_front() .expect("Empty pending messages."); - let outgoing = RepoMessage::Sync { - from_repo_id, - to_repo_id, - document_id, - message: message.encode(), - }; + let outgoing = RepoMessage::from(msg); + tracing::debug!(message = ?outgoing, remote=%repo_id, "sending message."); let result = pinned_sink.start_send(outgoing); if let Err(e) = result { tracing::error!(error = ?e, "Error on network sink."); @@ -1345,27 +1541,22 @@ impl Repo { } } + let req = self.requests.entry(document_id.clone()).or_insert_with(|| { + tracing::trace!(%document_id, "creating new local request"); + request::Request::new(document_id.clone()) + }); + if info.state.is_bootstrapping() { - Self::enqueue_share_decisions( - self.remote_repos.keys(), - &mut self.pending_share_decisions, - &mut self.share_decisions_to_poll, - self.share_policy.as_ref(), - document_id.clone(), - ShareType::Request, - ); + let to_request = req.initiate_local(self.remote_repos.keys()); + let commands = info.begin_requests(to_request.iter()).collect::>(); + self.dispatch_peer_conn_commands(&document_id, commands); } } RepoEvent::DocChange(doc_id) => { // Handle doc changes: sync the document. let local_repo_id = self.get_repo_id().clone(); if let Some(info) = self.documents.get_mut(&doc_id) { - // only run the documents_with_changes workflow if there - // was a change, but always generate potential sync messages - // (below) - if info.note_changes() { - self.documents_with_changes.push(doc_id.clone()); - } + self.documents_with_changes.push(doc_id.clone()); for (to_repo_id, message) in info.generate_sync_messages().into_iter() { let outgoing = NetworkMessage::Sync { from_repo_id: local_repo_id.clone(), @@ -1536,14 +1727,32 @@ impl Repo { } } + fn new_document_info() -> DocumentInfo { + // Note: since the handle count is zero, + // the document will not be removed from memory until shutdown. + // Perhaps remove this and rely on `request_document` calls. + let shared_document = SharedDocument { + automerge: new_document(), + }; + let state = DocState::Bootstrap { + resolvers: vec![], + storage_fut: None, + }; + let document = Arc::new(RwLock::new(shared_document)); + let handle_count = Arc::new(AtomicUsize::new(0)); + DocumentInfo::new(state, document, handle_count) + } + /// Apply incoming sync messages, and generate outgoing ones. fn sync_documents(&mut self) { // Re-organize messages so as to acquire the write lock // on the document only once per document. let mut per_doc_messages: HashMap>> = Default::default(); + for event in mem::take(&mut self.pending_events) { tracing::trace!(message = ?event, "processing sync message"); + match event { NetworkEvent::Sync { from_repo_id, @@ -1552,38 +1761,99 @@ impl Repo { message, } => { assert_eq!(to_repo_id, self.repo_id); - - // If we don't know about the document, - // create a new sync state and start syncing. - // Note: this is the mirror of sending sync messages for - // all known documents when a remote repo connects. let info = self .documents .entry(document_id.clone()) - .or_insert_with(|| { - // Note: since the handle count is zero, - // the document will not be removed from memory until shutdown. - // Perhaps remove this and rely on `request_document` calls. - let shared_document = SharedDocument { - automerge: new_document(), - }; - let state = DocState::Bootstrap { - resolvers: vec![], - storage_fut: None, - }; - let document = Arc::new(RwLock::new(shared_document)); - let handle_count = Arc::new(AtomicUsize::new(0)); - DocumentInfo::new(state, document, handle_count) - }); + .or_insert_with(Self::new_document_info); if !info.state.should_sync() { continue; } - let per_doc = per_doc_messages.entry(document_id).or_default(); - let per_remote = per_doc.entry(from_repo_id).or_default(); + let per_doc = per_doc_messages.entry(document_id.clone()).or_default(); + let per_remote = per_doc.entry(from_repo_id.clone()).or_default(); per_remote.push_back(message.clone()); } + NetworkEvent::Request { + from_repo_id, + to_repo_id, + document_id, + message, + } => { + assert_eq!(to_repo_id, self.repo_id); + let info = self + .documents + .entry(document_id.clone()) + .or_insert_with(Self::new_document_info); + match info.state { + DocState::Sync(_) => { + tracing::trace!( + ?from_repo_id, + "responding to request with sync as we have the doc" + ); + // if we have this document then just start syncing + Self::enqueue_share_decisions( + std::iter::once(&from_repo_id), + &mut self.pending_share_decisions, + &mut self.share_decisions_to_poll, + self.share_policy.as_ref(), + document_id.clone(), + ShareType::Synchronize, + ); + } + _ => { + let req = + self.requests.entry(document_id.clone()).or_insert_with(|| { + tracing::trace!(?from_repo_id, "creating new remote request"); + request::Request::new(document_id.clone()) + }); + + let request_from = req.initiate_remote( + &from_repo_id, + message, + self.remote_repos.keys(), + ); + let commands = + info.begin_requests(request_from.iter()).collect::>(); + + if req.is_complete() { + let req = self.requests.remove(&document_id).unwrap(); + Self::fail_request( + req, + &mut self.documents, + &mut self.pending_messages, + &mut self.sinks_to_poll, + self.repo_id.clone(), + ); + } + + self.dispatch_peer_conn_commands(&document_id, commands.into_iter()); + } + } + } + NetworkEvent::Unavailable { + from_repo_id, + to_repo_id: _, + document_id, + } => match self.requests.entry(document_id.clone()) { + Entry::Occupied(mut entry) => { + let req = entry.get_mut(); + req.mark_unavailable(&from_repo_id); + if req.is_complete() { + let req = entry.remove(); + Self::fail_request( + req, + &mut self.documents, + &mut self.pending_messages, + &mut self.sinks_to_poll, + self.repo_id.clone(), + ); + } + } + Entry::Vacant(_) => { + tracing::trace!(?from_repo_id, "received unavailable for request we didnt send or are no longer tracking"); + } + }, } } @@ -1593,57 +1863,10 @@ impl Repo { .get_mut(&document_id) .expect("Doc should have an info by now."); - let (has_changes, peer_conn_commands) = info.receive_sync_message(per_remote); - if has_changes && info.note_changes() { - self.documents_with_changes.push(document_id.clone()); - } - - for cmd in peer_conn_commands { - match cmd { - PeerConnCommand::RequestAuth(peer_id) => Self::enqueue_share_decisions( - std::iter::once(&peer_id), - &mut self.pending_share_decisions, - &mut self.share_decisions_to_poll, - self.share_policy.as_ref(), - document_id.clone(), - ShareType::Synchronize, - ), - } - } + let peer_conn_commands = info.receive_sync_message(per_remote); + self.documents_with_changes.push(document_id.clone()); - // Note: since receiving and generating sync messages is done - // in two separate critical sections, - // local changes could be made in between those, - // which is a good thing(generated messages will include those changes). - let mut ready = true; - for (to_repo_id, message) in info.generate_sync_messages().into_iter() { - if message.heads.is_empty() && !message.need.is_empty() { - ready = false; - } - let outgoing = NetworkMessage::Sync { - from_repo_id: self.repo_id.clone(), - to_repo_id: to_repo_id.clone(), - document_id: document_id.clone(), - message, - }; - self.pending_messages - .entry(to_repo_id.clone()) - .or_default() - .push_back(outgoing); - self.sinks_to_poll.insert(to_repo_id); - } - if ready && info.state.is_bootstrapping() { - info.handle_count.fetch_add(1, Ordering::SeqCst); - let handle = DocHandle::new( - self.repo_sender.clone(), - document_id.clone(), - info.document.clone(), - info.handle_count.clone(), - self.repo_id.clone(), - ); - info.state.resolve_bootstrap_fut(Ok(handle)); - info.state = DocState::Sync(vec![]); - } + self.dispatch_peer_conn_commands(&document_id, peer_conn_commands.into_iter()); } } @@ -1674,15 +1897,15 @@ impl Repo { } } - fn collect_sharepolicy_responses(&mut self) { + fn poll_sharepolicy_responses(&mut self) { let mut decisions = Vec::new(); for repo_id in mem::take(&mut self.share_decisions_to_poll) { if let Some(pending) = self.pending_share_decisions.remove(&repo_id) { let mut still_pending = Vec::new(); for PendingShareDecision { doc_id, - mut future, share_type, + mut future, } in pending { let waker = Arc::new(RepoWaker::ShareDecision( @@ -1697,8 +1920,8 @@ impl Repo { Poll::Pending => { still_pending.push(PendingShareDecision { doc_id, - future, share_type, + future, }); } Poll::Ready(Ok(res)) => { @@ -1722,53 +1945,70 @@ impl Repo { return; }; if share_decision == ShareDecision::Share { - match share_type { - ShareType::Announce | ShareType::Request => { - tracing::debug!(%doc, remote=%peer, "sharing document with remote"); - if let Some(pending_messages) = info.promote_pending_peer(&peer) { - tracing::trace!(remote=%peer, %doc, "we already had pending messages for this peer when announcing so we just wait to generate a sync message"); - for message in pending_messages { - self.pending_events.push_back(NetworkEvent::Sync { - from_repo_id: peer.clone(), - to_repo_id: our_id.clone(), - document_id: doc.clone(), - message, - }); - } - } else if let Some(message) = info.generate_first_sync_message(peer.clone()) - { - tracing::trace!(remote=%peer, %doc, "sending first sync message"); - let outgoing = NetworkMessage::Sync { - from_repo_id: our_id.clone(), - to_repo_id: peer.clone(), - document_id: doc.clone(), - message, - }; - self.pending_messages - .entry(peer.clone()) - .or_default() - .push_back(outgoing); - self.sinks_to_poll.insert(peer); + let message = info.authorize_peer(&peer); + self.documents_with_changes.push(doc.clone()); + let outgoing = message.map(|message| match share_type { + ShareType::Announce => { + tracing::trace!(remote=%peer, %doc, "announcing document to remote"); + NetworkMessage::Sync { + from_repo_id: our_id.clone(), + to_repo_id: peer.clone(), + document_id: doc.clone(), + message, + } + } + ShareType::Request => { + tracing::trace!(remote=%peer, %doc, "requesting document from remote"); + NetworkMessage::Request { + from_repo_id: our_id.clone(), + to_repo_id: peer.clone(), + document_id: doc.clone(), + message, } } ShareType::Synchronize => { tracing::debug!(%doc, remote=%peer, "synchronizing document with remote"); - if let Some(pending_messages) = info.promote_pending_peer(&peer) { - let events = - pending_messages - .into_iter() - .map(|message| NetworkEvent::Sync { - from_repo_id: peer.clone(), - to_repo_id: our_id.clone(), - document_id: doc.clone(), - message, - }); - self.pending_events.extend(events); + NetworkMessage::Sync { + from_repo_id: our_id.clone(), + to_repo_id: peer.clone(), + document_id: doc.clone(), + message, } } + }); + if let Some(outgoing) = outgoing { + self.pending_messages + .entry(peer.clone()) + .or_default() + .push_back(outgoing); + self.sinks_to_poll.insert(peer); } } else { - tracing::debug!(?doc, ?peer, "refusing to share document with remote"); + match share_type { + ShareType::Request => { + tracing::debug!(%doc, remote=%peer, "refusing to request document from remote"); + } + ShareType::Announce => { + tracing::debug!(%doc, remote=%peer, "refusing to announce document to remote"); + } + ShareType::Synchronize => { + tracing::debug!(%doc, remote=%peer, "refusing to synchronize document with remote"); + } + } + if let Some(req) = self.requests.get_mut(&doc) { + tracing::trace!(request=?req, "marking request as unavailable due to rejected authorization"); + req.mark_unavailable(&peer); + if req.is_complete() { + let req = self.requests.remove(&doc).unwrap(); + Self::fail_request( + req, + &mut self.documents, + &mut self.pending_messages, + &mut self.sinks_to_poll, + self.repo_id.clone(), + ); + } + } } } } @@ -1789,7 +2029,7 @@ impl Repo { let handle = thread::spawn(move || { let _entered = span.entered(); loop { - self.collect_sharepolicy_responses(); + self.poll_sharepolicy_responses(); self.collect_network_events(); self.sync_documents(); self.process_outgoing_network_messages(); @@ -1797,6 +2037,9 @@ impl Repo { self.remove_unused_sync_states(); self.remove_unused_pending_messages(); self.gc_docs(); + if !self.share_decisions_to_poll.is_empty() { + continue; + } select! { recv(self.repo_receiver) -> repo_event => { if let Ok(event) = repo_event { @@ -1957,6 +2200,9 @@ impl Repo { share_type: ShareType, ) { let remote_repos = remote_repos.collect::>(); + if remote_repos.is_empty() { + return; + } match share_type { ShareType::Request => { tracing::debug!(remotes=?remote_repos, ?document_id, "checking if we should request this document from remotes"); @@ -1979,10 +2225,98 @@ impl Repo { .or_default() .push(PendingShareDecision { doc_id: document_id.clone(), - future, share_type, + future, }); share_decisions_to_poll.insert(repo_id.clone()); } } + + fn fail_request( + request: request::Request, + documents: &mut HashMap, + pending_messages: &mut HashMap>, + sinks_to_poll: &mut HashSet, + our_repo_id: RepoId, + ) { + tracing::debug!(?request, "request is complete"); + + match documents.entry(request.document_id().clone()) { + Entry::Occupied(entry) => { + if entry.get().state.is_bootstrapping() { + let info = entry.remove(); + if let DocState::Bootstrap { mut resolvers, .. } = info.state { + for mut resolver in resolvers.drain(..) { + tracing::trace!("resolving local process waiting for request to None"); + resolver.resolve_fut(Ok(None)); + } + } + } + } + Entry::Vacant(_) => { + tracing::trace!("no local proess is waiting for this request to complete"); + } + } + + let document_id = request.document_id().clone(); + for repo_id in request.unavailable() { + let outgoing = NetworkMessage::Unavailable { + from_repo_id: our_repo_id.clone(), + to_repo_id: repo_id.clone(), + document_id: document_id.clone(), + }; + pending_messages + .entry(repo_id.clone()) + .or_default() + .push_back(outgoing); + sinks_to_poll.insert(repo_id.clone()); + } + } + + fn dispatch_peer_conn_commands>( + &mut self, + document_id: &DocumentId, + commands: I, + ) { + for command in commands { + match command { + PeerConnCommand::RequestAuth(peer_id, share_type) => { + Self::enqueue_share_decisions( + std::iter::once(&peer_id), + &mut self.pending_share_decisions, + &mut self.share_decisions_to_poll, + self.share_policy.as_ref(), + document_id.clone(), + share_type, + ); + } + PeerConnCommand::SendRequest { message, to } => { + let outgoing = NetworkMessage::Request { + from_repo_id: self.repo_id.clone(), + to_repo_id: to.clone(), + document_id: document_id.clone(), + message, + }; + self.pending_messages + .entry(to.clone()) + .or_default() + .push_back(outgoing); + self.sinks_to_poll.insert(to); + } + PeerConnCommand::SendSyncMessage { message, to } => { + let outgoing = NetworkMessage::Sync { + from_repo_id: self.repo_id.clone(), + to_repo_id: to.clone(), + document_id: document_id.clone(), + message, + }; + self.pending_messages + .entry(to.clone()) + .or_default() + .push_back(outgoing); + self.sinks_to_poll.insert(to); + } + } + } + } } diff --git a/src/repo/request.rs b/src/repo/request.rs new file mode 100644 index 0000000..cdcc189 --- /dev/null +++ b/src/repo/request.rs @@ -0,0 +1,83 @@ +use std::collections::{HashMap, HashSet}; + +use automerge::sync::Message as SyncMessage; + +use crate::{DocumentId, RepoId}; + +#[derive(Debug)] +pub(super) struct Request { + document_id: DocumentId, + awaiting_response_from: HashSet, + awaiting_our_response: HashMap, +} + +impl Request { + pub(super) fn new(doc_id: DocumentId) -> Self { + Request { + document_id: doc_id, + awaiting_response_from: HashSet::new(), + awaiting_our_response: HashMap::new(), + } + } + + pub(super) fn document_id(&self) -> &DocumentId { + &self.document_id + } + + pub(super) fn mark_unavailable(&mut self, repo_id: &RepoId) { + self.awaiting_our_response.remove(repo_id); + self.awaiting_response_from.remove(repo_id); + } + + pub(super) fn is_complete(&self) -> bool { + self.awaiting_response_from.is_empty() + } + + pub(super) fn initiate_local<'a, I: Iterator>( + &mut self, + connected_peers: I, + ) -> HashSet { + self.initiate_inner(None, connected_peers) + } + + pub(super) fn initiate_remote<'a, I: Iterator>( + &mut self, + from_peer: &RepoId, + request_sync_message: SyncMessage, + connected_peers: I, + ) -> HashSet { + self.initiate_inner(Some((from_peer, request_sync_message)), connected_peers) + } + + fn initiate_inner<'a, I: Iterator>( + &mut self, + from_repo_id: Option<(&RepoId, SyncMessage)>, + connected_peers: I, + ) -> HashSet { + if let Some((from_peer, initial_message)) = from_repo_id { + self.awaiting_our_response + .insert(from_peer.clone(), initial_message); + } + connected_peers + .filter(|remote| { + if self.awaiting_our_response.contains_key(remote) + || self.awaiting_response_from.contains(remote) + { + false + } else { + self.awaiting_response_from.insert((*remote).clone()); + true + } + }) + .cloned() + .collect() + } + + pub(super) fn fulfilled(self) -> HashMap { + self.awaiting_our_response + } + + pub(super) fn unavailable(self) -> impl Iterator { + self.awaiting_our_response.into_keys() + } +} diff --git a/tests/interop/main.rs b/tests/interop/main.rs index f36b57c..eec0899 100644 --- a/tests/interop/main.rs +++ b/tests/interop/main.rs @@ -2,15 +2,14 @@ use std::{panic::catch_unwind, path::PathBuf, process::Child, thread::sleep, tim use automerge::{transaction::Transactable, ReadDoc}; use automerge_repo::{ConnDirection, Repo}; +use test_log::test; use test_utils::storage_utils::InMemoryStorage; -//use test_log::test; const INTEROP_SERVER_PATH: &str = "interop-test-server"; const PORT: u16 = 8099; #[test] fn interop_test() { - env_logger::init(); tracing::trace!("we're starting up"); let mut server_process = start_js_server(); let result = catch_unwind(|| sync_two_repos(PORT)); @@ -27,7 +26,7 @@ fn sync_two_repos(port: u16) { let runtime = tokio::runtime::Runtime::new().unwrap(); runtime.block_on(async { let storage1 = Box::::default(); - let repo1 = Repo::new(None, storage1); + let repo1 = Repo::new(Some("repo1".to_string()), storage1); let repo1_handle = repo1.run(); let (conn, _) = tokio_tungstenite::connect_async(format!("ws://localhost:{}", port)) .await @@ -37,13 +36,13 @@ fn sync_two_repos(port: u16) { .connect_tungstenite(conn, ConnDirection::Outgoing) .await .expect("error connecting connection 1"); - tracing::trace!("connecting conn1"); + tokio::spawn(async { if let Err(e) = conn1_driver.await { tracing::error!("Error running repo 1 connection: {}", e); } + tracing::trace!("conn1 finished"); }); - tracing::trace!("connected conn1"); let doc_handle_repo1 = repo1_handle.new_document().await; doc_handle_repo1 @@ -55,8 +54,10 @@ fn sync_two_repos(port: u16) { }) .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + let storage2 = Box::::default(); - let repo2 = Repo::new(None, storage2); + let repo2 = Repo::new(Some("repo2".to_string()), storage2); let repo2_handle = repo2.run(); let (conn2, _) = tokio_tungstenite::connect_async(format!("ws://localhost:{}", port)) @@ -71,10 +72,9 @@ fn sync_two_repos(port: u16) { if let Err(e) = conn2_driver.await { tracing::error!("Error running repo 2 connection: {}", e); } + tracing::trace!("conn2 finished"); }); - tokio::time::sleep(Duration::from_millis(100)).await; - tracing::info!("Requesting"); //tokio::time::sleep(Duration::from_secs(1)).await; let doc_handle_repo2 = repo2_handle diff --git a/tests/network/document_request.rs b/tests/network/document_request.rs index 870ad09..1248b86 100644 --- a/tests/network/document_request.rs +++ b/tests/network/document_request.rs @@ -3,11 +3,15 @@ extern crate test_utils; use std::time::Duration; use automerge::{transaction::Transactable, ReadDoc}; -use automerge_repo::{DocumentId, Repo, RepoHandle, RepoId}; +use automerge_repo::{ + share_policy::ShareDecision, DocumentId, Repo, RepoHandle, RepoId, SharePolicy, + SharePolicyError, +}; +use futures::{future::BoxFuture, FutureExt}; use test_log::test; use test_utils::storage_utils::{InMemoryStorage, SimpleStorage}; -use crate::tincans::connect_repos; +use crate::tincans::{connect_repos, connect_to_nowhere}; #[test(tokio::test)] async fn test_requesting_document_connected_peers() { @@ -44,17 +48,16 @@ async fn test_requesting_document_connected_peers() { tokio::spawn(repo_handle_2.request_document(document_handle_1.document_id())); let _load = repo_handle_2.load(document_handle_1.document_id()); - assert_eq!( - doc_handle_future - .await - .expect("load future timed out") - .unwrap() - .expect("document should be found") - .document_id(), - document_handle_1.document_id() - ); + let doc_handle = tokio::time::timeout(Duration::from_millis(100), doc_handle_future) + .await + .expect("load timed out") + .expect("doc handle spawn failed") + .expect("doc handle future failed") + .expect("doc handle should exist"); + + assert_eq!(doc_handle.document_id(), document_handle_1.document_id()); - let _ = tokio::task::spawn(async move { + let storage_complete = tokio::task::spawn(async move { // Check that the document has been saved in storage. // TODO: replace the loop with an async notification mechanism. loop { @@ -63,8 +66,11 @@ async fn test_requesting_document_connected_peers() { } tokio::time::sleep(Duration::from_millis(100)).await; } - }) - .await; + }); + tokio::time::timeout(Duration::from_millis(100), storage_complete) + .await + .expect("storage complete timed out") + .expect("storage complete spawn failed"); // Stop the repos. tokio::task::spawn_blocking(|| { @@ -378,6 +384,223 @@ async fn request_doc_which_is_not_shared_does_not_announce() { assert!(doc_handle.is_none()); } +struct DontAnnounce; + +impl SharePolicy for DontAnnounce { + fn should_announce( + &self, + _doc_id: &DocumentId, + _with_peer: &RepoId, + ) -> BoxFuture<'static, Result> { + futures::future::ready(Ok(ShareDecision::DontShare)).boxed() + } + + fn should_sync( + &self, + _document_id: &DocumentId, + _with_peer: &RepoId, + ) -> BoxFuture<'static, Result> { + futures::future::ready(Ok(ShareDecision::Share)).boxed() + } + + fn should_request( + &self, + _document_id: &DocumentId, + _from_peer: &RepoId, + ) -> BoxFuture<'static, Result> { + futures::future::ready(Ok(ShareDecision::Share)).boxed() + } +} + +#[test(tokio::test)] +async fn request_document_transitive() { + // Test that requesting a document from a peer who doesn't have that document but who is + // connected to another peer that does have the document eventually resolves + + let repo_1 = Repo::new(Some("repo1".to_string()), Box::new(SimpleStorage)); + let repo_2 = Repo::new(Some("repo2".to_string()), Box::new(SimpleStorage)); + let repo_3 = Repo::new(Some("repo3".to_string()), Box::new(SimpleStorage)) + .with_share_policy(Box::new(DontAnnounce)); + + let repo_handle_1 = repo_1.run(); + let repo_handle_2 = repo_2.run(); + let repo_handle_3 = repo_3.run(); + + let document_id = create_doc_with_contents(&repo_handle_3, "peer", "repo3").await; + + connect_repos(&repo_handle_1, &repo_handle_2); + connect_repos(&repo_handle_2, &repo_handle_3); + + let doc_handle = match tokio::time::timeout( + Duration::from_millis(100), + repo_handle_1.request_document(document_id), + ) + .await + { + Ok(d) => d.unwrap(), + Err(_e) => { + panic!("Request timed out"); + } + }; + + //tokio::time::sleep(Duration::from_millis(100)).await; + + doc_handle.expect("doc should exist").with_doc(|doc| { + let val = doc.get(&automerge::ROOT, "peer").unwrap(); + assert_eq!(val.unwrap().0.into_string().unwrap(), "repo3"); + }); + + tokio::task::spawn_blocking(|| { + repo_handle_1.stop().unwrap(); + repo_handle_2.stop().unwrap(); + repo_handle_3.stop().unwrap(); + }) + .await + .unwrap(); +} + +#[test(tokio::test)] +async fn request_document_which_no_peer_has_returns_unavailable() { + let repo_1 = Repo::new(Some("repo1".to_string()), Box::new(SimpleStorage)); + let repo_2 = Repo::new(Some("repo2".to_string()), Box::new(SimpleStorage)); + let repo_3 = Repo::new(Some("repo3".to_string()), Box::new(SimpleStorage)); + + let repo_handle_1 = repo_1.run(); + let repo_handle_2 = repo_2.run(); + let repo_handle_3 = repo_3.run(); + + connect_repos(&repo_handle_1, &repo_handle_2); + connect_repos(&repo_handle_2, &repo_handle_3); + + let document_id = DocumentId::random(); + + let doc_handle = match tokio::time::timeout( + Duration::from_millis(1000), + repo_handle_1.request_document(document_id), + ) + .await + { + Ok(d) => d.unwrap(), + Err(_e) => { + panic!("Request timed out"); + } + }; + + assert!(doc_handle.is_none()); + + tokio::task::spawn_blocking(|| { + repo_handle_1.stop().unwrap(); + repo_handle_2.stop().unwrap(); + repo_handle_3.stop().unwrap(); + }) + .await + .unwrap(); +} + +#[test(tokio::test)] +async fn request_document_which_no_peer_has_but_peer_appears_after_request_starts_resolves_to_some() +{ + let repo_1 = Repo::new(Some("repo1".to_string()), Box::new(SimpleStorage)); + let repo_2 = Repo::new(Some("repo2".to_string()), Box::new(SimpleStorage)); + let repo_3 = Repo::new(Some("repo3".to_string()), Box::new(SimpleStorage)) + .with_share_policy(Box::new(DontAnnounce)); + + let repo_handle_1 = repo_1.run(); + let repo_handle_2 = repo_2.run(); + let repo_handle_3 = repo_3.run(); + + // note: repo 3 is not connected + connect_repos(&repo_handle_1, &repo_handle_2); + // This connection will never respond and so we will hang around waiting until someone has the + // document + connect_to_nowhere(&repo_handle_1); + + let document_id = create_doc_with_contents(&repo_handle_3, "peer", "repo3").await; + + let doc_handle_fut = repo_handle_1.request_document(document_id); + + // wait a little bit + tokio::time::sleep(Duration::from_millis(100)).await; + + //connect repo3 + connect_repos(&repo_handle_1, &repo_handle_3); + + let handle = match tokio::time::timeout(Duration::from_millis(100), doc_handle_fut).await { + Ok(d) => d.unwrap(), + Err(_e) => { + panic!("Request timed out"); + } + }; + + handle.expect("doc should exist").with_doc(|doc| { + let val = doc.get(&automerge::ROOT, "peer").unwrap(); + assert_eq!(val.unwrap().0.into_string().unwrap(), "repo3"); + }); + + tokio::task::spawn_blocking(|| { + repo_handle_1.stop().unwrap(); + repo_handle_2.stop().unwrap(); + repo_handle_3.stop().unwrap(); + }) + .await + .unwrap(); +} + +#[test(tokio::test)] +async fn request_document_which_no_peer_has_but_transitive_peer_appears_after_request_starts_resolves_to_some( +) { + let repo_1 = Repo::new(Some("repo1".to_string()), Box::new(SimpleStorage)); + let repo_2 = Repo::new(Some("repo2".to_string()), Box::new(SimpleStorage)); + let repo_3 = Repo::new(Some("repo3".to_string()), Box::new(SimpleStorage)) + .with_share_policy(Box::new(DontAnnounce)); + + let repo_handle_1 = repo_1.run(); + let repo_handle_2 = repo_2.run(); + let repo_handle_3 = repo_3.run(); + + // note: repo 3 is not connected + connect_repos(&repo_handle_1, &repo_handle_2); + // This connection will never respond and so we will hang around waiting until someone has the + // document + connect_to_nowhere(&repo_handle_2); + + let document_id = create_doc_with_contents(&repo_handle_3, "peer", "repo3").await; + + let doc_handle_fut = repo_handle_1.request_document(document_id); + + // wait a little bit + tokio::time::sleep(Duration::from_millis(100)).await; + + //connect repo3 + connect_repos(&repo_handle_2, &repo_handle_3); + + let handle = match tokio::time::timeout(Duration::from_millis(1000), doc_handle_fut).await { + Ok(d) => d.unwrap(), + Err(_e) => { + panic!("Request timed out"); + } + }; + + let handle = handle.expect("doc should exist"); + + // wait for the doc to sync up + // TODO: add an API for saying "wait until we're in sync with " + tokio::time::sleep(Duration::from_millis(100)).await; + + handle.with_doc(|doc| { + let val = doc.get(&automerge::ROOT, "peer").unwrap(); + assert_eq!(val.unwrap().0.into_string().unwrap(), "repo3"); + }); + + tokio::task::spawn_blocking(|| { + repo_handle_1.stop().unwrap(); + repo_handle_2.stop().unwrap(); + repo_handle_3.stop().unwrap(); + }) + .await + .unwrap(); +} + async fn create_doc_with_contents(handle: &RepoHandle, key: &str, value: &str) -> DocumentId { let document_handle = handle.new_document().await; document_handle.with_doc_mut(|doc| { diff --git a/tests/network/tincans.rs b/tests/network/tincans.rs index 209a3b6..599c479 100644 --- a/tests/network/tincans.rs +++ b/tests/network/tincans.rs @@ -1,6 +1,6 @@ use std::sync::{atomic::AtomicBool, Arc}; -use automerge_repo::{NetworkError, RepoHandle, RepoMessage}; +use automerge_repo::{NetworkError, RepoHandle, RepoId, RepoMessage}; use futures::{Sink, SinkExt, Stream, StreamExt}; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::PollSender; @@ -115,6 +115,13 @@ pub(crate) fn connect_repos(left: &RepoHandle, right: &RepoHandle) { right.new_remote_repo(left.get_repo_id().clone(), right_recv, right_send); } +pub(crate) fn connect_to_nowhere(handle: &RepoHandle) { + let TinCan { send, recv, .. } = tincan_to_nowhere(); + let random_suffix = rand::random::(); + let repo_id = RepoId::from(format!("nowhere-{}", random_suffix).as_str()); + handle.new_remote_repo(repo_id, recv, send); +} + /// A wrapper around a `Sink` which records whether `poll_close` has ever been called struct RecordCloseSink { inner: S,