diff --git a/client/network/src/protocol/sync.rs b/client/network/src/protocol/sync.rs index d98c0d2c04abe..bde7cc87fafa7 100644 --- a/client/network/src/protocol/sync.rs +++ b/client/network/src/protocol/sync.rs @@ -60,7 +60,10 @@ use std::{ fmt, ops::Range, pin::Pin, - sync::Arc, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, }; use warp::{WarpProofRequest, WarpSync, WarpSyncProvider}; @@ -234,10 +237,14 @@ pub struct ChainSync { import_existing: bool, /// Gap download process. gap_sync: Option>, + /// A lock to make sure only one download is in progress. + sync_lock: Lock, + /// A lock to make sure only one warp proof download is in progress. + warp_sync_lock: Lock, } /// All the data we have about a Peer that we are trying to sync with -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct PeerSync { /// Peer id of this peer. pub peer_id: PeerId, @@ -284,11 +291,44 @@ struct ForkTarget { peers: HashSet, } +/// An exclusive lock. +#[derive(Debug, Default)] +struct Lock(Arc); + +impl Lock { + fn lock(&self) -> Option { + if self.0.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_ok() { + Some(LockGuard(self.0.clone())) + } else { + None + } + } +} + +/// A guard for an actively locked lock. +/// +/// Will unlock the lock on drop. +#[derive(Debug)] +pub struct LockGuard(Arc); + +impl PartialEq for LockGuard { + fn eq(&self, rhs: &LockGuard) -> bool { + Arc::ptr_eq(&self.0, &rhs.0) + } +} + +impl Eq for LockGuard {} +impl Drop for LockGuard { + fn drop(&mut self) { + self.0.store(false, Ordering::SeqCst); + } +} + /// The state of syncing between a Peer and ourselves. /// /// Generally two categories, "busy" or `Available`. If busy, the enum /// defines what we are busy with. -#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[derive(Eq, PartialEq, Debug)] pub enum PeerSyncState { /// Available for sync requests. Available, @@ -303,9 +343,9 @@ pub enum PeerSyncState { /// Downloading justification for given block hash. DownloadingJustification(B::Hash), /// Downloading state. - DownloadingState, + DownloadingState(LockGuard), /// Downloading warp proof. - DownloadingWarpProof, + DownloadingWarpProof(LockGuard), /// Actively downloading block history after warp sync. DownloadingGap(NumberFor), } @@ -560,6 +600,8 @@ impl ChainSync { warp_sync_provider, import_existing: false, gap_sync: None, + sync_lock: Default::default(), + warp_sync_lock: Default::default(), }; sync.reset_sync_start_point()?; Ok(sync) @@ -994,17 +1036,15 @@ impl ChainSync { /// Get a state request, if any. pub fn state_request(&mut self) -> Option<(PeerId, StateRequest)> { - if self.peers.iter().any(|(_, peer)| peer.state == PeerSyncState::DownloadingState) { - // Only one pending state request is allowed. - return None - } + // Only one pending state request is allowed. + let lock = self.sync_lock.lock()?; if let Some(sync) = &self.state_sync { if sync.is_complete() { return None } for (id, peer) in self.peers.iter_mut() { if peer.state.is_available() && peer.common_number >= sync.target_block_num() { - peer.state = PeerSyncState::DownloadingState; + peer.state = PeerSyncState::DownloadingState(lock); let request = sync.next_request(); trace!(target: "sync", "New StateRequest for {}: {:?}", id, request); return Some((*id, request)) @@ -1021,7 +1061,7 @@ impl ChainSync { for (id, peer) in self.peers.iter_mut() { if peer.state.is_available() && peer.best_number >= target { trace!(target: "sync", "New StateRequest for {}: {:?}", id, request); - peer.state = PeerSyncState::DownloadingState; + peer.state = PeerSyncState::DownloadingState(lock); return Some((*id, request)) } } @@ -1032,14 +1072,8 @@ impl ChainSync { /// Get a warp sync request, if any. pub fn warp_sync_request(&mut self) -> Option<(PeerId, WarpProofRequest)> { - if self - .peers - .iter() - .any(|(_, peer)| peer.state == PeerSyncState::DownloadingWarpProof) - { - // Only one pending state request is allowed. - return None - } + // Only one pending state request is allowed. + let lock = self.warp_sync_lock.lock()?; if let Some(sync) = &self.warp_sync { if sync.is_complete() { return None @@ -1053,7 +1087,7 @@ impl ChainSync { for (id, peer) in self.peers.iter_mut() { if peer.state.is_available() && peer.best_number >= median { trace!(target: "sync", "New WarpProofRequest for {}", id); - peer.state = PeerSyncState::DownloadingWarpProof; + peer.state = PeerSyncState::DownloadingWarpProof(lock); return Some((*id, request)) } } @@ -1261,8 +1295,8 @@ impl ChainSync { }, PeerSyncState::Available | PeerSyncState::DownloadingJustification(..) | - PeerSyncState::DownloadingState | - PeerSyncState::DownloadingWarpProof => Vec::new(), + PeerSyncState::DownloadingState(..) | + PeerSyncState::DownloadingWarpProof(..) => Vec::new(), } } else { // When request.is_none() this is a block announcement. Just accept blocks. @@ -1304,7 +1338,7 @@ impl ChainSync { response: StateResponse, ) -> Result, BadPeer> { if let Some(peer) = self.peers.get_mut(&who) { - if let PeerSyncState::DownloadingState = peer.state { + if let PeerSyncState::DownloadingState(..) = peer.state { peer.state = PeerSyncState::Available; } } @@ -1366,7 +1400,7 @@ impl ChainSync { response: warp::EncodedProof, ) -> Result<(), BadPeer> { if let Some(peer) = self.peers.get_mut(&who) { - if let PeerSyncState::DownloadingWarpProof = peer.state { + if let PeerSyncState::DownloadingWarpProof(..) = peer.state { peer.state = PeerSyncState::Available; } } diff --git a/client/network/src/protocol/sync/extra_requests.rs b/client/network/src/protocol/sync/extra_requests.rs index d0bfebab66010..2386907a459fc 100644 --- a/client/network/src/protocol/sync/extra_requests.rs +++ b/client/network/src/protocol/sync/extra_requests.rs @@ -547,9 +547,38 @@ mod tests { // Some Arbitrary instances to allow easy construction of random peer sets: - #[derive(Debug, Clone)] + #[derive(Debug)] struct ArbitraryPeerSyncState(PeerSyncState); + fn clone_peer_sync_state(state: &PeerSyncState) -> PeerSyncState { + match state { + PeerSyncState::Available => PeerSyncState::Available, + PeerSyncState::DownloadingNew(ref block_number) => + PeerSyncState::DownloadingNew(block_number.clone()), + PeerSyncState::DownloadingStale(ref hash) => + PeerSyncState::DownloadingStale(hash.clone()), + PeerSyncState::DownloadingJustification(ref hash) => + PeerSyncState::DownloadingJustification(hash.clone()), + state => unimplemented!("unsupported peer sync state: {:?}", state), + } + } + + fn clone_peer_sync(peer_sync: &PeerSync) -> PeerSync { + PeerSync { + peer_id: peer_sync.peer_id.clone(), + common_number: peer_sync.common_number.clone(), + best_hash: peer_sync.best_hash.clone(), + best_number: peer_sync.best_number.clone(), + state: clone_peer_sync_state(&peer_sync.state), + } + } + + impl Clone for ArbitraryPeerSyncState { + fn clone(&self) -> Self { + Self(clone_peer_sync_state(&self.0)) + } + } + impl Arbitrary for ArbitraryPeerSyncState { fn arbitrary(g: &mut Gen) -> Self { let s = match u8::arbitrary(g) % 4 { @@ -563,9 +592,15 @@ mod tests { } } - #[derive(Debug, Clone)] + #[derive(Debug)] struct ArbitraryPeerSync(PeerSync); + impl Clone for ArbitraryPeerSync { + fn clone(&self) -> Self { + ArbitraryPeerSync(clone_peer_sync(&self.0)) + } + } + impl Arbitrary for ArbitraryPeerSync { fn arbitrary(g: &mut Gen) -> Self { let ps = PeerSync { @@ -579,9 +614,17 @@ mod tests { } } - #[derive(Debug, Clone)] + #[derive(Debug)] struct ArbitraryPeers(HashMap>); + impl Clone for ArbitraryPeers { + fn clone(&self) -> Self { + ArbitraryPeers( + self.0.iter().map(|(id, sync)| (id.clone(), clone_peer_sync(sync))).collect(), + ) + } + } + impl Arbitrary for ArbitraryPeers { fn arbitrary(g: &mut Gen) -> Self { let mut peers = HashMap::with_capacity(g.size());