-
Notifications
You must be signed in to change notification settings - Fork 82
dekaf: Improve timeout handling when refreshing tokens #2348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,7 +45,10 @@ impl fmt::Display for SharedError { | |
|
|
||
| pub type Result<T> = core::result::Result<T, SharedError>; | ||
|
|
||
| /// How long to keep a TaskManager alive without any listeners. | ||
| const TASK_TIMEOUT: Duration = Duration::from_secs(60 * 3); | ||
| /// How long before the end of an access token should we start trying to refresh it | ||
| const REFRESH_START_AT: Duration = Duration::from_secs(60 * 5); | ||
|
|
||
| #[derive(Clone)] | ||
| pub enum TaskState { | ||
|
|
@@ -85,11 +88,24 @@ impl TaskStateListener { | |
| // Scope to force the borrow to end before awaiting | ||
| { | ||
| let current_value = temp_rx.borrow_and_update(); | ||
| if let Some(ref result) = *current_value { | ||
| return result | ||
| .as_ref() | ||
| .map(|arc| Arc::clone(arc)) | ||
| .map_err(|e| anyhow::Error::from(e.clone())); | ||
| match &*current_value { | ||
| Some(Ok(state)) => match state.as_ref() { | ||
| TaskState::Authorized { | ||
| access_token_claims, | ||
| .. | ||
| } if access_token_claims.exp | ||
| <= time::OffsetDateTime::now_utc().unix_timestamp() as u64 => | ||
| { | ||
| anyhow::bail!("Access token has expired and the task manager has been unable to refresh it."); | ||
| } | ||
| _ => return Ok(state.clone()), | ||
| }, | ||
| Some(res) => { | ||
| return res.clone().map_err(anyhow::Error::from); | ||
| } | ||
| None => { | ||
| tracing::debug!("No task state available yet, waiting for the next update"); | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -125,20 +141,23 @@ pub struct TaskManager { | |
| >, | ||
| >, | ||
| interval: Duration, | ||
| timeout: Duration, | ||
| client: flow_client::Client, | ||
| data_plane_fqdn: String, | ||
| data_plane_signer: jsonwebtoken::EncodingKey, | ||
| } | ||
| impl TaskManager { | ||
| pub fn new( | ||
| interval: Duration, | ||
| timeout: Duration, | ||
| client: flow_client::Client, | ||
| data_plane_fqdn: String, | ||
| data_plane_signer: jsonwebtoken::EncodingKey, | ||
| ) -> Self { | ||
| TaskManager { | ||
| tasks: std::sync::Mutex::new(HashMap::new()), | ||
| interval, | ||
| timeout, | ||
| client, | ||
| data_plane_fqdn, | ||
| data_plane_signer: data_plane_signer, | ||
|
|
@@ -228,6 +247,8 @@ impl TaskManager { | |
| let mut cached_ops_stats_client: Option<Result<(journal::Client, proto_gazette::Claims)>> = | ||
| None; | ||
|
|
||
| let mut cached_dekaf_auth: Option<DekafTaskAuth> = None; | ||
|
|
||
| let mut timeout_start = None; | ||
|
|
||
| loop { | ||
|
|
@@ -258,21 +279,23 @@ impl TaskManager { | |
| let mut has_been_migrated = false; | ||
|
|
||
| let loop_result: Result<()> = async { | ||
| // For the moment, let's just refresh this every tick in order to have relatively | ||
| // fresh MaterializationSpecs, even if the access token may live for a while. | ||
| let dekaf_auth = fetch_dekaf_task_auth( | ||
| let dekaf_auth = get_or_refresh_dekaf_auth( | ||
| cached_dekaf_auth.take(), | ||
| &self.client, | ||
| &task_name, | ||
| &self.data_plane_fqdn, | ||
| &self.data_plane_signer, | ||
| self.timeout, | ||
| ) | ||
| .await | ||
| .context("error fetching dekaf task auth")?; | ||
| .context("error fetching or refreshing dekaf task auth")?; | ||
| cached_dekaf_auth = Some(dekaf_auth.clone()); | ||
|
|
||
| match dekaf_auth { | ||
| DekafTaskAuth::Redirect { | ||
| target_dataplane_fqdn, | ||
| spec, | ||
| .. | ||
| } => { | ||
| if !has_been_migrated { | ||
| has_been_migrated = true; | ||
|
|
@@ -306,6 +329,7 @@ impl TaskManager { | |
| &task_name, | ||
| &spec, | ||
| std::mem::take(&mut partitions_and_clients), | ||
| self.timeout, | ||
| ) | ||
| .await?; | ||
|
|
||
|
|
@@ -325,6 +349,7 @@ impl TaskManager { | |
| cached_ops_logs_client | ||
| .as_ref() | ||
| .and_then(|r| r.as_ref().ok()), | ||
| self.timeout, | ||
| ) | ||
| .await | ||
| .map_err(SharedError::from); | ||
|
|
@@ -346,6 +371,7 @@ impl TaskManager { | |
| cached_ops_stats_client | ||
| .as_ref() | ||
| .and_then(|r| r.as_ref().ok()), | ||
| self.timeout, | ||
| ) | ||
| .await | ||
| .map_err(SharedError::from); | ||
|
|
@@ -372,11 +398,11 @@ impl TaskManager { | |
| .collect_vec(), | ||
| ops_logs_client: cached_ops_logs_client | ||
| .as_ref() | ||
| .expect("this is guarinteed to be present") | ||
| .expect("this is guaranteed to be present") | ||
| .clone(), | ||
| ops_stats_client: cached_ops_stats_client | ||
| .as_ref() | ||
| .expect("this is guarinteed to be present") | ||
| .expect("this is guaranteed to be present") | ||
| .clone(), | ||
| })))); | ||
|
|
||
|
|
@@ -413,6 +439,7 @@ async fn update_partition_info( | |
| task_name: &str, | ||
| spec: &MaterializationSpec, | ||
| mut info: HashMap<String, Result<(journal::Client, proto_gazette::Claims, Vec<Partition>)>>, | ||
| timeout: Duration, | ||
| ) -> anyhow::Result<HashMap<String, Result<(journal::Client, proto_gazette::Claims, Vec<Partition>)>>> | ||
| { | ||
| let mut tasks = Vec::with_capacity(spec.bindings.len()); | ||
|
|
@@ -452,6 +479,7 @@ async fn update_partition_info( | |
| exclude: None, | ||
| }, | ||
| existing_client.as_ref(), | ||
| timeout | ||
| ) | ||
| .await; | ||
|
|
||
|
|
@@ -496,29 +524,55 @@ async fn get_or_refresh_journal_client( | |
| capability: u32, | ||
| selector: broker::LabelSelector, | ||
| cached_client_and_claims: Option<&(journal::Client, proto_gazette::Claims)>, | ||
| timeout: Duration, | ||
| ) -> anyhow::Result<(journal::Client, proto_gazette::Claims)> { | ||
| if let Some((cached_client, claims)) = cached_client_and_claims { | ||
| let now_unix = time::OffsetDateTime::now_utc().unix_timestamp(); | ||
| // Add a buffer to token expiry check | ||
| if claims.exp > now_unix as u64 + 60 { | ||
| // Refresh the client if its token is closer than REFRESH_START_AT to its expiration. | ||
| let refresh_from = (claims.exp - REFRESH_START_AT.as_millis() as u64) as i64; | ||
| if now_unix < refresh_from { | ||
| tracing::debug!(task=%task_name, "Re-using existing journal client."); | ||
| return Ok((cached_client.clone(), claims.clone())); | ||
| } else { | ||
| tracing::debug!(task=%task_name, "Journal client token expired or nearing expiry."); | ||
| } | ||
| } | ||
|
|
||
| let timeouts_allowed_until = if let Some((client, claims)) = cached_client_and_claims { | ||
| // If we have a cached client, we can use its expiration time to determine how long we can wait for the new client to be fetched. | ||
| Some((claims.exp, client, claims)) | ||
| } else { | ||
| None | ||
| }; | ||
|
|
||
| tracing::debug!(task=%task_name, capability, "Fetching new task authorization for journal client."); | ||
| metrics::counter!("dekaf_fetch_auth", "endpoint" => "/authorize/task", "task_name" => task_name.to_owned()).increment(1); | ||
| flow_client::fetch_task_authorization( | ||
| flow_client, | ||
| &crate::dekaf_shard_template_id(task_name), | ||
| data_plane_fqdn, | ||
| data_plane_signer, | ||
| capability, | ||
| selector, | ||
| match tokio::time::timeout( | ||
| timeout, | ||
| flow_client::fetch_task_authorization( | ||
| flow_client, | ||
| &crate::dekaf_shard_template_id(task_name), | ||
| data_plane_fqdn, | ||
| data_plane_signer, | ||
| capability, | ||
| selector, | ||
| ), | ||
| ) | ||
| .await | ||
| { | ||
| Ok(resp) => resp, | ||
| Err(_) => { | ||
| if let Some((allowed_until, cached_client, cached_claims)) = timeouts_allowed_until { | ||
jshearer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if time::OffsetDateTime::now_utc().unix_timestamp() < allowed_until as i64 { | ||
| tracing::warn!(task=%task_name, allowed_until, "Timed out while fetching task authorization for journal client within acceptable retry window."); | ||
| return Ok((cached_client.clone(), cached_claims.clone())); | ||
| } | ||
| } | ||
| Err(anyhow::anyhow!( | ||
| "Timed out while fetching task authorization for journal client." | ||
| )) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /// Fetch the journals of a collection and map into stable-order partitions. | ||
|
|
@@ -560,14 +614,17 @@ pub async fn fetch_partitions( | |
| // Claims returned by `/authorize/dekaf` | ||
| #[derive(Debug, Clone, serde::Deserialize)] | ||
| pub struct AccessTokenClaims { | ||
| pub iat: u64, | ||
| pub exp: u64, | ||
| } | ||
|
|
||
| #[derive(Debug, Clone)] | ||
| pub enum DekafTaskAuth { | ||
| /// Task has been migrated to a different dataplane, and the session should redirect to it. | ||
| Redirect { | ||
| target_dataplane_fqdn: String, | ||
| spec: MaterializationSpec, | ||
| fetched_at: time::OffsetDateTime, | ||
| }, | ||
| /// Task authorization data. | ||
| Auth { | ||
|
|
@@ -579,6 +636,79 @@ pub enum DekafTaskAuth { | |
| }, | ||
| } | ||
|
|
||
| impl DekafTaskAuth { | ||
| fn exp(&self) -> u64 { | ||
| match self { | ||
| DekafTaskAuth::Redirect { fetched_at, .. } => { | ||
| // Redirects are valid for 10 minutes | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just made this up. Since redirects don't get a token, they don't get an
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🤔 Does this mean that we'll now start to expire
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We used to re-fetch every 30s no matter what. Now we only fetch when the (made up) expiration is coming up. I don't think this substantively changes the behavior, just the amount of time a redirect response can be cached for |
||
| fetched_at.unix_timestamp() as u64 + 60 * 10 | ||
| } | ||
| DekafTaskAuth::Auth { claims, .. } => claims.exp, | ||
| } | ||
| } | ||
| fn refresh_at(&self) -> u64 { | ||
| // Refresh the client if its token is closer than REFRESH_START_AT to its expiration. | ||
| self.exp() - REFRESH_START_AT.as_millis() as u64 | ||
| } | ||
| } | ||
|
|
||
| async fn get_or_refresh_dekaf_auth( | ||
| cached: Option<DekafTaskAuth>, | ||
| client: &flow_client::Client, | ||
| shard_template_id: &str, | ||
| data_plane_fqdn: &str, | ||
| data_plane_signer: &jsonwebtoken::EncodingKey, | ||
| timeout: Duration, | ||
| ) -> anyhow::Result<DekafTaskAuth> { | ||
| let now = time::OffsetDateTime::now_utc().unix_timestamp() as u64; | ||
|
|
||
| if let Some(cached_auth) = cached { | ||
| if now < cached_auth.refresh_at() { | ||
| tracing::debug!("DekafTaskAuth is still valid, no need to refresh."); | ||
| return Ok(cached_auth); | ||
| } | ||
|
|
||
| // Try to refresh, but fall back to cached if timeout and still valid | ||
| match tokio::time::timeout( | ||
| timeout, | ||
| fetch_dekaf_task_auth( | ||
| client, | ||
| shard_template_id, | ||
| data_plane_fqdn, | ||
| data_plane_signer, | ||
| ), | ||
| ) | ||
| .await | ||
| { | ||
| Ok(resp) => resp, | ||
| Err(_) => { | ||
| if time::OffsetDateTime::now_utc().unix_timestamp() < cached_auth.exp() as i64 { | ||
| tracing::warn!( | ||
| "Timed out while refreshing DekafTaskAuth, but the token is still valid." | ||
| ); | ||
| return Ok(cached_auth); | ||
| } | ||
| anyhow::bail!( | ||
| "Timed out while refreshing DekafTaskAuth, and the token is expired." | ||
| ); | ||
| } | ||
| } | ||
| } else { | ||
| // No cached value, fetch new one | ||
| tokio::time::timeout( | ||
| timeout, | ||
| fetch_dekaf_task_auth( | ||
| client, | ||
| shard_template_id, | ||
| data_plane_fqdn, | ||
| data_plane_signer, | ||
| ), | ||
| ) | ||
| .await | ||
| .map_err(|_| anyhow::anyhow!("Timed out while fetching dekaf task auth"))? | ||
| } | ||
| } | ||
|
|
||
| #[tracing::instrument(skip(client, data_plane_signer), err)] | ||
| async fn fetch_dekaf_task_auth( | ||
| client: &flow_client::Client, | ||
|
|
@@ -647,6 +777,7 @@ async fn fetch_dekaf_task_auth( | |
| return Ok(DekafTaskAuth::Redirect { | ||
| target_dataplane_fqdn: redirect_fqdn, | ||
| spec: parsed_spec, | ||
| fetched_at: time::OffsetDateTime::now_utc(), | ||
| }); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We used to happily hand out expired tokens here, just assuming that the
TaskManagerloop could keep up. That's the real cause of the problem