Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Test for too many workers + comments
  • Loading branch information
jthomson04 committed Jun 8, 2025
commit 6f825e55e667946ef53016f873a16778f7696b99
131 changes: 106 additions & 25 deletions lib/runtime/src/utils/leader_worker_barrier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
};
use serde::{de::DeserializeOwned, Serialize};

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -113,21 +113,24 @@ async fn create_barrier_key<T: Serialize>(
/// Creates a worker-specific key in etcd
async fn create_worker_key(
client: &Client,
key: String,
key: &str,
lease_id: Option<i64>,
) -> Result<(), LeaderWorkerBarrierError> {
client
.kv_create(key, serde_json::to_vec(&()).unwrap(), lease_id)
.kv_create(key.to_owned(), serde_json::to_vec(&()).unwrap(), lease_id)
.await
.map_err(|_| LeaderWorkerBarrierError::BarrierWorkerIdNotUnique)?;

Ok(())
}

/// Waits for a single key to appear (used for completion/abort signals)
async fn wait_for_signal(client: &Client, key: String) -> Result<(), LeaderWorkerBarrierError> {
wait_for_key_count::<()>(client, key, 1, None).await?;
Ok(())
async fn wait_for_signal<T: DeserializeOwned>(
client: &Client,
key: String,
) -> Result<T, LeaderWorkerBarrierError> {
let data = wait_for_key_count::<T>(client, key, 1, None).await?;
Ok(data.into_values().next().unwrap())
}

#[derive(Debug)]
Expand All @@ -139,8 +142,10 @@ pub enum LeaderWorkerBarrierError {
SerdeError(serde_json::Error),
Timeout,
Aborted,
AlreadyCompleted,
}

/// A barrier for a leader to wait for a specific number of workers to join.
pub struct LeaderBarrier<T> {
barrier_id: String,
num_workers: usize,
Expand All @@ -158,6 +163,10 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
}
}

/// Synchronize the leader with the workers.
///
/// The leader will publish the barrier data, and the workers will wait for the barrier data to appear.
/// The leader will then signal completion or abort, and the workers will wait for the signal to appear.
pub async fn sync(
self,
rt: &DistributedRuntime,
Expand All @@ -177,7 +186,7 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
let worker_result = self.wait_for_workers(&etcd_client).await;

// Signal completion or abort
self.signal_completion(&etcd_client, worker_result.is_ok(), lease_id)
self.signal_completion(&etcd_client, &worker_result, lease_id)
.await?;

worker_result.map(|_| ())
Expand All @@ -193,28 +202,34 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
create_barrier_key(client, key, data, Some(lease_id)).await
}

async fn wait_for_workers(&self, client: &Client) -> Result<(), LeaderWorkerBarrierError> {
async fn wait_for_workers(
&self,
client: &Client,
) -> Result<HashSet<String>, LeaderWorkerBarrierError> {
let key = barrier_key(&self.barrier_id, BARRIER_WORKER);
wait_for_key_count::<()>(client, key, self.num_workers, self.timeout).await?;
Ok(())
let workers = wait_for_key_count::<()>(client, key, self.num_workers, self.timeout).await?;
Ok(workers.into_keys().collect())
}

async fn signal_completion(
&self,
client: &Client,
success: bool,
worker_result: &Result<HashSet<String>, LeaderWorkerBarrierError>,
lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> {
let suffix = if success {
BARRIER_COMPLETE
if let Ok(worker_result) = worker_result {
let key = barrier_key(&self.barrier_id, BARRIER_COMPLETE);
create_barrier_key(client, key, worker_result, Some(lease_id)).await?;
} else {
BARRIER_ABORT
};
let key = barrier_key(&self.barrier_id, suffix);
create_barrier_key(client, key, (), Some(lease_id)).await
let key = barrier_key(&self.barrier_id, BARRIER_ABORT);
create_barrier_key(client, key, (), Some(lease_id)).await?;
}

Ok(())
}
}

// A barrier to synchronize a worker with a leader.
pub struct WorkerBarrier<T> {
barrier_id: String,
worker_id: String,
Expand All @@ -230,6 +245,13 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
}
}

/// Synchronize the worker with the leader.
///
/// The worker will wait for the barrier data to appear, and then register as a worker.
/// The worker will then wait for the completion or abort signal to appear.
///
/// If the leader signals completion, the worker will return the barrier data.
/// If the leader signals abort, the worker will return an error.
pub async fn sync(
self,
rt: &DistributedRuntime,
Expand All @@ -244,10 +266,10 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
let barrier_data = self.get_barrier_data(&etcd_client).await?;

// Register as a worker
self.register_worker(&etcd_client, lease_id).await?;
let worker_key = self.register_worker(&etcd_client, lease_id).await?;

// Wait for completion or abort signal
self.wait_for_completion(&etcd_client).await?;
self.wait_for_completion(&etcd_client, worker_key).await?;

Ok(barrier_data)
}
Expand All @@ -261,7 +283,7 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
result?.into_values().next()
.ok_or(LeaderWorkerBarrierError::EtcdError(anyhow::anyhow!("No data found")))
}
_ = wait_for_signal(client, abort_key) => {
_ = wait_for_signal::<()>(client, abort_key) => {
Err(LeaderWorkerBarrierError::Aborted)
}
}
Expand All @@ -271,21 +293,33 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
&self,
client: &Client,
lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> {
) -> Result<String, LeaderWorkerBarrierError> {
let key = barrier_key(
&self.barrier_id,
&format!("{}/{}", BARRIER_WORKER, self.worker_id),
);
create_worker_key(client, key, Some(lease_id)).await
create_worker_key(client, &key, Some(lease_id))
.await
.map(|_| key)
}

