diff --git a/Cargo.lock b/Cargo.lock index 9811cb8cb1b5..d191338937ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4119,6 +4119,7 @@ name = "wasmtime-wasi-http" version = "0.0.1" dependencies = [ "anyhow", + "async-trait", "bytes", "http", "http-body", diff --git a/crates/wasi-http/Cargo.toml b/crates/wasi-http/Cargo.toml index e5b88e9f6ff3..858f7c80d6ef 100644 --- a/crates/wasi-http/Cargo.toml +++ b/crates/wasi-http/Cargo.toml @@ -10,6 +10,7 @@ readme = "readme.md" [dependencies] anyhow = { workspace = true } +async-trait = { workspace = true } bytes = "1.1.0" hyper = { version = "1.0.0-rc.3", features = ["full"] } tokio = { version = "1", default-features = false, features = ["net", "rt-multi-thread", "time"] } diff --git a/crates/wasi-http/src/http_impl.rs b/crates/wasi-http/src/http_impl.rs index 5b0e3862d909..1703933d206d 100644 --- a/crates/wasi-http/src/http_impl.rs +++ b/crates/wasi-http/src/http_impl.rs @@ -1,22 +1,15 @@ -use crate::r#struct::ActiveResponse; -use crate::r#struct::{Stream, WasiHttp}; +use crate::r#struct::{ActiveResponse, WasiHttp}; use crate::types::{RequestOptions, Scheme}; -#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] -use anyhow::anyhow; -use anyhow::bail; -use bytes::{BufMut, Bytes, BytesMut}; +use anyhow::{bail, Context}; +use bytes::{BufMut, BytesMut}; +use http::Uri; use http_body_util::{BodyExt, Full}; use hyper::Method; use hyper::Request; use std::collections::HashMap; -#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] -use std::sync::Arc; use std::time::Duration; -use tokio::net::TcpStream; use tokio::runtime::Runtime; use tokio::time::timeout; -#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] -use tokio_rustls::rustls::{self, OwnedTrustAnchor}; impl crate::default_outgoing_http::Host for WasiHttp { fn handle( @@ -42,39 +35,12 @@ impl crate::default_outgoing_http::Host for WasiHttp { } } -fn port_for_scheme(scheme: &Option) -> &str { - match scheme { - Some(s) => match s { - Scheme::Http => ":80", - Scheme::Https => ":443", - // This should never happen. - _ => panic!("unsupported scheme!"), - }, - None => ":443", - } -} - impl WasiHttp { async fn handle_async( &mut self, request_id: crate::default_outgoing_http::OutgoingRequest, options: Option, ) -> wasmtime::Result { - let opts = options.unwrap_or( - // TODO: Configurable defaults here? - RequestOptions { - connect_timeout_ms: Some(600 * 1000), - first_byte_timeout_ms: Some(600 * 1000), - between_bytes_timeout_ms: Some(600 * 1000), - }, - ); - let connect_timeout = - Duration::from_millis(opts.connect_timeout_ms.unwrap_or(600 * 1000).into()); - let first_bytes_timeout = - Duration::from_millis(opts.first_byte_timeout_ms.unwrap_or(600 * 1000).into()); - let between_bytes_timeout = - Duration::from_millis(opts.between_bytes_timeout_ms.unwrap_or(600 * 1000).into()); - let request = match self.requests.get(&request_id) { Some(r) => r, None => bail!("not found!"), @@ -92,104 +58,66 @@ impl WasiHttp { crate::types::Method::Patch => Method::PATCH, _ => bail!("unknown method!"), }; + let mut uri = Uri::builder() + .authority(request.authority.as_str()) + // NOTE: this is broken, but will be fixed by `wasi-http` dependency update + .path_and_query(request.path.to_owned() + &request.query); + match &request.scheme { + Some(Scheme::Http) => uri = uri.scheme("http"), + Some(Scheme::Https) => uri = uri.scheme("https"), + Some(scheme) => bail!("unsupported scheme `{scheme:?}`"), + _ => {} + } + // NOTE: This does not belong here, the complete struct should have been constructed + // on request creation + let uri = uri.build().context("failed to build URI")?; - let scheme = match request.scheme.as_ref().unwrap_or(&Scheme::Https) { - Scheme::Http => "http://", - Scheme::Https => "https://", - // TODO: this is wrong, fix this. - _ => panic!("Unsupported scheme!"), - }; + let mut req = Request::builder() + .method(method) + .uri(uri) + .header(hyper::header::HOST, &request.authority); + for (key, val) in request.headers.iter() { + for item in val { + req = req.header(key, item.clone()); + } + } + let body = self + .streams + .get(&request.body) + .map(|stream| stream.clone().into()) + .unwrap_or_default(); + let req = req.body(Full::new(body))?; - // Largely adapted from https://hyper.rs/guides/1/client/basic/ - let authority = match request.authority.find(":") { - Some(_) => request.authority.clone(), - None => request.authority.clone() + port_for_scheme(&request.scheme), + let connect_timeout = if let Some(RequestOptions { + connect_timeout_ms: Some(connect_timeout_ms), + .. + }) = options + { + Duration::from_millis(connect_timeout_ms.into()) + } else { + // TODO: Configurable default + Duration::from_millis(600) }; - let mut sender = if scheme == "https://" { - #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] - { - let stream = TcpStream::connect(authority.clone()).await?; - //TODO: uncomment this code and make the tls implementation a feature decision. - //let connector = tokio_native_tls::native_tls::TlsConnector::builder().build()?; - //let connector = tokio_native_tls::TlsConnector::from(connector); - //let host = authority.split(":").next().unwrap_or(&authority); - //let stream = connector.connect(&host, stream).await?; - - // derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs - let mut root_cert_store = rustls::RootCertStore::empty(); - root_cert_store.add_server_trust_anchors( - webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }), - ); - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); - let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); - let mut parts = authority.split(":"); - let host = parts.next().unwrap_or(&authority); - let domain = - rustls::ServerName::try_from(host).map_err(|_| anyhow!("invalid dnsname"))?; - let stream = connector.connect(domain, stream).await?; - let t = timeout( - connect_timeout, - hyper::client::conn::http1::handshake(stream), - ) - .await?; - let (s, conn) = t?; - tokio::task::spawn(async move { - if let Err(err) = conn.await { - println!("Connection failed: {:?}", err); - } - }); - s - } - #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))] - bail!("unsupported architecture for SSL") + let first_byte_timeout = if let Some(RequestOptions { + first_byte_timeout_ms: Some(first_byte_timeout_ms), + .. + }) = options + { + Duration::from_millis(first_byte_timeout_ms.into()) } else { - let tcp = TcpStream::connect(authority).await?; - let t = timeout(connect_timeout, hyper::client::conn::http1::handshake(tcp)).await?; - let (s, conn) = t?; - tokio::task::spawn(async move { - if let Err(err) = conn.await { - println!("Connection failed: {:?}", err); - } - }); - s + // TODO: Configurable default + Duration::from_millis(600) }; - let url = scheme.to_owned() + &request.authority + &request.path + &request.query; - - let mut call = Request::builder() - .method(method) - .uri(url) - .header(hyper::header::HOST, request.authority.as_str()); - - for (key, val) in request.headers.iter() { - for item in val { - call = call.header(key, item.clone()); - } - } + let res = self + .outgoing_handler + .handle(req, connect_timeout, first_byte_timeout) + .await?; let response_id = self.response_id_base; self.response_id_base = self.response_id_base + 1; let mut response = ActiveResponse::new(response_id); - let body = Full::::new( - self.streams - .get(&request.body) - .unwrap_or(&Stream::default()) - .data - .clone() - .freeze(), - ); - let t = timeout(first_bytes_timeout, sender.send_request(call.body(body)?)).await?; - let mut res = t?; response.status = res.status().try_into()?; for (key, value) in res.headers().iter() { let mut vec = std::vec::Vec::new(); @@ -198,8 +126,21 @@ impl WasiHttp { .response_headers .insert(key.as_str().to_string(), vec); } + + let between_bytes_timeout = if let Some(RequestOptions { + between_bytes_timeout_ms: Some(between_bytes_timeout_ms), + .. + }) = options + { + Duration::from_millis(between_bytes_timeout_ms.into()) + } else { + // TODO: Configurable default + Duration::from_millis(600) + }; + let body = res.into_body(); + let mut body = body.lock().await; let mut buf = BytesMut::new(); - while let Some(next) = timeout(between_bytes_timeout, res.frame()).await? { + while let Some(next) = timeout(between_bytes_timeout, body.frame()).await? { let frame = next?; if let Some(chunk) = frame.data_ref() { buf.put(chunk.clone()); diff --git a/crates/wasi-http/src/struct.rs b/crates/wasi-http/src/struct.rs index 574be6c8e5fa..c1d5a3990688 100644 --- a/crates/wasi-http/src/struct.rs +++ b/crates/wasi-http/src/struct.rs @@ -1,6 +1,16 @@ use crate::types::{Method, Scheme}; -use bytes::{BufMut, Bytes, BytesMut}; + use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::{bail, Context}; +use async_trait::async_trait; +use bytes::{BufMut, Bytes, BytesMut}; +use http_body_util::Full; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tokio::time; #[derive(Clone, Default)] pub struct Stream { @@ -8,8 +18,15 @@ pub struct Stream { pub data: BytesMut, } +impl From for Bytes { + fn from(Stream { data, .. }: Stream) -> Self { + data.freeze() + } +} + #[derive(Clone)] -pub struct WasiHttp { +pub struct WasiHttp>> { + pub outgoing_handler: Arc>>, pub request_id_base: u32, pub response_id_base: u32, pub fields_id_base: u32, @@ -89,17 +106,126 @@ impl From for Stream { } } -impl WasiHttp { - pub fn new() -> Self { +#[async_trait] +pub trait OutgoingHandler: Sync + Send { + type Body; + + async fn handle( + &self, + request: http::Request>, + connect_timeout: Duration, + first_byte_timeout: Duration, + ) -> anyhow::Result>; +} + +/// Default [OutgoingHandler], which relies on Tokio and Hyper to handle both HTTP and HTTPS +/// requests. +pub struct DefaultOutgoingHandler; + +#[async_trait] +impl OutgoingHandler for DefaultOutgoingHandler { + type Body = Arc>; + + async fn handle( + &self, + request: http::Request>, + connect_timeout: Duration, + first_byte_timeout: Duration, + ) -> anyhow::Result> { + let uri = request.uri(); + let authority = uri.authority().context("unknown authority")?; + let stream = TcpStream::connect(authority.as_str()) + .await + .with_context(|| format!("failed to connect to `{authority}`"))?; + let mut sender = match uri.scheme_str() { + Some("http") => { + let (sender, conn) = time::timeout( + connect_timeout, + hyper::client::conn::http1::handshake(stream), + ) + .await + .context("connection timed out")? + .context("handshake failed")?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {err:?}"); + } + }); + sender + } + #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))] + None | Some("https") => bail!("unsupported architecture for SSL"), + #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] + None | Some("https") => { + use tokio_rustls::rustls::{self, OwnedTrustAnchor}; + + //TODO: uncomment this code and make the tls implementation a feature decision. + //let connector = tokio_native_tls::native_tls::TlsConnector::builder().build()?; + //let connector = tokio_native_tls::TlsConnector::from(connector); + //let host = authority.split(":").next().unwrap_or(&authority); + //let stream = connector.connect(&host, stream).await?; + + // derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs + let mut root_cert_store = rustls::RootCertStore::empty(); + root_cert_store.add_server_trust_anchors( + webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }), + ); + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); + let domain = + rustls::ServerName::try_from(authority.host()).context("invalid dnsname")?; + let stream = connector.connect(domain, stream).await?; + let (sender, conn) = time::timeout( + connect_timeout, + hyper::client::conn::http1::handshake(stream), + ) + .await + .context("connection timed out")? + .context("handshake failed")?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {err:?}"); + } + }); + sender + } + Some(scheme) => bail!("unsupported scheme `{scheme}`"), + }; + time::timeout(first_byte_timeout, sender.send_request(request)) + .await + .context("request timed out")? + .map(|res| res.map(|b| Arc::new(Mutex::new(b)))) + .context("failed to send request") + } +} + +impl Default for WasiHttp { + fn default() -> Self { Self { + outgoing_handler: Arc::new(Box::new(DefaultOutgoingHandler)), request_id_base: 1, response_id_base: 1, fields_id_base: 1, streams_id_base: 1, - requests: HashMap::new(), - responses: HashMap::new(), - fields: HashMap::new(), - streams: HashMap::new(), + requests: HashMap::default(), + responses: HashMap::default(), + fields: HashMap::default(), + streams: HashMap::default(), } } } + +impl WasiHttp { + pub fn new() -> Self { + Self::default() + } +}