Skip to content
This repository was archived by the owner on Nov 15, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
rename: MsgFilter -> MessageInterceptor
  • Loading branch information
drahnr committed Sep 14, 2021
commit a5c1027a9bc80cfb40312a11b7deac99a93d7704
62 changes: 49 additions & 13 deletions node/malus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,28 @@ pub use polkadot_node_subsystem::{messages::AllMessages, overseer, FromOverseer}
use std::{future::Future, pin::Pin};

/// Filter incoming and outgoing messages.
pub trait MsgFilter: Send + Sync + Clone + 'static {
pub trait MessageInterceptor<Sender>: Send + Sync + Clone + 'static
where
Sender: overseer::SubsystemSender<Self::Message> + Clone + 'static,
{
/// The message type the original subsystem handles incoming.
type Message: Send + 'static;

/// Filter messages that are to be received by
/// the subsystem.
fn filter_in(&self, msg: FromOverseer<Self::Message>) -> Option<FromOverseer<Self::Message>> {
///
/// For non-trivial cases, the `sender` can be used to send
/// multiple messages after doing some additional processing.
fn intercept_incoming(
&self,
_sender: &mut Sender,
msg: FromOverseer<Self::Message>,
) -> Option<FromOverseer<Self::Message>> {
Some(msg)
}

/// Modify outgoing messages.
fn filter_out(&self, msg: AllMessages) -> Option<AllMessages> {
fn intercept_outgoing(&self, msg: AllMessages) -> Option<AllMessages> {
Some(msg)
}
}
Expand All @@ -51,11 +61,12 @@ pub struct FilteredSender<Sender, Fil> {
#[async_trait::async_trait]
impl<Sender, Fil> overseer::SubsystemSender<AllMessages> for FilteredSender<Sender, Fil>
where
Sender: overseer::SubsystemSender<AllMessages>,
Fil: MsgFilter,
Sender: overseer::SubsystemSender<AllMessages>
+ overseer::SubsystemSender<<Fil as MessageInterceptor<Sender>>::Message>,
Fil: MessageInterceptor<Sender>,
{
async fn send_message(&mut self, msg: AllMessages) {
if let Some(msg) = self.message_filter.filter_out(msg) {
if let Some(msg) = self.message_filter.intercept_outgoing(msg) {
self.inner.send_message(msg).await;
}
}
Expand All @@ -71,14 +82,21 @@ where
}

fn send_unbounded_message(&mut self, msg: AllMessages) {
if let Some(msg) = self.message_filter.filter_out(msg) {
if let Some(msg) = self.message_filter.intercept_outgoing(msg) {
self.inner.send_unbounded_message(msg);
}
}
}

/// A subsystem context, that filters the outgoing messages.
pub struct FilteredContext<Context: overseer::SubsystemContext + SubsystemContext, Fil: MsgFilter> {
pub struct FilteredContext<Context, Fil>
where
Context: overseer::SubsystemContext + SubsystemContext,
Fil: MessageInterceptor<<Context as overseer::SubsystemContext>::Sender>,
<Context as overseer::SubsystemContext>::Sender: overseer::SubsystemSender<
<Fil as MessageInterceptor<<Context as overseer::SubsystemContext>::Sender>>::Message,
>,
{
inner: Context,
message_filter: Fil,
sender: FilteredSender<<Context as overseer::SubsystemContext>::Sender, Fil>,
Expand All @@ -87,7 +105,13 @@ pub struct FilteredContext<Context: overseer::SubsystemContext + SubsystemContex
impl<Context, Fil> FilteredContext<Context, Fil>
where
Context: overseer::SubsystemContext + SubsystemContext,
Fil: MsgFilter<Message = <Context as overseer::SubsystemContext>::Message>,
Fil: MessageInterceptor<
<Context as overseer::SubsystemContext>::Sender,
Message = <Context as overseer::SubsystemContext>::Message,
>,
<Context as overseer::SubsystemContext>::Sender: overseer::SubsystemSender<
<Fil as MessageInterceptor<<Context as overseer::SubsystemContext>::Sender>>::Message,
>,
{
pub fn new(mut inner: Context, message_filter: Fil) -> Self {
let sender = FilteredSender::<<Context as overseer::SubsystemContext>::Sender, Fil> {
Expand All @@ -102,9 +126,15 @@ where
impl<Context, Fil> overseer::SubsystemContext for FilteredContext<Context, Fil>
where
Context: overseer::SubsystemContext + SubsystemContext,
Fil: MsgFilter<Message = <Context as overseer::SubsystemContext>::Message>,
Fil: MessageInterceptor<
<Context as overseer::SubsystemContext>::Sender,
Message = <Context as overseer::SubsystemContext>::Message,
>,
<Context as overseer::SubsystemContext>::AllMessages:
From<<Context as overseer::SubsystemContext>::Message>,
<Context as overseer::SubsystemContext>::Sender: overseer::SubsystemSender<
<Fil as MessageInterceptor<<Context as overseer::SubsystemContext>::Sender>>::Message,
>,
{
type Message = <Context as overseer::SubsystemContext>::Message;
type Sender = FilteredSender<<Context as overseer::SubsystemContext>::Sender, Fil>;
Expand All @@ -117,7 +147,7 @@ where
match self.inner.try_recv().await? {
None => return Ok(None),
Some(msg) =>
if let Some(msg) = self.message_filter.filter_in(msg) {
if let Some(msg) = self.message_filter.intercept_incoming(self.inner.sender(), msg) {
return Ok(Some(msg))
},
}
Expand All @@ -127,7 +157,7 @@ where
async fn recv(&mut self) -> SubsystemResult<FromOverseer<Self::Message>> {
loop {
let msg = self.inner.recv().await?;
if let Some(msg) = self.message_filter.filter_in(msg) {
if let Some(msg) = self.message_filter.intercept_incoming(self.inner.sender(), msg) {
return Ok(msg)
}
}
Expand Down Expand Up @@ -171,7 +201,13 @@ where
Context: overseer::SubsystemContext + SubsystemContext + Sync + Send,
Sub: overseer::Subsystem<FilteredContext<Context, Fil>, SubsystemError>,
FilteredContext<Context, Fil>: overseer::SubsystemContext + SubsystemContext,
Fil: MsgFilter<Message = <Context as overseer::SubsystemContext>::Message>,
Fil: MessageInterceptor<
<Context as overseer::SubsystemContext>::Sender,
Message = <Context as overseer::SubsystemContext>::Message,
>,
<Context as overseer::SubsystemContext>::Sender: overseer::SubsystemSender<
<Fil as MessageInterceptor<<Context as overseer::SubsystemContext>::Sender>>::Message,
>,
{
fn start(self, ctx: Context) -> SpawnedSubsystem {
let ctx = FilteredContext::new(ctx, self.message_filter);
Expand Down
13 changes: 10 additions & 3 deletions node/malus/src/variant-a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,24 @@ use structopt::StructOpt;
#[derive(Clone, Default, Debug)]
struct Skippy(Arc<AtomicUsize>);

impl MsgFilter for Skippy {
impl<Sender> MessageInterceptor<Sender> for Skippy
where
Sender: SubsystemSender<AllMessages> + SubsystemSender<Self::Message> + Clone + 'static,
{
type Message = CandidateValidationMessage;

fn filter_in(&self, msg: FromOverseer<Self::Message>) -> Option<FromOverseer<Self::Message>> {
fn intercept_incoming(
&self,
_sender: &mut S,
msg: FromOverseer<Self::Message>,
) -> Option<FromOverseer<Self::Message>> {
if self.0.fetch_add(1, Ordering::Relaxed) % 2 == 0 {
Some(msg)
} else {
None
}
}
fn filter_out(&self, msg: AllMessages) -> Option<AllMessages> {
fn intercept_outgoing(&self, _sender: &mut S, msg: AllMessages) -> Option<AllMessages> {
Some(msg)
}
}
Expand Down