async fn wait_for_completion(&self, client: &Client) -> Result<(), LeaderWorkerBarrierError> {
async fn wait_for_completion(
&self,
client: &Client,
worker_key: String,
) -> Result<(), LeaderWorkerBarrierError> {
let complete_key = barrier_key(&self.barrier_id, BARRIER_COMPLETE);
let abort_key = barrier_key(&self.barrier_id, BARRIER_ABORT);

tokio::select! {
_ = wait_for_signal(client, complete_key) => Ok(()),
_ = wait_for_signal(client, abort_key) => Err(LeaderWorkerBarrierError::Aborted),
Ok(workers) = wait_for_signal::<HashSet<String>>(client, complete_key) => {
if workers.contains(&worker_key) {
Ok(())
} else {
Err(LeaderWorkerBarrierError::AlreadyCompleted)
}
},
_ = wait_for_signal::<()>(client, abort_key) => Err(LeaderWorkerBarrierError::Aborted),
}
}
}
Expand Down Expand Up @@ -523,4 +557,51 @@ mod tests {
assert!(matches!(leader_res, Ok(Ok(_))));
assert!(matches!(worker_res, Ok(Ok(_))));
}

#[tokio::test]
async fn test_too_many_workers() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::from_settings(rt.clone()).await.unwrap();

let id = unique_id();

let leader = LeaderBarrier::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<()>::new(id.clone(), "worker2".to_string());

let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
leader.sync(&drt_clone, &()).await?;
Ok(())
});

let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
let drt_clone = drt.clone();
let worker1_join = tokio::spawn(async move { worker1.sync(&drt_clone).await });

let worker2_join = tokio::spawn(async move { worker2.sync(&drt).await });

let (worker1_res, worker2_res) = tokio::join!(worker1_join, worker2_join);

let mut num_successes = 0;
for worker_res in [worker1_res, worker2_res] {
if let Ok(Ok(_)) = worker_res {
num_successes += 1;
} else if let Ok(Err(LeaderWorkerBarrierError::AlreadyCompleted)) = worker_res {
} else {
panic!();
}
}

assert_eq!(num_successes, 1);
Ok(())
});

let (leader_res, worker_res) = tokio::join!(leader_join, worker_join);

assert!(matches!(leader_res, Ok(Ok(_))));
assert!(matches!(worker_res, Ok(Ok(_))));
}
}