Skip to content
Merged
Prev Previous commit
Next Next commit
Cleanup user API.
- Remove extension ID from call signature.
- Make extension mutable.
- Add example showing how to implement an extension with a struct.

Signed-off-by: David Calavera <[email protected]>
  • Loading branch information
calavera committed Dec 3, 2021
commit af95f363070646c914d7a1c38620c2e591062ca4
4 changes: 2 additions & 2 deletions lambda-extension/examples/basic.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use lambda_extension::{extension_fn, Error, ExtensionId, NextEvent};
use lambda_extension::{extension_fn, Error, NextEvent};
use log::LevelFilter;
use simple_logger::SimpleLogger;

async fn my_extension(_extension_id: ExtensionId, event: NextEvent) -> Result<(), Error> {
async fn my_extension(event: NextEvent) -> Result<(), Error> {
match event {
NextEvent::Shutdown(_e) => {
// do something with the shutdown event
Expand Down
4 changes: 2 additions & 2 deletions lambda-extension/examples/custom_events.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use lambda_extension::{extension_fn, Error, ExtensionId, NextEvent, Runtime};
use lambda_extension::{extension_fn, Error, NextEvent, Runtime};
use log::LevelFilter;
use simple_logger::SimpleLogger;

async fn my_extension(_extension_id: ExtensionId, event: NextEvent) -> Result<(), Error> {
async fn my_extension(event: NextEvent) -> Result<(), Error> {
match event {
NextEvent::Shutdown(_e) => {
// do something with the shutdown event
Expand Down
34 changes: 34 additions & 0 deletions lambda-extension/examples/custom_trait_implementation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use lambda_extension::{run, Error, NextEvent, Extension};
use log::LevelFilter;
use simple_logger::SimpleLogger;
use std::future::{Future, ready};
use std::pin::Pin;

struct MyExtension {}

impl Extension for MyExtension
{
type Fut = Pin<Box<dyn Future<Output = Result<(), Error>>>>;
fn call(&mut self, event: NextEvent) -> Self::Fut {
match event {
NextEvent::Shutdown(_e) => {
// do something with the shutdown event
}
_ => {
// ignore any other event
// because we've registered the extension
// only to receive SHUTDOWN events
}
}
Box::pin(ready(Ok(())))
}
}

#[tokio::main]
async fn main() -> Result<(), Error> {
// required to enable CloudWatch error logging by the runtime
// can be replaced with any other method of initializing `log`
SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap();

run(MyExtension {}).await
}
15 changes: 7 additions & 8 deletions lambda-extension/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use tracing::trace;
pub mod requests;

pub type Error = lambda_runtime_api_client::Error;
pub type ExtensionId = String;

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -59,7 +58,7 @@ pub trait Extension {
/// Response of this Extension.
type Fut: Future<Output = Result<(), Error>>;
/// Handle the incoming event.
fn call(&self, extension_id: ExtensionId, event: NextEvent) -> Self::Fut;
fn call(&mut self, event: NextEvent) -> Self::Fut;
}

/// Returns a new [`ExtensionFn`] with the given closure.
Expand All @@ -79,17 +78,17 @@ pub struct ExtensionFn<F> {

impl<F, Fut> Extension for ExtensionFn<F>
where
F: Fn(ExtensionId, NextEvent) -> Fut,
F: Fn(NextEvent) -> Fut,
Fut: Future<Output = Result<(), Error>>,
{
type Fut = Fut;
fn call(&self, extension_id: ExtensionId, event: NextEvent) -> Self::Fut {
(self.f)(extension_id, event)
fn call(&mut self, event: NextEvent) -> Self::Fut {
(self.f)(event)
}
}

pub struct Runtime<C: Service<http::Uri> = HttpConnector> {
extension_id: ExtensionId,
extension_id: String,
client: Client<C>,
}

Expand All @@ -106,7 +105,7 @@ where
<C as Service<http::Uri>>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
<C as Service<http::Uri>>::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
{
pub async fn run(&self, extension: impl Extension) -> Result<(), Error> {
pub async fn run(&self, mut extension: impl Extension) -> Result<(), Error> {
let client = &self.client;

let incoming = async_stream::stream! {
Expand All @@ -129,7 +128,7 @@ where
let event: NextEvent = serde_json::from_slice(&body)?;
let is_invoke = event.is_invoke();

let res = extension.call(self.extension_id.clone(), event).await;
let res = extension.call(event).await;
if let Err(error) = res {
let req = if is_invoke {
requests::init_error(&self.extension_id, &error.to_string(), None)?
Expand Down