diff --git a/pgdog/src/frontend/logical_session.rs b/pgdog/src/frontend/logical_session.rs new file mode 100644 index 00000000..d5d14435 --- /dev/null +++ b/pgdog/src/frontend/logical_session.rs @@ -0,0 +1,393 @@ +//! # Logical Session Management in PgDog +//! +//! This module provides a unified logical session interface to coordinate and +//! validate session variables across shards in PgDog. +//! +//! PgDog emulates a single-node PostgreSQL interface for clients, hiding the +//! underlying sharded topology. The `LogicalSession` struct maintains session +//! state to guarantee consistent behavior across horizontally sharded backend +//! PostgreSQL servers. +//! +//! Session variables configured via `SET` commands are logically tracked and +//! propagated (fanned out) to relevant shards during multi-shard query execution. +//! This avoids inconsistencies in query behavior caused by differing variable +//! settings across shards. +//! +//! Example (valid on single-node Postgres, fanned out by PgDog): +//! -- SET search_path TO public; +//! -- SELECT * FROM users; -- PgDog fans out the SET to all relevant shards before querying. +//! +//! Counterexample (invalid if not fanned out): +//! -- SET TimeZone = 'UTC'; +//! -- SELECT NOW(); -- Without fanout, shards might use different timezones. +//! +//! ## Future Improvements +//! - Optimize synchronization by tracking "synced" shards on a per-variable +//! basis to minimize redundant `SET` commands. + +use std::collections::{HashMap, HashSet}; +use std::error::Error; +use std::fmt; + +use super::router::parser::Shard; + +// ----------------------------------------------------------------------------- +// ----- LogicalSession -------------------------------------------------------- + +#[derive(Debug)] +pub struct LogicalSession<'a> { + configuration_parameters: HashMap, ConfigValue>, + synced_shards: HashSet<&'a Shard>, + swapped_values: HashMap<(&'a Shard, ConfigParameter<'a>), ConfigValue>, +} + +impl<'a> LogicalSession<'a> { + pub fn new() -> Self { + Self { + configuration_parameters: HashMap::new(), + synced_shards: HashSet::new(), + swapped_values: HashMap::new(), + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Named Struct: ConfigParameter(&str) ----------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ConfigParameter<'a>(&'a str); + +impl<'a> From<&'a str> for ConfigParameter<'a> { + fn from(s: &'a str) -> Self { + ConfigParameter(s) + } +} + +impl<'a> fmt::Display for ConfigParameter<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.0) + } +} + +// ----------------------------------------------------------------------------- +// ----- Named Struct: ConfigValue(Strig) -------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ConfigValue(String); + +impl From for ConfigValue { + fn from(s: String) -> Self { + ConfigValue(s) + } +} + +impl From<&str> for ConfigValue { + fn from(s: &str) -> Self { + ConfigValue(s.to_owned()) + } +} + +impl fmt::Display for ConfigValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +// ----------------------------------------------------------------------------- +// ----- LogicalSession: Public methods ---------------------------------------- + +impl<'a> LogicalSession<'a> { + /// Set a logical configuration parameter to a new value, clearing shard sync state. + /// + /// # Arguments + /// + /// * `name` - The name of the configuration parameter (e.g., "TimeZone"). + /// * `value` - The desired value (e.g., "UTC"). + pub fn set_variable(&mut self, name: P, value: V) -> Result<(), SessionError> + where + P: Into>, + V: Into, + { + let key = name.into(); + let val = value.into(); + Self::verify_can_set(key, &val)?; + self.configuration_parameters.insert(key, val); + self.synced_shards.clear(); + Ok(()) + } + + /// Retrieve the current value of a configuration parameter, if set. + /// + /// # Arguments + /// + /// * `name` - The name of the configuration parameter to lookup. + pub fn get_variable

