Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 182 additions & 20 deletions lib/runtime/src/storage/key_value_store/etcd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::time::Duration;
use crate::{slug::Slug, transports::etcd::Client};
use async_stream::stream;
use async_trait::async_trait;
use etcd_client::{EventType, PutOptions, WatchOptions};
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};

use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};

Expand Down Expand Up @@ -158,31 +158,44 @@ impl EtcdBucket {
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd create: {k}");

// Does it already exists? For 'create' it shouldn't.
let kvs = self
.client
.kv_get(k.clone(), None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
if !kvs.is_empty() {
let version = kvs.first().unwrap().version();
return Ok(StorageOutcome::Exists(version as u64));
}
// Use atomic transaction to check and create in one operation
let put_options = PutOptions::new();

// Write it
let mut put_resp = self
// Build transaction that creates key only if it doesn't exist
let txn = Txn::new()
.when(vec![Compare::version(k.as_str(), CompareOp::Equal, 0)]) // Atomic check
.and_then(vec![TxnOp::put(k.as_str(), value, Some(put_options))]) // Only if check passes
.or_else(vec![
TxnOp::get(k.as_str(), None), // Key exists, get its info
]);

// Execute the transaction
let result = self
.client
.kv_put_with_options(k, value, Some(PutOptions::new().with_prev_key()))
.etcd_client()
.kv_client()
.txn(txn)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
// Check if we overwrite something
if put_resp.take_prev_key().is_some() {
// Key created between our get and put
return Err(StorageError::Retry);

if result.succeeded() {
// Key was created successfully
return Ok(StorageOutcome::Created(1)); // version of new key is always 1
}

// version of a new key is always 1
Ok(StorageOutcome::Created(1))
// Key already existed, get its version
if let Some(etcd_client::TxnOpResponse::Get(get_resp)) =
result.op_responses().into_iter().next()
{
if let Some(kv) = get_resp.kvs().first() {
let version = kv.version() as u64;
return Ok(StorageOutcome::Exists(version));
}
}
// Shouldn't happen, but handle edge case
Err(StorageError::EtcdError(
"Unexpected transaction response".to_string(),
))
}

async fn update(
Expand Down Expand Up @@ -241,3 +254,152 @@ fn make_key(bucket_name: &str, key: &str) -> String {
]
.join("/")
}

#[cfg(feature = "integration")]
#[cfg(test)]
mod concurrent_create_tests {
use super::*;
use crate::{distributed::DistributedConfig, DistributedRuntime, Runtime};
use std::sync::Arc;
use tokio::sync::Barrier;

#[test]
fn test_concurrent_etcd_create_race_condition() {
let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone();
let config = DistributedConfig::from_settings(false);

rt_clone.primary().block_on(async move {
let drt = DistributedRuntime::new(rt, config).await.unwrap();
test_concurrent_create(drt).await.unwrap();
});
}

async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StorageError> {
let etcd_client = drt.etcd_client().expect("etcd client should be available");
let storage = EtcdStorage::new(etcd_client);

// Create a bucket for testing
let bucket = Arc::new(tokio::sync::Mutex::new(
storage
.get_or_create_bucket("test_concurrent_bucket", None)
.await?,
));

// Number of concurrent workers
let num_workers = 10;
let barrier = Arc::new(Barrier::new(num_workers));

// Shared test data
let test_key = format!("concurrent_test_key_{}", uuid::Uuid::new_v4());
let test_value = "test_value";

// Spawn multiple tasks that will all try to create the same key simultaneously
let mut handles = Vec::new();
let success_count = Arc::new(tokio::sync::Mutex::new(0));
let exists_count = Arc::new(tokio::sync::Mutex::new(0));

for worker_id in 0..num_workers {
let bucket_clone = bucket.clone();
let barrier_clone = barrier.clone();
let key_clone = test_key.clone();
let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
let success_count_clone = success_count.clone();
let exists_count_clone = exists_count.clone();

let handle = tokio::spawn(async move {
// Wait for all workers to be ready
barrier_clone.wait().await;

// All workers try to create the same key at the same time
let result = bucket_clone
.lock()
.await
.insert(key_clone, value_clone, 0)
.await;

match result {
Ok(StorageOutcome::Created(version)) => {
println!(
"Worker {} successfully created key with version {}",
worker_id, version
);
let mut count = success_count_clone.lock().await;
*count += 1;
Ok(version)
}
Ok(StorageOutcome::Exists(version)) => {
println!(
"Worker {} found key already exists with version {}",
worker_id, version
);
let mut count = exists_count_clone.lock().await;
*count += 1;
Ok(version)
}
Err(e) => {
println!("Worker {} got error: {:?}", worker_id, e);
Err(e)
}
}
});

handles.push(handle);
}

// Wait for all workers to complete
let mut results = Vec::new();
for handle in handles {
let result = handle.await.unwrap();
if let Ok(version) = result {
results.push(version);
}
}

// Verify results
let final_success_count = *success_count.lock().await;
let final_exists_count = *exists_count.lock().await;

println!(
"Final counts - Created: {}, Exists: {}",
final_success_count, final_exists_count
);

// CRITICAL ASSERTIONS:
// 1. Exactly ONE worker should have successfully created the key
assert_eq!(
final_success_count, 1,
"Exactly one worker should create the key"
);

// 2. All other workers should have gotten "Exists" response
assert_eq!(
final_exists_count,
num_workers - 1,
"All other workers should see key exists"
);

// 3. Total successful operations should equal number of workers
assert_eq!(
results.len(),
num_workers,
"All workers should complete successfully"
);

// 4. Verify the key actually exists in etcd
let stored_value = bucket.lock().await.get(&test_key).await?;
assert!(stored_value.is_some(), "Key should exist in etcd");

// 5. The stored value should be from one of the workers
let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
assert!(
stored_str.starts_with(test_value),
"Stored value should match expected prefix"
);

// Clean up
bucket.lock().await.delete(&test_key).await?;

Ok(())
}
}
Loading