Skip to content

Commit 548237e

Browse files
committed
examples: add auth middleware example
1 parent db2d43d commit 548237e

File tree

1 file changed

+352
-0
lines changed

1 file changed

+352
-0
lines changed

iroh/examples/auth-middleware.rs

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
//! Implementation of authentication using iroh middlewares
2+
//!
3+
//! This implements an auth protocol that works with iroh middlewares.
4+
//! It allows to put authentication in front of iroh protocols. The protocols don't need any special support.
5+
//! Authentication is handled prior to establishing the connections, over a separate connection.
6+
7+
use iroh::{Endpoint, EndpointAddr, protocol::Router};
8+
use n0_error::{Result, StdResultExt};
9+
10+
use crate::echo::Echo;
11+
12+
#[tokio::main]
13+
async fn main() -> Result<()> {
14+
tracing_subscriber::fmt::init();
15+
let server_router = accept_side(b"secret!!").await?;
16+
server_router.endpoint().online().await;
17+
let server_addr = server_router.endpoint().addr();
18+
19+
println!("-- no --");
20+
let res = connect_side_no_auth(server_addr.clone()).await;
21+
println!("echo without auth: {:#}", res.unwrap_err());
22+
23+
println!("-- wrong --");
24+
let res = connect_side(server_addr.clone(), b"dunno").await;
25+
println!("echo with wrong auth: {:#}", res.unwrap_err());
26+
27+
println!("-- correct --");
28+
let res = connect_side(server_addr.clone(), b"secret!!").await;
29+
println!("echo with correct auth: {res:?}");
30+
31+
server_router.shutdown().await.anyerr()?;
32+
33+
Ok(())
34+
}
35+
36+
async fn connect_side(remote_addr: EndpointAddr, token: &[u8]) -> Result<()> {
37+
let (auth_middleware, auth_connector) = auth::connect(token.to_vec());
38+
let endpoint = Endpoint::builder()
39+
.middleware(auth_middleware)
40+
.bind()
41+
.await?;
42+
let _guard = auth_connector.spawn(endpoint.clone());
43+
Echo::connect(&endpoint, remote_addr, b"hello there!").await
44+
}
45+
46+
async fn connect_side_no_auth(remote_addr: EndpointAddr) -> Result<()> {
47+
let endpoint = Endpoint::bind().await?;
48+
Echo::connect(&endpoint, remote_addr, b"hello there!").await
49+
}
50+
51+
async fn accept_side(token: &[u8]) -> Result<Router> {
52+
let (auth_middleware, auth_protocol) = auth::accept(token.to_vec());
53+
let endpoint = Endpoint::builder()
54+
.middleware(auth_middleware)
55+
.bind()
56+
.await?;
57+
58+
let router = Router::builder(endpoint)
59+
.accept(auth::ALPN, auth_protocol)
60+
.accept(echo::ALPN, Echo)
61+
.spawn();
62+
63+
Ok(router)
64+
}
65+
66+
mod echo {
67+
//! A bare-bones protocol with no knowledge of auth whatsoever.
68+
69+
use iroh::{
70+
Endpoint, EndpointAddr,
71+
endpoint::Connection,
72+
protocol::{AcceptError, ProtocolHandler},
73+
};
74+
use n0_error::{Result, StdResultExt, anyerr};
75+
76+
#[derive(Debug, Clone)]
77+
pub struct Echo;
78+
79+
pub const ALPN: &[u8] = b"iroh-example/echo/0";
80+
81+
impl Echo {
82+
pub async fn connect(
83+
endpoint: &Endpoint,
84+
remote: impl Into<EndpointAddr>,
85+
message: &[u8],
86+
) -> Result<()> {
87+
let conn = endpoint.connect(remote, ALPN).await?;
88+
let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
89+
send.write_all(message).await.anyerr()?;
90+
send.finish().anyerr()?;
91+
let response = recv.read_to_end(1000).await.anyerr()?;
92+
conn.close(0u32.into(), b"bye!");
93+
if response == message {
94+
Ok(())
95+
} else {
96+
Err(anyerr!("Received invalid response"))
97+
}
98+
}
99+
}
100+
101+
impl ProtocolHandler for Echo {
102+
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
103+
let (mut send, mut recv) = connection.accept_bi().await?;
104+
tokio::io::copy(&mut recv, &mut send).await?;
105+
send.finish()?;
106+
connection.closed().await;
107+
Ok(())
108+
}
109+
}
110+
}
111+
112+
mod auth {
113+
//! Authentication middleware
114+
115+
use std::{
116+
collections::{HashMap, HashSet, hash_map},
117+
sync::{Arc, Mutex},
118+
};
119+
120+
use iroh::{
121+
Endpoint, EndpointAddr, EndpointId,
122+
endpoint::{AfterHandshakeOutcome, BeforeConnectOutcome, Connection, Middleware},
123+
protocol::{AcceptError, ProtocolHandler},
124+
};
125+
use n0_error::{AnyError, Result, StackResultExt, StdResultExt, anyerr};
126+
use n0_future::task::AbortOnDropHandle;
127+
use quinn::ConnectionError;
128+
use tokio::{
129+
sync::{mpsc, oneshot},
130+
task::JoinSet,
131+
};
132+
use tracing::debug;
133+
134+
pub const ALPN: &[u8] = b"iroh-example/auth/0";
135+
136+
const CLOSE_ACCEPTED: u32 = 1;
137+
const CLOSE_DENIED: u32 = 403;
138+
139+
/// Connect side: Use this if you want to pre-auth outgoing connections.
140+
pub fn connect(token: Vec<u8>) -> (AuthConnectMiddleware, AuthConnectTask) {
141+
let (tx, rx) = mpsc::channel(16);
142+
let middleware = AuthConnectMiddleware { tx };
143+
let connector = AuthConnectTask {
144+
token,
145+
rx,
146+
allowed_remotes: Default::default(),
147+
pending_remotes: Default::default(),
148+
tasks: JoinSet::new(),
149+
};
150+
(middleware, connector)
151+
}
152+
153+
/// Middleware to mount on the endpoint builder.
154+
#[derive(Debug)]
155+
pub struct AuthConnectMiddleware {
156+
tx: mpsc::Sender<(EndpointId, oneshot::Sender<Result<(), Arc<AnyError>>>)>,
157+
}
158+
159+
impl AuthConnectMiddleware {
160+
async fn authenticate(&self, remote_id: EndpointId) -> Result<()> {
161+
let (tx, rx) = oneshot::channel();
162+
self.tx
163+
.send((remote_id, tx))
164+
.await
165+
.std_context("authenticator stopped")?;
166+
rx.await
167+
.std_context("authenticator stopped")?
168+
.context("failed to authenticate")
169+
}
170+
}
171+
172+
impl Middleware for AuthConnectMiddleware {
173+
async fn before_connect<'a>(
174+
&'a self,
175+
remote_addr: &'a EndpointAddr,
176+
alpn: &'a [u8],
177+
) -> BeforeConnectOutcome {
178+
// Don't intercept auth request themsevles
179+
if alpn == ALPN {
180+
BeforeConnectOutcome::Accept
181+
} else {
182+
match self.authenticate(remote_addr.id).await {
183+
Ok(()) => BeforeConnectOutcome::Accept,
184+
Err(err) => {
185+
debug!("authentication denied: {err:#}");
186+
BeforeConnectOutcome::Reject
187+
}
188+
}
189+
}
190+
}
191+
}
192+
193+
/// Connector task that initiates pre-auth request. Call [`Self::spawn`] once the endpoint is built.
194+
pub struct AuthConnectTask {
195+
token: Vec<u8>,
196+
rx: mpsc::Receiver<(EndpointId, oneshot::Sender<Result<(), Arc<AnyError>>>)>,
197+
allowed_remotes: HashSet<EndpointId>,
198+
pending_remotes: HashMap<EndpointId, Vec<oneshot::Sender<Result<(), Arc<AnyError>>>>>,
199+
tasks: JoinSet<(EndpointId, Result<()>)>,
200+
}
201+
202+
impl AuthConnectTask {
203+
pub fn spawn(self, endpoint: Endpoint) -> AbortOnDropHandle<()> {
204+
AbortOnDropHandle::new(tokio::spawn(self.run(endpoint)))
205+
}
206+
207+
async fn run(mut self, endpoint: Endpoint) {
208+
loop {
209+
tokio::select! {
210+
msg = self.rx.recv() => {
211+
let Some((remote_id, tx)) = msg else {
212+
break;
213+
};
214+
self.handle_msg(&endpoint, remote_id, tx);
215+
}
216+
Some(res) = self.tasks.join_next(), if !self.tasks.is_empty() => {
217+
let (remote_id, res) = res.expect("connect task panicked");
218+
let res = res.map_err(Arc::new);
219+
self.handle_task(remote_id, res);
220+
}
221+
}
222+
}
223+
}
224+
225+
fn handle_msg(
226+
&mut self,
227+
endpoint: &Endpoint,
228+
remote_id: EndpointId,
229+
tx: oneshot::Sender<Result<(), Arc<AnyError>>>,
230+
) {
231+
if self.allowed_remotes.contains(&remote_id) {
232+
tx.send(Ok(())).ok();
233+
} else {
234+
match self.pending_remotes.entry(remote_id) {
235+
hash_map::Entry::Occupied(mut entry) => {
236+
entry.get_mut().push(tx);
237+
}
238+
hash_map::Entry::Vacant(entry) => {
239+
let endpoint = endpoint.clone();
240+
let token = self.token.clone();
241+
self.tasks.spawn(async move {
242+
let res = Self::connect(endpoint, remote_id, token).await;
243+
(remote_id, res)
244+
});
245+
entry.insert(vec![tx]);
246+
}
247+
}
248+
}
249+
}
250+
251+
fn handle_task(&mut self, remote_id: EndpointId, res: Result<(), Arc<AnyError>>) {
252+
if res.is_ok() {
253+
self.allowed_remotes.insert(remote_id);
254+
}
255+
let senders = self.pending_remotes.remove(&remote_id);
256+
for tx in senders.into_iter().flatten() {
257+
tx.send(res.clone()).ok();
258+
}
259+
}
260+
261+
async fn connect(endpoint: Endpoint, remote_id: EndpointId, token: Vec<u8>) -> Result<()> {
262+
let conn = endpoint.connect(remote_id, ALPN).await?;
263+
let mut stream = conn.open_uni().await.anyerr()?;
264+
stream.write_all(&token).await.anyerr()?;
265+
stream.finish().anyerr()?;
266+
let reason = conn.closed().await;
267+
if let ConnectionError::ApplicationClosed(code) = &reason
268+
&& code.error_code.into_inner() as u32 == CLOSE_ACCEPTED
269+
{
270+
Ok(())
271+
} else if let ConnectionError::ApplicationClosed(code) = &reason
272+
&& code.error_code.into_inner() as u32 == CLOSE_DENIED
273+
{
274+
Err(anyerr!("authentication denied by remote"))
275+
} else {
276+
Err(AnyError::from_std(reason))
277+
}
278+
}
279+
}
280+
281+
/// Accept side: Use this if you want to only accept connections from peers with successful pre-auth requests.
282+
pub fn accept(token: Vec<u8>) -> (AuthAcceptMiddleware, AuthProtocol) {
283+
let allowed_remotes: Arc<Mutex<HashSet<EndpointId>>> = Default::default();
284+
let middleware = AuthAcceptMiddleware {
285+
allowed_remotes: allowed_remotes.clone(),
286+
};
287+
let protocol = AuthProtocol {
288+
allowed_remotes,
289+
token,
290+
};
291+
(middleware, protocol)
292+
}
293+
294+
/// Accept-side auth middleware: Mount this onto the endpoint.
295+
///
296+
/// This will reject incoming connections if the remote did not successfully authenticate before.
297+
#[derive(Debug)]
298+
pub struct AuthAcceptMiddleware {
299+
allowed_remotes: Arc<Mutex<HashSet<EndpointId>>>,
300+
}
301+
302+
impl Middleware for AuthAcceptMiddleware {
303+
async fn after_handshake<'a>(
304+
&'a self,
305+
conn: &'a iroh::endpoint::ConnectionInfo,
306+
) -> AfterHandshakeOutcome {
307+
if conn.alpn() == ALPN
308+
|| self
309+
.allowed_remotes
310+
.lock()
311+
.expect("poisoned")
312+
.contains(conn.remote_id())
313+
{
314+
AfterHandshakeOutcome::Accept
315+
} else {
316+
AfterHandshakeOutcome::Reject {
317+
error_code: 403u32.into(),
318+
reason: b"not authenticated".to_vec(),
319+
}
320+
}
321+
}
322+
}
323+
324+
/// Accept-side auth protocol. Mount this on the router to accept authentication requests.
325+
#[derive(Debug, Clone)]
326+
pub struct AuthProtocol {
327+
token: Vec<u8>,
328+
allowed_remotes: Arc<Mutex<HashSet<EndpointId>>>,
329+
}
330+
331+
impl ProtocolHandler for AuthProtocol {
332+
/// The `accept` method is called for each incoming connection for our ALPN.
333+
///
334+
/// The returned future runs on a newly spawned tokio task, so it can run as long as
335+
/// the connection lasts.
336+
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
337+
let mut stream = connection.accept_uni().await?;
338+
let token = stream.read_to_end(256).await.anyerr()?;
339+
let remote_id = connection.remote_id();
340+
if token == self.token {
341+
self.allowed_remotes
342+
.lock()
343+
.expect("poisoned")
344+
.insert(remote_id);
345+
connection.close(CLOSE_ACCEPTED.into(), b"accepted");
346+
} else {
347+
connection.close(CLOSE_DENIED.into(), b"rejected");
348+
}
349+
Ok(())
350+
}
351+
}
352+
}

0 commit comments

Comments
 (0)