diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 0b868776c..5282863d8 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -118,8 +118,13 @@ path = "src/tracing/server.rs" required-features = ["tracing"] [[bin]] -name = "uds-client" -path = "src/uds/client.rs" +name = "uds-client-standard" +path = "src/uds/client_standard.rs" +required-features = ["uds"] + +[[bin]] +name = "uds-client-with-connector" +path = "src/uds/client_with_connector.rs" required-features = ["uds"] [[bin]] diff --git a/examples/src/uds/client_standard.rs b/examples/src/uds/client_standard.rs new file mode 100644 index 000000000..264d41cfc --- /dev/null +++ b/examples/src/uds/client_standard.rs @@ -0,0 +1,34 @@ +#![cfg_attr(not(unix), allow(unused_imports))] + +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +use hello_world::{greeter_client::GreeterClient, HelloRequest}; + +#[cfg(unix)] +#[tokio::main] +async fn main() -> Result<(), Box> { + // Unix socket URI follows [RFC-3986](https://datatracker.ietf.org/doc/html/rfc3986) + // which is aligned with [the gRPC naming convention](https://github.com/grpc/grpc/blob/master/doc/naming.md). + // - unix:relative_path + // - unix:///absolute_path + let path = "unix:///tmp/tonic/helloworld"; + + let mut client = GreeterClient::connect(path).await?; + + let request = tonic::Request::new(HelloRequest { + name: "Tonic".into(), + }); + + let response = client.say_hello(request).await?; + + println!("RESPONSE={:?}", response); + + Ok(()) +} + +#[cfg(not(unix))] +fn main() { + panic!("The `uds` example only works on unix systems!"); +} diff --git a/examples/src/uds/client.rs b/examples/src/uds/client_with_connector.rs similarity index 100% rename from examples/src/uds/client.rs rename to examples/src/uds/client_with_connector.rs diff --git a/tests/default_stubs/Cargo.toml b/tests/default_stubs/Cargo.toml index 9a786f116..2207e8fb2 100644 --- a/tests/default_stubs/Cargo.toml +++ b/tests/default_stubs/Cargo.toml @@ -8,6 +8,7 @@ name = "default_stubs" tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} tokio-stream = {version = "0.1", features = ["net"]} prost = "0.13" +rand = "0.8" tonic = {path = "../../tonic"} [build-dependencies] diff --git a/tests/default_stubs/src/test_defaults.rs b/tests/default_stubs/src/test_defaults.rs index 32bed1be1..7b84d9e80 100644 --- a/tests/default_stubs/src/test_defaults.rs +++ b/tests/default_stubs/src/test_defaults.rs @@ -1,8 +1,13 @@ #![allow(unused_imports)] +use crate::test_client::TestClient; use crate::*; +use rand::Rng as _; +use std::env; +use std::fs; use std::net::SocketAddr; use tokio::net::TcpListener; +use tonic::transport::Channel; use tonic::transport::Server; #[cfg(test)] @@ -10,16 +15,14 @@ fn echo_requests_iter() -> impl Stream { tokio_stream::iter(1..usize::MAX).map(|_| ()) } -#[tokio::test()] -async fn test_default_stubs() { +#[cfg(test)] +async fn test_default_stubs( + mut client: TestClient, + mut client_default_stubs: TestClient, +) { use tonic::Code; - let addrs = run_services_in_background().await; - // First validate pre-existing functionality (trait has no default implementation, we explicitly return PermissionDenied in lib.rs). - let mut client = test_client::TestClient::connect(format!("http://{}", addrs.0)) - .await - .unwrap(); assert_eq!( client.unary(()).await.unwrap_err().code(), Code::PermissionDenied @@ -46,9 +49,6 @@ async fn test_default_stubs() { ); // Then validate opt-in new functionality (trait has default implementation of returning Unimplemented). - let mut client_default_stubs = test_client::TestClient::connect(format!("http://{}", addrs.1)) - .await - .unwrap(); assert_eq!( client_default_stubs.unary(()).await.unwrap_err().code(), Code::Unimplemented @@ -79,6 +79,27 @@ async fn test_default_stubs() { ); } +#[tokio::test()] +async fn test_default_stubs_tcp() { + let addrs = run_services_in_background().await; + let client = test_client::TestClient::connect(format!("http://{}", addrs.0)) + .await + .unwrap(); + let client_default_stubs = test_client::TestClient::connect(format!("http://{}", addrs.1)) + .await + .unwrap(); + test_default_stubs(client, client_default_stubs).await; +} + +#[tokio::test()] +#[cfg(not(target_os = "windows"))] +async fn test_default_stubs_uds() { + let addrs = run_services_in_background_uds().await; + let client = test_client::TestClient::connect(addrs.0).await.unwrap(); + let client_default_stubs = test_client::TestClient::connect(addrs.1).await.unwrap(); + test_default_stubs(client, client_default_stubs).await; +} + #[cfg(test)] async fn run_services_in_background() -> (SocketAddr, SocketAddr) { let svc = test_server::TestServer::new(Svc {}); @@ -110,3 +131,48 @@ async fn run_services_in_background() -> (SocketAddr, SocketAddr) { (addr, addr_default_stubs) } + +#[cfg(all(test, not(target_os = "windows")))] +async fn run_services_in_background_uds() -> (String, String) { + use tokio::net::UnixListener; + + let svc = test_server::TestServer::new(Svc {}); + let svc_default_stubs = test_default_server::TestDefaultServer::new(Svc {}); + + let mut rng = rand::thread_rng(); + let suffix: String = (0..8) + .map(|_| rng.sample(rand::distributions::Alphanumeric) as char) + .collect(); + let tmpdir = fs::canonicalize(env::temp_dir()) + .unwrap() + .join(format!("tonic_test_{}", suffix)); + fs::create_dir(&tmpdir).unwrap(); + + let uds_filepath = tmpdir.join("impl.sock").to_str().unwrap().to_string(); + let listener = UnixListener::bind(uds_filepath.as_str()).unwrap(); + let uds_addr = format!("unix://{}", uds_filepath); + + let uds_default_stubs_filepath = tmpdir.join("stub.sock").to_str().unwrap().to_string(); + let listener_default_stubs = UnixListener::bind(uds_default_stubs_filepath.as_str()).unwrap(); + let uds_default_stubs_addr = format!("unix://{}", uds_default_stubs_filepath); + + tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::UnixListenerStream::new(listener)) + .await + .unwrap(); + }); + + tokio::spawn(async move { + Server::builder() + .add_service(svc_default_stubs) + .serve_with_incoming(tokio_stream::wrappers::UnixListenerStream::new( + listener_default_stubs, + )) + .await + .unwrap(); + }); + + (uds_addr, uds_default_stubs_addr) +} diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 16934f34f..3a919e93d 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -1,23 +1,33 @@ #[cfg(feature = "_tls-any")] use super::service::TlsConnector; use super::service::{self, Executor, SharedExec}; +use super::uds_connector::UdsConnector; use super::Channel; #[cfg(feature = "_tls-any")] use super::ClientTlsConfig; +#[cfg(feature = "_tls-any")] +use crate::transport::error; use crate::transport::Error; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; use hyper::rt; use hyper_util::client::legacy::connect::HttpConnector; -use std::{fmt, future::Future, net::IpAddr, pin::Pin, str::FromStr, time::Duration}; +use std::{fmt, future::Future, net::IpAddr, pin::Pin, str, str::FromStr, time::Duration}; use tower_service::Service; +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) enum EndpointType { + Uri(Uri), + Uds(String), +} + /// Channel builder. /// /// This struct is used to build and configure HTTP/2 channels. #[derive(Clone)] pub struct Endpoint { - pub(crate) uri: Uri, + pub(crate) uri: EndpointType, + fallback_uri: Uri, pub(crate) origin: Option, pub(crate) user_agent: Option, pub(crate) timeout: Option, @@ -51,13 +61,68 @@ impl Endpoint { { let me = dst.try_into().map_err(|e| Error::from_source(e.into()))?; #[cfg(feature = "_tls-any")] - if me.uri.scheme() == Some(&http::uri::Scheme::HTTPS) { - return me.tls_config(ClientTlsConfig::new().with_enabled_roots()); + if let EndpointType::Uri(uri) = &me.uri { + if uri.scheme() == Some(&http::uri::Scheme::HTTPS) { + return me.tls_config(ClientTlsConfig::new().with_enabled_roots()); + } } - Ok(me) } + fn new_uri(uri: Uri) -> Self { + Self { + uri: EndpointType::Uri(uri.clone()), + fallback_uri: uri, + origin: None, + user_agent: None, + concurrency_limit: None, + rate_limit: None, + timeout: None, + #[cfg(feature = "_tls-any")] + tls: None, + buffer_size: None, + init_stream_window_size: None, + init_connection_window_size: None, + tcp_keepalive: None, + tcp_nodelay: true, + http2_keep_alive_interval: None, + http2_keep_alive_timeout: None, + http2_keep_alive_while_idle: None, + http2_max_header_list_size: None, + connect_timeout: None, + http2_adaptive_window: None, + executor: SharedExec::tokio(), + local_address: None, + } + } + + fn new_uds(uds_filepath: &str) -> Self { + Self { + uri: EndpointType::Uds(uds_filepath.to_string()), + fallback_uri: Uri::from_static("http://tonic"), + origin: None, + user_agent: None, + concurrency_limit: None, + rate_limit: None, + timeout: None, + #[cfg(feature = "_tls-any")] + tls: None, + buffer_size: None, + init_stream_window_size: None, + init_connection_window_size: None, + tcp_keepalive: None, + tcp_nodelay: true, + http2_keep_alive_interval: None, + http2_keep_alive_timeout: None, + http2_keep_alive_while_idle: None, + http2_max_header_list_size: None, + connect_timeout: None, + http2_adaptive_window: None, + executor: SharedExec::tokio(), + local_address: None, + } + } + /// Convert an `Endpoint` from a static string. /// /// # Panics @@ -69,8 +134,16 @@ impl Endpoint { /// Endpoint::from_static("https://example.com"); /// ``` pub fn from_static(s: &'static str) -> Self { - let uri = Uri::from_static(s); - Self::from(uri) + if s.starts_with("unix:") { + let uds_filepath = s + .strip_prefix("unix://") + .or_else(|| s.strip_prefix("unix:")) + .expect("Invalid unix domain socket URI"); + Self::new_uds(uds_filepath) + } else { + let uri = Uri::from_static(s); + Self::new_uri(uri) + } } /// Convert an `Endpoint` from shared bytes. @@ -80,8 +153,19 @@ impl Endpoint { /// Endpoint::from_shared("https://example.com".to_string()); /// ``` pub fn from_shared(s: impl Into) -> Result { - let uri = Uri::from_maybe_shared(s.into()).map_err(|e| Error::new_invalid_uri().with(e))?; - Ok(Self::from(uri)) + let s = str::from_utf8(&s.into()) + .map_err(|e| Error::new_invalid_uri().with(e))? + .to_string(); + if s.starts_with("unix:") { + let uds_filepath = s + .strip_prefix("unix://") + .or_else(|| s.strip_prefix("unix:")) + .ok_or(Error::new_invalid_uri())?; + Ok(Self::new_uds(uds_filepath)) + } else { + let uri = Uri::from_maybe_shared(s).map_err(|e| Error::new_invalid_uri().with(e))?; + Ok(Self::from(uri)) + } } /// Set a custom user-agent header. @@ -247,14 +331,17 @@ impl Endpoint { /// Configures TLS for the endpoint. #[cfg(feature = "_tls-any")] pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result { - Ok(Endpoint { - tls: Some( - tls_config - .into_tls_connector(&self.uri) - .map_err(Error::from_source)?, - ), - ..self - }) + match &self.uri { + EndpointType::Uri(uri) => Ok(Endpoint { + tls: Some( + tls_config + .into_tls_connector(uri) + .map_err(Error::from_source)?, + ), + ..self + }), + EndpointType::Uds(_) => Err(Error::new(error::Kind::InvalidTlsConfigForUds)), + } } /// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default. @@ -346,9 +433,18 @@ impl Endpoint { self.connector(http) } + pub(crate) fn uds_connector(&self, uds_filepath: &str) -> service::Connector { + self.connector(UdsConnector::new(uds_filepath)) + } + /// Create a channel from this config. pub async fn connect(&self) -> Result { - Channel::connect(self.http_connector(), self.clone()).await + match &self.uri { + EndpointType::Uri(_) => Channel::connect(self.http_connector(), self.clone()).await, + EndpointType::Uds(uds_filepath) => { + Channel::connect(self.uds_connector(uds_filepath.as_str()), self.clone()).await + } + } } /// Create a channel from this config. @@ -356,7 +452,12 @@ impl Endpoint { /// The channel returned by this method does not attempt to connect to the endpoint until first /// use. pub fn connect_lazy(&self) -> Channel { - Channel::new(self.http_connector(), self.clone()) + match &self.uri { + EndpointType::Uri(_) => Channel::new(self.http_connector(), self.clone()), + EndpointType::Uds(uds_filepath) => { + Channel::new(self.uds_connector(uds_filepath.as_str()), self.clone()) + } + } } /// Connect with a custom connector. @@ -418,7 +519,10 @@ impl Endpoint { /// assert_eq!(endpoint.uri(), &Uri::from_static("https://example.com")); /// ``` pub fn uri(&self) -> &Uri { - &self.uri + match &self.uri { + EndpointType::Uri(uri) => uri, + EndpointType::Uds(_) => &self.fallback_uri, + } } /// Get the value of `TCP_NODELAY` option for accepted connections. @@ -443,29 +547,7 @@ impl Endpoint { impl From for Endpoint { fn from(uri: Uri) -> Self { - Self { - uri, - origin: None, - user_agent: None, - concurrency_limit: None, - rate_limit: None, - timeout: None, - #[cfg(feature = "_tls-any")] - tls: None, - buffer_size: None, - init_stream_window_size: None, - init_connection_window_size: None, - tcp_keepalive: None, - tcp_nodelay: true, - http2_keep_alive_interval: None, - http2_keep_alive_timeout: None, - http2_keep_alive_while_idle: None, - http2_max_header_list_size: None, - connect_timeout: None, - http2_adaptive_window: None, - executor: SharedExec::tokio(), - local_address: None, - } + Self::new_uri(uri) } } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 85f1ee51c..fe8458fab 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -4,6 +4,7 @@ mod endpoint; pub(crate) mod service; #[cfg(feature = "_tls-any")] mod tls; +mod uds_connector; pub use self::service::Change; pub use endpoint::Endpoint; diff --git a/tonic/src/transport/channel/service/connection.rs b/tonic/src/transport/channel/service/connection.rs index 4e84ac92e..c4ce9408e 100644 --- a/tonic/src/transport/channel/service/connection.rs +++ b/tonic/src/transport/channel/service/connection.rs @@ -57,7 +57,7 @@ impl Connection { let stack = ServiceBuilder::new() .layer_fn(|s| { - let origin = endpoint.origin.as_ref().unwrap_or(&endpoint.uri).clone(); + let origin = endpoint.origin.as_ref().unwrap_or(endpoint.uri()).clone(); AddOrigin::new(s, origin) }) @@ -70,7 +70,7 @@ impl Connection { let make_service = MakeSendRequestService::new(connector, endpoint.executor.clone(), settings); - let conn = Reconnect::new(make_service, endpoint.uri.clone(), is_lazy); + let conn = Reconnect::new(make_service, endpoint.uri().clone(), is_lazy); Self { inner: BoxService::new(stack.layer(conn)), diff --git a/tonic/src/transport/channel/uds_connector.rs b/tonic/src/transport/channel/uds_connector.rs new file mode 100644 index 000000000..a67c4a47c --- /dev/null +++ b/tonic/src/transport/channel/uds_connector.rs @@ -0,0 +1,80 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use http::Uri; +use hyper_util::rt::TokioIo; + +use tower::Service; + +use crate::status::ConnectError; + +#[cfg(not(target_os = "windows"))] +use tokio::net::UnixStream; + +#[cfg(not(target_os = "windows"))] +async fn connect_uds(uds_path: String) -> Result { + UnixStream::connect(uds_path) + .await + .map_err(|err| ConnectError(From::from(err))) +} + +// Dummy type that will allow us to compile and match trait bounds +// but is never used. +#[cfg(target_os = "windows")] +#[allow(dead_code)] +type UnixStream = tokio::io::DuplexStream; + +#[cfg(target_os = "windows")] +async fn connect_uds(_uds_path: String) -> Result { + Err(ConnectError( + "uds connections are not allowed on windows".into(), + )) +} + +pub(crate) struct UdsConnector { + uds_filepath: String, +} + +impl UdsConnector { + pub(crate) fn new(uds_filepath: &str) -> Self { + UdsConnector { + uds_filepath: uds_filepath.to_string(), + } + } +} + +impl Service for UdsConnector { + type Response = TokioIo; + type Error = ConnectError; + type Future = UdsConnecting; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: Uri) -> Self::Future { + let uds_path = self.uds_filepath.clone(); + let fut = async move { + let stream = connect_uds(uds_path).await?; + Ok(TokioIo::new(stream)) + }; + UdsConnecting { + inner: Box::pin(fut), + } + } +} + +type ConnectResult = Result, ConnectError>; + +pub(crate) struct UdsConnecting { + inner: Pin + Send>>, +} + +impl Future for UdsConnecting { + type Output = ConnectResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.get_mut().inner.as_mut().poll(cx) + } +} diff --git a/tonic/src/transport/error.rs b/tonic/src/transport/error.rs index cdc7a6c54..31b317521 100644 --- a/tonic/src/transport/error.rs +++ b/tonic/src/transport/error.rs @@ -19,6 +19,8 @@ pub(crate) enum Kind { InvalidUri, #[cfg(feature = "channel")] InvalidUserAgent, + #[cfg(all(feature = "_tls-any", feature = "channel"))] + InvalidTlsConfigForUds, } impl Error { @@ -54,6 +56,8 @@ impl Error { Kind::InvalidUri => "invalid URI", #[cfg(feature = "channel")] Kind::InvalidUserAgent => "user agent is not a valid header value", + #[cfg(all(feature = "_tls-any", feature = "channel"))] + Kind::InvalidTlsConfigForUds => "cannot apply TLS config for unix domain socket", } } }