diff --git a/.circleci/pgcat.toml b/.circleci/pgcat.toml index 1c0c0104..b44e3d2e 100644 --- a/.circleci/pgcat.toml +++ b/.circleci/pgcat.toml @@ -17,6 +17,9 @@ connect_timeout = 100 # How much time to give the health check query to return with a result (ms). healthcheck_timeout = 100 +# How much time to give clients during shutdown before forcibly killing client connections (ms). +shutdown_timeout = 5000 + # For how long to ban a server if it fails a health check (seconds). ban_time = 60 # Seconds diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 138dbecb..431f2d60 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -74,12 +74,12 @@ cd ../.. # # Python tests +# These tests will start and stop the pgcat server so it will need to be restarted after the tests # -cd tests/python -pip3 install -r requirements.txt -python3 tests.py -cd ../.. +pip3 install -r tests/python/requirements.txt +python3 tests/python/tests.py +start_pgcat "info" # Admin tests export PGPASSWORD=admin_pass diff --git a/.gitignore b/.gitignore index a4b78411..3c654539 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea /target *.deb +.vscode \ No newline at end of file diff --git a/README.md b/README.md index 5b54a7fc..c03982d2 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ psql -h 127.0.0.1 -p 6432 -c 'SELECT 1' | `pool_mode` | The pool mode to use, i.e. `session` or `transaction`. | `transaction` | | `connect_timeout` | Maximum time to establish a connection to a server (milliseconds). If reached, the server is banned and the next target is attempted. | `5000` | | `healthcheck_timeout` | Maximum time to pass a health check (`SELECT 1`, milliseconds). If reached, the server is banned and the next target is attempted. | `1000` | +| `shutdown_timeout` | Maximum time to give clients during shutdown before forcibly killing client connections (ms). | `60000` | | `ban_time` | Ban time for a server (seconds). It won't be allowed to serve transactions until the ban expires; failover targets will be used instead. | `60` | | | | | | **`user`** | | | @@ -250,6 +251,7 @@ The config can be reloaded by sending a `kill -s SIGHUP` to the process or by qu | `pool_mode` | no | | `connect_timeout` | yes | | `healthcheck_timeout` | no | +| `shutdown_timeout` | no | | `ban_time` | no | | `user` | yes | | `shards` | yes | diff --git a/examples/docker/pgcat.toml b/examples/docker/pgcat.toml index 874f737a..40a54928 100644 --- a/examples/docker/pgcat.toml +++ b/examples/docker/pgcat.toml @@ -17,6 +17,9 @@ connect_timeout = 5000 # How much time to give `SELECT 1` health check query to return with a result (ms). healthcheck_timeout = 1000 +# How much time to give clients during shutdown before forcibly killing client connections (ms). +shutdown_timeout = 60000 + # For how long to ban a server if it fails a health check (seconds). ban_time = 60 # seconds diff --git a/pgcat.toml b/pgcat.toml index a1937e6c..3d9d7df3 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -17,6 +17,9 @@ connect_timeout = 5000 # How much time to give `SELECT 1` health check query to return with a result (ms). healthcheck_timeout = 1000 +# How much time to give clients during shutdown before forcibly killing client connections (ms). +shutdown_timeout = 60000 + # For how long to ban a server if it fails a health check (seconds). ban_time = 60 # seconds diff --git a/src/client.rs b/src/client.rs index 1775ad22..cc912191 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,6 +4,7 @@ use log::{debug, error, info, trace}; use std::collections::HashMap; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; +use tokio::sync::broadcast::Receiver; use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::config::get_config; @@ -73,12 +74,15 @@ pub struct Client { last_server_id: Option, target_pool: ConnectionPool, + + shutdown_event_receiver: Receiver<()>, } /// Client entrypoint. pub async fn client_entrypoint( mut stream: TcpStream, client_server_map: ClientServerMap, + shutdown_event_receiver: Receiver<()>, ) -> Result<(), Error> { // Figure out if the client wants TLS or not. let addr = stream.peer_addr().unwrap(); @@ -97,7 +101,7 @@ pub async fn client_entrypoint( write_all(&mut stream, yes).await?; // Negotiate TLS. - match startup_tls(stream, client_server_map).await { + match startup_tls(stream, client_server_map, shutdown_event_receiver).await { Ok(mut client) => { info!("Client {:?} connected (TLS)", addr); @@ -121,7 +125,16 @@ pub async fn client_entrypoint( let (read, write) = split(stream); // Continue with regular startup. - match Client::startup(read, write, addr, bytes, client_server_map).await { + match Client::startup( + read, + write, + addr, + bytes, + client_server_map, + shutdown_event_receiver, + ) + .await + { Ok(mut client) => { info!("Client {:?} connected (plain)", addr); @@ -142,7 +155,16 @@ pub async fn client_entrypoint( let (read, write) = split(stream); // Continue with regular startup. - match Client::startup(read, write, addr, bytes, client_server_map).await { + match Client::startup( + read, + write, + addr, + bytes, + client_server_map, + shutdown_event_receiver, + ) + .await + { Ok(mut client) => { info!("Client {:?} connected (plain)", addr); @@ -157,7 +179,16 @@ pub async fn client_entrypoint( let (read, write) = split(stream); // Continue with cancel query request. - match Client::cancel(read, write, addr, bytes, client_server_map).await { + match Client::cancel( + read, + write, + addr, + bytes, + client_server_map, + shutdown_event_receiver, + ) + .await + { Ok(mut client) => { info!("Client {:?} issued a cancel query request", addr); @@ -214,6 +245,7 @@ where pub async fn startup_tls( stream: TcpStream, client_server_map: ClientServerMap, + shutdown_event_receiver: Receiver<()>, ) -> Result>, WriteHalf>>, Error> { // Negotiate TLS. let tls = Tls::new()?; @@ -237,7 +269,15 @@ pub async fn startup_tls( Ok((ClientConnectionType::Startup, bytes)) => { let (read, write) = split(stream); - Client::startup(read, write, addr, bytes, client_server_map).await + Client::startup( + read, + write, + addr, + bytes, + client_server_map, + shutdown_event_receiver, + ) + .await } // Bad Postgres client. @@ -258,6 +298,7 @@ where addr: std::net::SocketAddr, bytes: BytesMut, // The rest of the startup message. client_server_map: ClientServerMap, + shutdown_event_receiver: Receiver<()>, ) -> Result, Error> { let config = get_config(); let stats = get_reporter(); @@ -384,6 +425,7 @@ where last_address_id: None, last_server_id: None, target_pool: target_pool, + shutdown_event_receiver: shutdown_event_receiver, }); } @@ -394,6 +436,7 @@ where addr: std::net::SocketAddr, mut bytes: BytesMut, // The rest of the startup message. client_server_map: ClientServerMap, + shutdown_event_receiver: Receiver<()>, ) -> Result, Error> { let process_id = bytes.get_i32(); let secret_key = bytes.get_i32(); @@ -413,6 +456,7 @@ where last_address_id: None, last_server_id: None, target_pool: ConnectionPool::default(), + shutdown_event_receiver: shutdown_event_receiver, }); } @@ -467,7 +511,14 @@ where // We can parse it here before grabbing a server from the pool, // in case the client is sending some custom protocol messages, e.g. // SET SHARDING KEY TO 'bigint'; - let mut message = read_message(&mut self.read).await?; + + let mut message = tokio::select! { + _ = self.shutdown_event_receiver.recv() => { + error_response_terminal(&mut self.write, &format!("terminating connection due to administrator command")).await?; + return Ok(()) + }, + message_result = read_message(&mut self.read) => message_result? + }; // Get a pool instance referenced by the most up-to-date // pointer. This ensures we always read the latest config diff --git a/src/config.rs b/src/config.rs index f1138f98..2a214260 100644 --- a/src/config.rs +++ b/src/config.rs @@ -119,6 +119,7 @@ pub struct General { pub port: i16, pub connect_timeout: u64, pub healthcheck_timeout: u64, + pub shutdown_timeout: u64, pub ban_time: i64, pub autoreload: bool, pub tls_certificate: Option, @@ -134,6 +135,7 @@ impl Default for General { port: 5432, connect_timeout: 5000, healthcheck_timeout: 1000, + shutdown_timeout: 60000, ban_time: 60, autoreload: false, tls_certificate: None, @@ -273,6 +275,10 @@ impl From<&Config> for std::collections::HashMap { "healthcheck_timeout".to_string(), config.general.healthcheck_timeout.to_string(), ), + ( + "shutdown_timeout".to_string(), + config.general.shutdown_timeout.to_string(), + ), ("ban_time".to_string(), config.general.ban_time.to_string()), ]; @@ -290,6 +296,7 @@ impl Config { self.general.healthcheck_timeout ); info!("Connection timeout: {}ms", self.general.connect_timeout); + info!("Shutdown timeout: {}ms", self.general.shutdown_timeout); match self.general.tls_certificate.clone() { Some(tls_certificate) => { info!("TLS certificate: {}", tls_certificate); diff --git a/src/main.rs b/src/main.rs index 3622398c..5e5c9248 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,13 +40,13 @@ use log::{debug, error, info}; use parking_lot::Mutex; use tokio::net::TcpListener; use tokio::{ - signal, signal::unix::{signal as unix_signal, SignalKind}, sync::mpsc, }; use std::collections::HashMap; use std::sync::Arc; +use tokio::sync::broadcast; mod admin; mod client; @@ -139,24 +139,52 @@ async fn main() { info!("Waiting for clients"); + let (shutdown_event_tx, mut shutdown_event_rx) = broadcast::channel::<()>(1); + + let shutdown_event_tx_clone = shutdown_event_tx.clone(); + // Client connection loop. tokio::task::spawn(async move { + // Creates event subscriber for shutdown event, this is dropped when shutdown event is broadcast + let mut listener_shutdown_event_rx = shutdown_event_tx_clone.subscribe(); loop { let client_server_map = client_server_map.clone(); - let (socket, addr) = match listener.accept().await { - Ok((socket, addr)) => (socket, addr), - Err(err) => { - error!("{:?}", err); - continue; + // Listen for shutdown event and client connection at the same time + let (socket, addr) = tokio::select! { + _ = listener_shutdown_event_rx.recv() => { + // Exits client connection loop which drops listener, listener_shutdown_event_rx and shutdown_event_tx_clone + break; + } + + listener_response = listener.accept() => { + match listener_response { + Ok((socket, addr)) => (socket, addr), + Err(err) => { + error!("{:?}", err); + continue; + } + } } }; + // Used to signal shutdown + let client_shutdown_handler_rx = shutdown_event_tx_clone.subscribe(); + + // Used to signal that the task has completed + let dummy_tx = shutdown_event_tx_clone.clone(); + // Handle client. tokio::task::spawn(async move { let start = chrono::offset::Utc::now().naive_utc(); - match client::client_entrypoint(socket, client_server_map).await { + match client::client_entrypoint( + socket, + client_server_map, + client_shutdown_handler_rx, + ) + .await + { Ok(_) => { let duration = chrono::offset::Utc::now().naive_utc() - start; @@ -171,6 +199,8 @@ async fn main() { debug!("Client disconnected with error {:?}", err); } }; + // Drop this transmitter so receiver knows that the task is completed + drop(dummy_tx); }); } }); @@ -214,13 +244,41 @@ async fn main() { }); } - // Exit on Ctrl-C (SIGINT) and SIGTERM. let mut term_signal = unix_signal(SignalKind::terminate()).unwrap(); + let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap(); tokio::select! { - _ = signal::ctrl_c() => (), + // Initiate graceful shutdown sequence on sig int + _ = interrupt_signal.recv() => { + info!("Got SIGINT, waiting for client connection drain now"); + + // Broadcast that client tasks need to finish + shutdown_event_tx.send(()).unwrap(); + // Closes transmitter + drop(shutdown_event_tx); + + // This is in a loop because the first event that the receiver receives will be the shutdown event + // This is not what we are waiting for instead, we want the receiver to send an error once all senders are closed which is reached after the shutdown event is received + loop { + match tokio::time::timeout( + tokio::time::Duration::from_millis(config.general.shutdown_timeout), + shutdown_event_rx.recv(), + ) + .await + { + Ok(res) => match res { + Ok(_) => {} + Err(_) => break, + }, + Err(_) => { + info!("Timed out while waiting for clients to shutdown"); + break; + } + } + } + }, _ = term_signal.recv() => (), - }; + } info!("Shutting down..."); } diff --git a/src/messages.rs b/src/messages.rs index ba22a579..113e1ed5 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -98,7 +98,9 @@ pub async fn ready_for_query(stream: &mut S) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { - let mut bytes = BytesMut::with_capacity(5); + let mut bytes = BytesMut::with_capacity( + mem::size_of::() + mem::size_of::() + mem::size_of::(), + ); bytes.put_u8(b'Z'); bytes.put_i32(5); @@ -252,18 +254,25 @@ where res.put_i32(len); res.put_slice(&set_complete[..]); - // ReadyForQuery (idle) - res.put_u8(b'Z'); - res.put_i32(5); - res.put_u8(b'I'); - - write_all_half(stream, res).await + write_all_half(stream, res).await?; + ready_for_query(stream).await } /// Send a custom error message to the client. /// Tell the client we are ready for the next query and no rollback is necessary. /// Docs on error codes: . pub async fn error_response(stream: &mut S, message: &str) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + error_response_terminal(stream, message).await?; + ready_for_query(stream).await +} + +/// Send a custom error message to the client. +/// Tell the client we are ready for the next query and no rollback is necessary. +/// Docs on error codes: . +pub async fn error_response_terminal(stream: &mut S, message: &str) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { @@ -288,21 +297,12 @@ where // No more fields follow. error.put_u8(0); - // Ready for query, no rollback needed (I = idle). - let mut ready_for_query = BytesMut::new(); - - ready_for_query.put_u8(b'Z'); - ready_for_query.put_i32(5); - ready_for_query.put_u8(b'I'); - // Compose the two message reply. - let mut res = BytesMut::with_capacity(error.len() + ready_for_query.len() + 5); + let mut res = BytesMut::with_capacity(error.len() + 5); res.put_u8(b'E'); res.put_i32(error.len() as i32 + 4); - res.put(error); - res.put(ready_for_query); Ok(write_all_half(stream, res).await?) } @@ -366,12 +366,8 @@ where // CommandComplete res.put(command_complete("SELECT 1")); - // ReadyForQuery - res.put_u8(b'Z'); - res.put_i32(5); - res.put_u8(b'I'); - - write_all_half(stream, res).await + write_all_half(stream, res).await?; + ready_for_query(stream).await } pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt index d7661d4d..eebd9c90 100644 --- a/tests/python/requirements.txt +++ b/tests/python/requirements.txt @@ -1 +1,2 @@ psycopg2==2.9.3 +psutil==5.9.1 \ No newline at end of file diff --git a/tests/python/tests.py b/tests/python/tests.py index 15e3822e..3ff99a09 100644 --- a/tests/python/tests.py +++ b/tests/python/tests.py @@ -1,22 +1,158 @@ +from typing import Tuple import psycopg2 +import psutil +import os +import signal +import subprocess +from threading import Thread +import time -def test_normal_db_access(): - conn = psycopg2.connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat") +SHUTDOWN_TIMEOUT = 5 + +PGCAT_HOST = "127.0.0.1" +PGCAT_PORT = "6432" + + +def pgcat_start(): + pg_cat_send_signal(signal.SIGTERM) + pgcat_start_command = "./target/debug/pgcat .circleci/pgcat.toml" + subprocess.Popen(pgcat_start_command.split()) + + +def pg_cat_send_signal(signal: signal.Signals): + for proc in psutil.process_iter(["pid", "name"]): + if "pgcat" == proc.name(): + os.kill(proc.pid, signal) + + +def connect_normal_db( + autocommit: bool = False, +) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]: + conn = psycopg2.connect( + f"postgres://sharding_user:sharding_user@{PGCAT_HOST}:{PGCAT_PORT}/sharded_db?application_name=testing_pgcat" + ) + conn.autocommit = autocommit cur = conn.cursor() + return (conn, cur) + + +def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor): + cur.close() + conn.close() + + +def test_normal_db_access(): + conn, cur = connect_normal_db() cur.execute("SELECT 1") res = cur.fetchall() print(res) + cleanup_conn(conn, cur) def test_admin_db_access(): - conn = psycopg2.connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat") - conn.autocommit = True # BEGIN/COMMIT is not supported by admin db + conn = psycopg2.connect( + f"postgres://admin_user:admin_pass@{PGCAT_HOST}:{PGCAT_PORT}/pgcat" + ) + conn.autocommit = True # BEGIN/COMMIT is not supported by admin db cur = conn.cursor() cur.execute("SHOW POOLS") res = cur.fetchall() print(res) + cleanup_conn(conn, cur) + + +def test_shutdown_logic(): + + ##### NO ACTIVE QUERIES SIGINT HANDLING ##### + # Start pgcat + server = Thread(target=pgcat_start) + server.start() + + # Wait for server to fully start up + time.sleep(2) + + # Create client connection and send query (not in transaction) + conn, cur = connect_normal_db(True) + + cur.execute("BEGIN;") + cur.execute("SELECT 1;") + cur.execute("COMMIT;") + + # Send sigint to pgcat + pg_cat_send_signal(signal.SIGINT) + time.sleep(1) + + # Check that any new queries fail after sigint since server should close with no active transactions + try: + cur.execute("SELECT 1;") + except psycopg2.OperationalError as e: + pass + else: + # Fail if query execution succeeded + raise Exception("Server not closed after sigint") + cleanup_conn(conn, cur) + + ##### HANDLE TRANSACTION WITH SIGINT ##### + # Start pgcat + server = Thread(target=pgcat_start) + server.start() + + # Wait for server to fully start up + time.sleep(2) + + # Create client connection and begin transaction + conn, cur = connect_normal_db(True) + + cur.execute("BEGIN;") + cur.execute("SELECT 1;") + + # Send sigint to pgcat while still in transaction + pg_cat_send_signal(signal.SIGINT) + time.sleep(1) + + # Check that any new queries succeed after sigint since server should still allow transaction to complete + try: + cur.execute("SELECT 1;") + except psycopg2.OperationalError as e: + # Fail if query fails since server closed + raise Exception("Server closed while in transaction", e.pgerror) + + cleanup_conn(conn, cur) + + ##### HANDLE SHUTDOWN TIMEOUT WITH SIGINT ##### + # Start pgcat + server = Thread(target=pgcat_start) + server.start() + + # Wait for server to fully start up + time.sleep(3) + + # Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached + conn, cur = connect_normal_db(True) + + cur.execute("BEGIN;") + cur.execute("SELECT 1;") + + # Send sigint to pgcat while still in transaction + pg_cat_send_signal(signal.SIGINT) + + # pgcat shutdown timeout is set to SHUTDOWN_TIMEOUT seconds, so we sleep for SHUTDOWN_TIMEOUT + 1 seconds + time.sleep(SHUTDOWN_TIMEOUT + 1) + + # Check that any new queries succeed after sigint since server should still allow transaction to complete + try: + cur.execute("SELECT 1;") + except psycopg2.OperationalError as e: + pass + else: + # Fail if query execution succeeded + raise Exception("Server not closed after sigint and expected timeout") + + cleanup_conn(conn, cur) + test_normal_db_access() test_admin_db_access() +test_shutdown_logic()