Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion crates/dekaf/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,14 @@ pub struct Cli {
#[arg(long, env = "IDLE_SESSION_TIMEOUT", value_parser = humantime::parse_duration, default_value = "30s")]
idle_session_timeout: std::time::Duration,

/// How long to cache materialization specs and other task metadata for before re-refreshing
/// How long to cache materialization specs and other task metadata for before refreshing
#[arg(long, env = "TASK_REFRESH_INTERVAL", value_parser = humantime::parse_duration, default_value = "30s")]
task_refresh_interval: std::time::Duration,

/// How long before a request for materialization specs and other task metadata times out
#[arg(long, env = "TASK_REQUEST_TIMEOUT", value_parser = humantime::parse_duration, default_value = "30s")]
task_request_timeout: std::time::Duration,

/// Timeout for TLS handshake completion
#[arg(long, env = "TLS_HANDSHAKE_TIMEOUT", value_parser = humantime::parse_duration, default_value = "10s")]
tls_handshake_timeout: std::time::Duration,
Expand Down Expand Up @@ -223,6 +227,7 @@ async fn async_main(cli: Cli) -> anyhow::Result<()> {

let task_manager = Arc::new(TaskManager::new(
cli.task_refresh_interval,
cli.task_request_timeout,
client_base.clone(),
cli.data_plane_fqdn.clone(),
signing_token.clone(),
Expand Down
3 changes: 3 additions & 0 deletions crates/dekaf/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ impl Read {
if timeout_at > self.stream_exp {
timeout_at = self.stream_exp;
}
if timeout_at < now {
anyhow::bail!("Encountered a read stream with token expiring in the past. This should not happen, cancelling the read.");
}
tokio::time::Instant::now() + timeout_at.duration_since(now)?
};

Expand Down
171 changes: 151 additions & 20 deletions crates/dekaf/src/task_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ impl fmt::Display for SharedError {

pub type Result<T> = core::result::Result<T, SharedError>;

/// How long to keep a TaskManager alive without any listeners.
const TASK_TIMEOUT: Duration = Duration::from_secs(60 * 3);
/// How long before the end of an access token should we start trying to refresh it
const REFRESH_START_AT: Duration = Duration::from_secs(60 * 5);

#[derive(Clone)]
pub enum TaskState {
Expand Down Expand Up @@ -85,11 +88,24 @@ impl TaskStateListener {
// Scope to force the borrow to end before awaiting
{
let current_value = temp_rx.borrow_and_update();
if let Some(ref result) = *current_value {
return result
.as_ref()
.map(|arc| Arc::clone(arc))
.map_err(|e| anyhow::Error::from(e.clone()));
match &*current_value {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used to happily hand out expired tokens here, just assuming that the TaskManager loop could keep up. That's the real cause of the problem

Some(Ok(state)) => match state.as_ref() {
TaskState::Authorized {
access_token_claims,
..
} if access_token_claims.exp
<= time::OffsetDateTime::now_utc().unix_timestamp() as u64 =>
{
anyhow::bail!("Access token has expired and the task manager has been unable to refresh it.");
}
_ => return Ok(state.clone()),
},
Some(res) => {
return res.clone().map_err(anyhow::Error::from);
}
None => {
tracing::debug!("No task state available yet, waiting for the next update");
}
}
}

Expand Down Expand Up @@ -125,20 +141,23 @@ pub struct TaskManager {
>,
>,
interval: Duration,
timeout: Duration,
client: flow_client::Client,
data_plane_fqdn: String,
data_plane_signer: jsonwebtoken::EncodingKey,
}
impl TaskManager {
pub fn new(
interval: Duration,
timeout: Duration,
client: flow_client::Client,
data_plane_fqdn: String,
data_plane_signer: jsonwebtoken::EncodingKey,
) -> Self {
TaskManager {
tasks: std::sync::Mutex::new(HashMap::new()),
interval,
timeout,
client,
data_plane_fqdn,
data_plane_signer: data_plane_signer,
Expand Down Expand Up @@ -228,6 +247,8 @@ impl TaskManager {
let mut cached_ops_stats_client: Option<Result<(journal::Client, proto_gazette::Claims)>> =
None;

let mut cached_dekaf_auth: Option<DekafTaskAuth> = None;

let mut timeout_start = None;

loop {
Expand Down Expand Up @@ -258,21 +279,23 @@ impl TaskManager {
let mut has_been_migrated = false;

let loop_result: Result<()> = async {
// For the moment, let's just refresh this every tick in order to have relatively
// fresh MaterializationSpecs, even if the access token may live for a while.
let dekaf_auth = fetch_dekaf_task_auth(
let dekaf_auth = get_or_refresh_dekaf_auth(
cached_dekaf_auth.take(),
&self.client,
&task_name,
&self.data_plane_fqdn,
&self.data_plane_signer,
self.timeout,
)
.await
.context("error fetching dekaf task auth")?;
.context("error fetching or refreshing dekaf task auth")?;
cached_dekaf_auth = Some(dekaf_auth.clone());

match dekaf_auth {
DekafTaskAuth::Redirect {
target_dataplane_fqdn,
spec,
..
} => {
if !has_been_migrated {
has_been_migrated = true;
Expand Down Expand Up @@ -306,6 +329,7 @@ impl TaskManager {
&task_name,
&spec,
std::mem::take(&mut partitions_and_clients),
self.timeout,
)
.await?;

Expand All @@ -325,6 +349,7 @@ impl TaskManager {
cached_ops_logs_client
.as_ref()
.and_then(|r| r.as_ref().ok()),
self.timeout,
)
.await
.map_err(SharedError::from);
Expand All @@ -346,6 +371,7 @@ impl TaskManager {
cached_ops_stats_client
.as_ref()
.and_then(|r| r.as_ref().ok()),
self.timeout,
)
.await
.map_err(SharedError::from);
Expand All @@ -372,11 +398,11 @@ impl TaskManager {
.collect_vec(),
ops_logs_client: cached_ops_logs_client
.as_ref()
.expect("this is guarinteed to be present")
.expect("this is guaranteed to be present")
.clone(),
ops_stats_client: cached_ops_stats_client
.as_ref()
.expect("this is guarinteed to be present")
.expect("this is guaranteed to be present")
.clone(),
}))));

Expand Down Expand Up @@ -413,6 +439,7 @@ async fn update_partition_info(
task_name: &str,
spec: &MaterializationSpec,
mut info: HashMap<String, Result<(journal::Client, proto_gazette::Claims, Vec<Partition>)>>,
timeout: Duration,
) -> anyhow::Result<HashMap<String, Result<(journal::Client, proto_gazette::Claims, Vec<Partition>)>>>
{
let mut tasks = Vec::with_capacity(spec.bindings.len());
Expand Down Expand Up @@ -452,6 +479,7 @@ async fn update_partition_info(
exclude: None,
},
existing_client.as_ref(),
timeout
)
.await;

Expand Down Expand Up @@ -496,29 +524,55 @@ async fn get_or_refresh_journal_client(
capability: u32,
selector: broker::LabelSelector,
cached_client_and_claims: Option<&(journal::Client, proto_gazette::Claims)>,
timeout: Duration,
) -> anyhow::Result<(journal::Client, proto_gazette::Claims)> {
if let Some((cached_client, claims)) = cached_client_and_claims {
let now_unix = time::OffsetDateTime::now_utc().unix_timestamp();
// Add a buffer to token expiry check
if claims.exp > now_unix as u64 + 60 {
// Refresh the client if its token is closer than REFRESH_START_AT to its expiration.
let refresh_from = (claims.exp - REFRESH_START_AT.as_millis() as u64) as i64;
if now_unix < refresh_from {
tracing::debug!(task=%task_name, "Re-using existing journal client.");
return Ok((cached_client.clone(), claims.clone()));
} else {
tracing::debug!(task=%task_name, "Journal client token expired or nearing expiry.");
}
}

let timeouts_allowed_until = if let Some((client, claims)) = cached_client_and_claims {
// If we have a cached client, we can use its expiration time to determine how long we can wait for the new client to be fetched.
Some((claims.exp, client, claims))
} else {
None
};

tracing::debug!(task=%task_name, capability, "Fetching new task authorization for journal client.");
metrics::counter!("dekaf_fetch_auth", "endpoint" => "/authorize/task", "task_name" => task_name.to_owned()).increment(1);
flow_client::fetch_task_authorization(
flow_client,
&crate::dekaf_shard_template_id(task_name),
data_plane_fqdn,
data_plane_signer,
capability,
selector,
match tokio::time::timeout(
timeout,
flow_client::fetch_task_authorization(
flow_client,
&crate::dekaf_shard_template_id(task_name),
data_plane_fqdn,
data_plane_signer,
capability,
selector,
),
)
.await
{
Ok(resp) => resp,
Err(_) => {
if let Some((allowed_until, cached_client, cached_claims)) = timeouts_allowed_until {
if time::OffsetDateTime::now_utc().unix_timestamp() < allowed_until as i64 {
tracing::warn!(task=%task_name, allowed_until, "Timed out while fetching task authorization for journal client within acceptable retry window.");
return Ok((cached_client.clone(), cached_claims.clone()));
}
}
Err(anyhow::anyhow!(
"Timed out while fetching task authorization for journal client."
))
}
}
}

/// Fetch the journals of a collection and map into stable-order partitions.
Expand Down Expand Up @@ -560,14 +614,17 @@ pub async fn fetch_partitions(
// Claims returned by `/authorize/dekaf`
#[derive(Debug, Clone, serde::Deserialize)]
pub struct AccessTokenClaims {
pub iat: u64,
pub exp: u64,
}

#[derive(Debug, Clone)]
pub enum DekafTaskAuth {
/// Task has been migrated to a different dataplane, and the session should redirect to it.
Redirect {
target_dataplane_fqdn: String,
spec: MaterializationSpec,
fetched_at: time::OffsetDateTime,
},
/// Task authorization data.
Auth {
Expand All @@ -579,6 +636,79 @@ pub enum DekafTaskAuth {
},
}

impl DekafTaskAuth {
fn exp(&self) -> u64 {
match self {
DekafTaskAuth::Redirect { fetched_at, .. } => {
// Redirects are valid for 10 minutes
Copy link
Contributor Author

@jshearer jshearer Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just made this up. Since redirects don't get a token, they don't get an agent-api-specified expiration. 10 minutes seems more than fine, as the only time where this would matter is if a task were migrated such that it is no longer a redirect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 Does this mean that we'll now start to expire Redirects, where previously we haven't? Just wanting to double check whether that poses any risk for existing tasks that may not have encountered that before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used to re-fetch every 30s no matter what. Now we only fetch when the (made up) expiration is coming up. I don't think this substantively changes the behavior, just the amount of time a redirect response can be cached for

fetched_at.unix_timestamp() as u64 + 60 * 10
}
DekafTaskAuth::Auth { claims, .. } => claims.exp,
}
}
fn refresh_at(&self) -> u64 {
// Refresh the client if its token is closer than REFRESH_START_AT to its expiration.
self.exp() - REFRESH_START_AT.as_millis() as u64
}
}

async fn get_or_refresh_dekaf_auth(
cached: Option<DekafTaskAuth>,
client: &flow_client::Client,
shard_template_id: &str,
data_plane_fqdn: &str,
data_plane_signer: &jsonwebtoken::EncodingKey,
timeout: Duration,
) -> anyhow::Result<DekafTaskAuth> {
let now = time::OffsetDateTime::now_utc().unix_timestamp() as u64;

if let Some(cached_auth) = cached {
if now < cached_auth.refresh_at() {
tracing::debug!("DekafTaskAuth is still valid, no need to refresh.");
return Ok(cached_auth);
}

// Try to refresh, but fall back to cached if timeout and still valid
match tokio::time::timeout(
timeout,
fetch_dekaf_task_auth(
client,
shard_template_id,
data_plane_fqdn,
data_plane_signer,
),
)
.await
{
Ok(resp) => resp,
Err(_) => {
if time::OffsetDateTime::now_utc().unix_timestamp() < cached_auth.exp() as i64 {
tracing::warn!(
"Timed out while refreshing DekafTaskAuth, but the token is still valid."
);
return Ok(cached_auth);
}
anyhow::bail!(
"Timed out while refreshing DekafTaskAuth, and the token is expired."
);
}
}
} else {
// No cached value, fetch new one
tokio::time::timeout(
timeout,
fetch_dekaf_task_auth(
client,
shard_template_id,
data_plane_fqdn,
data_plane_signer,
),
)
.await
.map_err(|_| anyhow::anyhow!("Timed out while fetching dekaf task auth"))?
}
}

#[tracing::instrument(skip(client, data_plane_signer), err)]
async fn fetch_dekaf_task_auth(
client: &flow_client::Client,
Expand Down Expand Up @@ -647,6 +777,7 @@ async fn fetch_dekaf_task_auth(
return Ok(DekafTaskAuth::Redirect {
target_dataplane_fqdn: redirect_fqdn,
spec: parsed_spec,
fetched_at: time::OffsetDateTime::now_utc(),
});
}

Expand Down
1 change: 1 addition & 0 deletions crates/dekaf/src/topology.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{connector, utils, SessionAuthentication, TaskState};
use anyhow::{anyhow, bail, Context};
use futures::StreamExt;
use gazette::{
broker::{self, journal_spec, ReadResponse},
journal, uuid,
Expand Down
Loading