From c84d0263a0ba0294e571685de46e61960e3db9a1 Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Mon, 16 Jan 2023 10:33:17 +0000 Subject: [PATCH 1/8] Have `KeyIterator` clone the `prefix` it receives --- client/api/src/backend.rs | 42 ++++++++++++++++++----------- client/service/src/client/client.rs | 12 ++++----- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/client/api/src/backend.rs b/client/api/src/backend.rs index 79cc0d7a16bcc..ed89f71645835 100644 --- a/client/api/src/backend.rs +++ b/client/api/src/backend.rs @@ -303,32 +303,44 @@ pub trait AuxStore { } /// An `Iterator` that iterates keys in a given block under a prefix. -pub struct KeyIterator<'a, State, Block> { +pub struct KeyIterator { state: State, child_storage: Option, - prefix: Option<&'a StorageKey>, + prefix: Option, current_key: Vec, _phantom: PhantomData, } -impl<'a, State, Block> KeyIterator<'a, State, Block> { +impl KeyIterator { /// create a KeyIterator instance - pub fn new(state: State, prefix: Option<&'a StorageKey>, current_key: Vec) -> Self { - Self { state, child_storage: None, prefix, current_key, _phantom: PhantomData } + pub fn new(state: State, prefix: Option<&StorageKey>, current_key: Vec) -> Self { + Self { + state, + child_storage: None, + prefix: prefix.map(|prefix| prefix.clone()), + current_key, + _phantom: PhantomData, + } } /// Create a `KeyIterator` instance for a child storage. pub fn new_child( state: State, child_info: ChildInfo, - prefix: Option<&'a StorageKey>, + prefix: Option<&StorageKey>, current_key: Vec, ) -> Self { - Self { state, child_storage: Some(child_info), prefix, current_key, _phantom: PhantomData } + Self { + state, + child_storage: Some(child_info), + prefix: prefix.map(|prefix| prefix.clone()), + current_key, + _phantom: PhantomData, + } } } -impl<'a, State, Block> Iterator for KeyIterator<'a, State, Block> +impl Iterator for KeyIterator where Block: BlockT, State: StateBackend>, @@ -344,7 +356,7 @@ where .ok() .flatten()?; // this terminates the iterator the first time it fails. - if let Some(prefix) = self.prefix { + if let Some(ref prefix) = self.prefix { if !next_key.starts_with(&prefix.0[..]) { return None } @@ -387,12 +399,12 @@ pub trait StorageProvider> { /// Given a block's `Hash` and a key prefix, return a `KeyIterator` iterates matching storage /// keys in that block. - fn storage_keys_iter<'a>( + fn storage_keys_iter( &self, hash: Block::Hash, - prefix: Option<&'a StorageKey>, + prefix: Option<&StorageKey>, start_key: Option<&StorageKey>, - ) -> sp_blockchain::Result>; + ) -> sp_blockchain::Result>; /// Given a block's `Hash`, a key and a child storage key, return the value under the key in /// that block. @@ -414,13 +426,13 @@ pub trait StorageProvider> { /// Given a block's `Hash` and a key `prefix` and a child storage key, /// return a `KeyIterator` that iterates matching storage keys in that block. - fn child_storage_keys_iter<'a>( + fn child_storage_keys_iter( &self, hash: Block::Hash, child_info: ChildInfo, - prefix: Option<&'a StorageKey>, + prefix: Option<&StorageKey>, start_key: Option<&StorageKey>, - ) -> sp_blockchain::Result>; + ) -> sp_blockchain::Result>; /// Given a block's `Hash`, a key and a child storage key, return the hash under the key in that /// block. diff --git a/client/service/src/client/client.rs b/client/service/src/client/client.rs index 18012fc1931fe..582b462846baf 100644 --- a/client/service/src/client/client.rs +++ b/client/service/src/client/client.rs @@ -1432,24 +1432,24 @@ where Ok(keys) } - fn storage_keys_iter<'a>( + fn storage_keys_iter( &self, hash: ::Hash, - prefix: Option<&'a StorageKey>, + prefix: Option<&StorageKey>, start_key: Option<&StorageKey>, - ) -> sp_blockchain::Result> { + ) -> sp_blockchain::Result> { let state = self.state_at(hash)?; let start_key = start_key.or(prefix).map(|key| key.0.clone()).unwrap_or_else(Vec::new); Ok(KeyIterator::new(state, prefix, start_key)) } - fn child_storage_keys_iter<'a>( + fn child_storage_keys_iter( &self, hash: ::Hash, child_info: ChildInfo, - prefix: Option<&'a StorageKey>, + prefix: Option<&StorageKey>, start_key: Option<&StorageKey>, - ) -> sp_blockchain::Result> { + ) -> sp_blockchain::Result> { let state = self.state_at(hash)?; let start_key = start_key.or(prefix).map(|key| key.0.clone()).unwrap_or_else(Vec::new); Ok(KeyIterator::new_child(state, child_info, prefix, start_key)) From 83cd8e8d600a15416fe8c7648c314606de236a4e Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Mon, 16 Jan 2023 13:57:28 +0000 Subject: [PATCH 2/8] Stream keys in `storage_size` RPC and add a runtime limit --- client/rpc-api/src/state/mod.rs | 4 +- client/rpc/Cargo.toml | 4 +- client/rpc/src/state/mod.rs | 19 +++- client/rpc/src/state/state_full.rs | 67 ++++++++++----- client/rpc/src/state/tests.rs | 7 +- client/rpc/src/state/utils.rs | 134 +++++++++++++++++++++++++++++ 6 files changed, 205 insertions(+), 30 deletions(-) create mode 100644 client/rpc/src/state/utils.rs diff --git a/client/rpc-api/src/state/mod.rs b/client/rpc-api/src/state/mod.rs index 40e208c2eba8d..323e6ad1d41a3 100644 --- a/client/rpc-api/src/state/mod.rs +++ b/client/rpc-api/src/state/mod.rs @@ -71,8 +71,8 @@ pub trait StateApi { fn storage_hash(&self, key: StorageKey, hash: Option) -> RpcResult>; /// Returns the size of a storage entry at a block's state. - #[method(name = "state_getStorageSize", aliases = ["state_getStorageSizeAt"], blocking)] - fn storage_size(&self, key: StorageKey, hash: Option) -> RpcResult>; + #[method(name = "state_getStorageSize", aliases = ["state_getStorageSizeAt"])] + async fn storage_size(&self, key: StorageKey, hash: Option) -> RpcResult>; /// Returns the runtime metadata as an opaque blob. #[method(name = "state_getMetadata", blocking)] diff --git a/client/rpc/Cargo.toml b/client/rpc/Cargo.toml index d97170ddf42e6..c4cc2acc2be0f 100644 --- a/client/rpc/Cargo.toml +++ b/client/rpc/Cargo.toml @@ -36,7 +36,7 @@ sp-runtime = { version = "7.0.0", path = "../../primitives/runtime" } sp-session = { version = "4.0.0-dev", path = "../../primitives/session" } sp-version = { version = "5.0.0", path = "../../primitives/version" } -tokio = { version = "1.22.0", optional = true } +tokio = { version = "1.22.0" } [dev-dependencies] env_logger = "0.9" @@ -51,4 +51,4 @@ sp-io = { version = "7.0.0", path = "../../primitives/io" } substrate-test-runtime-client = { version = "2.0.0", path = "../../test-utils/runtime/client" } [features] -test-helpers = ["tokio"] +test-helpers = [] diff --git a/client/rpc/src/state/mod.rs b/client/rpc/src/state/mod.rs index fd802e5a80391..9ba4c8218dd79 100644 --- a/client/rpc/src/state/mod.rs +++ b/client/rpc/src/state/mod.rs @@ -19,6 +19,7 @@ //! Substrate state API. mod state_full; +mod utils; #[cfg(test)] mod tests; @@ -28,7 +29,7 @@ use std::sync::Arc; use crate::SubscriptionTaskExecutor; use jsonrpsee::{ - core::{server::rpc_module::SubscriptionSink, Error as JsonRpseeError, RpcResult}, + core::{async_trait, server::rpc_module::SubscriptionSink, Error as JsonRpseeError, RpcResult}, types::SubscriptionResult, }; @@ -53,6 +54,7 @@ use sp_blockchain::{HeaderBackend, HeaderMetadata}; const STORAGE_KEYS_PAGED_MAX_COUNT: u32 = 1000; /// State backend API. +#[async_trait] pub trait StateBackend: Send + Sync + 'static where Block: BlockT + 'static, @@ -107,10 +109,11 @@ where /// /// If data is available at `key`, it is returned. Else, the sum of values who's key has `key` /// prefix is returned, i.e. all the storage (double) maps that have this prefix. - fn storage_size( + async fn storage_size( &self, block: Option, key: StorageKey, + deny_unsafe: DenyUnsafe, ) -> Result, Error>; /// Returns the runtime metadata as an opaque blob. @@ -202,6 +205,7 @@ pub struct State { deny_unsafe: DenyUnsafe, } +#[async_trait] impl StateApiServer for State where Block: BlockT + 'static, @@ -262,8 +266,15 @@ where self.backend.storage_hash(block, key).map_err(Into::into) } - fn storage_size(&self, key: StorageKey, block: Option) -> RpcResult> { - self.backend.storage_size(block, key).map_err(Into::into) + async fn storage_size( + &self, + key: StorageKey, + block: Option, + ) -> RpcResult> { + self.backend + .storage_size(block, key, self.deny_unsafe) + .await + .map_err(Into::into) } fn metadata(&self, block: Option) -> RpcResult { diff --git a/client/rpc/src/state/state_full.rs b/client/rpc/src/state/state_full.rs index 58aeac66e5c79..38f1367e68e79 100644 --- a/client/rpc/src/state/state_full.rs +++ b/client/rpc/src/state/state_full.rs @@ -18,17 +18,20 @@ //! State API backend for full nodes. -use std::{collections::HashMap, marker::PhantomData, sync::Arc}; +use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration}; use super::{ client_err, error::{Error, Result}, ChildStateBackend, StateBackend, }; -use crate::SubscriptionTaskExecutor; +use crate::{DenyUnsafe, SubscriptionTaskExecutor}; use futures::{future, stream, FutureExt, StreamExt}; -use jsonrpsee::{core::Error as JsonRpseeError, SubscriptionSink}; +use jsonrpsee::{ + core::{async_trait, Error as JsonRpseeError}, + SubscriptionSink, +}; use sc_client_api::{ Backend, BlockBackend, BlockchainEvents, CallExecutor, ExecutorProvider, ProofProvider, StorageProvider, @@ -48,6 +51,9 @@ use sp_core::{ use sp_runtime::{generic::BlockId, traits::Block as BlockT}; use sp_version::RuntimeVersion; +/// The maximum time allowed for an RPC call when running without unsafe RPC enabled. +const MAXIMUM_SAFE_RPC_CALL_TIMEOUT: Duration = Duration::from_secs(30); + /// Ranges to query in state_queryStorage. struct QueryStorageRange { /// Hashes of all the blocks in the range. @@ -166,6 +172,7 @@ where } } +#[async_trait] impl StateBackend for FullState where Block: BlockT + 'static, @@ -251,33 +258,53 @@ where .map_err(client_err) } - fn storage_size( + async fn storage_size( &self, block: Option, key: StorageKey, + deny_unsafe: DenyUnsafe, ) -> std::result::Result, Error> { let block = match self.block_or_best(block) { Ok(b) => b, Err(e) => return Err(client_err(e)), }; - match self.client.storage(block, &key) { - Ok(Some(d)) => return Ok(Some(d.0.len() as u64)), - Err(e) => return Err(client_err(e)), - Ok(None) => {}, - } + let client = self.client.clone(); + let timeout = match deny_unsafe { + DenyUnsafe::Yes => Some(MAXIMUM_SAFE_RPC_CALL_TIMEOUT), + DenyUnsafe::No => None, + }; - self.client - .storage_pairs(block, &key) - .map(|kv| { - let item_sum = kv.iter().map(|(_, v)| v.0.len() as u64).sum::(); - if item_sum > 0 { - Some(item_sum) - } else { - None - } - }) - .map_err(client_err) + super::utils::spawn_blocking_with_timeout(timeout, move |is_cancelled| { + // Does the key point to a concrete entry in the database? + match client.storage(block, &key) { + Ok(Some(d)) => return Ok(Ok(Some(d.0.len() as u64))), + Err(e) => return Ok(Err(client_err(e))), + Ok(None) => {}, + } + + // The key doesn't point to anything, so it's probably a prefix. + let iter = match client.storage_keys_iter(block, Some(&key), None).map_err(client_err) { + Ok(iter) => iter, + Err(e) => return Ok(Err(e)), + }; + + let mut sum = 0; + for storage_key in iter { + let value = client.storage(block, &storage_key).ok().flatten().unwrap_or_default(); + sum += value.0.len() as u64; + + is_cancelled.check_if_cancelled()?; + } + + if sum > 0 { + Ok(Ok(Some(sum))) + } else { + Ok(Ok(None)) + } + }) + .await + .map_err(|error| Error::Client(Box::new(error)))? } fn storage_hash( diff --git a/client/rpc/src/state/tests.rs b/client/rpc/src/state/tests.rs index 3ef59e5ca9a7c..fe8bdf0ac2da8 100644 --- a/client/rpc/src/state/tests.rs +++ b/client/rpc/src/state/tests.rs @@ -70,9 +70,12 @@ async fn should_return_storage() { client.storage_hash(key.clone(), Some(genesis_hash).into()).map(|x| x.is_some()), Ok(true) ); - assert_eq!(client.storage_size(key.clone(), None).unwrap().unwrap() as usize, VALUE.len(),); assert_eq!( - client.storage_size(StorageKey(b":map".to_vec()), None).unwrap().unwrap() as usize, + client.storage_size(key.clone(), None).await.unwrap().unwrap() as usize, + VALUE.len(), + ); + assert_eq!( + client.storage_size(StorageKey(b":map".to_vec()), None).await.unwrap().unwrap() as usize, 2 + 3, ); assert_eq!( diff --git a/client/rpc/src/state/utils.rs b/client/rpc/src/state/utils.rs new file mode 100644 index 0000000000000..d075f1f42fa7c --- /dev/null +++ b/client/rpc/src/state/utils.rs @@ -0,0 +1,134 @@ +// This file is part of Substrate. + +// Copyright (C) 2017-2023 Parity Technologies (UK) Ltd. +// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0 + +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; + +/// An error signifying that a task has been cancelled due to a timeout. +#[derive(Debug)] +pub struct Cancelled; + +impl std::error::Error for Cancelled {} +impl std::fmt::Display for Cancelled { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.write_str("task has been running too long") + } +} + +/// A handle which can be used to check whether the task has been cancelled. +#[repr(transparent)] +pub struct IsCancelled(Arc); + +impl IsCancelled { + #[must_use] + pub fn check_if_cancelled(&self) -> std::result::Result<(), Cancelled> { + if self.0.load(Ordering::Relaxed) { + Err(Cancelled) + } else { + Ok(()) + } + } +} + +/// An error for a task which either panicked, or has been cancelled. +#[derive(Debug)] +pub enum SpawnWithTimeoutError { + JoinError(tokio::task::JoinError), + Cancelled, +} + +impl std::error::Error for SpawnWithTimeoutError {} +impl std::fmt::Display for SpawnWithTimeoutError { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + SpawnWithTimeoutError::JoinError(error) => error.fmt(fmt), + SpawnWithTimeoutError::Cancelled => Cancelled.fmt(fmt), + } + } +} + +struct CancelOnDrop(Arc); +impl Drop for CancelOnDrop { + fn drop(&mut self) { + self.0.store(true, Ordering::Relaxed) + } +} + +/// Spawns a new blocking task with a given `timeout`. +/// +/// The `callback` should continuously call [`IsCancelled::check_if_cancelled`], +/// which will return an error once the task runs for longer than `timeout`. +/// +/// If `timeout` is `None` then this works just as a regular `spawn_blocking`. +pub async fn spawn_blocking_with_timeout( + timeout: Option, + callback: impl FnOnce(IsCancelled) -> std::result::Result + Send + 'static, +) -> Result +where + R: Send + 'static, +{ + let is_cancelled_arc = Arc::new(AtomicBool::new(false)); + let is_cancelled = IsCancelled(is_cancelled_arc.clone()); + let _cancel_on_drop = CancelOnDrop(is_cancelled_arc); + let task = tokio::task::spawn_blocking(move || callback(is_cancelled)); + + let result; + if let Some(timeout) = timeout { + result = tokio::select! { + biased; + + task_result = task => task_result, + _ = tokio::time::sleep(timeout) => Ok(Err(Cancelled)) + }; + } else { + result = task.await; + } + + match result { + Ok(Ok(result)) => Ok(result), + Ok(Err(Cancelled)) => Err(SpawnWithTimeoutError::Cancelled), + Err(error) => Err(SpawnWithTimeoutError::JoinError(error)), + } +} + +#[tokio::test] +async fn spawn_blocking_with_timeout_works() { + let task: Result<(), SpawnWithTimeoutError> = + spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_cancelled| { + std::thread::sleep(Duration::from_millis(200)); + is_cancelled.check_if_cancelled()?; + unreachable!(); + }) + .await; + + assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Cancelled)); + + let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_cancelled| { + std::thread::sleep(Duration::from_millis(20)); + is_cancelled.check_if_cancelled()?; + Ok(()) + }) + .await; + + assert_matches::assert_matches!(task, Ok(())); +} From 027a718e7800e1d6050fc14c3fd5339eebfb8bbd Mon Sep 17 00:00:00 2001 From: Koute Date: Tue, 17 Jan 2023 16:11:49 +0900 Subject: [PATCH 3/8] Update client/rpc/Cargo.toml MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Bastian Köcher --- client/rpc/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/rpc/Cargo.toml b/client/rpc/Cargo.toml index c4cc2acc2be0f..d39d8d8581ff6 100644 --- a/client/rpc/Cargo.toml +++ b/client/rpc/Cargo.toml @@ -36,7 +36,7 @@ sp-runtime = { version = "7.0.0", path = "../../primitives/runtime" } sp-session = { version = "4.0.0-dev", path = "../../primitives/session" } sp-version = { version = "5.0.0", path = "../../primitives/version" } -tokio = { version = "1.22.0" } +tokio = "1.22.0" [dev-dependencies] env_logger = "0.9" From b83d731b2f544d906475435b889ba7a47fe627bb Mon Sep 17 00:00:00 2001 From: Koute Date: Tue, 17 Jan 2023 16:17:10 +0900 Subject: [PATCH 4/8] Update client/rpc/src/state/utils.rs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Bastian Köcher --- client/rpc/src/state/utils.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/client/rpc/src/state/utils.rs b/client/rpc/src/state/utils.rs index d075f1f42fa7c..8a431e9513623 100644 --- a/client/rpc/src/state/utils.rs +++ b/client/rpc/src/state/utils.rs @@ -92,17 +92,16 @@ where let _cancel_on_drop = CancelOnDrop(is_cancelled_arc); let task = tokio::task::spawn_blocking(move || callback(is_cancelled)); - let result; - if let Some(timeout) = timeout { - result = tokio::select! { + let result = if let Some(timeout) = timeout { + tokio::select! { biased; task_result = task => task_result, _ = tokio::time::sleep(timeout) => Ok(Err(Cancelled)) - }; + } } else { - result = task.await; - } + task.await + }; match result { Ok(Ok(result)) => Ok(result), From cdbb894c5a3b725cdacdb57f2150cb93ab466c7e Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 17 Jan 2023 08:52:07 +0000 Subject: [PATCH 5/8] Rename the types to signify that the cancellation is due to a timeout --- client/rpc/src/state/state_full.rs | 4 +-- client/rpc/src/state/utils.rs | 48 +++++++++++++++--------------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/client/rpc/src/state/state_full.rs b/client/rpc/src/state/state_full.rs index 38f1367e68e79..d8fe39030aa82 100644 --- a/client/rpc/src/state/state_full.rs +++ b/client/rpc/src/state/state_full.rs @@ -275,7 +275,7 @@ where DenyUnsafe::No => None, }; - super::utils::spawn_blocking_with_timeout(timeout, move |is_cancelled| { + super::utils::spawn_blocking_with_timeout(timeout, move |is_timed_out| { // Does the key point to a concrete entry in the database? match client.storage(block, &key) { Ok(Some(d)) => return Ok(Ok(Some(d.0.len() as u64))), @@ -294,7 +294,7 @@ where let value = client.storage(block, &storage_key).ok().flatten().unwrap_or_default(); sum += value.0.len() as u64; - is_cancelled.check_if_cancelled()?; + is_timed_out.check_if_timed_out()?; } if sum > 0 { diff --git a/client/rpc/src/state/utils.rs b/client/rpc/src/state/utils.rs index 8a431e9513623..714cf423fd747 100644 --- a/client/rpc/src/state/utils.rs +++ b/client/rpc/src/state/utils.rs @@ -26,35 +26,35 @@ use std::{ /// An error signifying that a task has been cancelled due to a timeout. #[derive(Debug)] -pub struct Cancelled; +pub struct Timeout; -impl std::error::Error for Cancelled {} -impl std::fmt::Display for Cancelled { +impl std::error::Error for Timeout {} +impl std::fmt::Display for Timeout { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fmt.write_str("task has been running too long") } } -/// A handle which can be used to check whether the task has been cancelled. +/// A handle which can be used to check whether the task has been cancelled due to a timeout. #[repr(transparent)] -pub struct IsCancelled(Arc); +pub struct IsTimedOut(Arc); -impl IsCancelled { +impl IsTimedOut { #[must_use] - pub fn check_if_cancelled(&self) -> std::result::Result<(), Cancelled> { + pub fn check_if_timed_out(&self) -> std::result::Result<(), Timeout> { if self.0.load(Ordering::Relaxed) { - Err(Cancelled) + Err(Timeout) } else { Ok(()) } } } -/// An error for a task which either panicked, or has been cancelled. +/// An error for a task which either panicked, or has been cancelled due to a timeout. #[derive(Debug)] pub enum SpawnWithTimeoutError { JoinError(tokio::task::JoinError), - Cancelled, + Timeout, } impl std::error::Error for SpawnWithTimeoutError {} @@ -62,7 +62,7 @@ impl std::fmt::Display for SpawnWithTimeoutError { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { match self { SpawnWithTimeoutError::JoinError(error) => error.fmt(fmt), - SpawnWithTimeoutError::Cancelled => Cancelled.fmt(fmt), + SpawnWithTimeoutError::Timeout => Timeout.fmt(fmt), } } } @@ -76,28 +76,28 @@ impl Drop for CancelOnDrop { /// Spawns a new blocking task with a given `timeout`. /// -/// The `callback` should continuously call [`IsCancelled::check_if_cancelled`], +/// The `callback` should continuously call [`IsTimedOut::check_if_timed_out`], /// which will return an error once the task runs for longer than `timeout`. /// /// If `timeout` is `None` then this works just as a regular `spawn_blocking`. pub async fn spawn_blocking_with_timeout( timeout: Option, - callback: impl FnOnce(IsCancelled) -> std::result::Result + Send + 'static, + callback: impl FnOnce(IsTimedOut) -> std::result::Result + Send + 'static, ) -> Result where R: Send + 'static, { - let is_cancelled_arc = Arc::new(AtomicBool::new(false)); - let is_cancelled = IsCancelled(is_cancelled_arc.clone()); - let _cancel_on_drop = CancelOnDrop(is_cancelled_arc); - let task = tokio::task::spawn_blocking(move || callback(is_cancelled)); + let is_timed_out_arc = Arc::new(AtomicBool::new(false)); + let is_timed_out = IsTimedOut(is_timed_out_arc.clone()); + let _cancel_on_drop = CancelOnDrop(is_timed_out_arc); + let task = tokio::task::spawn_blocking(move || callback(is_timed_out)); let result = if let Some(timeout) = timeout { tokio::select! { biased; task_result = task => task_result, - _ = tokio::time::sleep(timeout) => Ok(Err(Cancelled)) + _ = tokio::time::sleep(timeout) => Ok(Err(Timeout)) } } else { task.await @@ -105,7 +105,7 @@ where match result { Ok(Ok(result)) => Ok(result), - Ok(Err(Cancelled)) => Err(SpawnWithTimeoutError::Cancelled), + Ok(Err(Timeout)) => Err(SpawnWithTimeoutError::Timeout), Err(error) => Err(SpawnWithTimeoutError::JoinError(error)), } } @@ -113,18 +113,18 @@ where #[tokio::test] async fn spawn_blocking_with_timeout_works() { let task: Result<(), SpawnWithTimeoutError> = - spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_cancelled| { + spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| { std::thread::sleep(Duration::from_millis(200)); - is_cancelled.check_if_cancelled()?; + is_timed_out.check_if_timed_out()?; unreachable!(); }) .await; - assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Cancelled)); + assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout)); - let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_cancelled| { + let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| { std::thread::sleep(Duration::from_millis(20)); - is_cancelled.check_if_cancelled()?; + is_timed_out.check_if_timed_out()?; Ok(()) }) .await; From e3f65551b43e3d3df75283a8ddb0c06f3bf33042 Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 17 Jan 2023 08:54:41 +0000 Subject: [PATCH 6/8] Move the test into a `mod tests` --- client/rpc/src/state/utils.rs | 37 ++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/client/rpc/src/state/utils.rs b/client/rpc/src/state/utils.rs index 714cf423fd747..775061db1668f 100644 --- a/client/rpc/src/state/utils.rs +++ b/client/rpc/src/state/utils.rs @@ -110,24 +110,29 @@ where } } -#[tokio::test] -async fn spawn_blocking_with_timeout_works() { - let task: Result<(), SpawnWithTimeoutError> = - spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| { - std::thread::sleep(Duration::from_millis(200)); +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn spawn_blocking_with_timeout_works() { + let task: Result<(), SpawnWithTimeoutError> = + spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| { + std::thread::sleep(Duration::from_millis(200)); + is_timed_out.check_if_timed_out()?; + unreachable!(); + }) + .await; + + assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout)); + + let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| { + std::thread::sleep(Duration::from_millis(20)); is_timed_out.check_if_timed_out()?; - unreachable!(); + Ok(()) }) .await; - assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout)); - - let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| { - std::thread::sleep(Duration::from_millis(20)); - is_timed_out.check_if_timed_out()?; - Ok(()) - }) - .await; - - assert_matches::assert_matches!(task, Ok(())); + assert_matches::assert_matches!(task, Ok(())); + } } From 8ed99edbe8e5c756438e36635374c43b784684b8 Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 17 Jan 2023 08:57:59 +0000 Subject: [PATCH 7/8] Add a comment regarding `biased` in `tokio::select` --- client/rpc/src/state/utils.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/rpc/src/state/utils.rs b/client/rpc/src/state/utils.rs index 775061db1668f..81476cdd34262 100644 --- a/client/rpc/src/state/utils.rs +++ b/client/rpc/src/state/utils.rs @@ -94,6 +94,8 @@ where let result = if let Some(timeout) = timeout { tokio::select! { + // Shouldn't really matter, but make sure the task is polled before the timeout, + // in case the task finishes after the timeout and the timeout is really short. biased; task_result = task => task_result, From 2868c174ec220976d250f202c0ab77edb7c6dd3d Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 17 Jan 2023 09:04:19 +0000 Subject: [PATCH 8/8] Make the `clone` explicit when calling `KeyIterator::{new, new_child}` --- client/api/src/backend.rs | 20 ++++---------------- client/service/src/client/client.rs | 4 ++-- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/client/api/src/backend.rs b/client/api/src/backend.rs index ed89f71645835..21d213ffb15cf 100644 --- a/client/api/src/backend.rs +++ b/client/api/src/backend.rs @@ -313,30 +313,18 @@ pub struct KeyIterator { impl KeyIterator { /// create a KeyIterator instance - pub fn new(state: State, prefix: Option<&StorageKey>, current_key: Vec) -> Self { - Self { - state, - child_storage: None, - prefix: prefix.map(|prefix| prefix.clone()), - current_key, - _phantom: PhantomData, - } + pub fn new(state: State, prefix: Option, current_key: Vec) -> Self { + Self { state, child_storage: None, prefix, current_key, _phantom: PhantomData } } /// Create a `KeyIterator` instance for a child storage. pub fn new_child( state: State, child_info: ChildInfo, - prefix: Option<&StorageKey>, + prefix: Option, current_key: Vec, ) -> Self { - Self { - state, - child_storage: Some(child_info), - prefix: prefix.map(|prefix| prefix.clone()), - current_key, - _phantom: PhantomData, - } + Self { state, child_storage: Some(child_info), prefix, current_key, _phantom: PhantomData } } } diff --git a/client/service/src/client/client.rs b/client/service/src/client/client.rs index 582b462846baf..8e10a7b2eda0a 100644 --- a/client/service/src/client/client.rs +++ b/client/service/src/client/client.rs @@ -1440,7 +1440,7 @@ where ) -> sp_blockchain::Result> { let state = self.state_at(hash)?; let start_key = start_key.or(prefix).map(|key| key.0.clone()).unwrap_or_else(Vec::new); - Ok(KeyIterator::new(state, prefix, start_key)) + Ok(KeyIterator::new(state, prefix.cloned(), start_key)) } fn child_storage_keys_iter( @@ -1452,7 +1452,7 @@ where ) -> sp_blockchain::Result> { let state = self.state_at(hash)?; let start_key = start_key.or(prefix).map(|key| key.0.clone()).unwrap_or_else(Vec::new); - Ok(KeyIterator::new_child(state, child_info, prefix, start_key)) + Ok(KeyIterator::new_child(state, child_info, prefix.cloned(), start_key)) } fn storage(