diff --git a/client/api/src/execution_extensions.rs b/client/api/src/execution_extensions.rs index 4f2ddb77e6653..dec8557a7ed23 100644 --- a/client/api/src/execution_extensions.rs +++ b/client/api/src/execution_extensions.rs @@ -168,16 +168,20 @@ impl ExecutionExtensions { if capabilities.has(offchain::Capability::TransactionPool) { if let Some(pool) = self.transaction_pool.read().as_ref().and_then(|x| x.upgrade()) { - extensions.register(TransactionPoolExt(Box::new(TransactionPoolAdapter { - at: *at, - pool, - }) as _)); + extensions.register( + TransactionPoolExt( + Box::new(TransactionPoolAdapter { + at: *at, + pool, + }) as _ + ), + ); } } if let ExecutionContext::OffchainCall(Some(ext)) = context { extensions.register( - OffchainExt::new(offchain::LimitedExternalities::new(capabilities, ext.0)) + OffchainExt::new(offchain::LimitedExternalities::new(capabilities, ext.0)), ); } diff --git a/primitives/externalities/src/extensions.rs b/primitives/externalities/src/extensions.rs index d79f99d3344ea..ee45148487d21 100644 --- a/primitives/externalities/src/extensions.rs +++ b/primitives/externalities/src/extensions.rs @@ -123,14 +123,25 @@ impl Extensions { } /// Register the given extension. - pub fn register(&mut self, ext: E) { - self.extensions.insert(ext.type_id(), Box::new(ext)); + pub fn register( + &mut self, + ext: E, + ) { + let type_id = ext.type_id(); + self.extensions.insert(type_id, Box::new(ext)); } - /// Register extension `ext`. - pub fn register_with_type_id(&mut self, type_id: TypeId, extension: Box) -> Result<(), Error> { + /// Register extension `extension` using the given `type_id`. + pub fn register_with_type_id( + &mut self, + type_id: TypeId, + extension: Box, + ) -> Result<(), Error> { match self.extensions.entry(type_id) { - Entry::Vacant(vacant) => { vacant.insert(extension); Ok(()) }, + Entry::Vacant(vacant) => { + vacant.insert(extension); + Ok(()) + }, Entry::Occupied(_) => Err(Error::ExtensionAlreadyRegistered), } } @@ -140,9 +151,16 @@ impl Extensions { self.extensions.get_mut(&ext_type_id).map(DerefMut::deref_mut).map(Extension::as_mut_any) } - /// Deregister extension of type `E`. - pub fn deregister(&mut self, type_id: TypeId) -> Option> { - self.extensions.remove(&type_id) + /// Deregister extension for the given `type_id`. + /// + /// Returns `true` when the extension was registered. + pub fn deregister(&mut self, type_id: TypeId) -> bool { + self.extensions.remove(&type_id).is_some() + } + + /// Returns a mutable iterator over all extensions. + pub fn iter_mut<'a>(&'a mut self) -> impl Iterator)> { + self.extensions.iter_mut() } } diff --git a/primitives/state-machine/src/basic.rs b/primitives/state-machine/src/basic.rs index 3db7a54750a02..76d53659db64b 100644 --- a/primitives/state-machine/src/basic.rs +++ b/primitives/state-machine/src/basic.rs @@ -348,10 +348,11 @@ impl sp_externalities::ExtensionStore for BasicExternalities { } fn deregister_extension_by_type_id(&mut self, type_id: TypeId) -> Result<(), sp_externalities::Error> { - self.extensions - .deregister(type_id) - .ok_or(sp_externalities::Error::ExtensionIsNotRegistered(type_id)) - .map(drop) + if self.extensions.deregister(type_id) { + Ok(()) + } else { + Err(sp_externalities::Error::ExtensionIsNotRegistered(type_id)) + } } } diff --git a/primitives/state-machine/src/ext.rs b/primitives/state-machine/src/ext.rs index e9259f9a10bc1..53aab42999d5e 100644 --- a/primitives/state-machine/src/ext.rs +++ b/primitives/state-machine/src/ext.rs @@ -19,7 +19,7 @@ use crate::{ StorageKey, StorageValue, OverlayedChanges, - backend::Backend, + backend::Backend, overlayed_changes::OverlayedExtensions, }; use hash_db::Hasher; use sp_core::{ @@ -27,8 +27,9 @@ use sp_core::{ hexdisplay::HexDisplay, }; use sp_trie::{trie_types::Layout, empty_child_trie_root}; -use sp_externalities::{Externalities, Extensions, Extension, - ExtensionStore}; +use sp_externalities::{ + Externalities, Extensions, Extension, ExtensionStore, +}; use codec::{Decode, Encode, EncodeAppend}; use sp_std::{fmt, any::{Any, TypeId}, vec::Vec, vec, boxed::Box}; @@ -115,7 +116,7 @@ pub struct Ext<'a, H, N, B> _phantom: sp_std::marker::PhantomData, /// Extensions registered with this instance. #[cfg(feature = "std")] - extensions: Option<&'a mut Extensions>, + extensions: Option>, } @@ -159,7 +160,7 @@ impl<'a, H, N, B> Ext<'a, H, N, B> storage_transaction_cache, id: rand::random(), _phantom: Default::default(), - extensions, + extensions: extensions.map(OverlayedExtensions::new), } } @@ -753,7 +754,7 @@ where extension: Box, ) -> Result<(), sp_externalities::Error> { if let Some(ref mut extensions) = self.extensions { - extensions.register_with_type_id(type_id, extension) + extensions.register(type_id, extension) } else { Err(sp_externalities::Error::ExtensionsAreNotSupported) } @@ -761,9 +762,10 @@ where fn deregister_extension_by_type_id(&mut self, type_id: TypeId) -> Result<(), sp_externalities::Error> { if let Some(ref mut extensions) = self.extensions { - match extensions.deregister(type_id) { - Some(_) => Ok(()), - None => Err(sp_externalities::Error::ExtensionIsNotRegistered(type_id)) + if extensions.deregister(type_id) { + Ok(()) + } else { + Err(sp_externalities::Error::ExtensionIsNotRegistered(type_id)) } } else { Err(sp_externalities::Error::ExtensionsAreNotSupported) diff --git a/primitives/state-machine/src/lib.rs b/primitives/state-machine/src/lib.rs index 5b86640aa7d0e..28148b6411a13 100644 --- a/primitives/state-machine/src/lib.rs +++ b/primitives/state-machine/src/lib.rs @@ -31,7 +31,7 @@ mod ext; mod testing; #[cfg(feature = "std")] mod basic; -mod overlayed_changes; +pub(crate) mod overlayed_changes; #[cfg(feature = "std")] mod proving_backend; mod trie_backend; @@ -907,7 +907,7 @@ mod tests { _method: &str, _data: &[u8], use_native: bool, - _native_call: Option, + native_call: Option, ) -> (CallResult, bool) { if self.change_changes_trie_config { ext.place_storage( @@ -922,8 +922,15 @@ mod tests { } let using_native = use_native && self.native_available; - match (using_native, self.native_succeeds, self.fallback_succeeds) { - (true, true, _) | (false, _, true) => { + match (using_native, self.native_succeeds, self.fallback_succeeds, native_call) { + (true, true, _, Some(call)) => { + let res = sp_externalities::set_and_run_with_externalities(ext, || call()); + ( + res.map(NativeOrEncoded::Native).map_err(|_| 0), + true + ) + }, + (true, true, _, None) | (false, _, true, None) => { ( Ok( NativeOrEncoded::Encoded( @@ -1473,4 +1480,51 @@ mod tests { overlay.commit_transaction().unwrap(); assert_eq!(overlay.storage(b"ccc"), Some(None)); } + + #[test] + fn runtime_registered_extensions_are_removed_after_execution() { + use sp_externalities::ExternalitiesExt; + sp_externalities::decl_extension! { + struct DummyExt(u32); + } + + let backend = trie_backend::tests::test_trie(); + let mut overlayed_changes = Default::default(); + let mut offchain_overlayed_changes = Default::default(); + let wasm_code = RuntimeCode::empty(); + + let mut state_machine = StateMachine::new( + &backend, + changes_trie::disabled_state::<_, u64>(), + &mut overlayed_changes, + &mut offchain_overlayed_changes, + &DummyCodeExecutor { + change_changes_trie_config: false, + native_available: true, + native_succeeds: true, + fallback_succeeds: false, + }, + "test", + &[], + Default::default(), + &wasm_code, + TaskExecutor::new(), + ); + + let run_state_machine = |state_machine: &mut StateMachine<_, _, _, _>| { + state_machine.execute_using_consensus_failure_handler:: _, _, _>( + ExecutionManager::NativeWhenPossible, + Some(|| { + sp_externalities::with_externalities(|mut ext| { + ext.register_extension(DummyExt(2)).unwrap(); + }).unwrap(); + + Ok(()) + }), + ).unwrap(); + }; + + run_state_machine(&mut state_machine); + run_state_machine(&mut state_machine); + } } diff --git a/primitives/state-machine/src/overlayed_changes/mod.rs b/primitives/state-machine/src/overlayed_changes/mod.rs index 992f7b3519299..6ef09fc81505d 100644 --- a/primitives/state-machine/src/overlayed_changes/mod.rs +++ b/primitives/state-machine/src/overlayed_changes/mod.rs @@ -23,7 +23,7 @@ use crate::{ backend::Backend, stats::StateMachineStats, }; -use sp_std::vec::Vec; +use sp_std::{vec::Vec, any::{TypeId, Any}, boxed::Box}; use self::changeset::OverlayedChangeSet; #[cfg(feature = "std")] @@ -36,9 +36,9 @@ use crate::{ }; use crate::changes_trie::BlockNumber; #[cfg(feature = "std")] -use std::collections::HashMap as Map; +use std::collections::{HashMap as Map, hash_map::Entry as MapEntry}; #[cfg(not(feature = "std"))] -use sp_std::collections::btree_map::BTreeMap as Map; +use sp_std::collections::btree_map::{BTreeMap as Map, Entry as MapEntry}; use sp_std::collections::btree_set::BTreeSet; use codec::{Decode, Encode}; use sp_core::storage::{well_known_keys::EXTRINSIC_INDEX, ChildInfo}; @@ -46,6 +46,7 @@ use sp_core::storage::{well_known_keys::EXTRINSIC_INDEX, ChildInfo}; use sp_core::offchain::storage::OffchainOverlayedChanges; use hash_db::Hasher; use crate::DefaultError; +use sp_externalities::{Extensions, Extension}; pub use self::changeset::{OverlayedValue, NoOpenTransaction, AlreadyInRuntime, NotInRuntime}; @@ -638,7 +639,7 @@ fn retain_map(map: &mut Map, f: F) { map.retain(f); } - + #[cfg(not(feature = "std"))] fn retain_map(map: &mut Map, mut f: F) where @@ -652,7 +653,67 @@ fn retain_map(map: &mut Map, mut f: F) } } } - + +/// An overlayed extension is either a mutable reference +/// or an owned extension. +pub enum OverlayedExtension<'a> { + MutRef(&'a mut Box), + Owned(Box), +} + +/// Overlayed extensions which are sourced from [`Extensions`]. +/// +/// The sourced extensions will be stored as mutable references, +/// while extensions that are registered while execution are stored +/// as owned references. After the execution of a runtime function, we +/// can safely drop this object while not having modified the original +/// list. +pub struct OverlayedExtensions<'a> { + extensions: Map>, +} + +impl<'a> OverlayedExtensions<'a> { + /// Create a new instance of overalyed extensions from the given extensions. + pub fn new(extensions: &'a mut Extensions) -> Self { + Self { + extensions: extensions + .iter_mut() + .map(|(k, v)| (*k, OverlayedExtension::MutRef(v))) + .collect(), + } + } + + /// Return a mutable reference to the requested extension. + pub fn get_mut(&mut self, ext_type_id: TypeId) -> Option<&mut dyn Any> { + self.extensions.get_mut(&ext_type_id).map(|ext| match ext { + OverlayedExtension::MutRef(ext) => ext.as_mut_any(), + OverlayedExtension::Owned(ext) => ext.as_mut_any(), + }) + } + + /// Register extension `extension` with the given `type_id`. + pub fn register( + &mut self, + type_id: TypeId, + extension: Box, + ) -> Result<(), sp_externalities::Error> { + match self.extensions.entry(type_id) { + MapEntry::Vacant(vacant) => { + vacant.insert(OverlayedExtension::Owned(extension)); + Ok(()) + }, + MapEntry::Occupied(_) => Err(sp_externalities::Error::ExtensionAlreadyRegistered), + } + } + + /// Deregister extension with the given `type_id`. + /// + /// Returns `true` when there was an extension registered for the given `type_id`. + pub fn deregister(&mut self, type_id: TypeId) -> bool { + self.extensions.remove(&type_id).is_some() + } +} + #[cfg(test)] mod tests { use hex_literal::hex; diff --git a/primitives/state-machine/src/testing.rs b/primitives/state-machine/src/testing.rs index be7dc6df9de9a..de68d7e415cdd 100644 --- a/primitives/state-machine/src/testing.rs +++ b/primitives/state-machine/src/testing.rs @@ -233,10 +233,11 @@ impl sp_externalities::ExtensionStore for TestExternalities where } fn deregister_extension_by_type_id(&mut self, type_id: TypeId) -> Result<(), sp_externalities::Error> { - self.extensions - .deregister(type_id) - .expect("There should be an extension we try to remove in TestExternalities"); - Ok(()) + if self.extensions.deregister(type_id) { + Ok(()) + } else { + Err(sp_externalities::Error::ExtensionIsNotRegistered(type_id)) + } } }