(&self, name: P) -> Option + where + P: Into>, + { + let key = name.into(); + self.configuration_parameters.get(&key).cloned() + } + + /// Mark a shard as having been synced with the latest parameter state. + /// + /// # Arguments + /// + /// * `shard` - Reference to the shard that has been updated. + pub fn sync_shard(&mut self, shard: &'a Shard) { + self.synced_shards.insert(shard); + } + + /// Check if a given shard is already synced with current parameters. + /// + /// # Arguments + /// + /// * `shard` - The shard to check sync status for. + pub fn is_shard_synced(&self, shard: &Shard) -> bool { + self.synced_shards.contains(shard) + } + + /// Store the previous value pulled from a shard before overwriting it. + /// + /// # Arguments + /// + /// * `shard` - The shard where the swap occurred. + /// * `name` - The configuration parameter name. + /// * `value` - The old value returned by the shard. + /// + /// # Example + /// ```sql + /// -- 1. Retrieve current TimeZone from shard 0: + /// SHOW TimeZone; -- returns 'UTC' + /// + /// -- 2. Internally store the old value before changing: + /// -- LogicalSession::new().store_swapped_value(Shard::Direct(0), 'TimeZone', 'UTC'); + /// + /// -- 3. Apply new setting: + /// SET TimeZone = 'America/New_York'; + pub fn store_swapped_value(&mut self, shard: &'a Shard, name: P, value: V) + where + P: Into>, + V: Into, + { + let key = (shard, name.into()); + self.swapped_values.insert(key, value.into()); + } + + /// Remove and return the stored swapped value for a shard+parameter, if any. + /// + /// # Arguments + /// + /// * `shard` - The shard to retrieve the swapped value from. + /// * `name` - The configuration parameter name. + pub fn take_swapped_value

