diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 0a68638f54..2cd8e3ca7b 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -60,6 +60,11 @@ pub type SyncMethod = Arc M /// Similar to [`SyncMethod`], but represents an asynchronous handler. pub type AsyncMethod<'a> = Arc, Params<'a>, ConnectionId, MaxResponseSize) -> BoxFuture<'a, MethodResponse>>; + +/// Similar to [`AsyncMethod`], but represents an asynchronous handler with connection details. +#[doc(hidden)] +pub type AsyncMethodWithDetails<'a> = + Arc, Params<'a>, ConnectionDetails, MaxResponseSize) -> BoxFuture<'a, MethodResponse>>; /// Method callback for subscriptions. pub type SubscriptionMethod<'a> = Arc BoxFuture<'a, MethodResponse>>; @@ -79,6 +84,27 @@ pub type MaxResponseSize = usize; /// - a [`mpsc::UnboundedReceiver`] to receive future subscription results pub type RawRpcResponse = (String, mpsc::Receiver); +/// The connection details exposed to the server methods. +#[derive(Debug, Clone)] +#[allow(missing_copy_implementations)] +#[doc(hidden)] +pub struct ConnectionDetails { + id: ConnectionId, +} + +impl ConnectionDetails { + /// Construct a new [`ConnectionDetails`]. + #[doc(hidden)] + pub fn _new(id: ConnectionId) -> ConnectionDetails { + Self { id } + } + + /// Get the connection ID. + pub fn id(&self) -> ConnectionId { + self.id + } +} + /// The error that can occur when [`Methods::call`] or [`Methods::subscribe`] is invoked. #[derive(thiserror::Error, Debug)] pub enum MethodsError { @@ -131,6 +157,9 @@ pub enum MethodCallback { Sync(SyncMethod), /// Asynchronous method handler. Async(AsyncMethod<'static>), + /// Asynchronous method handler with details. + #[doc(hidden)] + AsyncWithDetails(AsyncMethodWithDetails<'static>), /// Subscription method handler. Subscription(SubscriptionMethod<'static>), /// Unsubscription method handler. @@ -184,6 +213,7 @@ impl Debug for MethodCallback { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Async(_) => write!(f, "Async"), + Self::AsyncWithDetails(_) => write!(f, "AsyncWithDetails"), Self::Sync(_) => write!(f, "Sync"), Self::Subscription(_) => write!(f, "Subscription"), Self::Unsubscription(_) => write!(f, "Unsubscription"), @@ -355,6 +385,9 @@ impl Methods { None => MethodResponse::error(req.id, ErrorObject::from(ErrorCode::MethodNotFound)), Some(MethodCallback::Sync(cb)) => (cb)(id, params, usize::MAX), Some(MethodCallback::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), 0, usize::MAX).await, + Some(MethodCallback::AsyncWithDetails(cb)) => { + (cb)(id.into_owned(), params.into_owned(), ConnectionDetails::_new(0), usize::MAX).await + } Some(MethodCallback::Subscription(cb)) => { let conn_state = SubscriptionState { conn_id: 0, id_provider: &RandomIntegerIdProvider, subscription_permit }; @@ -598,6 +631,43 @@ impl RpcModule { Ok(callback) } + /// Register a new raw RPC method, which computes the response with the given callback. + /// + /// ## Examples + /// + /// ``` + /// use jsonrpsee_core::server::RpcModule; + /// + /// let mut module = RpcModule::new(()); + /// module.register_async_method_with_details("say_hello", |_params, _connection_details, _ctx| async { "lo" }).unwrap(); + /// ``` + #[doc(hidden)] + pub fn register_async_method_with_details( + &mut self, + method_name: &'static str, + callback: Fun, + ) -> Result<&mut MethodCallback, RegisterMethodError> + where + R: IntoResponse + 'static, + Fut: Future + Send, + Fun: (Fn(Params<'static>, ConnectionDetails, Arc) -> Fut) + Clone + Send + Sync + 'static, + { + let ctx = self.ctx.clone(); + self.methods.verify_and_insert( + method_name, + MethodCallback::AsyncWithDetails(Arc::new(move |id, params, connection_details, max_response_size| { + let ctx = ctx.clone(); + let callback = callback.clone(); + + let future = async move { + let rp = callback(params, connection_details, ctx).await.into_response(); + MethodResponse::response(id, rp, max_response_size) + }; + future.boxed() + })), + ) + } + /// Register a new publish/subscribe interface using JSON-RPC notifications. /// /// It implements the [ethereum pubsub specification](https://geth.ethereum.org/docs/rpc/pubsub) diff --git a/examples/examples/server_with_connection_details.rs b/examples/examples/server_with_connection_details.rs new file mode 100644 index 0000000000..26464e9504 --- /dev/null +++ b/examples/examples/server_with_connection_details.rs @@ -0,0 +1,126 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::net::SocketAddr; + +use jsonrpsee::core::{async_trait, client::Subscription}; +use jsonrpsee::proc_macros::rpc; +use jsonrpsee::server::{PendingSubscriptionSink, Server, SubscriptionMessage}; +use jsonrpsee::types::ErrorObjectOwned; +use jsonrpsee::ws_client::WsClientBuilder; +use jsonrpsee::ConnectionDetails; + +#[rpc(server, client)] +pub trait Rpc { + /// Raw method with connection ID. + #[method(name = "connectionIdMethod", raw_method)] + async fn raw_method(&self, first_param: usize, second_param: u16) -> Result; + + /// Normal method call example. + #[method(name = "normalMethod")] + fn normal_method(&self, first_param: usize, second_param: u16) -> Result; + + /// Subscriptions expose the connection ID on the subscription sink. + #[subscription(name = "subscribeSync" => "sync", item = usize)] + fn sub(&self, first_param: usize); +} + +pub struct RpcServerImpl; + +#[async_trait] +impl RpcServer for RpcServerImpl { + async fn raw_method( + &self, + connection_details: ConnectionDetails, + _first_param: usize, + _second_param: u16, + ) -> Result { + // Return the connection ID from which this method was called. + Ok(connection_details.id()) + } + + fn normal_method(&self, _first_param: usize, _second_param: u16) -> Result { + // The normal method does not have access to the connection ID. + Ok(usize::MAX) + } + + fn sub(&self, pending: PendingSubscriptionSink, _first_param: usize) { + tokio::spawn(async move { + // The connection ID can be obtained before or after accepting the subscription + let pending_connection_id = pending.connection_id(); + let sink = pending.accept().await.unwrap(); + let sink_connection_id = sink.connection_id(); + + assert_eq!(pending_connection_id, sink_connection_id); + + let msg = SubscriptionMessage::from_json(&sink_connection_id).unwrap(); + sink.send(msg).await.unwrap(); + }); + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); + + let server_addr = run_server().await?; + let url = format!("ws://{}", server_addr); + + let client = WsClientBuilder::default().build(&url).await?; + let connection_id_first = client.raw_method(1, 2).await.unwrap(); + + // Second call from the same connection ID. + assert_eq!(client.raw_method(1, 2).await.unwrap(), connection_id_first); + + // Second client will increment the connection ID. + let client_second = WsClientBuilder::default().build(&url).await?; + let connection_id_second = client_second.raw_method(1, 2).await.unwrap(); + assert_ne!(connection_id_first, connection_id_second); + + let mut sub: Subscription = RpcClient::sub(&client, 0).await.unwrap(); + assert_eq!(connection_id_first, sub.next().await.transpose().unwrap().unwrap()); + + let mut sub: Subscription = RpcClient::sub(&client_second, 0).await.unwrap(); + assert_eq!(connection_id_second, sub.next().await.transpose().unwrap().unwrap()); + + Ok(()) +} + +async fn run_server() -> anyhow::Result { + let server = Server::builder().build("127.0.0.1:0").await?; + + let addr = server.local_addr()?; + let handle = server.start(RpcServerImpl.into_rpc()); + + // In this example we don't care about doing shutdown so let's it run forever. + // You may use the `ServerHandle` to shut it down or manage it yourself. + tokio::spawn(handle.stopped()); + + Ok(addr) +} diff --git a/proc-macros/src/render_server.rs b/proc-macros/src/render_server.rs index 5c4996ebde..2be84da7c8 100644 --- a/proc-macros/src/render_server.rs +++ b/proc-macros/src/render_server.rs @@ -61,7 +61,15 @@ impl RpcDescription { fn render_methods(&self) -> Result { let methods = self.methods.iter().map(|method| { let docs = &method.docs; - let method_sig = &method.signature; + let mut method_sig = method.signature.clone(); + + if method.raw_method { + let context_ty = self.jrps_server_item(quote! { ConnectionDetails }); + // Add `ConnectionDetails` as the second parameter to the signature. + let context: syn::FnArg = syn::parse_quote!(connection_details: #context_ty); + method_sig.sig.inputs.insert(1, context); + } + quote! { #docs #method_sig @@ -132,23 +140,32 @@ impl RpcDescription { check_name(&rpc_method_name, rust_method_name.span()); - if method.signature.sig.asyncness.is_some() { + if method.raw_method { handle_register_result(quote! { - rpc.register_async_method(#rpc_method_name, |params, context| async move { + rpc.register_async_method_with_details(#rpc_method_name, |params, connection_details, context| async move { #parsing - #into_response::into_response(context.as_ref().#rust_method_name(#params_seq).await) + #into_response::into_response(context.as_ref().#rust_method_name(connection_details, #params_seq).await) }) }) } else { - let register_kind = - if method.blocking { quote!(register_blocking_method) } else { quote!(register_method) }; + if method.signature.sig.asyncness.is_some() { + handle_register_result(quote! { + rpc.register_async_method(#rpc_method_name, |params, context| async move { + #parsing + #into_response::into_response(context.as_ref().#rust_method_name(#params_seq).await) + }) + }) + } else { + let register_kind = + if method.blocking { quote!(register_blocking_method) } else { quote!(register_method) }; - handle_register_result(quote! { - rpc.#register_kind(#rpc_method_name, |params, context| { - #parsing - #into_response::into_response(context.#rust_method_name(#params_seq)) + handle_register_result(quote! { + rpc.#register_kind(#rpc_method_name, |params, context| { + #parsing + #into_response::into_response(context.#rust_method_name(#params_seq)) + }) }) - }) + } } }) .collect::>(); diff --git a/proc-macros/src/rpc_macro.rs b/proc-macros/src/rpc_macro.rs index e500c72281..4fe5368565 100644 --- a/proc-macros/src/rpc_macro.rs +++ b/proc-macros/src/rpc_macro.rs @@ -48,17 +48,19 @@ pub struct RpcMethod { pub returns: Option, pub signature: syn::TraitItemFn, pub aliases: Vec, + pub raw_method: bool, } impl RpcMethod { pub fn from_item(attr: Attribute, mut method: syn::TraitItemFn) -> syn::Result { - let [aliases, blocking, name, param_kind] = - AttributeMeta::parse(attr)?.retain(["aliases", "blocking", "name", "param_kind"])?; + let [aliases, blocking, name, param_kind, raw_method] = + AttributeMeta::parse(attr)?.retain(["aliases", "blocking", "name", "param_kind", "raw_method"])?; let aliases = parse_aliases(aliases)?; let blocking = optional(blocking, Argument::flag)?.is_some(); let name = name?.string()?; let param_kind = parse_param_kind(param_kind)?; + let raw_method = optional(raw_method, Argument::flag)?.is_some(); let sig = method.sig.clone(); let docs = extract_doc_comments(&method.attrs); @@ -98,7 +100,18 @@ impl RpcMethod { // We've analyzed attributes and don't need them anymore. method.attrs.clear(); - Ok(Self { aliases, blocking, name, params, param_kind, returns, signature: method, docs, deprecated }) + Ok(Self { + aliases, + blocking, + name, + params, + param_kind, + returns, + signature: method, + docs, + deprecated, + raw_method, + }) } } @@ -212,7 +225,6 @@ impl RpcDescription { let namespace = optional(namespace, Argument::string)?; let client_bounds = optional(client_bounds, Argument::group)?; let server_bounds = optional(server_bounds, Argument::group)?; - if !needs_server && !needs_client { return Err(syn::Error::new_spanned(&item.ident, "Either 'server' or 'client' attribute must be applied")); } @@ -260,6 +272,28 @@ impl RpcDescription { is_method = true; let method_data = RpcMethod::from_item(attr.clone(), method.clone())?; + + if method_data.blocking && method_data.raw_method { + return Err(syn::Error::new_spanned( + method, + "Methods cannot be blocking when used with `raw_method`; remove `blocking` attribute or `raw_method` attribute", + )); + } + + if !needs_server && method_data.raw_method { + return Err(syn::Error::new_spanned( + &item.ident, + "Attribute 'raw_method' must be specified with 'server'", + )); + } + + if method.sig.asyncness.is_none() && method_data.raw_method { + return Err(syn::Error::new_spanned( + method, + "Methods must be asynchronous when used with `raw_method`; use `async fn` instead of `fn`", + )); + } + methods.push(method_data); } if let Some(attr) = find_attr(&method.attrs, "subscription") { diff --git a/proc-macros/tests/ui/correct/server_with_raw_methods.rs b/proc-macros/tests/ui/correct/server_with_raw_methods.rs new file mode 100644 index 0000000000..e8959ce8f7 --- /dev/null +++ b/proc-macros/tests/ui/correct/server_with_raw_methods.rs @@ -0,0 +1,14 @@ +//! Example of using proc macro to generate working client. + +use jsonrpsee::{core::RpcResult, proc_macros::rpc, types::ErrorObjectOwned}; + +#[rpc(server)] +pub trait Rpc { + #[method(name = "foo", raw_method)] + async fn async_method(&self, param_a: u8, param_b: String) -> RpcResult; + + #[method(name = "bar")] + fn sync_method(&self) -> Result; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/method/method_blocking_raw_incompatible.rs b/proc-macros/tests/ui/incorrect/method/method_blocking_raw_incompatible.rs new file mode 100644 index 0000000000..b87213d9ff --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_blocking_raw_incompatible.rs @@ -0,0 +1,9 @@ +use jsonrpsee::proc_macros::rpc; + +#[rpc(server)] +pub trait BlockingMethodCannotBeRaw { + #[method(name = "a", blocking, raw_method)] + fn a(&self, param: Vec); +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/method/method_blocking_raw_incompatible.stderr b/proc-macros/tests/ui/incorrect/method/method_blocking_raw_incompatible.stderr new file mode 100644 index 0000000000..7658d0df09 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_blocking_raw_incompatible.stderr @@ -0,0 +1,6 @@ +error: Methods cannot be blocking when used with `raw_method`; remove `blocking` attribute or `raw_method` attribute + --> tests/ui/incorrect/method/method_blocking_raw_incompatible.rs:5:2 + | +5 | / #[method(name = "a", blocking, raw_method)] +6 | | fn a(&self, param: Vec); + | |________________________________^ diff --git a/proc-macros/tests/ui/incorrect/method/method_sync_raw_incompatible.rs b/proc-macros/tests/ui/incorrect/method/method_sync_raw_incompatible.rs new file mode 100644 index 0000000000..37751827f2 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_sync_raw_incompatible.rs @@ -0,0 +1,9 @@ +use jsonrpsee::proc_macros::rpc; + +#[rpc(server)] +pub trait SyncMethodCannotBeRaw { + #[method(name = "a", raw_method)] + fn a(&self, param: Vec) -> RpcResult; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/method/method_sync_raw_incompatible.stderr b/proc-macros/tests/ui/incorrect/method/method_sync_raw_incompatible.stderr new file mode 100644 index 0000000000..572faafa4c --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_sync_raw_incompatible.stderr @@ -0,0 +1,6 @@ +error: Methods must be asynchronous when used with `raw_method`; use `async fn` instead of `fn` + --> tests/ui/incorrect/method/method_sync_raw_incompatible.rs:5:2 + | +5 | / #[method(name = "a", raw_method)] +6 | | fn a(&self, param: Vec) -> RpcResult; + | |__________________________________________________^ diff --git a/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr b/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr index 8fecc8437a..fc19a2111c 100644 --- a/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr +++ b/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr @@ -1,5 +1,5 @@ -error: Unknown argument `magic`, expected one of: `aliases`, `blocking`, `name`, `param_kind` - --> $DIR/method_unexpected_field.rs:6:25 +error: Unknown argument `magic`, expected one of: `aliases`, `blocking`, `name`, `param_kind`, `raw_method` + --> tests/ui/incorrect/method/method_unexpected_field.rs:6:25 | 6 | #[method(name = "foo", magic = false)] | ^^^^^ diff --git a/server/src/middleware/rpc/layer/rpc_service.rs b/server/src/middleware/rpc/layer/rpc_service.rs index 061908801a..f5bae98ede 100644 --- a/server/src/middleware/rpc/layer/rpc_service.rs +++ b/server/src/middleware/rpc/layer/rpc_service.rs @@ -32,7 +32,7 @@ use std::sync::Arc; use crate::middleware::rpc::RpcServiceT; use futures_util::future::BoxFuture; use jsonrpsee_core::server::{ - BoundedSubscriptions, MethodCallback, MethodResponse, MethodSink, Methods, SubscriptionState, + BoundedSubscriptions, ConnectionDetails, MethodCallback, MethodResponse, MethodSink, Methods, SubscriptionState, }; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_types::error::{reject_too_many_subscriptions, ErrorCode}; @@ -94,6 +94,14 @@ impl<'a> RpcServiceT<'a> for RpcService { let fut = (callback)(id, params, conn_id, max_response_body_size); ResponseFuture::future(fut) } + MethodCallback::AsyncWithDetails(callback) => { + let params = params.into_owned(); + let id = id.into_owned(); + + // Note: Add the `Request::extensions` to the connection details when available here. + let fut = (callback)(id, params, ConnectionDetails::_new(conn_id), max_response_body_size); + ResponseFuture::future(fut) + } MethodCallback::Sync(callback) => { let rp = (callback)(id, params, max_response_body_size); ResponseFuture::ready(rp) diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 71765b9c76..88ed9ed799 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -137,6 +137,13 @@ pub async fn server() -> SocketAddr { let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _| "hello").unwrap(); + module + .register_async_method_with_details( + "raw_method", + |_, connection_details, _| async move { connection_details.id() }, + ) + .unwrap(); + module .register_async_method("slow_hello", |_, _| async { tokio::time::sleep(std::time::Duration::from_secs(1)).await; diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index b6991b6196..97a8739348 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -200,6 +200,25 @@ async fn ws_method_call_works_over_proxy_stream() { assert_eq!(&response, "hello"); } +#[tokio::test] +async fn raw_methods_with_different_ws_clients() { + init_logger(); + + let server_addr = server().await; + let server_url = format!("ws://{}", server_addr); + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + + // Connection ID does not change for the same client. + let connection_id: usize = client.request("raw_method", rpc_params![]).await.unwrap(); + let identical_connection_id: usize = client.request("raw_method", rpc_params![]).await.unwrap(); + assert_eq!(connection_id, identical_connection_id); + + // Connection ID is different for different clients. + let second_client = WsClientBuilder::default().build(&server_url).await.unwrap(); + let second_connection_id: usize = second_client.request("raw_method", rpc_params![]).await.unwrap(); + assert_ne!(connection_id, second_connection_id); +} + #[tokio::test] async fn ws_method_call_str_id_works() { init_logger(); diff --git a/tests/tests/proc_macros.rs b/tests/tests/proc_macros.rs index 182481d01e..d489d7a72d 100644 --- a/tests/tests/proc_macros.rs +++ b/tests/tests/proc_macros.rs @@ -49,6 +49,7 @@ mod rpc_impl { use jsonrpsee::core::{async_trait, SubscriptionResult}; use jsonrpsee::proc_macros::rpc; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; + use jsonrpsee::ConnectionDetails; pub struct CustomSubscriptionRet; @@ -74,6 +75,9 @@ mod rpc_impl { #[method(name = "bar")] fn sync_method(&self) -> Result; + #[method(name = "syncRaw", raw_method)] + async fn async_raw_method(&self) -> Result; + #[subscription(name = "sub", unsubscribe = "unsub", item = String)] async fn sub(&self) -> SubscriptionResult; #[subscription(name = "echo", unsubscribe = "unsubscribe_echo", aliases = ["alias_echo"], item = u32)] @@ -162,6 +166,10 @@ mod rpc_impl { Ok(10) } + async fn async_raw_method(&self, connection_details: ConnectionDetails) -> Result { + Ok(connection_details.id()) + } + async fn sub(&self, pending: PendingSubscriptionSink) -> SubscriptionResult { let sink = pending.accept().await?; sink.send("Response_A".into()).await?; @@ -239,6 +247,24 @@ async fn proc_macros_generic_ws_client_api() { assert_eq!(second_recv, 42); } +#[tokio::test] +async fn raw_methods_with_different_ws_clients() { + init_logger(); + + let server_addr = server().await; + let server_url = format!("ws://{}", server_addr); + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + + // Connection ID does not change for the same client. + let connection_id = client.async_raw_method().await.unwrap(); + assert_eq!(connection_id, client.async_raw_method().await.unwrap()); + + // Connection ID is different for different clients. + let second_client = WsClientBuilder::default().build(&server_url).await.unwrap(); + let second_connection_id = second_client.async_raw_method().await.unwrap(); + assert_ne!(connection_id, second_connection_id); +} + #[tokio::test] async fn macro_param_parsing() { let module = RpcServerImpl.into_rpc();