Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' into zain/handle-statup-parameters
  • Loading branch information
zainkabani committed Jun 17, 2023
commit 3589859fdd4262f6f0f28ab5315972a4eec6494b
281 changes: 281 additions & 0 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,284 @@ impl BytesMutReader for BytesMut {
}
}
}
/// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Parse {
code: char,
#[allow(dead_code)]
len: i32,
pub name: String,
pub generated_name: String,
query: String,
num_params: i16,
param_types: Vec<i32>,
}

impl TryFrom<&BytesMut> for Parse {
type Error = Error;

fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let name = cursor.read_string()?;
let query = cursor.read_string()?;
let num_params = cursor.get_i16();
let mut param_types = Vec::new();

for _ in 0..num_params {
param_types.push(cursor.get_i32());
}

Ok(Parse {
code,
len,
name,
generated_name: prepared_statement_name(),
query,
num_params,
param_types,
})
}
}

impl TryFrom<Parse> for BytesMut {
type Error = Error;

fn try_from(parse: Parse) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();

let name_binding = CString::new(parse.name)?;
let name = name_binding.as_bytes_with_nul();

let query_binding = CString::new(parse.query)?;
let query = query_binding.as_bytes_with_nul();

// Recompute length of the message.
let len = 4 // self
+ name.len()
+ query.len()
+ 2
+ 4 * parse.num_params as usize;

bytes.put_u8(parse.code as u8);
bytes.put_i32(len as i32);
bytes.put_slice(name);
bytes.put_slice(query);
bytes.put_i16(parse.num_params);
for param in parse.param_types {
bytes.put_i32(param);
}

Ok(bytes)
}
}

impl TryFrom<&Parse> for BytesMut {
type Error = Error;

fn try_from(parse: &Parse) -> Result<BytesMut, Error> {
parse.clone().try_into()
}
}

impl Parse {
pub fn rename(mut self) -> Self {
self.name = self.generated_name.to_string();
self
}

pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
}

/// Bind (B) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Bind {
code: char,
#[allow(dead_code)]
len: i64,
portal: String,
pub prepared_statement: String,
num_param_format_codes: i16,
param_format_codes: Vec<i16>,
num_param_values: i16,
param_values: Vec<(i32, BytesMut)>,
num_result_column_format_codes: i16,
result_columns_format_codes: Vec<i16>,
}

impl TryFrom<&BytesMut> for Bind {
type Error = Error;

fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
let num_param_format_codes = cursor.get_i16();
let mut param_format_codes = Vec::new();

for _ in 0..num_param_format_codes {
param_format_codes.push(cursor.get_i16());
}

let num_param_values = cursor.get_i16();
let mut param_values = Vec::new();

for _ in 0..num_param_values {
let param_len = cursor.get_i32();
let mut param = BytesMut::with_capacity(param_len as usize);
param.resize(param_len as usize, b'0');
cursor.copy_to_slice(&mut param);
param_values.push((param_len, param));
}

let num_result_column_format_codes = cursor.get_i16();
let mut result_columns_format_codes = Vec::new();

for _ in 0..num_result_column_format_codes {
result_columns_format_codes.push(cursor.get_i16());
}

Ok(Bind {
code,
len: len as i64,
portal,
prepared_statement,
num_param_format_codes,
param_format_codes,
num_param_values,
param_values,
num_result_column_format_codes,
result_columns_format_codes,
})
}
}

impl TryFrom<Bind> for BytesMut {
type Error = Error;

fn try_from(bind: Bind) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();

let portal_binding = CString::new(bind.portal)?;
let portal = portal_binding.as_bytes_with_nul();

let prepared_statement_binding = CString::new(bind.prepared_statement)?;
let prepared_statement = prepared_statement_binding.as_bytes_with_nul();

let mut len = 4 // self
+ portal.len()
+ prepared_statement.len()
+ 2 // num_param_format_codes
+ 2 * bind.num_param_format_codes as usize // num_param_format_codes
+ 2; // num_param_values

for (param_len, _) in &bind.param_values {
len += 4 + *param_len as usize;
}
len += 2; // num_result_column_format_codes
len += 2 * bind.num_result_column_format_codes as usize;

bytes.put_u8(bind.code as u8);
bytes.put_i32(len as i32);
bytes.put_slice(portal);
bytes.put_slice(prepared_statement);
bytes.put_i16(bind.num_param_format_codes);
for param_format_code in bind.param_format_codes {
bytes.put_i16(param_format_code);
}
bytes.put_i16(bind.num_param_values);
for (param_len, param) in bind.param_values {
bytes.put_i32(param_len);
bytes.put_slice(&param);
}
bytes.put_i16(bind.num_result_column_format_codes);
for result_column_format_code in bind.result_columns_format_codes {
bytes.put_i16(result_column_format_code);
}

Ok(bytes)
}
}

impl Bind {
pub fn reassign(mut self, parse: &Parse) -> Self {
self.prepared_statement = parse.name.clone();
self
}

pub fn anonymous(&self) -> bool {
self.prepared_statement.is_empty()
}
}

#[derive(Debug, Clone)]
pub struct Describe {
code: char,

#[allow(dead_code)]
len: i32,
target: char,
pub statement_name: String,
}

impl TryFrom<&BytesMut> for Describe {
type Error = Error;

fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
let mut cursor = Cursor::new(bytes);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let target = cursor.get_u8() as char;
let statement_name = cursor.read_string()?;

Ok(Describe {
code,
len,
target,
statement_name,
})
}
}

impl TryFrom<Describe> for BytesMut {
type Error = Error;

fn try_from(describe: Describe) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();
let statement_name_binding = CString::new(describe.statement_name)?;
let statement_name = statement_name_binding.as_bytes_with_nul();
let len = 4 + 1 + statement_name.len();

bytes.put_u8(describe.code as u8);
bytes.put_i32(len as i32);
bytes.put_u8(describe.target as u8);
bytes.put_slice(statement_name);

Ok(bytes)
}
}

impl Describe {
pub fn rename(mut self, name: &str) -> Self {
self.statement_name = name.to_string();
self
}

pub fn anonymous(&self) -> bool {
self.statement_name.is_empty()
}
}

pub fn prepared_statement_name() -> String {
format!(
"P_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
)
}
2 changes: 1 addition & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use postgres_protocol::message;
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeSet, HashMap, HashSet};
use std::mem;
use std::net::IpAddr;
use std::sync::{Arc, Once};
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.