diff --git a/Cargo.toml b/Cargo.toml index 95dc2dd..759a39e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "redis-async" -version = "0.8.1" +version = "0.9.0" authors = ["Ben Ashford "] license = "MIT/Apache-2.0" readme = "README.md" @@ -10,30 +10,35 @@ keywords = ["redis", "tokio"] edition = "2018" [dependencies] +async-global-executor = { version = "1.4", optional = true } +async-net = { version = "1.5", optional = true } bytes_05 = { package = "bytes", version = "0.5", optional = true } -bytes_06 = { package = "bytes", version = "^0.6.0", optional = true } bytes_10 = { package = "bytes", version = "1.0", optional = true } log = "^0.4.11" +lwactors = "0.2" futures-channel = "^0.3.7" futures-sink = "^0.3.7" -futures-util = "^0.3.7" +futures-util = { version = "^0.3.7", features = ["sink"] } +thiserror = "1.0" tokio_02 = { package = "tokio", version = "0.2", features = ["rt-core", "net", "time"], optional = true} -tokio_03 = { package = "tokio", version = "^0.3.2", features = ["rt", "net", "time"], optional = true } tokio_10 = { package = "tokio", version = "1.0", features = ["rt", "net", "time"], optional = true } tokio-util_03 = { package = "tokio-util", version = "0.3", features = ["codec"], optional = true } -tokio-util_05 = { package = "tokio-util", version = "^0.5", features = ["codec"], optional = true } tokio-util_06 = { package = "tokio-util", version = "0.6", features = ["codec"], optional = true } [dev-dependencies] env_logger = "^0.8.1" futures = "^0.3.7" +async-std = { version = "1.8", features = ["attributes"] } tokio_02 = { package = "tokio", version = "0.2", features = ["full"] } -tokio_03 = { package = "tokio", version = "^0.3.2", features = ["full"] } tokio_10 = { package = "tokio", version = "1.0", features = ["full"] } [features] -default = ["tokio03"] +default = ["tokio10"] -tokio02 = ["bytes_05", "tokio_02", "tokio-util_03"] -tokio03 = ["bytes_06", "tokio_03", "tokio-util_05"] -tokio10 = ["bytes_10", "tokio_10", "tokio-util_06"] \ No newline at end of file +async-std18 = ["bytes_10", "async-net", "async-global-executor", "with_async_std"] +tokio02 = ["bytes_05", "tokio_02", "tokio-util_03", "tokio_codec", "lwactors/with_tokio02", "with_tokio"] +tokio10 = ["bytes_10", "tokio_10", "tokio-util_06", "tokio_codec", "lwactors/with_tokio10", "with_tokio"] + +tokio_codec = [] +with_async_std = ["lwactors/with_async_global_executor14"] +with_tokio = ["lwactors/tokio10"] \ No newline at end of file diff --git a/examples/monitor.rs b/examples/monitor.rs index 3256d56..8769e6c 100644 --- a/examples/monitor.rs +++ b/examples/monitor.rs @@ -11,9 +11,6 @@ #[cfg(feature = "tokio02")] extern crate tokio_02 as tokio; -#[cfg(feature = "tokio03")] -extern crate tokio_03 as tokio; - #[cfg(feature = "tokio10")] extern crate tokio_10 as tokio; @@ -23,8 +20,19 @@ use futures::{sink::SinkExt, stream::StreamExt}; use redis_async::{client, resp_array}; +#[cfg(feature = "with_tokio")] #[tokio::main] async fn main() { + do_main().await; +} + +#[cfg(feature = "with_async_std")] +#[async_std::main] +async fn main() { + do_main().await; +} + +async fn do_main() { let addr = env::args() .nth(1) .unwrap_or_else(|| "127.0.0.1:6379".to_string()) diff --git a/examples/psubscribe.rs b/examples/psubscribe.rs index f8c47a7..10f1c2f 100644 --- a/examples/psubscribe.rs +++ b/examples/psubscribe.rs @@ -11,9 +11,6 @@ #[cfg(feature = "tokio02")] extern crate tokio_02 as tokio; -#[cfg(feature = "tokio03")] -extern crate tokio_03 as tokio; - #[cfg(feature = "tokio10")] extern crate tokio_10 as tokio; @@ -21,21 +18,34 @@ use std::env; use futures::StreamExt; -use redis_async::{client, resp::FromResp}; +use redis_async::{client::ConnectionBuilder, protocol::FromResp}; +#[cfg(feature = "with_tokio")] #[tokio::main] async fn main() { + do_main().await; +} + +#[cfg(feature = "with_async_std")] +#[async_std::main] +async fn main() { + do_main().await; +} + +async fn do_main() { env_logger::init(); let topic = env::args().nth(1).unwrap_or_else(|| "test.*".to_string()); + let addr = env::args() .nth(2) - .unwrap_or_else(|| "127.0.0.1:6379".to_string()) - .parse() - .unwrap(); + .unwrap_or_else(|| "127.0.0.1:6379".to_string()); - let pubsub_con = client::pubsub_connect(addr) + let pubsub_con = ConnectionBuilder::new(addr) + .expect("Cannot parse address") + .pubsub_connect() .await - .expect("Cannot connect to Redis"); + .expect("Cannot open connection"); + let mut msgs = pubsub_con .psubscribe(&topic) .await diff --git a/examples/realistic.rs b/examples/realistic.rs index 6a799f6..4d3f6e0 100644 --- a/examples/realistic.rs +++ b/examples/realistic.rs @@ -11,9 +11,6 @@ #[cfg(feature = "tokio02")] extern crate tokio_02 as tokio; -#[cfg(feature = "tokio03")] -extern crate tokio_03 as tokio; - #[cfg(feature = "tokio10")] extern crate tokio_10 as tokio; @@ -23,22 +20,34 @@ use futures_util::future; // use futures::{future, Future}; -use redis_async::{client, resp_array}; +use redis_async::{client::ConnectionBuilder, resp_array}; -/// An artificial "realistic" non-trivial example to demonstrate usage +#[cfg(feature = "with_tokio")] #[tokio::main] async fn main() { + do_main().await; +} + +#[cfg(feature = "with_async_std")] +#[async_std::main] +async fn main() { + do_main().await; +} + +/// An artificial "realistic" non-trivial example to demonstrate usage +async fn do_main() { // Create some completely arbitrary "test data" let test_data_size = 10; let test_data: Vec<_> = (0..test_data_size).map(|x| (x, x.to_string())).collect(); let addr = env::args() .nth(1) - .unwrap_or_else(|| "127.0.0.1:6379".to_string()) - .parse() - .unwrap(); + .unwrap_or_else(|| "127.0.0.1:6379".to_string()); + + let connection_builder = ConnectionBuilder::new(addr).expect("Cannot parse address"); - let connection = client::paired_connect(addr) + let connection = connection_builder + .paired_connect() .await .expect("Cannot open connection"); @@ -46,7 +55,7 @@ async fn main() { let connection_inner = connection.clone(); let incr_f = connection.send(resp_array!["INCR", "realistic_test_ctr"]); async move { - let ctr: String = incr_f.await.expect("Cannot increment"); + let ctr: i64 = incr_f.await.expect("Cannot increment"); let key = format!("rt_{}", ctr); let d_val = data.0.to_string(); diff --git a/examples/subscribe.rs b/examples/subscribe.rs index 602b5d3..5f85318 100644 --- a/examples/subscribe.rs +++ b/examples/subscribe.rs @@ -11,33 +11,44 @@ #[cfg(feature = "tokio02")] extern crate tokio_02 as tokio; -#[cfg(feature = "tokio03")] -extern crate tokio_03 as tokio; - #[cfg(feature = "tokio10")] extern crate tokio_10 as tokio; use std::env; +use client::ConnectionBuilder; use futures::StreamExt; -use redis_async::{client, resp::FromResp}; +use redis_async::{client, protocol::FromResp}; +#[cfg(feature = "with_tokio")] #[tokio::main] async fn main() { + do_main().await; +} + +#[cfg(feature = "with_async_std")] +#[async_std::main] +async fn main() { + do_main().await; +} + +async fn do_main() { env_logger::init(); let topic = env::args() .nth(1) .unwrap_or_else(|| "test-topic".to_string()); + let addr = env::args() .nth(2) - .unwrap_or_else(|| "127.0.0.1:6379".to_string()) - .parse() - .unwrap(); + .unwrap_or_else(|| "127.0.0.1:6379".to_string()); - let pubsub_con = client::pubsub_connect(addr) + let pubsub_con = ConnectionBuilder::new(addr) + .expect("Cannot parse address") + .pubsub_connect() .await - .expect("Cannot connect to Redis"); + .expect("Cannot open connection"); + let mut msgs = pubsub_con .subscribe(&topic) .await diff --git a/src/client/connect/async_std.rs b/src/client/connect/async_std.rs new file mode 100644 index 0000000..a8e6b84 --- /dev/null +++ b/src/client/connect/async_std.rs @@ -0,0 +1,166 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +//! Experimental support for a non-Tokio runtime. This hasn't been tested as much as Tokio, so +//! should be considered an unstable feature for the time being. + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use async_net::TcpStream; + +use bytes::{Buf, BytesMut}; + +use futures_sink::Sink; +use futures_util::{ + io::{AsyncRead, AsyncWrite}, + stream::Stream, +}; + +use crate::{ + error::Error, + protocol::{ + codec::{decode, encode}, + resp::RespValue, + }, +}; + +const TCP_PACKET_SIZE: usize = 1500; +const DEFAULT_BUF_LEN: usize = TCP_PACKET_SIZE; +const MAX_PACKETS: usize = 100; +const MAX_BUF_LEN: usize = TCP_PACKET_SIZE * MAX_PACKETS; +const BUF_INC_STEP: usize = TCP_PACKET_SIZE * 4; + +pub struct RespTcpStream { + tcp_stream: TcpStream, + out_buf: BytesMut, + in_buf: BytesMut, +} + +impl RespTcpStream { + pub(crate) fn new(tcp_stream: TcpStream) -> Self { + RespTcpStream { + tcp_stream, + out_buf: BytesMut::with_capacity(DEFAULT_BUF_LEN), + in_buf: BytesMut::with_capacity(DEFAULT_BUF_LEN), + } + } +} + +impl RespTcpStream { + fn attempt_push(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match Pin::new(&mut self.tcp_stream).poll_write(cx, &self.out_buf) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(0)) => return Poll::Ready(Ok(())), + Poll::Ready(Ok(bytes_written)) => { + self.out_buf.advance(bytes_written); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + } + } + } + + fn pull_into_buffer(&mut self, cx: &mut Context<'_>) -> Poll> { + self.in_buf.reserve(BUF_INC_STEP); + let mut old_len = self.in_buf.len(); + let new_len = old_len + BUF_INC_STEP; + unsafe { + self.in_buf.set_len(new_len); + } + let result = match Pin::new(&mut self.tcp_stream) + .poll_read(cx, &mut self.in_buf[old_len..new_len]) + { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(bytes_read)) => { + old_len += bytes_read; + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + }; + unsafe { + self.in_buf.set_len(old_len); + } + result + } +} + +impl Sink for RespTcpStream { + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut_self = self.get_mut(); + + if mut_self.out_buf.len() == 0 { + return Poll::Ready(Ok(())); + } + + if let Poll::Ready(Err(e)) = mut_self.attempt_push(cx) { + return Poll::Ready(Err(e)); + } + + if mut_self.out_buf.len() >= MAX_BUF_LEN { + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: RespValue) -> Result<(), Self::Error> { + let mut_self = self.get_mut(); + encode(item, &mut mut_self.out_buf); + + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().attempt_push(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut_self = self.get_mut(); + while mut_self.out_buf.len() > 0 { + match mut_self.attempt_push(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(())) => (), + } + } + + match Pin::new(&mut mut_self.tcp_stream).poll_close(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + } + } +} + +impl Stream for RespTcpStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut_self = self.get_mut(); + loop { + // Result, Error> + match decode(&mut mut_self.in_buf, 0) { + Ok(Some((pos, thing))) => { + mut_self.in_buf.advance(pos); + return Poll::Ready(Some(Ok(thing))); + } + Ok(None) => match mut_self.pull_into_buffer(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + }, + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + } +} diff --git a/src/client/connect.rs b/src/client/connect/mod.rs similarity index 81% rename from src/client/connect.rs rename to src/client/connect/mod.rs index 671683d..8eca6d0 100644 --- a/src/client/connect.rs +++ b/src/client/connect/mod.rs @@ -8,16 +8,29 @@ * except according to those terms. */ +#[cfg(feature = "with_async_std")] +mod async_std; + use std::net::SocketAddr; +#[cfg(feature = "with_async_std")] +use async_net::TcpStream; + use futures_util::{SinkExt, StreamExt}; +#[cfg(feature = "with_tokio")] use tokio::net::TcpStream; +#[cfg(feature = "with_tokio")] use tokio_util::codec::{Decoder, Framed}; -use crate::{error, resp}; +#[cfg(feature = "tokio_codec")] +use crate::protocol::RespCodec; +use crate::{error, protocol::FromResp}; -pub type RespConnection = Framed; +#[cfg(feature = "with_tokio")] +pub type RespConnection = Framed; +#[cfg(feature = "with_async_std")] +pub type RespConnection = async_std::RespTcpStream; /// Connect to a Redis server and return a Future that resolves to a /// `RespConnection` for reading and writing asynchronously. @@ -34,11 +47,19 @@ pub type RespConnection = Framed; /// /// But since most Redis usages involve issue commands that result in one /// single result, this library also implements `paired_connect`. +#[cfg(feature = "with_tokio")] +pub async fn connect(addr: &SocketAddr) -> Result { + let tcp_stream = TcpStream::connect(addr).await?; + Ok(RespCodec.framed(tcp_stream)) +} + +#[cfg(feature = "with_async_std")] pub async fn connect(addr: &SocketAddr) -> Result { let tcp_stream = TcpStream::connect(addr).await?; - Ok(resp::RespCodec.framed(tcp_stream)) + Ok(RespConnection::new(tcp_stream)) } +/// Connect with optional authentication pub async fn connect_with_auth( addr: &SocketAddr, username: Option<&str>, @@ -57,7 +78,7 @@ pub async fn connect_with_auth( connection.send(auth).await?; match connection.next().await { - Some(Ok(value)) => match resp::FromResp::from_resp(value) { + Some(Ok(value)) => match FromResp::from_resp(value) { Ok(()) => (), Err(e) => return Err(e), }, @@ -80,7 +101,7 @@ mod test { stream::{self, StreamExt}, }; - use crate::resp; + use crate::protocol::resp; #[tokio::test] async fn can_connect() { diff --git a/src/client/mod.rs b/src/client/mod.rs index 68fc214..31bafa9 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -20,15 +20,13 @@ //! in one response. //! * `pubsub_connect` is used for Redis's PUBSUB functionality. -pub mod connect; -#[macro_use] -pub mod paired; mod builder; -pub mod pubsub; +pub mod connect; +pub(crate) mod paired; +pub(crate) mod pubsub; +mod reconnect; pub use self::{ - builder::ConnectionBuilder, - connect::connect, - paired::{paired_connect, PairedConnection}, - pubsub::{pubsub_connect, PubsubConnection}, + builder::ConnectionBuilder, connect::connect, paired::PairedConnection, + pubsub::PubsubConnection, }; diff --git a/src/client/paired.rs b/src/client/paired.rs index 623ac9e..a0e7698 100644 --- a/src/client/paired.rs +++ b/src/client/paired.rs @@ -19,20 +19,17 @@ use std::task::{Context, Poll}; use futures_channel::{mpsc, oneshot}; use futures_sink::Sink; use futures_util::{ - future::{self, TryFutureExt}, + future::{self, Either}, stream::StreamExt, }; use super::{ connect::{connect_with_auth, RespConnection}, + reconnect::{ActionWork, Reconnectable, ReconnectableActions, ReconnectableConnectionFuture}, ConnectionBuilder, }; -use crate::{ - error, - reconnect::{reconnect, Reconnect}, - resp, -}; +use crate::{error, protocol::resp, task::spawn}; /// The state of sending messages to a Redis server enum SendStatus { @@ -56,18 +53,21 @@ enum ReceiveStatus { NotReady, } -type Responder = oneshot::Sender; -type SendPayload = (resp::RespValue, Responder); +#[derive(Debug)] +enum SendPayload { + One(resp::RespValue, oneshot::Sender), + Batch(Vec<(resp::RespValue, oneshot::Sender)>), +} -// /// The PairedConnectionInner is a spawned future that is responsible for pairing commands and -// /// results onto a `RespConnection` that is otherwise unpaired +/// The PairedConnectionInner is a spawned future that is responsible for pairing commands and +/// results onto a `RespConnection` that is otherwise unpaired struct PairedConnectionInner { /// The underlying connection that talks the RESP protocol connection: RespConnection, /// The channel upon which commands are received - out_rx: mpsc::UnboundedReceiver, + out_rx: mpsc::UnboundedReceiver<(resp::RespValue, oneshot::Sender)>, /// The queue of waiting oneshot's for commands sent but results not yet received - waiting: VecDeque, + waiting: VecDeque>, /// The status of the underlying connection send_status: SendStatus, @@ -193,23 +193,85 @@ impl Future for PairedConnectionInner { } } +impl ActionWork for SendPayload { + type ConnectionType = + mpsc::UnboundedSender<(resp::RespValue, oneshot::Sender)>; + + fn call(self, con: &Self::ConnectionType) -> Result<(), error::Error> { + match self { + SendPayload::One(value, receiver) => { + con.unbounded_send((value, receiver))?; + } + SendPayload::Batch(batches) => { + for (value, receiver) in batches { + con.unbounded_send((value, receiver))?; + } + } + } + + Ok(()) + } +} + +#[derive(Debug)] +struct PairedConnectionActions { + addr: SocketAddr, + username: Option>, + password: Option>, +} + +impl ReconnectableActions for PairedConnectionActions { + type WorkPayload = SendPayload; + + fn do_connection( + &self, + ) -> ReconnectableConnectionFuture< + mpsc::UnboundedSender<(resp::RespValue, oneshot::Sender)>, + error::Error, + > { + let con_f = inner_conn_fn(self.addr, self.username.clone(), self.password.clone()); + Box::pin(con_f) + } +} + /// A shareable and cheaply cloneable connection to which Redis commands can be sent #[derive(Debug, Clone)] pub struct PairedConnection { - out_tx_c: Arc>>, + out_tx_c: Arc>, +} + +impl PairedConnection { + async fn init( + addr: SocketAddr, + username: Option>, + password: Option>, + ) -> Result { + Ok(PairedConnection { + out_tx_c: Arc::new( + Reconnectable::init(PairedConnectionActions { + addr, + username, + password, + }) + .await?, + ), + }) + } } async fn inner_conn_fn( addr: SocketAddr, username: Option>, password: Option>, -) -> Result, error::Error> { +) -> Result)>, error::Error> +{ let username = username.as_ref().map(|u| u.as_ref()); let password = password.as_ref().map(|p| p.as_ref()); let connection = connect_with_auth(&addr, username, password).await?; let (out_tx, out_rx) = mpsc::unbounded(); let paired_connection_inner = PairedConnectionInner::new(connection, out_rx); - tokio::spawn(paired_connection_inner); + spawn(paired_connection_inner); + Ok(out_tx) } @@ -219,37 +281,10 @@ impl ConnectionBuilder { let username = self.username.clone(); let password = self.password.clone(); - let work_fn = |con: &mpsc::UnboundedSender, act| { - con.unbounded_send(act).map_err(|e| e.into()) - }; - - let conn_fn = move || { - let con_f = inner_conn_fn(addr, username.clone(), password.clone()); - Box::pin(con_f) as Pin> + Send + Sync>> - }; - - let reconnecting_con = reconnect(work_fn, conn_fn); - reconnecting_con.map_ok(|con| PairedConnection { - out_tx_c: Arc::new(con), - }) + PairedConnection::init(addr, username, password) } } -/// The default starting point to use most default Redis functionality. -/// -/// Returns a future that resolves to a `PairedConnection`. The future will complete when the -/// initial connection is established. -/// -/// Once the initial connection is established, the connection will attempt to reconnect should -/// the connection be broken (e.g. the Redis server being restarted), but reconnections occur -/// asynchronously, so all commands issued while the connection is unavailable will error, it is -/// the client's responsibility to retry commands as applicable. Also, at least one command needs -/// to be tried against the connection to trigger the re-connection attempt; this means at least -/// one command will definitely fail in a disconnect/reconnect scenario. -pub async fn paired_connect(addr: SocketAddr) -> Result { - ConnectionBuilder::new(addr)?.paired_connect().await -} - impl PairedConnection { /// Sends a command to Redis. /// @@ -263,41 +298,71 @@ impl PairedConnection { /// Behind the scenes the message is queued up and sent to Redis asynchronously before the /// future is realised. As such, it is guaranteed that messages are sent in the same order /// that `send` is called. - pub fn send(&self, msg: resp::RespValue) -> impl Future> + pub fn send<'a, T>( + &'a self, + msg: resp::RespValue, + ) -> impl Future> + 'a where - T: resp::FromResp, + T: resp::FromResp + 'a, { match &msg { resp::RespValue::Array(_) => (), _ => { - return future::Either::Right(future::ready(Err(error::internal( + return Either::Left(future::err(error::internal( "Command must be a RespValue::Array", - )))); + ))); } } let (tx, rx) = oneshot::channel(); - match self.out_tx_c.do_work((msg, tx)) { - Ok(()) => future::Either::Left(async move { - match rx.await { - Ok(v) => Ok(T::from_resp(v)?), - Err(_) => Err(error::internal( - "Connection closed before response received", - )), + let work_f = self.out_tx_c.do_work(SendPayload::One(msg, tx)); + + Either::Right(async { + let _ = work_f.await?; + match rx.await { + Ok(v) => Ok(T::from_resp(v)?), + Err(_) => Err(error::internal( + "Connection closed before response received", + )), + } + }) + } + + pub fn send_batch( + &self, + msgs: Vec, + ) -> impl Future, error::Error>> + '_ { + let batch_size = msgs.len(); + let mut work = Vec::with_capacity(batch_size); + let mut receivers = Vec::with_capacity(batch_size); + + for msg in msgs { + let (tx, rx) = oneshot::channel(); + work.push((msg, tx)); + receivers.push(rx); + } + + let work_f = self.out_tx_c.do_work(SendPayload::Batch(work)); + + async move { + let _ = work_f.await?; + let mut results = Vec::with_capacity(batch_size); + for receiver in receivers { + match receiver.await { + Ok(v) => results.push(v), + Err(_) => { + return Err(error::internal( + "Connection closed before response received", + )) + } } - }), - Err(e) => future::Either::Right(future::ready(Err(e))), + } + Ok(results) } } pub fn send_and_forget(&self, msg: resp::RespValue) { - let send_f = self.send::(msg); - let forget_f = async { - if let Err(e) = send_f.await { - log::error!("Error in send_and_forget: {}", e); - } - }; - tokio::spawn(forget_f); + let _ = self.send::(msg); } } @@ -307,9 +372,9 @@ mod test { #[tokio::test] async fn can_paired_connect() { - let addr = "127.0.0.1:6379".parse().unwrap(); - - let connection = super::paired_connect(addr) + let connection = ConnectionBuilder::new("127.0.0.1:6379") + .expect("Cannot build builder") + .paired_connect() .await .expect("Cannot establish connection"); @@ -326,18 +391,18 @@ mod test { #[tokio::test] async fn complex_paired_connect() { - let addr = "127.0.0.1:6379".parse().unwrap(); - - let connection = super::paired_connect(addr) + let connection = ConnectionBuilder::new("127.0.0.1:6379") + .expect("Cannot build builder") + .paired_connect() .await .expect("Cannot establish connection"); - let value: String = connection + let value: u64 = connection .send(resp_array!["INCR", "CTR"]) .await .expect("Cannot increment counter"); let result: String = connection - .send(resp_array!["SET", "LASTCTR", value]) + .send(resp_array!["SET", "LASTCTR", value.to_string()]) .await .expect("Cannot set value"); @@ -346,11 +411,12 @@ mod test { #[tokio::test] async fn sending_a_lot_of_data_test() { - let addr = "127.0.0.1:6379".parse().unwrap(); - - let connection = super::paired_connect(addr) + let connection = ConnectionBuilder::new("127.0.0.1:6379") + .expect("Cannot build builder") + .paired_connect() .await - .expect("Cannot connect to Redis"); + .expect("Cannot establish connection"); + let mut futures = Vec::with_capacity(1000); for i in 0..1000 { let key = format!("X_{}", i); diff --git a/src/client/pubsub.rs b/src/client/pubsub.rs index ee086e2..6162610 100644 --- a/src/client/pubsub.rs +++ b/src/client/pubsub.rs @@ -17,10 +17,7 @@ use std::task::{Context, Poll}; use futures_channel::{mpsc, oneshot}; use futures_sink::Sink; -use futures_util::{ - future::TryFutureExt, - stream::{Fuse, Stream, StreamExt}, -}; +use futures_util::stream::{Fuse, Stream, StreamExt}; use super::{ connect::{connect_with_auth, RespConnection}, @@ -29,8 +26,12 @@ use super::{ use crate::{ error::{self, ConnectionReason}, - reconnect::{reconnect, Reconnect}, - resp::{self, FromResp}, + protocol::resp::{self, FromResp}, + task::spawn, +}; + +use super::reconnect::{ + ActionWork, Reconnectable, ReconnectableActions, ReconnectableConnectionFuture, }; #[derive(Debug)] @@ -323,10 +324,57 @@ impl Future for PubsubConnectionInner { } } +impl ActionWork for PubsubEvent { + type ConnectionType = mpsc::UnboundedSender; + + fn call(self, con: &Self::ConnectionType) -> Result<(), error::Error> { + con.unbounded_send(self).map_err(|e| e.into()) + } +} + +// PubsubEvent, mpsc::UnboundedSender + +#[derive(Debug)] +struct PubsubConnectionActions { + addr: SocketAddr, + username: Option>, + password: Option>, +} + +impl ReconnectableActions for PubsubConnectionActions { + type WorkPayload = PubsubEvent; + + fn do_connection( + &self, + ) -> ReconnectableConnectionFuture, error::Error> { + let con_f = inner_conn_fn(self.addr, self.username.clone(), self.password.clone()); + Box::pin(con_f) + } +} + /// A shareable reference to subscribe to PUBSUB topics #[derive(Debug, Clone)] pub struct PubsubConnection { - out_tx_c: Arc>>, + out_tx_c: Arc>, +} + +impl PubsubConnection { + async fn init( + addr: SocketAddr, + username: Option>, + password: Option>, + ) -> Result { + Ok(PubsubConnection { + out_tx_c: Arc::new( + Reconnectable::init(PubsubConnectionActions { + addr, + username, + password, + }) + .await?, + ), + }) + } } async fn inner_conn_fn( @@ -339,7 +387,7 @@ async fn inner_conn_fn( let connection = connect_with_auth(&addr, username, password).await?; let (out_tx, out_rx) = mpsc::unbounded(); - tokio::spawn(async { + spawn(async { match PubsubConnectionInner::new(connection, out_rx).await { Ok(_) => (), Err(e) => log::error!("Pub/Sub error: {:?}", e), @@ -349,36 +397,21 @@ async fn inner_conn_fn( } impl ConnectionBuilder { + /// Used for Redis's PUBSUB functionality. + /// + /// Returns a future that resolves to a `PubsubConnection`. The future will only resolve once the + /// connection is established; after the intial establishment, if the connection drops for any + /// reason (e.g. Redis server being restarted), the connection will attempt re-connect, however + /// any subscriptions will need to be re-subscribed. pub fn pubsub_connect(&self) -> impl Future> { let addr = self.addr; let username = self.username.clone(); let password = self.password.clone(); - let reconnecting_f = reconnect( - |con: &mpsc::UnboundedSender, act| { - con.unbounded_send(act).map_err(|e| e.into()) - }, - move || { - let con_f = inner_conn_fn(addr, username.clone(), password.clone()); - Box::pin(con_f) - }, - ); - reconnecting_f.map_ok(|con| PubsubConnection { - out_tx_c: Arc::new(con), - }) + PubsubConnection::init(addr, username, password) } } -/// Used for Redis's PUBSUB functionality. -/// -/// Returns a future that resolves to a `PubsubConnection`. The future will only resolve once the -/// connection is established; after the intial establishment, if the connection drops for any -/// reason (e.g. Redis server being restarted), the connection will attempt re-connect, however -/// any subscriptions will need to be re-subscribed. -pub async fn pubsub_connect(addr: SocketAddr) -> Result { - ConnectionBuilder::new(addr)?.pubsub_connect().await -} - impl PubsubConnection { /// Subscribes to a particular PUBSUB topic. /// @@ -394,7 +427,8 @@ impl PubsubConnection { let (tx, rx) = mpsc::unbounded(); let (signal_t, signal_r) = oneshot::channel(); self.out_tx_c - .do_work(PubsubEvent::Subscribe(topic.to_owned(), tx, signal_t))?; + .do_work(PubsubEvent::Subscribe(topic.to_owned(), tx, signal_t)) + .await?; match signal_r.await { Ok(_) => Ok(PubsubStream { @@ -410,7 +444,8 @@ impl PubsubConnection { let (tx, rx) = mpsc::unbounded(); let (signal_t, signal_r) = oneshot::channel(); self.out_tx_c - .do_work(PubsubEvent::Psubscribe(topic.to_owned(), tx, signal_t))?; + .do_work(PubsubEvent::Psubscribe(topic.to_owned(), tx, signal_t)) + .await?; match signal_r.await { Ok(_) => Ok(PubsubStream { @@ -466,13 +501,18 @@ impl Drop for PubsubStream { mod test { use futures::{try_join, StreamExt, TryStreamExt}; - use crate::{client, resp}; + use super::ConnectionBuilder; + + use crate::protocol::resp; #[tokio::test] async fn subscribe_test() { - let addr = "127.0.0.1:6379".parse().unwrap(); - let paired_c = client::paired_connect(addr); - let pubsub_c = super::pubsub_connect(addr); + let connection_builder = + ConnectionBuilder::new("127.0.0.1:6379").expect("Cannot build builder"); + + let paired_c = connection_builder.paired_connect(); + let pubsub_c = connection_builder.pubsub_connect(); + let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis"); let topic_messages = pubsub @@ -500,9 +540,12 @@ mod test { #[tokio::test] async fn psubscribe_test() { - let addr = "127.0.0.1:6379".parse().unwrap(); - let paired_c = client::paired_connect(addr); - let pubsub_c = super::pubsub_connect(addr); + let connection_builder = + ConnectionBuilder::new("127.0.0.1:6379").expect("Cannot build builder"); + + let paired_c = connection_builder.paired_connect(); + let pubsub_c = connection_builder.pubsub_connect(); + let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis"); let topic_messages = pubsub diff --git a/src/client/reconnect/holder.rs b/src/client/reconnect/holder.rs new file mode 100644 index 0000000..4a41657 --- /dev/null +++ b/src/client/reconnect/holder.rs @@ -0,0 +1,218 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +use std::{ + future::Future, + time::{Duration, Instant}, +}; + +use futures_util::{future, TryFutureExt}; + +use lwactors::{actor, Action, ActorSender}; + +use crate::error; + +use super::ActionWork; + +/// A standalone actor which holds a Redis connection +#[derive(Debug)] +pub(crate) struct ConnectionHolder +where + F: ActionWork, +{ + queue: + ActorSender, ConnectionHolderResult, error::Error>, +} + +impl ConnectionHolder +where + F: ActionWork + Send + 'static, +{ + pub(crate) fn new(t: F::ConnectionType) -> Self { + ConnectionHolder { + queue: actor(ConnectionHolderState::new(t)), + } + } + + /// Perform a chunk of work on the available connection, if available. + /// + /// Returns a boolean. True means the work was done. False means the work was not done, and the + /// caller must attempt a reconnection. Any other failure will return an error, the caller + /// should not attempt a reconnection. + pub(crate) fn do_work(&self, f: F) -> impl Future> { + self.queue + .invoke(ConnectionHolderAction::DoWork(f)) + .and_then(|result| match result { + ConnectionHolderResult::DoWork(DoWorkState::Connecting) => future::err( + error::Error::Connection(error::ConnectionReason::Connecting), + ), + ConnectionHolderResult::DoWork(DoWorkState::NotConnected) => future::ok(false), + ConnectionHolderResult::DoWork(DoWorkState::ConnectedErr(e)) => future::err(e), + ConnectionHolderResult::DoWork(DoWorkState::ConnectedOk(())) => future::ok(true), + _ => panic!("Not a DoWork result"), + }) + } + + /// Set a new connection if previously advised to attempt re-connection. + pub(crate) async fn set_connection(&self, con: F::ConnectionType) -> Result<(), error::Error> { + match self + .queue + .invoke(ConnectionHolderAction::SetConnection(con)) + .await? + { + ConnectionHolderResult::SetConnection => Ok(()), + _ => panic!("Wrong response"), + } + } + + pub(crate) async fn set_connection_failed(&self) -> Result<(), error::Error> { + match self + .queue + .invoke(ConnectionHolderAction::SetConnectionFailed) + .await? + { + ConnectionHolderResult::SetConnectionFailed => Ok(()), + _ => panic!("Wrong response"), + } + } +} + +impl Clone for ConnectionHolder +where + F: ActionWork, +{ + fn clone(&self) -> Self { + ConnectionHolder { + queue: self.queue.clone(), + } + } +} + +// TODO - should probably be configurable... +const MAX_CONNECTION_DUR: Duration = Duration::from_secs(10); + +#[derive(Debug)] +enum ConnectionHolderAction +where + F: ActionWork, +{ + DoWork(F), + SetConnection(F::ConnectionType), + SetConnectionFailed, +} + +impl Action for ConnectionHolderAction +where + F: ActionWork, +{ + type State = ConnectionHolderState; + type Result = ConnectionHolderResult; + type Error = error::Error; + + fn act(self, state: &mut Self::State) -> Result { + let res = match self { + ConnectionHolderAction::DoWork(work_f) => { + let dws: DoWorkState = match state { + ConnectionHolderState::Connected(ref con) => match work_f.call(con) { + Ok(()) => DoWorkState::ConnectedOk(()), + Err(e) => { + if e.is_io() || e.is_unexpected() { + *state = ConnectionHolderState::Connecting(Instant::now()); + DoWorkState::NotConnected + } else { + DoWorkState::ConnectedErr(e) + } + } + }, + ConnectionHolderState::NotConnected => { + *state = ConnectionHolderState::Connecting(Instant::now()); + DoWorkState::NotConnected + } + ConnectionHolderState::Connecting(ref mut inst) => { + let now = Instant::now(); + let dur = now - *inst; + if dur > MAX_CONNECTION_DUR { + *inst = now; + DoWorkState::NotConnected + } else { + DoWorkState::Connecting + } + } + }; + ConnectionHolderResult::DoWork(dws) + } + ConnectionHolderAction::SetConnection(con) => { + match state { + ConnectionHolderState::Connected(_) => { + log::warn!("Cannot set state when in Connected state"); + } + ConnectionHolderState::Connecting(_) => { + *state = ConnectionHolderState::Connected(con) + } + ConnectionHolderState::NotConnected => { + log::warn!("This is a valid, but rare sequence of events"); + *state = ConnectionHolderState::Connected(con) + } + } + ConnectionHolderResult::SetConnection + } + ConnectionHolderAction::SetConnectionFailed => { + match state { + ConnectionHolderState::Connected(_) => { + log::warn!("Cannot set state when in Connected state"); + } + ConnectionHolderState::Connecting(_) => { + *state = ConnectionHolderState::NotConnected + } + ConnectionHolderState::NotConnected => { + log::warn!("Suspicious series of events..."); + } + } + ConnectionHolderResult::SetConnectionFailed + } + }; + + Ok(res) + } +} + +#[derive(Debug)] +enum ConnectionHolderState +where + T: Send, +{ + Connecting(Instant), + Connected(T), + NotConnected, +} + +impl ConnectionHolderState +where + T: Send + 'static, +{ + fn new(t: T) -> Self { + ConnectionHolderState::Connected(t) + } +} + +#[derive(Debug)] +pub(crate) enum ConnectionHolderResult { + DoWork(DoWorkState), + SetConnection, + SetConnectionFailed, +} + +#[derive(Debug)] +pub(crate) enum DoWorkState { + NotConnected, + Connecting, + ConnectedOk(()), + ConnectedErr(E), +} diff --git a/src/client/reconnect/mod.rs b/src/client/reconnect/mod.rs new file mode 100644 index 0000000..efa0c13 --- /dev/null +++ b/src/client/reconnect/mod.rs @@ -0,0 +1,102 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +mod holder; + +use std::{fmt, future::Future, pin::Pin}; + +use crate::{error, task::spawn}; + +use holder::ConnectionHolder; + +/// A trait to be implemented by the chunks of work that are sent to a Redis connection +pub(crate) trait ActionWork { + type ConnectionType: Send + fmt::Debug; + + fn call(self, con: &Self::ConnectionType) -> Result<(), error::Error>; +} + +pub(crate) type ReconnectableConnectionFuture = + Pin> + Send>>; + +/// A trait to be implemented to allow a connection to be re-established should it be lost +pub(crate) trait ReconnectableActions { + type WorkPayload: ActionWork + 'static; + + fn do_connection( + &self, + ) -> ReconnectableConnectionFuture< + ::ConnectionType, + error::Error, + >; +} + +/// A wrapper around a Redis connection that will automatically try and re-connect should the +/// connection be lost +#[derive(Debug)] +pub(crate) struct Reconnectable +where + A: ReconnectableActions, +{ + con: ConnectionHolder, + actions: A, +} + +impl Reconnectable +where + A: ReconnectableActions, + A::WorkPayload: Send, +{ + pub(crate) async fn init(actions: A) -> Result { + let t = actions.do_connection().await?; + Ok(Reconnectable { + con: ConnectionHolder::new(t), + actions, + }) + } + + pub(crate) fn do_work( + &self, + work: A::WorkPayload, + ) -> impl Future> + '_ { + let work_f = self.con.do_work(work); + + async move { + if work_f.await? { + Ok(()) + } else { + self.reconnect(); + Err(error::Error::Connection( + error::ConnectionReason::NotConnected, + )) + } + } + } + + fn reconnect(&self) { + let con = self.con.clone(); + let con_f = self.actions.do_connection(); + spawn(async move { + match con_f.await { + Ok(new_con) => match con.set_connection(new_con).await { + Ok(()) => (), + Err(e) => log::warn!("Couldn't set new connection: {}", e), + }, + Err(e) => { + log::error!("Could not open connection: {}", e); + match con.set_connection_failed().await { + Ok(()) => (), + Err(e) => log::warn!("Couldn't set connection failure: {}", e), + } + } + } + }) + } +} diff --git a/src/error.rs b/src/error.rs index beeeda0..fb770c7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,5 @@ /* - * Copyright 2017-2019 Ben Ashford + * Copyright 2017-2020 Ben Ashford * * Licensed under the Apache License, Version 2.0 or the MIT license @@ -10,27 +10,34 @@ //! Error handling -use std::{error, fmt, io}; +use std::{fmt, io}; use futures_channel::mpsc; -use crate::resp; +use thiserror::Error; -#[derive(Debug)] +use crate::protocol::resp; + +#[derive(Debug, Error)] pub enum Error { /// A non-specific internal error that prevented an operation from completing + #[error("Internal Error: {0}")] Internal(String), /// An IO error occurred - IO(io::Error), + #[error("IO Error: {0}")] + IO(#[from] io::Error), /// A RESP parsing/serialising error occurred + #[error("{0}, {1:?}")] RESP(String, Option), /// A remote error + #[error("Remote Redis error: {0}")] Remote(String), /// Error creating a connection, or an error with a connection being closed unexpectedly + #[error("Connection error: {0}")] Connection(ConnectionReason), /// An unexpected error. In this context "unexpected" means @@ -40,9 +47,20 @@ pub enum Error { /// /// If any error is propagated this way that needs to be handled, then it should be made into /// a proper option. + #[error("Unexpected error: {0}")] Unexpected(String), } +impl Error { + pub(crate) fn is_io(&self) -> bool { + matches!(self, Error::IO(_)) + } + + pub(crate) fn is_unexpected(&self) -> bool { + matches!(self, Error::Unexpected(_)) + } +} + pub(crate) fn internal(msg: impl Into) -> Error { Error::Internal(msg.into()) } @@ -55,46 +73,15 @@ pub(crate) fn resp(msg: impl Into, resp: resp::RespValue) -> Error { Error::RESP(msg.into(), Some(resp)) } -impl From for Error { - fn from(err: io::Error) -> Error { - Error::IO(err) - } -} - impl From> for Error { fn from(err: mpsc::TrySendError) -> Error { Error::Unexpected(format!("Cannot write to channel: {}", err)) } } -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - match self { - Error::IO(err) => Some(err), - _ => None, - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Error::Internal(s) => write!(f, "{}", s), - Error::IO(err) => write!(f, "{}", err), - Error::RESP(s, _) => write!(f, "{}", s), - Error::Remote(s) => write!(f, "{}", s), - Error::Connection(ConnectionReason::Connected) => { - write!(f, "Connection already established") - } - Error::Connection(ConnectionReason::Connecting) => write!(f, "Connection in progress"), - Error::Connection(ConnectionReason::ConnectionFailed) => { - write!(f, "The last attempt to establish a connection failed") - } - Error::Connection(ConnectionReason::NotConnected) => { - write!(f, "Connection has been closed") - } - Error::Unexpected(err) => write!(f, "{}", err), - } +impl From for Error { + fn from(err: lwactors::ActorError) -> Error { + Error::Internal(format!("Actor error: {}", err)) } } @@ -114,3 +101,14 @@ pub enum ConnectionReason { /// clients should try again NotConnected, } + +impl fmt::Display for ConnectionReason { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + ConnectionReason::Connecting => "Connecting", + ConnectionReason::Connected => "Connected", + ConnectionReason::ConnectionFailed => "ConnectionFailed", + ConnectionReason::NotConnected => "NotConnected", + }) + } +} diff --git a/src/lib.rs b/src/lib.rs index 350d498..e972173 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,14 +60,7 @@ extern crate tokio_02 as tokio; #[cfg(feature = "tokio02")] extern crate tokio_util_03 as tokio_util; -#[cfg(feature = "tokio03")] -extern crate bytes_06 as bytes; -#[cfg(feature = "tokio03")] -extern crate tokio_03 as tokio; -#[cfg(feature = "tokio03")] -extern crate tokio_util_05 as tokio_util; - -#[cfg(feature = "tokio10")] +#[cfg(feature = "bytes_10")] extern crate bytes_10 as bytes; #[cfg(feature = "tokio10")] extern crate tokio_10 as tokio; @@ -75,11 +68,9 @@ extern crate tokio_10 as tokio; extern crate tokio_util_06 as tokio_util; #[macro_use] -pub mod resp; +pub mod protocol; -#[macro_use] pub mod client; - pub mod error; -pub(crate) mod reconnect; +mod task; diff --git a/src/protocol/codec/decode.rs b/src/protocol/codec/decode.rs new file mode 100644 index 0000000..d824be2 --- /dev/null +++ b/src/protocol/codec/decode.rs @@ -0,0 +1,182 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +use std::str; + +use bytes::BytesMut; + +use crate::{error::Error, protocol::resp::RespValue}; + +type DecodeResult = Result, Error>; + +#[inline] +fn parse_error(message: String) -> Error { + Error::RESP(message, None) +} + +/// Many RESP types have their length (which is either bytes or "number of elements", depending on context) +/// encoded as a string, terminated by "\r\n", this looks for them. +/// +/// Only return the string if the whole sequence is complete, including the terminator bytes (but those final +/// two bytes will not be returned) +/// +/// TODO - rename this function potentially, it's used for simple integers too +fn scan_integer(buf: &mut BytesMut, idx: usize) -> Result, Error> { + let length = buf.len(); + let mut at_end = false; + let mut pos = idx; + loop { + if length <= pos { + return Ok(None); + } + match (at_end, buf[pos]) { + (true, b'\n') => return Ok(Some((pos + 1, &buf[idx..pos - 1]))), + (false, b'\r') => at_end = true, + (false, b'0'..=b'9') => (), + (false, b'-') => (), + (_, val) => { + return Err(parse_error(format!( + "Unexpected byte in size_string: {}", + val + ))); + } + } + pos += 1; + } +} + +fn scan_string(buf: &mut BytesMut, idx: usize) -> Option<(usize, String)> { + let length = buf.len(); + let mut at_end = false; + let mut pos = idx; + loop { + if length <= pos { + return None; + } + match (at_end, buf[pos]) { + (true, b'\n') => { + let value = String::from_utf8_lossy(&buf[idx..pos - 1]).into_owned(); + return Some((pos + 1, value)); + } + (true, _) => at_end = false, + (false, b'\r') => at_end = true, + (false, _) => (), + } + pos += 1; + } +} + +fn decode_raw_integer(buf: &mut BytesMut, idx: usize) -> Result, Error> { + match scan_integer(buf, idx) { + Ok(None) => Ok(None), + Ok(Some((pos, int_str))) => { + // Redis integers are transmitted as strings, so we first convert the raw bytes into a string... + match str::from_utf8(int_str) { + Ok(string) => { + // ...and then parse the string. + match string.parse() { + Ok(int) => Ok(Some((pos, int))), + Err(_) => Err(parse_error(format!("Not an integer: {}", string))), + } + } + Err(_) => Err(parse_error(format!("Not a valid string: {:?}", int_str))), + } + } + Err(e) => Err(e), + } +} + +fn decode_bulk_string(buf: &mut BytesMut, idx: usize) -> DecodeResult { + match decode_raw_integer(buf, idx) { + Ok(None) => Ok(None), + Ok(Some((pos, -1))) => Ok(Some((pos, RespValue::Nil))), + Ok(Some((pos, size))) if size >= 0 => { + let size = size as usize; + let remaining = buf.len() - pos; + let required_bytes = size + 2; + + if remaining < required_bytes { + return Ok(None); + } + + let bulk_string = RespValue::BulkString(buf[pos..(pos + size)].to_vec()); + Ok(Some((pos + required_bytes, bulk_string))) + } + Ok(Some((_, size))) => Err(parse_error(format!("Invalid string size: {}", size))), + Err(e) => Err(e), + } +} + +fn decode_array(buf: &mut BytesMut, idx: usize) -> DecodeResult { + match decode_raw_integer(buf, idx) { + Ok(None) => Ok(None), + Ok(Some((pos, -1))) => Ok(Some((pos, RespValue::Nil))), + Ok(Some((pos, size))) if size >= 0 => { + let size = size as usize; + let mut pos = pos; + let mut values = Vec::with_capacity(size); + for _ in 0..size { + match decode(buf, pos) { + Ok(None) => return Ok(None), + Ok(Some((new_pos, value))) => { + values.push(value); + pos = new_pos; + } + Err(e) => return Err(e), + } + } + Ok(Some((pos, RespValue::Array(values)))) + } + Ok(Some((_, size))) => Err(parse_error(format!("Invalid array size: {}", size))), + Err(e) => Err(e), + } +} + +fn decode_integer(buf: &mut BytesMut, idx: usize) -> DecodeResult { + match decode_raw_integer(buf, idx) { + Ok(None) => Ok(None), + Ok(Some((pos, int))) => Ok(Some((pos, RespValue::Integer(int)))), + Err(e) => Err(e), + } +} + +/// A simple string is any series of bytes that ends with `\r\n` +#[allow(clippy::unknown_clippy_lints, clippy::unnecessary_wraps)] +fn decode_simple_string(buf: &mut BytesMut, idx: usize) -> DecodeResult { + match scan_string(buf, idx) { + None => Ok(None), + Some((pos, string)) => Ok(Some((pos, RespValue::SimpleString(string)))), + } +} + +#[allow(clippy::unknown_clippy_lints, clippy::unnecessary_wraps)] +fn decode_error(buf: &mut BytesMut, idx: usize) -> DecodeResult { + match scan_string(buf, idx) { + None => Ok(None), + Some((pos, string)) => Ok(Some((pos, RespValue::Error(string)))), + } +} + +pub(crate) fn decode(buf: &mut BytesMut, idx: usize) -> DecodeResult { + let length = buf.len(); + if length <= idx { + return Ok(None); + } + + let first_byte = buf[idx]; + match first_byte { + b'$' => decode_bulk_string(buf, idx + 1), + b'*' => decode_array(buf, idx + 1), + b':' => decode_integer(buf, idx + 1), + b'+' => decode_simple_string(buf, idx + 1), + b'-' => decode_error(buf, idx + 1), + _ => Err(parse_error(format!("Unexpected byte: {}", first_byte))), + } +} diff --git a/src/protocol/codec/encode.rs b/src/protocol/codec/encode.rs new file mode 100644 index 0000000..2c6191d --- /dev/null +++ b/src/protocol/codec/encode.rs @@ -0,0 +1,91 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +use std::cmp; + +use bytes::{BufMut, BytesMut}; + +use crate::protocol::resp::RespValue; + +const DEFAULT_MESSAGE_SIZE: usize = 1024; + +fn check_and_reserve(buf: &mut BytesMut, amt: usize) { + let remaining_bytes = buf.remaining_mut(); + if remaining_bytes < amt { + buf.reserve(cmp::max(amt, DEFAULT_MESSAGE_SIZE)); + } +} + +fn write_rn(buf: &mut BytesMut) { + buf.put_u8(b'\r'); + buf.put_u8(b'\n'); +} + +fn write_simple_string(symb: u8, string: &str, buf: &mut BytesMut) { + let bytes = string.as_bytes(); + let size = 1 + bytes.len() + 2; + check_and_reserve(buf, size); + buf.put_u8(symb); + buf.extend(bytes); + write_rn(buf); +} + +fn write_header(symb: u8, len: i64, buf: &mut BytesMut) { + let len_as_string = len.to_string(); + let len_as_bytes = len_as_string.as_bytes(); + let header_bytes = 1 + len_as_bytes.len() + 2; + check_and_reserve(buf, header_bytes); + buf.put_u8(symb); + buf.extend(len_as_bytes); + write_rn(buf); +} + +fn encode_nil(buf: &mut BytesMut) { + write_header(b'$', -1, buf); +} + +fn encode_array(ary: Vec, buf: &mut BytesMut) { + write_header(b'*', ary.len() as i64, buf); + for v in ary { + encode(v, buf); + } +} + +fn encode_bulkstring(bstr: Vec, buf: &mut BytesMut) { + let len = bstr.len(); + write_header(b'$', len as i64, buf); + check_and_reserve(buf, len + 2); + buf.extend(bstr); + write_rn(buf); +} + +fn encode_error(err: &str, buf: &mut BytesMut) { + write_simple_string(b'-', err, buf); +} + +fn encode_integer(val: i64, buf: &mut BytesMut) { + // Simple integer are just the header + write_header(b':', val, buf); +} + +fn encode_simple_string(string: &str, buf: &mut BytesMut) { + write_simple_string(b'+', string, buf); +} + +pub(crate) fn encode(msg: RespValue, buf: &mut BytesMut) { + match msg { + RespValue::Nil => encode_nil(buf), + RespValue::Array(ary) => encode_array(ary, buf), + RespValue::BulkString(bstr) => encode_bulkstring(bstr, buf), + RespValue::Error(ref string) => encode_error(string, buf), + RespValue::Integer(val) => encode_integer(val, buf), + RespValue::SimpleString(ref string) => encode_simple_string(string, buf), + } +} diff --git a/src/protocol/codec/mod.rs b/src/protocol/codec/mod.rs new file mode 100644 index 0000000..56d202b --- /dev/null +++ b/src/protocol/codec/mod.rs @@ -0,0 +1,21 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +mod decode; +mod encode; + +#[cfg(feature = "tokio_codec")] +pub(crate) mod tokio; + +#[cfg(feature = "with_async_std")] +pub(crate) use encode::encode; + +#[cfg(feature = "with_async_std")] +pub(crate) use decode::decode; diff --git a/src/protocol/codec/tokio/decode.rs b/src/protocol/codec/tokio/decode.rs new file mode 100644 index 0000000..f082fed --- /dev/null +++ b/src/protocol/codec/tokio/decode.rs @@ -0,0 +1,36 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +use bytes::{Buf, BytesMut}; + +use tokio_util::codec::Decoder; + +use crate::{ + error::Error, + protocol::{codec::decode::decode, resp::RespValue}, +}; + +use super::RespCodec; + +impl Decoder for RespCodec { + type Item = RespValue; + type Error = Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { + match decode(buf, 0) { + Ok(None) => Ok(None), + Ok(Some((pos, item))) => { + buf.advance(pos); + Ok(Some(item)) + } + Err(e) => Err(e), + } + } +} diff --git a/src/protocol/codec/tokio/encode.rs b/src/protocol/codec/tokio/encode.rs new file mode 100644 index 0000000..1b31dc3 --- /dev/null +++ b/src/protocol/codec/tokio/encode.rs @@ -0,0 +1,28 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +use std::io; + +use bytes::BytesMut; + +use tokio_util::codec::Encoder; + +use crate::protocol::{codec::encode::encode, resp::RespValue}; + +use super::RespCodec; + +impl Encoder for RespCodec { + type Error = io::Error; + + fn encode(&mut self, msg: RespValue, buf: &mut BytesMut) -> Result<(), Self::Error> { + encode(msg, buf); + Ok(()) + } +} diff --git a/src/protocol/codec/tokio/mod.rs b/src/protocol/codec/tokio/mod.rs new file mode 100644 index 0000000..aae8b8f --- /dev/null +++ b/src/protocol/codec/tokio/mod.rs @@ -0,0 +1,92 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +mod decode; +mod encode; + +/// Codec to read frames +pub struct RespCodec; + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + + use tokio_util::codec::{Decoder, Encoder}; + + use crate::protocol::resp::RespValue; + + use super::RespCodec; + + fn obj_to_bytes(obj: RespValue) -> Vec { + let mut bytes = BytesMut::new(); + let mut codec = RespCodec; + codec.encode(obj, &mut bytes).unwrap(); + bytes.to_vec() + } + + #[test] + fn test_resp_array_macro() { + let resp_object = resp_array!["SET", "x"]; + let bytes = obj_to_bytes(resp_object); + assert_eq!(b"*2\r\n$3\r\nSET\r\n$1\r\nx\r\n", bytes.as_slice()); + + let resp_object = resp_array!["RPUSH", "wyz"].append(vec!["a", "b"]); + let bytes = obj_to_bytes(resp_object); + assert_eq!( + &b"*4\r\n$5\r\nRPUSH\r\n$3\r\nwyz\r\n$1\r\na\r\n$1\r\nb\r\n"[..], + bytes.as_slice() + ); + + let vals = vec![String::from("a"), String::from("b")]; + let resp_object = resp_array!["RPUSH", "xyz"].append(&vals); + let bytes = obj_to_bytes(resp_object); + assert_eq!( + &b"*4\r\n$5\r\nRPUSH\r\n$3\r\nxyz\r\n$1\r\na\r\n$1\r\nb\r\n"[..], + bytes.as_slice() + ); + } + + #[test] + fn test_bulk_string() { + let resp_object = RespValue::BulkString(b"THISISATEST".to_vec()); + let mut bytes = BytesMut::new(); + let mut codec = RespCodec; + codec.encode(resp_object.clone(), &mut bytes).unwrap(); + assert_eq!(b"$11\r\nTHISISATEST\r\n".to_vec(), bytes.to_vec()); + + let deserialized = codec.decode(&mut bytes).unwrap().unwrap(); + assert_eq!(deserialized, resp_object); + } + + #[test] + fn test_array() { + let resp_object = RespValue::Array(vec!["TEST1".into(), "TEST2".into()]); + let mut bytes = BytesMut::new(); + let mut codec = RespCodec; + codec.encode(resp_object.clone(), &mut bytes).unwrap(); + assert_eq!( + b"*2\r\n$5\r\nTEST1\r\n$5\r\nTEST2\r\n".to_vec(), + bytes.to_vec() + ); + + let deserialized = codec.decode(&mut bytes).unwrap().unwrap(); + assert_eq!(deserialized, resp_object); + } + + #[test] + fn test_nil_string() { + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&b"$-1\r\n"[..]); + + let mut codec = RespCodec; + let deserialized = codec.decode(&mut bytes).unwrap().unwrap(); + assert_eq!(deserialized, RespValue::Nil); + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 0000000..c4eed6b --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,19 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +#[macro_use] +pub(crate) mod resp; + +pub(crate) mod codec; + +#[cfg(feature = "tokio_codec")] +pub(crate) use codec::tokio::RespCodec; + +pub use resp::{FromResp, RespValue}; diff --git a/src/resp.rs b/src/protocol/resp.rs similarity index 56% rename from src/resp.rs rename to src/protocol/resp.rs index 2a5f678..aa9c8c4 100644 --- a/src/resp.rs +++ b/src/protocol/resp.rs @@ -12,15 +12,10 @@ use std::collections::HashMap; use std::hash::{BuildHasher, Hash}; -use std::io; use std::str; use std::sync::Arc; -use bytes::{Buf, BufMut, BytesMut}; - -use tokio_util::codec::{Decoder, Encoder}; - -use super::error::{self, Error}; +use crate::error::{self, Error}; /// A single RESP value, this owns the data that is read/to-be written to Redis. /// @@ -109,7 +104,6 @@ impl FromResp for String { fn from_resp_int(resp: RespValue) -> Result { match resp { RespValue::BulkString(ref bytes) => Ok(String::from_utf8_lossy(bytes).into_owned()), - RespValue::Integer(i) => Ok(i.to_string()), RespValue::SimpleString(string) => Ok(string), _ => Err(error::resp("Cannot convert into a string", resp)), } @@ -331,7 +325,7 @@ where macro_rules! resp_array { ($($e:expr),* $(,)?) => { { - $crate::resp::RespValue::Array(vec![ + $crate::protocol::RespValue::Array(vec![ $( $e.into(), )* @@ -413,341 +407,36 @@ macro_rules! integer_into_resp { }; } -impl ToRespInteger for usize { +impl ToRespInteger for i64 { fn to_resp_integer(self) -> RespValue { - RespValue::Integer(self as i64) - } -} -integer_into_resp!(usize); - -/// Codec to read frames -pub struct RespCodec; - -fn write_rn(buf: &mut BytesMut) { - buf.put_u8(b'\r'); - buf.put_u8(b'\n'); -} - -fn check_and_reserve(buf: &mut BytesMut, amt: usize) { - let remaining_bytes = buf.remaining_mut(); - if remaining_bytes < amt { - buf.reserve(amt); + RespValue::Integer(self) } } +integer_into_resp!(i64); -fn write_header(symb: u8, len: i64, buf: &mut BytesMut) { - let len_as_string = len.to_string(); - let len_as_bytes = len_as_string.as_bytes(); - let header_bytes = 1 + len_as_bytes.len() + 2; - check_and_reserve(buf, header_bytes); - buf.put_u8(symb); - buf.extend(len_as_bytes); - write_rn(buf); -} - -fn write_simple_string(symb: u8, string: &str, buf: &mut BytesMut) { - let bytes = string.as_bytes(); - let size = 1 + bytes.len() + 2; - check_and_reserve(buf, size); - buf.put_u8(symb); - buf.extend(bytes); - write_rn(buf); -} - -impl Encoder for RespCodec { - type Error = io::Error; - - fn encode(&mut self, msg: RespValue, buf: &mut BytesMut) -> Result<(), Self::Error> { - match msg { - RespValue::Nil => { - write_header(b'$', -1, buf); - } - RespValue::Array(ary) => { - write_header(b'*', ary.len() as i64, buf); - for v in ary { - self.encode(v, buf)?; - } - } - RespValue::BulkString(bstr) => { - let len = bstr.len(); - write_header(b'$', len as i64, buf); - check_and_reserve(buf, len + 2); - buf.extend(bstr); - write_rn(buf); - } - RespValue::Error(ref string) => { - write_simple_string(b'-', string, buf); - } - RespValue::Integer(val) => { - // Simple integer are just the header - write_header(b':', val, buf); - } - RespValue::SimpleString(ref string) => { - write_simple_string(b'+', string, buf); - } - } - Ok(()) - } -} - -#[inline] -fn parse_error(message: String) -> Error { - Error::RESP(message, None) -} - -/// Many RESP types have their length (which is either bytes or "number of elements", depending on context) -/// encoded as a string, terminated by "\r\n", this looks for them. -/// -/// Only return the string if the whole sequence is complete, including the terminator bytes (but those final -/// two bytes will not be returned) -/// -/// TODO - rename this function potentially, it's used for simple integers too -fn scan_integer(buf: &mut BytesMut, idx: usize) -> Result, Error> { - let length = buf.len(); - let mut at_end = false; - let mut pos = idx; - loop { - if length <= pos { - return Ok(None); - } - match (at_end, buf[pos]) { - (true, b'\n') => return Ok(Some((pos + 1, &buf[idx..pos - 1]))), - (false, b'\r') => at_end = true, - (false, b'0'..=b'9') => (), - (false, b'-') => (), - (_, val) => { - return Err(parse_error(format!( - "Unexpected byte in size_string: {}", - val - ))); - } - } - pos += 1; - } -} - -fn scan_string(buf: &mut BytesMut, idx: usize) -> Option<(usize, String)> { - let length = buf.len(); - let mut at_end = false; - let mut pos = idx; - loop { - if length <= pos { - return None; - } - match (at_end, buf[pos]) { - (true, b'\n') => { - let value = String::from_utf8_lossy(&buf[idx..pos - 1]).into_owned(); - return Some((pos + 1, value)); - } - (true, _) => at_end = false, - (false, b'\r') => at_end = true, - (false, _) => (), - } - pos += 1; - } -} - -fn decode_raw_integer(buf: &mut BytesMut, idx: usize) -> Result, Error> { - match scan_integer(buf, idx) { - Ok(None) => Ok(None), - Ok(Some((pos, int_str))) => { - // Redis integers are transmitted as strings, so we first convert the raw bytes into a string... - match str::from_utf8(int_str) { - Ok(string) => { - // ...and then parse the string. - match string.parse() { - Ok(int) => Ok(Some((pos, int))), - Err(_) => Err(parse_error(format!("Not an integer: {}", string))), - } - } - Err(_) => Err(parse_error(format!("Not a valid string: {:?}", int_str))), - } - } - Err(e) => Err(e), - } -} - -type DecodeResult = Result, Error>; - -fn decode_bulk_string(buf: &mut BytesMut, idx: usize) -> DecodeResult { - match decode_raw_integer(buf, idx) { - Ok(None) => Ok(None), - Ok(Some((pos, -1))) => Ok(Some((pos, RespValue::Nil))), - Ok(Some((pos, size))) if size >= 0 => { - let size = size as usize; - let remaining = buf.len() - pos; - let required_bytes = size + 2; - - if remaining < required_bytes { - return Ok(None); - } - - let bulk_string = RespValue::BulkString(buf[pos..(pos + size)].to_vec()); - Ok(Some((pos + required_bytes, bulk_string))) - } - Ok(Some((_, size))) => Err(parse_error(format!("Invalid string size: {}", size))), - Err(e) => Err(e), - } -} - -fn decode_array(buf: &mut BytesMut, idx: usize) -> DecodeResult { - match decode_raw_integer(buf, idx) { - Ok(None) => Ok(None), - Ok(Some((pos, -1))) => Ok(Some((pos, RespValue::Nil))), - Ok(Some((pos, size))) if size >= 0 => { - let size = size as usize; - let mut pos = pos; - let mut values = Vec::with_capacity(size); - for _ in 0..size { - match decode(buf, pos) { - Ok(None) => return Ok(None), - Ok(Some((new_pos, value))) => { - values.push(value); - pos = new_pos; - } - Err(e) => return Err(e), +macro_rules! impl_toresp_integers { + ($($int_ty:ident),* $(,)*) => { + $( + impl ToRespInteger for $int_ty { + fn to_resp_integer(self) -> RespValue { + let new_self = self as i64; + new_self.to_resp_integer() } } - Ok(Some((pos, RespValue::Array(values)))) - } - Ok(Some((_, size))) => Err(parse_error(format!("Invalid array size: {}", size))), - Err(e) => Err(e), - } -} - -fn decode_integer(buf: &mut BytesMut, idx: usize) -> DecodeResult { - match decode_raw_integer(buf, idx) { - Ok(None) => Ok(None), - Ok(Some((pos, int))) => Ok(Some((pos, RespValue::Integer(int)))), - Err(e) => Err(e), - } -} - -/// A simple string is any series of bytes that ends with `\r\n` -#[allow(clippy::unknown_clippy_lints, clippy::unnecessary_wraps)] -fn decode_simple_string(buf: &mut BytesMut, idx: usize) -> DecodeResult { - match scan_string(buf, idx) { - None => Ok(None), - Some((pos, string)) => Ok(Some((pos, RespValue::SimpleString(string)))), - } -} - -#[allow(clippy::unknown_clippy_lints, clippy::unnecessary_wraps)] -fn decode_error(buf: &mut BytesMut, idx: usize) -> DecodeResult { - match scan_string(buf, idx) { - None => Ok(None), - Some((pos, string)) => Ok(Some((pos, RespValue::Error(string)))), - } -} - -fn decode(buf: &mut BytesMut, idx: usize) -> DecodeResult { - let length = buf.len(); - if length <= idx { - return Ok(None); - } - - let first_byte = buf[idx]; - match first_byte { - b'$' => decode_bulk_string(buf, idx + 1), - b'*' => decode_array(buf, idx + 1), - b':' => decode_integer(buf, idx + 1), - b'+' => decode_simple_string(buf, idx + 1), - b'-' => decode_error(buf, idx + 1), - _ => Err(parse_error(format!("Unexpected byte: {}", first_byte))), - } + integer_into_resp!($int_ty); + )* + }; } -impl Decoder for RespCodec { - type Item = RespValue; - type Error = Error; - - fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { - match decode(buf, 0) { - Ok(None) => Ok(None), - Ok(Some((pos, item))) => { - buf.advance(pos); - Ok(Some(item)) - } - Err(e) => Err(e), - } - } -} +impl_toresp_integers!(isize, i32, u32); #[cfg(test)] mod tests { use std::collections::HashMap; - use bytes::BytesMut; - - use tokio_util::codec::{Decoder, Encoder}; - - use super::{Error, FromResp, RespCodec, RespValue}; - - fn obj_to_bytes(obj: RespValue) -> Vec { - let mut bytes = BytesMut::new(); - let mut codec = RespCodec; - codec.encode(obj, &mut bytes).unwrap(); - bytes.to_vec() - } - - #[test] - fn test_resp_array_macro() { - let resp_object = resp_array!["SET", "x"]; - let bytes = obj_to_bytes(resp_object); - assert_eq!(b"*2\r\n$3\r\nSET\r\n$1\r\nx\r\n", bytes.as_slice()); - - let resp_object = resp_array!["RPUSH", "wyz"].append(vec!["a", "b"]); - let bytes = obj_to_bytes(resp_object); - assert_eq!( - &b"*4\r\n$5\r\nRPUSH\r\n$3\r\nwyz\r\n$1\r\na\r\n$1\r\nb\r\n"[..], - bytes.as_slice() - ); - - let vals = vec![String::from("a"), String::from("b")]; - let resp_object = resp_array!["RPUSH", "xyz"].append(&vals); - let bytes = obj_to_bytes(resp_object); - assert_eq!( - &b"*4\r\n$5\r\nRPUSH\r\n$3\r\nxyz\r\n$1\r\na\r\n$1\r\nb\r\n"[..], - bytes.as_slice() - ); - } - - #[test] - fn test_bulk_string() { - let resp_object = RespValue::BulkString(b"THISISATEST".to_vec()); - let mut bytes = BytesMut::new(); - let mut codec = RespCodec; - codec.encode(resp_object.clone(), &mut bytes).unwrap(); - assert_eq!(b"$11\r\nTHISISATEST\r\n".to_vec(), bytes.to_vec()); - - let deserialized = codec.decode(&mut bytes).unwrap().unwrap(); - assert_eq!(deserialized, resp_object); - } + use crate::error::Error; - #[test] - fn test_array() { - let resp_object = RespValue::Array(vec!["TEST1".into(), "TEST2".into()]); - let mut bytes = BytesMut::new(); - let mut codec = RespCodec; - codec.encode(resp_object.clone(), &mut bytes).unwrap(); - assert_eq!( - b"*2\r\n$5\r\nTEST1\r\n$5\r\nTEST2\r\n".to_vec(), - bytes.to_vec() - ); - - let deserialized = codec.decode(&mut bytes).unwrap().unwrap(); - assert_eq!(deserialized, resp_object); - } - - #[test] - fn test_nil_string() { - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(&b"$-1\r\n"[..]); - - let mut codec = RespCodec; - let deserialized = codec.decode(&mut bytes).unwrap().unwrap(); - assert_eq!(deserialized, RespValue::Nil); - } + use super::{FromResp, RespValue}; #[test] fn test_integer_overflow() { diff --git a/src/reconnect.rs b/src/reconnect.rs deleted file mode 100644 index e61912d..0000000 --- a/src/reconnect.rs +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright 2018-2020 Ben Ashford - * - * Licensed under the Apache License, Version 2.0 or the MIT license - * , at your - * option. This file may not be copied, modified, or distributed - * except according to those terms. - */ - -use std::fmt; -use std::future::Future; -use std::mem; -use std::pin::Pin; -use std::sync::{Arc, Mutex, MutexGuard}; -use std::time::Duration; - -use futures_util::{ - future::{self, Either}, - TryFutureExt, -}; - -use tokio::time::timeout; - -use crate::error::{self, ConnectionReason}; - -type WorkFn = dyn Fn(&T, A) -> Result<(), error::Error> + Send + Sync; -type ConnFn = - dyn Fn() -> Pin> + Send + Sync>> + Send + Sync; - -struct ReconnectInner { - state: Mutex>, - work_fn: Box>, - conn_fn: Box>, -} - -pub(crate) struct Reconnect(Arc>); - -impl Clone for Reconnect { - fn clone(&self) -> Self { - Reconnect(self.0.clone()) - } -} - -impl fmt::Debug for Reconnect -where - T: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Reconnect") - .field("state", &self.0.state) - .field("work_fn", &String::from("REDACTED")) - .field("conn_fn", &String::from("REDACTED")) - .finish() - } -} - -pub(crate) async fn reconnect(w: W, c: C) -> Result, error::Error> -where - A: Send + 'static, - W: Fn(&T, A) -> Result<(), error::Error> + Send + Sync + 'static, - C: Fn() -> Pin> + Send + Sync>> - + Send - + Sync - + 'static, - T: Clone + Send + Sync + 'static, -{ - let r = Reconnect(Arc::new(ReconnectInner { - state: Mutex::new(ReconnectState::NotConnected), - - work_fn: Box::new(w), - conn_fn: Box::new(c), - })); - let rf = { - let state = r.0.state.lock().expect("Poisoned lock"); - r.reconnect(state) - }; - rf.await?; - Ok(r) -} - -enum ReconnectState { - NotConnected, - Connected(T), - ConnectionFailed(Mutex>), - Connecting, -} - -impl fmt::Debug for ReconnectState { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ReconnectState::")?; - match self { - NotConnected => write!(f, "NotConnected"), - Connected(_) => write!(f, "Connected"), - ConnectionFailed(_) => write!(f, "ConnectionFailed"), - Connecting => write!(f, "Connecting"), - } - } -} - -use self::ReconnectState::*; - -const CONNECTION_TIMEOUT_SECONDS: u64 = 10; -const CONNECTION_TIMEOUT: Duration = Duration::from_secs(CONNECTION_TIMEOUT_SECONDS); - -impl Reconnect -where - A: Send + 'static, - T: Clone + Send + Sync + 'static, -{ - fn call_work(&self, t: &T, a: A) -> Result { - if let Err(e) = (self.0.work_fn)(t, a) { - match e { - error::Error::IO(_) | error::Error::Unexpected(_) => { - log::error!("Error in work_fn will force connection closed, next command will attempt to re-establish connection: {}", e); - return Ok(false); - } - _ => (), - } - Err(e) - } else { - Ok(true) - } - } - - pub(crate) fn do_work(&self, a: A) -> Result<(), error::Error> { - let mut state = self.0.state.lock().expect("Cannot obtain read lock"); - match *state { - NotConnected => { - self.reconnect_spawn(state); - Err(error::Error::Connection(ConnectionReason::NotConnected)) - } - Connected(ref t) => { - let success = self.call_work(t, a)?; - if !success { - *state = NotConnected; - self.reconnect_spawn(state); - } - Ok(()) - } - ConnectionFailed(ref e) => { - let mut lock = e.lock().expect("Poisioned lock"); - let e = match lock.take() { - Some(e) => e, - None => error::Error::Connection(ConnectionReason::NotConnected), - }; - mem::drop(lock); - - *state = NotConnected; - self.reconnect_spawn(state); - Err(e) - } - Connecting => Err(error::Error::Connection(ConnectionReason::Connecting)), - } - } - - /// Returns a future that completes when the connection is established or failed to establish - /// used only for timing. - fn reconnect( - &self, - mut state: MutexGuard>, - ) -> impl Future> + Send { - log::info!("Attempting to reconnect, current state: {:?}", *state); - - match *state { - Connected(_) => { - return Either::Right(future::err(error::Error::Connection( - ConnectionReason::Connected, - ))); - } - Connecting => { - return Either::Right(future::err(error::Error::Connection( - ConnectionReason::Connecting, - ))); - } - NotConnected | ConnectionFailed(_) => (), - } - *state = ReconnectState::Connecting; - - mem::drop(state); - - let reconnect = self.clone(); - - let connection_f = async move { - let connection = match timeout(CONNECTION_TIMEOUT, (reconnect.0.conn_fn)()).await { - Ok(con_r) => con_r, - Err(_) => Err(error::internal(format!( - "Connection timed-out after {} seconds", - CONNECTION_TIMEOUT_SECONDS - ))), - }; - - let mut state = reconnect.0.state.lock().expect("Cannot obtain write lock"); - - match *state { - NotConnected | Connecting => match connection { - Ok(t) => { - log::info!("Connection established"); - *state = Connected(t); - Ok(()) - } - Err(e) => { - log::error!("Connection cannot be established: {}", e); - *state = ConnectionFailed(Mutex::new(Some(e))); - Err(error::Error::Connection(ConnectionReason::ConnectionFailed)) - } - }, - ConnectionFailed(_) => { - panic!("The connection state wasn't reset before connecting") - } - Connected(_) => panic!("A connected state shouldn't be attempting to reconnect"), - } - }; - - Either::Left(connection_f) - } - - fn reconnect_spawn(&self, state: MutexGuard>) { - let reconnect_f = self - .reconnect(state) - .map_err(|e| log::error!("Error asynchronously reconnecting: {}", e)); - - tokio::spawn(reconnect_f); - } -} diff --git a/src/task.rs b/src/task.rs new file mode 100644 index 0000000..7b962bb --- /dev/null +++ b/src/task.rs @@ -0,0 +1,29 @@ +/* + * Copyright 2020 Ben Ashford + * + * Licensed under the Apache License, Version 2.0 or the MIT license + * , at your + * option. This file may not be copied, modified, or distributed + * except according to those terms. + */ + +use std::future::Future; + +#[cfg(feature = "with_tokio")] +pub(crate) fn spawn(f: F) +where + F: Future + Send + 'static, + F::Output: Send, +{ + tokio::spawn(f); +} + +#[cfg(feature = "with_async_std")] +pub(crate) fn spawn(f: F) +where + F: Future + Send + 'static, + F::Output: Send, +{ + async_global_executor::spawn(f).detach() +}