(&mut self, shard: &'a Shard, name: P) -> Option + where + P: Into>, + { + let key = (shard, name.into()); + self.swapped_values.remove(&key) + } + + /// Reset the session state (e.g., on connection close or explicit RESET). + pub fn reset(&mut self) { + self.configuration_parameters.clear(); + self.synced_shards.clear(); + self.swapped_values.clear(); + } + + /// Reset the session state, returning all stored swapped values before clearing. + /// + /// # Returns + /// A map of `(Shard, ConfigParameter) -> ConfigValue` containing all swapped values. + pub fn reset_after_take(&mut self) -> HashMap<(&'a Shard, ConfigParameter<'a>), ConfigValue> { + // take swapped_values out and leave an empty map + let prev = std::mem::take(&mut self.swapped_values); + + // clear other session state + self.configuration_parameters.clear(); + self.synced_shards.clear(); + + prev + } +} + +// ----------------------------------------------------------------------------- +// ----- LogicalSession: Private methods --------------------------------------- + +impl<'a> LogicalSession<'a> { + /// Ensures the configuration parameters key and values are allowed. + /// Currently whitelists everything. + fn verify_can_set( + _name: ConfigParameter<'a>, + _value: &ConfigValue, + ) -> Result<(), SessionError> { + Ok(()) + } +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum SessionError { + InvalidVariableName(String), +} + +impl fmt::Display for SessionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SessionError::InvalidVariableName(name) => { + write!( + f, + "invalid or disallowed session configuration parameter: {}", + name + ) + } + } + } +} + +impl Error for SessionError {} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new() { + let session = LogicalSession::new(); + assert!(session.configuration_parameters.is_empty()); + assert!(session.synced_shards.is_empty()); + assert!(session.swapped_values.is_empty()); + } + + #[test] + fn test_set_and_get_variable() { + let mut session = LogicalSession::new(); + session.set_variable("TimeZone", "UTC").unwrap(); + let gotten = session.get_variable("TimeZone"); + + assert_eq!(gotten, Some(ConfigValue("UTC".to_owned()))); + assert_eq!(session.get_variable("NonExistent"), None); + } + + #[test] + fn test_set_clears_synced_shards() { + let mut session = LogicalSession::new(); + let shard1 = Shard::Direct(1); + session.sync_shard(&shard1); + assert!(session.is_shard_synced(&shard1)); + + session.set_variable("search_path", "public").unwrap(); + assert!(!session.is_shard_synced(&shard1)); + } + + #[test] + fn test_sync_and_check_shard() { + let mut session = LogicalSession::new(); + let shard1 = Shard::Direct(1); + let shard2 = Shard::Direct(2); + + session.sync_shard(&shard1); + assert!(session.is_shard_synced(&shard1)); + assert!(!session.is_shard_synced(&shard2)); + + session.sync_shard(&shard2); + assert!(session.is_shard_synced(&shard2)); + } + + #[test] + fn test_store_and_take_swapped_value() { + let shard1 = Shard::Direct(1); + let shard2 = Shard::Direct(2); + + let mut session = LogicalSession::new(); + session.store_swapped_value(&shard1, "TimeZone", "UTC"); + + // Value can be taken once + let taken = session.take_swapped_value(&shard1, "TimeZone"); + assert_eq!(taken, Some(ConfigValue("UTC".to_owned()))); + + // Value that has been taken is not there + let taken = session.take_swapped_value(&shard1, "TimeZone"); + assert_eq!(taken, None); + + // Value that has never been set is not there + let taken = session.take_swapped_value(&shard2, "TimeZone"); + assert_eq!(taken, None); + } + + #[test] + fn test_reset() { + let mut session = LogicalSession::new(); + let shard1 = Shard::Direct(1); + + session.set_variable("TimeZone", "UTC").unwrap(); + session.sync_shard(&shard1); + session.store_swapped_value(&shard1, "TimeZone", "OldUTC"); + + session.reset(); + assert!(session.configuration_parameters.is_empty()); + assert!(session.synced_shards.is_empty()); + assert!(session.swapped_values.is_empty()); + } + + #[test] + fn test_multiple_operations() { + let mut session = LogicalSession::new(); + let shard1 = Shard::Direct(1); + let shard2 = Shard::Direct(2); + + session.set_variable("TimeZone", "UTC").unwrap(); + session.set_variable("search_path", "public").unwrap(); + assert_eq!( + session.get_variable("TimeZone"), + Some(ConfigValue("UTC".to_owned())) + ); + + session.sync_shard(&shard1); + session.sync_shard(&shard2); + assert!(session.is_shard_synced(&shard1)); + + session.store_swapped_value(&shard1, "TimeZone", "America/New_York"); + assert_eq!( + session.take_swapped_value(&shard1, "TimeZone"), + Some(ConfigValue("America/New_York".to_owned())) + ); + + session.set_variable("TimeZone", "PST").unwrap(); // Should clear synced_shards + assert!(!session.is_shard_synced(&shard1)); + assert!(!session.is_shard_synced(&shard2)); + } + + #[test] + fn reset_after_take_returns_swapped_and_clears() { + let mut sess = LogicalSession::new(); + let shard1 = &Shard::Direct(1); + let shard2 = &Shard::Direct(2); + sess.set_variable("a", "1").unwrap(); + + // Caller has mapped over every shard, pulled their existing value and set "a" to "1". + sess.store_swapped_value(shard1, "a", "2"); + sess.store_swapped_value(shard2, "a", "2"); + sess.sync_shard(shard1); + + let swapped = sess.reset_after_take(); + assert_eq!(swapped.len(), 2); + assert_eq!(swapped.get(&(shard1, ConfigParameter("a"))).unwrap().0, "2"); + assert_eq!(swapped.get(&(shard2, ConfigParameter("a"))).unwrap().0, "2"); + + assert!(sess.configuration_parameters.is_empty()); + assert!(sess.synced_shards.is_empty()); + assert!(sess.swapped_values.is_empty()); + assert!(sess.get_variable("a").is_none()); + assert!(!sess.is_shard_synced(shard1)); + assert!(sess.take_swapped_value(shard1, "b").is_none()); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/frontend/logical_transaction.rs b/pgdog/src/frontend/logical_transaction.rs new file mode 100644 index 00000000..232b8034 --- /dev/null +++ b/pgdog/src/frontend/logical_transaction.rs @@ -0,0 +1,663 @@ +//! # Logical Transaction Management in PgDog +//! +//! Exposes a unified logical transaction interface while coordinating and validating transactions +//! across shards, preventing illegal behavior. +//! +//! PgDog presents a single-node PostgreSQL interface to clients, concealing the underlying shard topology. +//! `LogicalTransaction` tracks transaction state to ensure consistent transactional behavior across horizontally +//! sharded backend Postgres servers. +//! +//! Sharding hints and "dirty" shard tracking enforce single-shard constraints within transactions. +//! +//! Example (valid on single-node Postgres): +//! -- BEGIN; +//! -- INSERT INTO users (id) VALUES (123); +//! -- INSERT INTO users (id) VALUES (345); +//! -- COMMIT; +//! +//! Counterexample (invalid cross-shard sequence without 2PC): +//! -- BEGIN; +//! -- SET pgdog_shard = 0; +//! -- INSERT INTO users (id) VALUES (123); +//! -- SET pgdog_shard = 8; +//! -- INSERT INTO users (id) VALUES (345); +//! -- COMMIT; +//! +//! Future: `allow_cross_shard_transaction = true` may enable PgDog to manage 2PCs automatically with a performance hit. +//! + +use std::error::Error; +use std::fmt; + +use super::router::parser::Shard; + +// ----------------------------------------------------------------------------- +// ----- LogicalTransaction ---------------------------------------------------- + +#[derive(Debug)] +pub struct LogicalTransaction { + pub status: TransactionStatus, + manual_shard: Option, + dirty_shard: Option, +} + +impl LogicalTransaction { + pub fn new() -> Self { + Self { + status: TransactionStatus::Idle, + manual_shard: None, + dirty_shard: None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- LogicalTransaction: Public methods ------------------------------------ + +impl LogicalTransaction { + /// Return the shard to apply statements to. + /// If a manual shard is set, returns it. Otherwise returns the touched shard. + /// In practice, either only one value is set, or both values are the same. + pub fn active_shard(&self) -> Option { + self.dirty_shard + .clone() + .or_else(|| self.manual_shard.clone()) + } + + /// Mark that a `BEGIN` is pending. + /// + /// Transitions `Idle -> BeginPending`. + /// + /// # Errors + /// - `AlreadyInTransaction` if already `BeginPending` or `InProgress`. + /// - `AlreadyFinalized` if `Committed` or `RolledBack`.tx or finalized. + pub fn soft_begin(&mut self) -> Result<(), TransactionError> { + match self.status { + TransactionStatus::Idle => { + self.status = TransactionStatus::BeginPending; + Ok(()) + } + TransactionStatus::BeginPending | TransactionStatus::InProgress => { + Err(TransactionError::AlreadyInTransaction) + } + TransactionStatus::Committed | TransactionStatus::RolledBack => { + Err(TransactionError::AlreadyFinalized) + } + } + } + + /// Execute a query against `shard`, updating transactional state. + /// + /// - Touches the shard (enforcing the shard conflict rules). + /// - Transitions `BeginPending -> InProgress` on first statement. + /// - No-op state change when already `InProgress`. + /// + /// # Errors + /// - `NoPendingBegins` if status is `Idle`. + /// - `AlreadyFinalized` if `Committed` or `RolledBack`. + /// - `InvalidManualShardType` if `shard` is not `Shard::Direct(_)`. + /// - `ShardConflict` if `active_shard` is set to a different shard. + pub fn execute_query(&mut self, shard: Shard) -> Result<(), TransactionError> { + self.touch_shard(shard)?; + + match self.status { + TransactionStatus::BeginPending => { + self.status = TransactionStatus::InProgress; + Ok(()) + } + + TransactionStatus::Idle => Err(TransactionError::NoPendingBegins), + TransactionStatus::InProgress => Ok(()), + TransactionStatus::Committed => Err(TransactionError::AlreadyFinalized), + TransactionStatus::RolledBack => Err(TransactionError::AlreadyFinalized), + } + } + + /// Commit the transaction. + /// + /// Transitions `InProgress -> Committed`. + /// + /// # Errors + /// - `NoPendingBegins` if `Idle`. + /// - `NoActiveTransaction` if `BeginPending` (nothing ran). + /// - `AlreadyFinalized` if already `Committed` or `RolledBack`. + pub fn commit(&mut self) -> Result<(), TransactionError> { + match self.status { + TransactionStatus::InProgress => { + self.status = TransactionStatus::Committed; + Ok(()) + } + + TransactionStatus::Idle => Err(TransactionError::NoPendingBegins), + TransactionStatus::BeginPending => Err(TransactionError::NoActiveTransaction), + TransactionStatus::Committed => Err(TransactionError::AlreadyFinalized), + TransactionStatus::RolledBack => Err(TransactionError::AlreadyFinalized), + } + } + + /// Roll back the transaction. + /// + /// Transitions `InProgress -> RolledBack`. + /// + /// # Errors + /// - `NoPendingBegins` if `Idle`. + /// - `NoActiveTransaction` if `BeginPending` (nothing ran). + /// - `AlreadyFinalized` if already `Committed` or `RolledBack`. + pub fn rollback(&mut self) -> Result<(), TransactionError> { + match self.status { + TransactionStatus::InProgress => { + self.status = TransactionStatus::RolledBack; + Ok(()) + } + + TransactionStatus::Idle => Err(TransactionError::NoPendingBegins), + TransactionStatus::BeginPending => Err(TransactionError::NoActiveTransaction), + TransactionStatus::Committed => Err(TransactionError::AlreadyFinalized), + TransactionStatus::RolledBack => Err(TransactionError::AlreadyFinalized), + } + } + + /// Reset all transactional/session state. + /// + /// Sets status to `Idle`, clears manual and dirty shard + /// Safe to call in any state. + pub fn reset(&mut self) { + self.status = TransactionStatus::Idle; + self.manual_shard = None; + self.dirty_shard = None; + } + + /// Pin the transaction to a specific shard. + /// + /// Accepts only `Shard::Direct(_)`. + /// No-op if setting the same shard again. + /// If a different shard was already touched, fails. + /// + /// # Errors + /// - `InvalidManualShardType` unless `Shard::Direct(_)`. + /// - `ShardConflict` if `dirty_shard` is set to a different shard. + pub fn set_manual_shard(&mut self, shard: Shard) -> Result<(), TransactionError> { + // only Shard::Direct(n) is valid in a transaction + if !matches!(shard, Shard::Direct(_)) { + return Err(TransactionError::InvalidShardType); + } + + // no-op if unchanged + if self.manual_shard.as_ref().map_or(false, |h| h == &shard) { + return Ok(()); + } + + // if we already touched a different shard, error + if let Some(d) = &self.dirty_shard { + if *d != shard { + return Err(TransactionError::ShardConflict); + } + } + + self.manual_shard = Some(shard); + Ok(()) + } +} + +// ----------------------------------------------------------------------------- +// ----- LogicalTransaction: Private methods ----------------------------------- + +impl LogicalTransaction { + /// Record that this transaction touched `shard`. + /// Enforces single-shard discipline. + fn touch_shard(&mut self, shard: Shard) -> Result<(), TransactionError> { + // Shard must be of type Shard::Direct(n). + if !matches!(shard, Shard::Direct(_)) { + return Err(TransactionError::InvalidShardType); + } + + // Already pinned to a manual shard → forbid drift. + if let Some(hint) = &self.manual_shard { + if *hint != shard { + return Err(TransactionError::ShardConflict); + } + } + + // Already dirtied another shard → forbid drift. + if let Some(dirty) = &self.dirty_shard { + if *dirty != shard { + return Err(TransactionError::ShardConflict); + } + } + + // Nothing in conflict; mark the shard. + self.dirty_shard = Some(shard); + Ok(()) + } +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum TransactionError { + // Transaction lifecycle + AlreadyInTransaction, + NoActiveTransaction, + AlreadyFinalized, + NoPendingBegins, + + // Sharding policy + InvalidShardType, + ShardConflict, +} + +impl fmt::Display for TransactionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use TransactionError::*; + match self { + AlreadyInTransaction => write!(f, "transaction already started"), + NoActiveTransaction => write!(f, "no active transaction"), + AlreadyFinalized => write!(f, "transaction already finalized"), + NoPendingBegins => write!(f, "transaction not pending"), + InvalidShardType => write!(f, "sharding hints must be ::Direct(n)"), + ShardConflict => { + write!(f, "can't run a transaction on multiple shards") + } + } + } +} + +impl Error for TransactionError {} + +// ----------------------------------------------------------------------------- +// ----- SubStruct: TransactionStatus ------------------------------------------ + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TransactionStatus { + /// No transaction started. + Idle, + /// BEGIN issued by client; waiting to relay it until first in-transaction query. + BeginPending, + /// Transaction active. + InProgress, + /// ROLLBACK issued. + RolledBack, + /// COMMIT issued. + Committed, +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_transaction_is_idle() { + let tx = LogicalTransaction::new(); + assert_eq!(tx.status, TransactionStatus::Idle); + assert_eq!(tx.manual_shard, None); + assert_eq!(tx.dirty_shard, None); + } + + #[test] + fn test_soft_begin_from_idle() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + assert_eq!(tx.status, TransactionStatus::BeginPending); + } + + #[test] + fn test_soft_begin_already_pending_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + let err = tx.soft_begin().unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyInTransaction)); + } + + #[test] + fn test_soft_begin_in_progress_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + let err = tx.soft_begin().unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyInTransaction)); + } + + #[test] + fn test_soft_begin_after_commit_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.commit().unwrap(); + let err = tx.soft_begin().unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyFinalized)); + } + + #[test] + fn test_soft_begin_after_rollback_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.rollback().unwrap(); + let err = tx.soft_begin().unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyFinalized)); + } + + #[test] + fn test_execute_query_from_begin_pending() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + assert_eq!(tx.status, TransactionStatus::InProgress); + assert_eq!(tx.dirty_shard, Some(Shard::Direct(0))); + } + + #[test] + fn test_execute_query_from_idle_errors() { + let mut tx = LogicalTransaction::new(); + let err = tx.execute_query(Shard::Direct(0)).unwrap_err(); + assert!(matches!(err, TransactionError::NoPendingBegins)); + } + + #[test] + fn test_execute_query_after_commit_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.commit().unwrap(); + let err = tx.execute_query(Shard::Direct(0)).unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyFinalized)); + } + + #[test] + fn test_execute_query_multiple_on_same_shard() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + assert_eq!(tx.dirty_shard, Some(Shard::Direct(0))); + assert_eq!(tx.status, TransactionStatus::InProgress); + } + + #[test] + fn test_execute_query_cross_shard_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + let err = tx.execute_query(Shard::Direct(1)).unwrap_err(); + assert!(matches!(err, TransactionError::ShardConflict)); + } + + #[test] + fn test_execute_query_invalid_shard_type_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + let err = tx.execute_query(Shard::All).unwrap_err(); + assert!(matches!(err, TransactionError::InvalidShardType)); + } + + #[test] + fn test_commit_from_in_progress() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.commit().unwrap(); + assert_eq!(tx.status, TransactionStatus::Committed); + } + + #[test] + fn test_commit_from_idle_errors() { + let mut tx = LogicalTransaction::new(); + let err = tx.commit().unwrap_err(); + assert!(matches!(err, TransactionError::NoPendingBegins)); + } + + #[test] + fn test_commit_from_begin_pending_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + let err = tx.commit().unwrap_err(); + assert!(matches!(err, TransactionError::NoActiveTransaction)); + } + + #[test] + fn test_commit_already_committed_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.commit().unwrap(); + let err = tx.commit().unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyFinalized)); + } + + #[test] + fn test_rollback_from_in_progress() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.rollback().unwrap(); + assert_eq!(tx.status, TransactionStatus::RolledBack); + } + + #[test] + fn test_rollback_from_begin_pending_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + let err = tx.rollback().unwrap_err(); + assert!(matches!(err, TransactionError::NoActiveTransaction)); + } + + #[test] + fn test_reset_clears_state() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.set_manual_shard(Shard::Direct(0)).unwrap(); + tx.reset(); + assert_eq!(tx.status, TransactionStatus::Idle); + assert_eq!(tx.manual_shard, None); + assert_eq!(tx.dirty_shard, None); + } + + #[test] + fn test_set_manual_shard_before_touch() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.set_manual_shard(Shard::Direct(0)).unwrap(); + assert_eq!(tx.manual_shard, Some(Shard::Direct(0))); + tx.execute_query(Shard::Direct(0)).unwrap(); // should succeed + } + + #[test] + fn test_set_manual_shard_after_touch_same_ok() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.set_manual_shard(Shard::Direct(0)).unwrap(); + assert_eq!(tx.manual_shard, Some(Shard::Direct(0))); + } + + #[test] + fn test_set_manual_shard_after_touch_different_errors() { + let mut tx = LogicalTransaction::new(); + // touch shard 0 + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + // manually set shard 1 + let err = tx.set_manual_shard(Shard::Direct(1)).unwrap_err(); + assert!(matches!(err, TransactionError::ShardConflict)); + } + + #[test] + fn test_manual_then_dirty_conflict() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + // pin to shard 0 + tx.set_manual_shard(Shard::Direct(0)).unwrap(); + // touching another shard must fail + let err = tx.execute_query(Shard::Direct(1)).unwrap_err(); + assert!(matches!(err, TransactionError::ShardConflict)); + } + + #[test] + fn test_set_manual_shard_invalid_type_errors() { + let mut tx = LogicalTransaction::new(); + let err = tx.set_manual_shard(Shard::All).unwrap_err(); + assert!(matches!(err, TransactionError::InvalidShardType)); + } + + #[test] + fn test_active_shard_dirty() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(69)).unwrap(); + assert_eq!(tx.active_shard(), Some(Shard::Direct(69))); + } + + #[test] + fn test_active_shard_manual() { + let mut tx = LogicalTransaction::new(); + tx.set_manual_shard(Shard::Direct(1)).unwrap(); + assert_eq!(tx.active_shard(), Some(Shard::Direct(1))); + } + + #[test] + fn test_rollback_from_idle_errors() { + let mut tx = LogicalTransaction::new(); + let err = tx.rollback().unwrap_err(); + assert!(matches!(err, TransactionError::NoPendingBegins)); + } + + #[test] + fn test_commit_after_rollback_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.rollback().unwrap(); + let err = tx.commit().unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyFinalized)); + } + + #[test] + fn test_rollback_after_commit_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.commit().unwrap(); + let err = tx.rollback().unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyFinalized)); + } + + #[test] + fn test_rollback_already_rolledback_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.rollback().unwrap(); + let err = tx.rollback().unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyFinalized)); + } + + #[test] + fn test_execute_query_after_rollback_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.rollback().unwrap(); + let err = tx.execute_query(Shard::Direct(0)).unwrap_err(); + assert!(matches!(err, TransactionError::AlreadyFinalized)); + } + + #[test] + fn test_set_manual_shard_multiple_changes_before_execute() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.set_manual_shard(Shard::Direct(1)).unwrap(); + tx.set_manual_shard(Shard::Direct(2)).unwrap(); + assert_eq!(tx.manual_shard, Some(Shard::Direct(2))); + tx.execute_query(Shard::Direct(2)).unwrap(); + let err = tx.execute_query(Shard::Direct(1)).unwrap_err(); + assert!(matches!(err, TransactionError::ShardConflict)); + } + + #[test] + fn test_set_manual_shard_after_commit_same_ok() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.commit().unwrap(); + tx.set_manual_shard(Shard::Direct(0)).unwrap(); + assert_eq!(tx.manual_shard, Some(Shard::Direct(0))); + } + + #[test] + fn test_set_manual_shard_after_commit_different_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.commit().unwrap(); + let err = tx.set_manual_shard(Shard::Direct(1)).unwrap_err(); + assert!(matches!(err, TransactionError::ShardConflict)); + } + + #[test] + fn test_set_manual_shard_after_rollback_same_ok() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.rollback().unwrap(); + tx.set_manual_shard(Shard::Direct(0)).unwrap(); + assert_eq!(tx.manual_shard, Some(Shard::Direct(0))); + } + + #[test] + fn test_set_manual_shard_after_rollback_different_errors() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.rollback().unwrap(); + let err = tx.set_manual_shard(Shard::Direct(1)).unwrap_err(); + assert!(matches!(err, TransactionError::ShardConflict)); + } + + #[test] + fn test_active_shard_none() { + let tx = LogicalTransaction::new(); + assert_eq!(tx.active_shard(), None); + } + + #[test] + fn test_set_manual_shard_in_idle() { + let mut tx = LogicalTransaction::new(); + tx.set_manual_shard(Shard::Direct(0)).unwrap(); + assert_eq!(tx.manual_shard, Some(Shard::Direct(0))); + } + + #[test] + fn test_soft_begin_after_reset_from_finalized() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.commit().unwrap(); + tx.reset(); + tx.soft_begin().unwrap(); + assert_eq!(tx.status, TransactionStatus::BeginPending); + } + + #[test] + fn test_active_shard_both_same() { + let mut tx = LogicalTransaction::new(); + tx.set_manual_shard(Shard::Direct(3)).unwrap(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(3)).unwrap(); + assert_eq!(tx.active_shard(), Some(Shard::Direct(3))); + } + + #[test] + fn test_statements_executed_remains_zero_after_execute() { + let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/frontend/mod.rs b/pgdog/src/frontend/mod.rs index e9228e85..2bb445ee 100644 --- a/pgdog/src/frontend/mod.rs +++ b/pgdog/src/frontend/mod.rs @@ -6,6 +6,8 @@ pub mod comms; pub mod connected_client; pub mod error; pub mod listener; +pub mod logical_session; +pub mod logical_transaction; pub mod prepared_statements; #[cfg(debug_assertions)] pub mod query_logger;