From f70fd6250b0533b1f1995c8042932aa4d6ae2959 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:28:12 +0100 Subject: [PATCH 01/30] feat(proto): create proto crate --- proto/Cargo.toml | 10 ++++++++++ proto/src/lib.rs | 0 2 files changed, 10 insertions(+) create mode 100644 proto/Cargo.toml create mode 100644 proto/src/lib.rs diff --git a/proto/Cargo.toml b/proto/Cargo.toml new file mode 100644 index 0000000..6e5d2fb --- /dev/null +++ b/proto/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "proto" +version = "0.1.0" +edition = "2021" + +[dependencies] +bincode = "2.0.0-rc.3" +anyhow = "1.0.75" +tokio = { version = "1.34.0", features = ["io-util", "macros", "test-util"] } +async-trait = "0.1.74" diff --git a/proto/src/lib.rs b/proto/src/lib.rs new file mode 100644 index 0000000..e69de29 From aa649769d265e06e3fdcb16e2479b8967b067007 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:36:52 +0100 Subject: [PATCH 02/30] feat(proto): add protocol primitives --- proto/src/lib.rs | 1 + proto/src/message/mod.rs | 1 + proto/src/message/primitive/config.rs | 7 +++ proto/src/message/primitive/data.rs | 22 +++++++ proto/src/message/primitive/mod.rs | 4 ++ proto/src/message/primitive/pglist.rs | 84 +++++++++++++++++++++++++ proto/src/message/primitive/pgstring.rs | 54 ++++++++++++++++ 7 files changed, 173 insertions(+) create mode 100644 proto/src/message/mod.rs create mode 100644 proto/src/message/primitive/config.rs create mode 100644 proto/src/message/primitive/data.rs create mode 100644 proto/src/message/primitive/mod.rs create mode 100644 proto/src/message/primitive/pglist.rs create mode 100644 proto/src/message/primitive/pgstring.rs diff --git a/proto/src/lib.rs b/proto/src/lib.rs index e69de29..e216a50 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -0,0 +1 @@ +pub mod message; diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs new file mode 100644 index 0000000..2d8afe5 --- /dev/null +++ b/proto/src/message/mod.rs @@ -0,0 +1 @@ +pub mod primitive; diff --git a/proto/src/message/primitive/config.rs b/proto/src/message/primitive/config.rs new file mode 100644 index 0000000..5aa74f5 --- /dev/null +++ b/proto/src/message/primitive/config.rs @@ -0,0 +1,7 @@ +use bincode::config::{BigEndian, Configuration, Fixint}; + +pub fn pg_proto_config() -> Configuration { + bincode::config::standard() + .with_big_endian() + .with_fixed_int_encoding() +} diff --git a/proto/src/message/primitive/data.rs b/proto/src/message/primitive/data.rs new file mode 100644 index 0000000..41af0f6 --- /dev/null +++ b/proto/src/message/primitive/data.rs @@ -0,0 +1,22 @@ +use crate::message::primitive::config::pg_proto_config; +use bincode::{Decode, Encode}; + +pub trait MessageData: Sized { + fn deserialize(data: &[u8]) -> anyhow::Result; + fn serialize(&self) -> anyhow::Result>; +} + +impl MessageData for T +where + T: Encode + Decode, +{ + #[inline] + fn deserialize(data: &[u8]) -> anyhow::Result { + Ok(bincode::decode_from_slice(data, pg_proto_config())?.0) + } + + #[inline] + fn serialize(&self) -> anyhow::Result> { + Ok(bincode::encode_to_vec(self, pg_proto_config())?) + } +} diff --git a/proto/src/message/primitive/mod.rs b/proto/src/message/primitive/mod.rs new file mode 100644 index 0000000..e275e6e --- /dev/null +++ b/proto/src/message/primitive/mod.rs @@ -0,0 +1,4 @@ +pub(crate) mod config; +pub(crate) mod data; +pub mod pglist; +pub mod pgstring; diff --git a/proto/src/message/primitive/pglist.rs b/proto/src/message/primitive/pglist.rs new file mode 100644 index 0000000..aa95ca2 --- /dev/null +++ b/proto/src/message/primitive/pglist.rs @@ -0,0 +1,84 @@ +use bincode::de::Decoder; +use bincode::enc::Encoder; +use bincode::error::{DecodeError, EncodeError}; +use bincode::{BorrowDecode, Decode, Encode}; +use std::marker::PhantomData; + +#[derive(Debug, Clone, PartialEq, BorrowDecode)] +pub struct PgList(Vec, PhantomData); + +impl PgList { + pub fn as_slice(&self) -> &[T] { + &self.0 + } +} + +impl From> for Vec { + fn from(pg_list: PgList) -> Self { + pg_list.0 + } +} + +impl From> for PgList { + fn from(list: Vec) -> Self { + PgList(list, PhantomData) + } +} + +impl Encode for PgList +where + T: Encode, +{ + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let length = self.0.len() as i16; + length.encode(encoder)?; + for item in &self.0 { + item.encode(encoder)?; + } + Ok(()) + } +} + +impl Decode for PgList +where + T: Decode, +{ + fn decode(decoder: &mut D) -> Result { + let length = i16::decode(decoder)?; + let mut list = Vec::new(); + for _ in 0..length { + list.push(T::decode(decoder)?); + } + + Ok(PgList(list, PhantomData)) + } +} + +impl Encode for PgList +where + T: Encode, +{ + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let length = self.0.len() as i32; + length.encode(encoder)?; + for item in &self.0 { + item.encode(encoder)?; + } + Ok(()) + } +} + +impl Decode for PgList +where + T: Decode, +{ + fn decode(decoder: &mut D) -> Result { + let length = i32::decode(decoder)?; + let mut list = Vec::new(); + for _ in 0..length { + list.push(T::decode(decoder)?); + } + + Ok(PgList(list, PhantomData)) + } +} diff --git a/proto/src/message/primitive/pgstring.rs b/proto/src/message/primitive/pgstring.rs new file mode 100644 index 0000000..2c6cd7b --- /dev/null +++ b/proto/src/message/primitive/pgstring.rs @@ -0,0 +1,54 @@ +use bincode::de::Decoder; +use bincode::enc::write::Writer; +use bincode::enc::Encoder; +use bincode::error::{DecodeError, EncodeError}; +use bincode::{BorrowDecode, Decode, Encode}; + +#[derive(Debug, Clone, BorrowDecode)] +pub struct PgString(String); + +impl PgString { + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl From<&str> for PgString { + fn from(string: &str) -> Self { + PgString(string.to_string()) + } +} + +impl From for String { + fn from(pg_string: PgString) -> Self { + pg_string.0 + } +} + +impl From for PgString { + fn from(string: String) -> Self { + PgString(string) + } +} + +impl Encode for PgString { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + encoder.writer().write(self.0.as_bytes())?; + encoder.writer().write(b"\0") + } +} + +impl Decode for PgString { + fn decode(decoder: &mut D) -> Result { + let mut string = String::new(); + loop { + let byte = u8::decode(decoder)?; + if byte == 0 { + break; + } + string.push(byte as char); + } + + Ok(PgString(string)) + } +} From 4a9bc44a0f2dc9d5dfe0d7f86a7f6fa975f07c01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:42:25 +0100 Subject: [PATCH 03/30] feat(proto): add proto message trait --- proto/src/message/mod.rs | 1 + proto/src/message/proto_message.rs | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 proto/src/message/proto_message.rs diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs index 2d8afe5..79c2f96 100644 --- a/proto/src/message/mod.rs +++ b/proto/src/message/mod.rs @@ -1 +1,2 @@ pub mod primitive; +pub mod proto_message; diff --git a/proto/src/message/proto_message.rs b/proto/src/message/proto_message.rs new file mode 100644 index 0000000..f2a1a4d --- /dev/null +++ b/proto/src/message/proto_message.rs @@ -0,0 +1,5 @@ +pub trait ProtoMessage: Sized { + fn variant(&self) -> u8; + fn serialize(&self) -> anyhow::Result>; + fn deserialize(variant: u8, data: &[u8]) -> anyhow::Result; +} From 3512067c4ec5fbf77f4168b5bef9dc6aaec34c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:43:19 +0100 Subject: [PATCH 04/30] feat(proto): add backend messages --- proto/src/message/backend.rs | 357 +++++++++++++++++++++++++++++++++++ proto/src/message/mod.rs | 1 + 2 files changed, 358 insertions(+) create mode 100644 proto/src/message/backend.rs diff --git a/proto/src/message/backend.rs b/proto/src/message/backend.rs new file mode 100644 index 0000000..1d96833 --- /dev/null +++ b/proto/src/message/backend.rs @@ -0,0 +1,357 @@ +use crate::message::primitive::data::MessageData; +use crate::message::primitive::pglist::PgList; +use crate::message::primitive::pgstring::PgString; +use crate::message::proto_message::ProtoMessage; +use bincode::{Decode, Encode}; + +#[derive(Debug)] +pub enum BackendMessage { + AuthenticationOk(AuthenticationOkData), + BackendKeyData(BackendKeyDataData), + CommandComplete(CommandCompleteData), + DataRow(DataRowData), + EmptyQueryResponse, + ErrorResponse(ErrorResponseData), + NoData, + ParameterStatus(ParameterStatusData), + ReadyForQuery(ReadyForQueryData), + RowDescription(RowDescriptionData), +} + +impl ProtoMessage for BackendMessage { + fn variant(&self) -> u8 { + match self { + BackendMessage::AuthenticationOk(_) => b'R', + BackendMessage::BackendKeyData(_) => b'K', + BackendMessage::CommandComplete(_) => b'C', + BackendMessage::DataRow(_) => b'D', + BackendMessage::EmptyQueryResponse => b'I', + BackendMessage::ErrorResponse(_) => b'E', + BackendMessage::NoData => b'n', + BackendMessage::ParameterStatus(_) => b'S', + BackendMessage::ReadyForQuery(_) => b'Z', + BackendMessage::RowDescription(_) => b'T', + } + } + + fn serialize(&self) -> anyhow::Result> { + match self { + BackendMessage::AuthenticationOk(data) => data.serialize(), + BackendMessage::BackendKeyData(data) => data.serialize(), + BackendMessage::CommandComplete(data) => data.serialize(), + BackendMessage::DataRow(data) => data.serialize(), + BackendMessage::EmptyQueryResponse => Ok(Vec::with_capacity(0)), + BackendMessage::ErrorResponse(data) => data.serialize(), + BackendMessage::NoData => Ok(Vec::with_capacity(0)), + BackendMessage::ParameterStatus(data) => data.serialize(), + BackendMessage::ReadyForQuery(data) => data.serialize(), + BackendMessage::RowDescription(data) => data.serialize(), + } + } + + fn deserialize(variant: u8, data: &[u8]) -> anyhow::Result { + match variant { + b'R' => Ok(BackendMessage::AuthenticationOk( + AuthenticationOkData::deserialize(data)?, + )), + b'K' => { + let data = BackendKeyDataData::deserialize(data)?; + Ok(BackendMessage::BackendKeyData(data)) + } + b'C' => { + let data = CommandCompleteData::deserialize(data)?; + Ok(BackendMessage::CommandComplete(data)) + } + b'D' => { + let data = DataRowData::deserialize(data)?; + Ok(BackendMessage::DataRow(data)) + } + b'I' => Ok(BackendMessage::EmptyQueryResponse), + b'E' => { + let data = ErrorResponseData::deserialize(data)?; + Ok(BackendMessage::ErrorResponse(data)) + } + b'n' => Ok(BackendMessage::NoData), + b'S' => { + let data = ParameterStatusData::deserialize(data)?; + Ok(BackendMessage::ParameterStatus(data)) + } + b'Z' => { + let data = ReadyForQueryData::deserialize(data)?; + Ok(BackendMessage::ReadyForQuery(data)) + } + b'T' => { + let data = RowDescriptionData::deserialize(data)?; + Ok(BackendMessage::RowDescription(data)) + } + v => Err(anyhow::anyhow!("Unknown backend message variant: {v}")), + } + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct AuthenticationOkData { + pub status: i32, +} + +impl From for BackendMessage { + fn from(data: AuthenticationOkData) -> Self { + BackendMessage::AuthenticationOk(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct BackendKeyDataData { + pub process: i32, + pub secret: i32, +} + +impl From for BackendMessage { + fn from(data: BackendKeyDataData) -> Self { + BackendMessage::BackendKeyData(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct CommandCompleteData { + pub tag: PgString, +} + +impl From for BackendMessage { + fn from(data: CommandCompleteData) -> Self { + BackendMessage::CommandComplete(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct DataRowData { + pub columns: PgList, i16>, +} + +impl From for BackendMessage { + fn from(data: DataRowData) -> Self { + BackendMessage::DataRow(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ErrorResponseData { + pub code: u8, + pub message: PgString, +} + +impl From for BackendMessage { + fn from(data: ErrorResponseData) -> Self { + BackendMessage::ErrorResponse(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ParameterStatusData { + pub name: PgString, + pub value: PgString, +} + +impl From for BackendMessage { + fn from(data: ParameterStatusData) -> Self { + BackendMessage::ParameterStatus(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ReadyForQueryData { + pub status: u8, +} + +impl From for BackendMessage { + fn from(data: ReadyForQueryData) -> Self { + BackendMessage::ReadyForQuery(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct RowDescriptionData { + pub columns: PgList, +} + +impl From for BackendMessage { + fn from(data: RowDescriptionData) -> Self { + BackendMessage::RowDescription(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ColumnDescription { + pub name: PgString, + pub table_oid: i32, + pub column_index: i16, + pub type_oid: i32, + pub type_size: i16, + pub type_modifier: i32, + pub format_code: i16, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symmetric_authentication_ok() { + let backend = BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 }) + )); + } + + #[test] + fn test_symmetric_backend_key_data() { + let backend = BackendMessage::BackendKeyData(BackendKeyDataData { + process: 123, + secret: 456, + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::BackendKeyData(BackendKeyDataData { + process: 123, + secret: 456 + }) + )); + } + + #[test] + fn test_symmetric_command_complete() { + let backend = BackendMessage::CommandComplete(CommandCompleteData { + tag: PgString::from("SELECT 1"), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::CommandComplete(CommandCompleteData { tag }) if tag.as_str() == "SELECT 1" + )); + } + + #[test] + fn test_symmetric_data_row() { + let backend = BackendMessage::DataRow(DataRowData { + columns: PgList::from(vec![PgList::from(vec![1, 2, 3])]), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::DataRow(DataRowData { columns }) if columns == PgList::from(vec![PgList::from(vec![1, 2, 3])]) + )); + } + + #[test] + fn test_symmetric_empty_query_response() { + let backend = BackendMessage::EmptyQueryResponse; + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!(message, BackendMessage::EmptyQueryResponse)); + } + + #[test] + fn test_symmetric_error_response() { + let backend = BackendMessage::ErrorResponse(ErrorResponseData { + code: b'X', + message: PgString::from("Some error"), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::ErrorResponse(ErrorResponseData { code, message }) if code == b'X' && message.as_str() == "Some error" + )); + } + + #[test] + fn test_symmetric_no_data() { + let backend = BackendMessage::NoData; + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!(message, BackendMessage::NoData)); + } + + #[test] + fn test_symmetric_parameter_status() { + let backend = BackendMessage::ParameterStatus(ParameterStatusData { + name: PgString::from("Some name"), + value: PgString::from("Some value"), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::ParameterStatus(ParameterStatusData { name, value }) if name.as_str() == "Some name" && value.as_str() == "Some value" + )); + } + + #[test] + fn test_symmetric_ready_for_query() { + let backend = BackendMessage::ReadyForQuery(ReadyForQueryData { status: b'I' }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::ReadyForQuery(ReadyForQueryData { status }) if status == b'I' + )); + } + + #[test] + fn test_symmetric_row_description() { + let backend = BackendMessage::RowDescription(RowDescriptionData { + columns: PgList::from(vec![ColumnDescription { + name: PgString::from("Some name"), + table_oid: 123, + column_index: 456, + type_oid: 789, + type_size: 101, + type_modifier: 112, + format_code: 113, + }]), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(match message { + BackendMessage::RowDescription(RowDescriptionData { columns }) => { + let columns: Vec = columns.into(); + let column = &columns[0]; + column.name.as_str() == "Some name" + && column.table_oid == 123 + && column.column_index == 456 + && column.type_oid == 789 + && column.type_size == 101 + && column.type_modifier == 112 + && column.format_code == 113 + } + _ => false, + },) + } +} diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs index 79c2f96..5b3cd7a 100644 --- a/proto/src/message/mod.rs +++ b/proto/src/message/mod.rs @@ -1,2 +1,3 @@ +pub mod backend; pub mod primitive; pub mod proto_message; From ee2742ea03eb5facdeb586cfee9ea821aa7932cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:43:46 +0100 Subject: [PATCH 05/30] feat(proto): add frontend messages --- proto/src/message/frontend.rs | 68 +++++++++++++++++++++++++++++++++++ proto/src/message/mod.rs | 1 + 2 files changed, 69 insertions(+) create mode 100644 proto/src/message/frontend.rs diff --git a/proto/src/message/frontend.rs b/proto/src/message/frontend.rs new file mode 100644 index 0000000..4914fc2 --- /dev/null +++ b/proto/src/message/frontend.rs @@ -0,0 +1,68 @@ +use crate::message::primitive::data::MessageData; +use crate::message::primitive::pgstring::PgString; +use crate::message::proto_message::ProtoMessage; +use bincode::{Decode, Encode}; + +#[derive(Debug)] +pub enum FrontendMessage { + Query(QueryData), + Terminate, +} + +impl ProtoMessage for FrontendMessage { + fn variant(&self) -> u8 { + match self { + FrontendMessage::Query(_) => b'Q', + FrontendMessage::Terminate => b'X', + } + } + + fn serialize(&self) -> anyhow::Result> { + match self { + FrontendMessage::Query(data) => data.serialize(), + FrontendMessage::Terminate => Ok(Vec::with_capacity(0)), + } + } + + fn deserialize(variant: u8, data: &[u8]) -> anyhow::Result { + match variant { + b'Q' => Ok(FrontendMessage::Query(QueryData::deserialize(data)?)), + b'X' => Ok(FrontendMessage::Terminate), + v => Err(anyhow::anyhow!("Unknown frontend message variant: {v}")), + } + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct QueryData { + pub query: PgString, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symmetric_query() { + let frontend = FrontendMessage::Query(QueryData { + query: PgString::from("SELECT * FROM foo WHERE bar = $1"), + }); + let raw = frontend.serialize().unwrap(); + let variant = frontend.variant(); + + let message = FrontendMessage::deserialize(variant, &raw).unwrap(); + assert!( + matches!(message, FrontendMessage::Query(QueryData { query }) if query.as_str() == "SELECT * FROM foo WHERE bar = $1") + ); + } + + #[test] + fn test_symmetric_terminate() { + let frontend = FrontendMessage::Terminate; + let raw = frontend.serialize().unwrap(); + let variant = frontend.variant(); + + let message = FrontendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!(message, FrontendMessage::Terminate)); + } +} diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs index 5b3cd7a..30ae1a3 100644 --- a/proto/src/message/mod.rs +++ b/proto/src/message/mod.rs @@ -1,3 +1,4 @@ pub mod backend; +pub mod frontend; pub mod primitive; pub mod proto_message; From 65f90ba600edbbf7e5cfc671631ff6851fd1e839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:44:35 +0100 Subject: [PATCH 06/30] feat(proto): add special messages --- proto/src/message/mod.rs | 1 + proto/src/message/special.rs | 60 ++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 proto/src/message/special.rs diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs index 30ae1a3..527e584 100644 --- a/proto/src/message/mod.rs +++ b/proto/src/message/mod.rs @@ -2,3 +2,4 @@ pub mod backend; pub mod frontend; pub mod primitive; pub mod proto_message; +pub mod special; diff --git a/proto/src/message/special.rs b/proto/src/message/special.rs new file mode 100644 index 0000000..166320f --- /dev/null +++ b/proto/src/message/special.rs @@ -0,0 +1,60 @@ +use crate::message::primitive::pgstring::PgString; +use bincode::de::Decoder; +use bincode::enc::Encoder; +use bincode::error::{DecodeError, EncodeError}; +use bincode::{Decode, Encode}; + +#[derive(Debug)] +pub enum SpecialMessage { + CancelRequest(CancelRequestData), + SSLRequest, + StartupMessage(StartupMessageData), +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct CancelRequestData { + pub pid: i32, + pub secret: i32, +} + +#[derive(Debug, Clone)] +pub struct StartupMessageData { + pub version: i32, + pub params: Vec<(PgString, PgString)>, +} + +impl Encode for StartupMessageData { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + self.version.encode(encoder)?; + for (key, value) in &self.params { + key.encode(encoder)?; + value.encode(encoder)?; + } + Ok(()) + } +} + +impl Decode for StartupMessageData { + fn decode(decoder: &mut D) -> Result { + let version = i32::decode(decoder)?; + let mut params = Vec::new(); + loop { + let maybe_key = PgString::decode(decoder); + match maybe_key { + Ok(_) => {} + Err(DecodeError::UnexpectedEnd { .. }) => break, + Err(e) => return Err(e), + } + + let maybe_value = PgString::decode(decoder); + match maybe_value { + Ok(_) => {} + Err(DecodeError::UnexpectedEnd { .. }) => break, + Err(e) => return Err(e), + } + + params.push((maybe_key.unwrap(), maybe_value.unwrap())); + } + Ok(StartupMessageData { version, params }) + } +} From 0a3683e2fa750a83f8b56d450a8e9f12adb72dd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:49:30 +0100 Subject: [PATCH 07/30] feat(proto): add generic proto writer --- proto/src/lib.rs | 1 + proto/src/writer/mod.rs | 3 +++ proto/src/writer/oneway.rs | 31 ++++++++++++++++++++++++++++++ proto/src/writer/protowriter.rs | 34 +++++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+) create mode 100644 proto/src/writer/mod.rs create mode 100644 proto/src/writer/oneway.rs create mode 100644 proto/src/writer/protowriter.rs diff --git a/proto/src/lib.rs b/proto/src/lib.rs index e216a50..c65ece0 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -1 +1,2 @@ pub mod message; +pub mod writer; diff --git a/proto/src/writer/mod.rs b/proto/src/writer/mod.rs new file mode 100644 index 0000000..fac68e8 --- /dev/null +++ b/proto/src/writer/mod.rs @@ -0,0 +1,3 @@ + +pub mod oneway; +pub mod protowriter; diff --git a/proto/src/writer/oneway.rs b/proto/src/writer/oneway.rs new file mode 100644 index 0000000..1649fd8 --- /dev/null +++ b/proto/src/writer/oneway.rs @@ -0,0 +1,31 @@ +use crate::message::proto_message::ProtoMessage; +use crate::writer::protowriter::ProtoWriter; +use async_trait::async_trait; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +#[async_trait] +pub trait OneWayProtoWriter +where + T: ProtoMessage, +{ + async fn write_proto(&mut self, message: T) -> anyhow::Result<()>; +} + +#[async_trait] +impl OneWayProtoWriter for ProtoWriter +where + W: AsyncWrite + Unpin + Send, + T: ProtoMessage + Send + 'static, +{ + async fn write_proto(&mut self, message: T) -> anyhow::Result<()> { + let variant = message.variant(); + let mut data = message.serialize()?; + let length = data.len() as i32 + 4; + + self.inner.write_u8(variant).await?; + self.inner.write_i32(length).await?; + self.inner.write_all(&mut data).await?; + + Ok(()) + } +} diff --git a/proto/src/writer/protowriter.rs b/proto/src/writer/protowriter.rs new file mode 100644 index 0000000..848a727 --- /dev/null +++ b/proto/src/writer/protowriter.rs @@ -0,0 +1,34 @@ +use async_trait::async_trait; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +pub struct ProtoWriter +where + W: AsyncWrite + Unpin + Send, +{ + pub(super) inner: W, +} + +impl ProtoWriter +where + W: AsyncWrite + Unpin + Send, +{ + pub fn new(writer: W) -> ProtoWriter { + ProtoWriter { inner: writer } + } +} + +#[async_trait] +pub trait ProtoFlush { + async fn flush(&mut self) -> anyhow::Result<()>; +} + +#[async_trait] +impl ProtoFlush for ProtoWriter +where + W: AsyncWrite + Unpin + Send, +{ + async fn flush(&mut self) -> anyhow::Result<()> { + self.inner.flush().await?; + Ok(()) + } +} From 225f9e43d30f4475c58e2e6bd0257722dbefb500 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:50:06 +0100 Subject: [PATCH 08/30] feat(proto): add backend message writer --- proto/src/writer/backend.rs | 59 +++++++++++++++++++++++++++++++++++++ proto/src/writer/mod.rs | 1 + 2 files changed, 60 insertions(+) create mode 100644 proto/src/writer/backend.rs diff --git a/proto/src/writer/backend.rs b/proto/src/writer/backend.rs new file mode 100644 index 0000000..4cbbdc8 --- /dev/null +++ b/proto/src/writer/backend.rs @@ -0,0 +1,59 @@ +use crate::message::backend::BackendMessage; +use crate::writer::oneway::OneWayProtoWriter; +use crate::writer::protowriter::ProtoWriter; +use async_trait::async_trait; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +#[async_trait] +pub trait BackendProtoWriter: OneWayProtoWriter { + async fn write_ssl_reject(&mut self) -> anyhow::Result<()>; +} + +#[async_trait] +impl BackendProtoWriter for ProtoWriter +where + W: AsyncWrite + Unpin + Send, +{ + async fn write_ssl_reject(&mut self) -> anyhow::Result<()> { + self.inner.write_u8(b'N').await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::backend::AuthenticationOkData; + use crate::writer::protowriter::ProtoWriter; + use tokio::io::BufWriter; + + #[tokio::test] + async fn test_message_sequence() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer + .write_proto(BackendMessage::AuthenticationOk(AuthenticationOkData { + status: 123, + })) + .await + .unwrap(); + + writer.write_proto(BackendMessage::NoData).await.unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![b'R', 0, 0, 0, 8, 0, 0, 0, 123, b'n', 0, 0, 0, 4] + ); + } + + #[tokio::test] + async fn test_ssl_reject() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer.write_ssl_reject().await.unwrap(); + + assert_eq!(writer.inner.buffer(), vec![b'N']); + } +} diff --git a/proto/src/writer/mod.rs b/proto/src/writer/mod.rs index fac68e8..58e55a6 100644 --- a/proto/src/writer/mod.rs +++ b/proto/src/writer/mod.rs @@ -1,3 +1,4 @@ +pub mod backend; pub mod oneway; pub mod protowriter; From 67af05ea4233a53760b8eb5d41c3149adbaddc92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:50:33 +0100 Subject: [PATCH 09/30] feat(proto): add frontend message writer --- proto/src/writer/frontend.rs | 40 ++++++++++++++++++++++++++++++++++++ proto/src/writer/mod.rs | 2 +- 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 proto/src/writer/frontend.rs diff --git a/proto/src/writer/frontend.rs b/proto/src/writer/frontend.rs new file mode 100644 index 0000000..ba634cc --- /dev/null +++ b/proto/src/writer/frontend.rs @@ -0,0 +1,40 @@ +use crate::message::frontend::FrontendMessage; +use crate::writer::oneway::OneWayProtoWriter; +use async_trait::async_trait; + +#[async_trait] +pub trait FrontendProtoWriter: OneWayProtoWriter {} + +#[async_trait] +impl FrontendProtoWriter for W where W: OneWayProtoWriter {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::frontend::QueryData; + use crate::writer::protowriter::ProtoWriter; + use tokio::io::BufWriter; + + #[tokio::test] + async fn test_message_sequence() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer + .write_proto(FrontendMessage::Query(QueryData { + query: "SLIME".into(), + })) + .await + .unwrap(); + + writer + .write_proto(FrontendMessage::Terminate) + .await + .unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4] + ); + } +} diff --git a/proto/src/writer/mod.rs b/proto/src/writer/mod.rs index 58e55a6..3f16f05 100644 --- a/proto/src/writer/mod.rs +++ b/proto/src/writer/mod.rs @@ -1,4 +1,4 @@ - pub mod backend; +pub mod frontend; pub mod oneway; pub mod protowriter; From 413e0216e367b47741523e95a16bbab3537680c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:53:21 +0100 Subject: [PATCH 10/30] feat(proto): add generic proto reader --- proto/src/lib.rs | 1 + proto/src/reader/mod.rs | 3 +++ proto/src/reader/oneway.rs | 37 +++++++++++++++++++++++++++++++++ proto/src/reader/protoreader.rs | 22 ++++++++++++++++++++ proto/src/reader/utils.rs | 24 +++++++++++++++++++++ 5 files changed, 87 insertions(+) create mode 100644 proto/src/reader/mod.rs create mode 100644 proto/src/reader/oneway.rs create mode 100644 proto/src/reader/protoreader.rs create mode 100644 proto/src/reader/utils.rs diff --git a/proto/src/lib.rs b/proto/src/lib.rs index c65ece0..c964d21 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -1,2 +1,3 @@ pub mod message; +pub mod reader; pub mod writer; diff --git a/proto/src/reader/mod.rs b/proto/src/reader/mod.rs new file mode 100644 index 0000000..450e7d4 --- /dev/null +++ b/proto/src/reader/mod.rs @@ -0,0 +1,3 @@ +pub mod oneway; +pub mod protoreader; +mod utils; diff --git a/proto/src/reader/oneway.rs b/proto/src/reader/oneway.rs new file mode 100644 index 0000000..4cca72b --- /dev/null +++ b/proto/src/reader/oneway.rs @@ -0,0 +1,37 @@ +use crate::message::proto_message::ProtoMessage; +use crate::reader::protoreader::ProtoReader; +use crate::reader::utils::AsyncPeek; +use async_trait::async_trait; +use tokio::io::{AsyncBufRead, AsyncReadExt}; + +#[async_trait] +pub trait OneWayProtoReader +where + T: ProtoMessage, +{ + async fn read_proto(&mut self) -> anyhow::Result; +} + +#[async_trait] +impl OneWayProtoReader for ProtoReader +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, + T: ProtoMessage, +{ + async fn read_proto(&mut self) -> anyhow::Result { + let variant = self.inner.read_u8().await?; + let length = self.inner.read_i32().await?; + + if length < 4 { + return Err(anyhow::anyhow!("Invalid message length")); + } + if length > self.msg_len_limit { + return Err(anyhow::anyhow!("Message length over limit")); + } + + let mut data = vec![0u8; (length - 4) as usize]; + self.inner.read_exact(&mut data).await?; + + T::deserialize(variant, &data) + } +} diff --git a/proto/src/reader/protoreader.rs b/proto/src/reader/protoreader.rs new file mode 100644 index 0000000..5e3f572 --- /dev/null +++ b/proto/src/reader/protoreader.rs @@ -0,0 +1,22 @@ +use crate::reader::utils::AsyncPeek; +use tokio::io::AsyncBufRead; + +pub struct ProtoReader +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + pub(super) inner: R, + pub(super) msg_len_limit: i32, +} + +impl ProtoReader +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + pub fn new(reader: R, msg_len_limit: i32) -> ProtoReader { + ProtoReader { + inner: reader, + msg_len_limit, + } + } +} diff --git a/proto/src/reader/utils.rs b/proto/src/reader/utils.rs new file mode 100644 index 0000000..e4a70eb --- /dev/null +++ b/proto/src/reader/utils.rs @@ -0,0 +1,24 @@ +use async_trait::async_trait; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +#[async_trait] +pub trait AsyncPeek { + async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result; +} + +#[async_trait] +impl AsyncPeek for T +where + T: AsyncBufRead + Unpin + Send, +{ + async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result { + let filled = self.fill_buf().await?; + if filled.len() >= buf.len() { + buf.copy_from_slice(&filled[..buf.len()]); + Ok(buf.len()) + } else { + buf[..filled.len()].copy_from_slice(filled); + Ok(filled.len()) + } + } +} From 0a6e486005510ab89cd059548f1f5ffc25bf403c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:54:02 +0100 Subject: [PATCH 11/30] feat(proto): add frontend message reader --- proto/src/reader/frontend.rs | 279 +++++++++++++++++++++++++++++++++++ proto/src/reader/mod.rs | 1 + 2 files changed, 280 insertions(+) create mode 100644 proto/src/reader/frontend.rs diff --git a/proto/src/reader/frontend.rs b/proto/src/reader/frontend.rs new file mode 100644 index 0000000..fc6ceca --- /dev/null +++ b/proto/src/reader/frontend.rs @@ -0,0 +1,279 @@ +use crate::message::frontend::FrontendMessage; +use crate::message::primitive::config::pg_proto_config; +use crate::message::special::{CancelRequestData, SpecialMessage}; +use crate::reader::oneway::OneWayProtoReader; +use crate::reader::protoreader::ProtoReader; +use crate::reader::utils::AsyncPeek; +use anyhow::anyhow; +use async_trait::async_trait; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +#[async_trait] +pub trait FrontendProtoReader: OneWayProtoReader { + async fn peek_special_message(&mut self) -> anyhow::Result>; + async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()>; +} + +#[async_trait] +impl FrontendProtoReader for ProtoReader +where + R: AsyncBufRead + Unpin + Send, +{ + async fn peek_special_message(&mut self) -> anyhow::Result> { + if let Some(cancel) = try_get_cancel_request(&mut self).await? { + return Ok(Some(cancel)); + } + + if let Some(ssl) = try_get_ssl_request(&mut self).await? { + return Ok(Some(ssl)); + } + + if let Some(startup) = try_get_startup_message(&mut self).await? { + return Ok(Some(startup)); + } + + Ok(None) + } + + async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()> { + Ok(match msg { + SpecialMessage::CancelRequest(_) => consume_cancel_request(self), + SpecialMessage::SSLRequest => consume_ssl_request(self), + SpecialMessage::StartupMessage(_) => consume_startup_message(self).await?, + }) + } +} + +async fn try_get_cancel_request( + reader: &mut ProtoReader, +) -> anyhow::Result> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 16]; + if reader.inner.peek(&mut header).await? != 16 { + return Ok(None); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]); + if length != 16 { + return Ok(None); + } + + let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); + if code != 80877102 { + return Ok(None); + } + + let pid = i32::from_be_bytes([header[8], header[9], header[10], header[11]]); + let secret = i32::from_be_bytes([header[12], header[13], header[14], header[15]]); + + Ok(Some(SpecialMessage::CancelRequest(CancelRequestData { + pid, + secret, + }))) +} + +fn consume_cancel_request(reader: &mut ProtoReader) +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + reader.inner.consume(16); +} + +async fn try_get_ssl_request( + reader: &mut ProtoReader, +) -> anyhow::Result> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 8]; + if reader.inner.peek(&mut header).await? != 8 { + return Ok(None); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]); + if length != 8 { + return Ok(None); + } + + let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); + if code != 80877103 { + return Ok(None); + } + + Ok(Some(SpecialMessage::SSLRequest)) +} + +fn consume_ssl_request(reader: &mut ProtoReader) +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + reader.inner.consume(8); +} + +async fn try_get_startup_message( + reader: &mut ProtoReader, +) -> anyhow::Result> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 8]; + if reader.inner.peek(&mut header).await? != 8 { + return Ok(None); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]); + if length < 8 { + return Ok(None); + } + if length > reader.msg_len_limit { + return Err(anyhow!("Message length is over the limit")); + } + + let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); + if version != 196608 { + return Ok(None); + } + + let length = length as usize; + let mut data = vec![0u8; length]; + if reader.inner.peek(&mut data).await? != length { + return Ok(None); + } + + let data = bincode::decode_from_slice(&data[4..], pg_proto_config())?.0; + + Ok(Some(SpecialMessage::StartupMessage(data))) +} + +async fn consume_startup_message(reader: &mut ProtoReader) -> anyhow::Result<()> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 4]; + if reader.inner.peek(&mut header).await? != 4 { + return Err(anyhow!("Invalid header peek length")); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize; + if length < 8 { + return Err(anyhow!("Invalid startup message length")); + } + + reader.inner.consume(length); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::frontend::QueryData; + use crate::message::special::StartupMessageData; + use std::io::Cursor; + use tokio::io::{AsyncBufReadExt, BufReader}; + + #[tokio::test] + async fn test_message_sequence() { + let data = [ + b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let msg = reader.read_proto().await; + assert!( + match &msg { + Ok(FrontendMessage::Query(QueryData { query })) => query.as_str() == "SLIME", + _ => false, + }, + "{msg:?}" + ); + + let msg = reader.read_proto().await; + assert!(matches!(msg, Ok(FrontendMessage::Terminate)), "{msg:?}"); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } + + #[tokio::test] + async fn test_cancel_request() { + let data = [ + 0, 0, 0, 16, 0x04, 0xD2, 0x16, 0x2E, 0, 0, 0, 111, 0, 0, 0, 222, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let peeked = reader.peek_special_message().await.unwrap(); + assert!(matches!( + peeked, + Some(SpecialMessage::CancelRequest(CancelRequestData { + pid: 111, + secret: 222 + })) + )); + + reader + .consume_special_message(&peeked.unwrap()) + .await + .unwrap(); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } + + #[tokio::test] + async fn test_ssl_request() { + let data = [0, 0, 0, 8, 0x04, 0xD2, 0x16, 0x2F]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let peeked = reader.peek_special_message().await.unwrap(); + assert!(matches!(peeked, Some(SpecialMessage::SSLRequest))); + + reader + .consume_special_message(&peeked.unwrap()) + .await + .unwrap(); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } + + #[tokio::test] + async fn test_startup_message() { + let data = [ + 0, 0, 0, 26, 0, 3, 0, 0, b'd', b'a', b't', b'a', b'b', b'a', b's', b'e', 0, b'b', b'r', + b'a', b'n', b'i', b'k', 0, 0, 0, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let peeked = reader.peek_special_message().await.unwrap(); + assert!(match &peeked { + Some(SpecialMessage::StartupMessage(StartupMessageData { + version: 196608, + params, + })) => + params.len() == 2 + && params[0].0.as_str() == "database" + && params[0].1.as_str() == "branik" + && params[1].0.as_str() == "" + && params[1].1.as_str() == "", + _ => false, + }); + + reader + .consume_special_message(&peeked.unwrap()) + .await + .unwrap(); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } +} diff --git a/proto/src/reader/mod.rs b/proto/src/reader/mod.rs index 450e7d4..53bf75c 100644 --- a/proto/src/reader/mod.rs +++ b/proto/src/reader/mod.rs @@ -1,3 +1,4 @@ +pub mod frontend; pub mod oneway; pub mod protoreader; mod utils; From 393bc0a75146ab73b1b33e5ad86e762172c046c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:54:12 +0100 Subject: [PATCH 12/30] feat(proto): add backend message reader --- proto/src/reader/backend.rs | 58 +++++++++++++++++++++++++++++++++++++ proto/src/reader/mod.rs | 1 + 2 files changed, 59 insertions(+) create mode 100644 proto/src/reader/backend.rs diff --git a/proto/src/reader/backend.rs b/proto/src/reader/backend.rs new file mode 100644 index 0000000..33db099 --- /dev/null +++ b/proto/src/reader/backend.rs @@ -0,0 +1,58 @@ +use crate::message::backend::BackendMessage; +use crate::reader::oneway::OneWayProtoReader; +use async_trait::async_trait; + +#[async_trait] +pub trait BackendProtoReader: OneWayProtoReader {} + +#[async_trait] +impl BackendProtoReader for R where R: OneWayProtoReader {} + +#[cfg(test)] +mod tests { + use crate::message::backend::{ + AuthenticationOkData, BackendKeyDataData, BackendMessage, CommandCompleteData, + }; + use crate::reader::oneway::OneWayProtoReader; + use crate::reader::protoreader::ProtoReader; + use std::io::Cursor; + use tokio::io::{AsyncBufReadExt, BufReader}; + + #[tokio::test] + async fn test_message_sequence() { + let data = [ + b'R', 0, 0, 0, 8, 0, 0, 0, 123, b'K', 0, 0, 0, 12, 0, 0, 0, 111, 0, 0, 0, 222, b'C', 0, + 0, 0, 8, b'A', b'B', b'C', 0, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let msg = reader.read_proto().await; + assert!(matches!( + msg, + Ok(BackendMessage::AuthenticationOk(AuthenticationOkData { + status: 123 + })) + )); + + let msg = reader.read_proto().await; + assert!(matches!( + msg, + Ok(BackendMessage::BackendKeyData(BackendKeyDataData { + process: 111, + secret: 222 + })) + )); + + let msg = reader.read_proto().await; + assert!(match msg { + Ok(BackendMessage::CommandComplete(CommandCompleteData { tag })) => + tag.as_str() == "ABC", + _ => false, + }); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } +} diff --git a/proto/src/reader/mod.rs b/proto/src/reader/mod.rs index 53bf75c..3c85a0f 100644 --- a/proto/src/reader/mod.rs +++ b/proto/src/reader/mod.rs @@ -1,3 +1,4 @@ +pub mod backend; pub mod frontend; pub mod oneway; pub mod protoreader; From dbd0ef397015fa2acd65758975a72ffb7b35a36e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:54:41 +0100 Subject: [PATCH 13/30] feat(proto): add server handshake handler --- proto/src/handshake/mod.rs | 1 + proto/src/handshake/server.rs | 59 +++++++++++++++++++++++++++++++++++ proto/src/lib.rs | 1 + 3 files changed, 61 insertions(+) create mode 100644 proto/src/handshake/mod.rs create mode 100644 proto/src/handshake/server.rs diff --git a/proto/src/handshake/mod.rs b/proto/src/handshake/mod.rs new file mode 100644 index 0000000..74f47ad --- /dev/null +++ b/proto/src/handshake/mod.rs @@ -0,0 +1 @@ +pub mod server; diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs new file mode 100644 index 0000000..660ea0c --- /dev/null +++ b/proto/src/handshake/server.rs @@ -0,0 +1,59 @@ +use crate::message::backend::{ + AuthenticationOkData, BackendKeyDataData, BackendMessage, ParameterStatusData, + ReadyForQueryData, +}; +use crate::message::special::{SpecialMessage, StartupMessageData}; +use crate::reader::frontend::FrontendProtoReader; +use crate::writer::backend::BackendProtoWriter; +use crate::writer::protowriter::ProtoFlush; + +pub async fn do_server_handshake( + writer: &mut (impl BackendProtoWriter + ProtoFlush), + reader: &mut impl FrontendProtoReader, + name: &str, + process: i32, + secret: i32, +) -> anyhow::Result { + match &reader.peek_special_message().await? { + Some(msg @ SpecialMessage::SSLRequest) => { + reader.consume_special_message(msg).await?; + writer.write_ssl_reject().await?; + writer.flush().await?; + } + _ => { + // No SSL request + } + } + + let startup_message = match &reader.peek_special_message().await? { + Some(msg @ SpecialMessage::StartupMessage(data)) => { + reader.consume_special_message(msg).await?; + data.clone() + } + _ => { + return Err(anyhow::anyhow!("Expected Startup Message")); + } + }; + + writer + .write_proto(BackendMessage::from(AuthenticationOkData { status: 0 })) + .await?; + + writer + .write_proto(BackendMessage::from(ParameterStatusData { + name: "server_version".to_string().into(), + value: format!("16.0 ({name})").into(), + })) + .await?; + + writer + .write_proto(BackendMessage::from(BackendKeyDataData { process, secret })) + .await?; + + writer + .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) + .await?; + + writer.flush().await?; + Ok(startup_message) +} diff --git a/proto/src/lib.rs b/proto/src/lib.rs index c964d21..69afed6 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -1,3 +1,4 @@ +pub mod handshake; pub mod message; pub mod reader; pub mod writer; From bb39d138d89a50a025ab591712fd6daaef705743 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Tue, 12 Dec 2023 19:31:27 +0100 Subject: [PATCH 14/30] refactor(proto): start replacing anyhow with thiserror --- proto/Cargo.toml | 1 + proto/src/message/backend.rs | 16 +++++++++++++--- proto/src/message/errors.rs | 16 ++++++++++++++++ proto/src/message/frontend.rs | 17 ++++++++++++++--- proto/src/message/mod.rs | 1 + proto/src/message/primitive/data.rs | 13 +++++++------ proto/src/message/proto_message.rs | 6 ++++-- proto/src/reader/oneway.rs | 2 +- 8 files changed, 57 insertions(+), 15 deletions(-) create mode 100644 proto/src/message/errors.rs diff --git a/proto/Cargo.toml b/proto/Cargo.toml index 6e5d2fb..9fb51ce 100644 --- a/proto/Cargo.toml +++ b/proto/Cargo.toml @@ -8,3 +8,4 @@ bincode = "2.0.0-rc.3" anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["io-util", "macros", "test-util"] } async-trait = "0.1.74" +thiserror = "1.0.50" diff --git a/proto/src/message/backend.rs b/proto/src/message/backend.rs index 1d96833..7ba0004 100644 --- a/proto/src/message/backend.rs +++ b/proto/src/message/backend.rs @@ -3,6 +3,7 @@ use crate::message::primitive::pglist::PgList; use crate::message::primitive::pgstring::PgString; use crate::message::proto_message::ProtoMessage; use bincode::{Decode, Encode}; +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; #[derive(Debug)] pub enum BackendMessage { @@ -34,7 +35,7 @@ impl ProtoMessage for BackendMessage { } } - fn serialize(&self) -> anyhow::Result> { + fn serialize(&self) -> Result, ProtoSerializeError> { match self { BackendMessage::AuthenticationOk(data) => data.serialize(), BackendMessage::BackendKeyData(data) => data.serialize(), @@ -49,7 +50,7 @@ impl ProtoMessage for BackendMessage { } } - fn deserialize(variant: u8, data: &[u8]) -> anyhow::Result { + fn deserialize(variant: u8, data: &[u8]) -> Result { match variant { b'R' => Ok(BackendMessage::AuthenticationOk( AuthenticationOkData::deserialize(data)?, @@ -84,7 +85,7 @@ impl ProtoMessage for BackendMessage { let data = RowDescriptionData::deserialize(data)?; Ok(BackendMessage::RowDescription(data)) } - v => Err(anyhow::anyhow!("Unknown backend message variant: {v}")), + v => Err(ProtoDeserializeError::InvalidVariant(v)), } } } @@ -354,4 +355,13 @@ mod tests { _ => false, },) } + + #[test] + fn test_unknown_variant() { + let variant = 0; + let data = vec![1, 2, 3]; + + let message = BackendMessage::deserialize(variant, &data); + assert!(matches!(message, Err(ProtoDeserializeError::InvalidVariant(0)))); + } } diff --git a/proto/src/message/errors.rs b/proto/src/message/errors.rs new file mode 100644 index 0000000..7a3e04e --- /dev/null +++ b/proto/src/message/errors.rs @@ -0,0 +1,16 @@ +use bincode::error::{DecodeError, EncodeError}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ProtoDeserializeError { + #[error("invalid message variant: {0}")] + InvalidVariant(u8), + #[error("decoding of inner data failed")] + DecodeData(#[from] DecodeError), +} + +#[derive(Debug, Error)] +pub enum ProtoSerializeError { + #[error("encoding of inner data failed")] + EncodeData(#[from] EncodeError), +} diff --git a/proto/src/message/frontend.rs b/proto/src/message/frontend.rs index 4914fc2..6d7f993 100644 --- a/proto/src/message/frontend.rs +++ b/proto/src/message/frontend.rs @@ -2,6 +2,7 @@ use crate::message::primitive::data::MessageData; use crate::message::primitive::pgstring::PgString; use crate::message::proto_message::ProtoMessage; use bincode::{Decode, Encode}; +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; #[derive(Debug)] pub enum FrontendMessage { @@ -17,18 +18,18 @@ impl ProtoMessage for FrontendMessage { } } - fn serialize(&self) -> anyhow::Result> { + fn serialize(&self) -> Result, ProtoSerializeError> { match self { FrontendMessage::Query(data) => data.serialize(), FrontendMessage::Terminate => Ok(Vec::with_capacity(0)), } } - fn deserialize(variant: u8, data: &[u8]) -> anyhow::Result { + fn deserialize(variant: u8, data: &[u8]) -> Result { match variant { b'Q' => Ok(FrontendMessage::Query(QueryData::deserialize(data)?)), b'X' => Ok(FrontendMessage::Terminate), - v => Err(anyhow::anyhow!("Unknown frontend message variant: {v}")), + v => Err(ProtoDeserializeError::InvalidVariant(v)), } } } @@ -40,6 +41,7 @@ pub struct QueryData { #[cfg(test)] mod tests { + use crate::message::backend::BackendMessage; use super::*; #[test] @@ -65,4 +67,13 @@ mod tests { let message = FrontendMessage::deserialize(variant, &raw).unwrap(); assert!(matches!(message, FrontendMessage::Terminate)); } + + #[test] + fn test_unknown_variant() { + let variant = 0; + let data = vec![1, 2, 3]; + + let message = BackendMessage::deserialize(variant, &data); + assert!(matches!(message, Err(ProtoDeserializeError::InvalidVariant(0)))); + } } diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs index 527e584..2c11140 100644 --- a/proto/src/message/mod.rs +++ b/proto/src/message/mod.rs @@ -3,3 +3,4 @@ pub mod frontend; pub mod primitive; pub mod proto_message; pub mod special; +pub mod errors; diff --git a/proto/src/message/primitive/data.rs b/proto/src/message/primitive/data.rs index 41af0f6..37b491a 100644 --- a/proto/src/message/primitive/data.rs +++ b/proto/src/message/primitive/data.rs @@ -1,9 +1,10 @@ use crate::message::primitive::config::pg_proto_config; use bincode::{Decode, Encode}; +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; pub trait MessageData: Sized { - fn deserialize(data: &[u8]) -> anyhow::Result; - fn serialize(&self) -> anyhow::Result>; + fn serialize(&self) -> Result, ProtoSerializeError>; + fn deserialize(data: &[u8]) -> Result; } impl MessageData for T @@ -11,12 +12,12 @@ where T: Encode + Decode, { #[inline] - fn deserialize(data: &[u8]) -> anyhow::Result { - Ok(bincode::decode_from_slice(data, pg_proto_config())?.0) + fn serialize(&self) -> Result, ProtoSerializeError> { + Ok(bincode::encode_to_vec(self, pg_proto_config())?) } #[inline] - fn serialize(&self) -> anyhow::Result> { - Ok(bincode::encode_to_vec(self, pg_proto_config())?) + fn deserialize(data: &[u8]) -> Result { + Ok(bincode::decode_from_slice(data, pg_proto_config())?.0) } } diff --git a/proto/src/message/proto_message.rs b/proto/src/message/proto_message.rs index f2a1a4d..13986e1 100644 --- a/proto/src/message/proto_message.rs +++ b/proto/src/message/proto_message.rs @@ -1,5 +1,7 @@ +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; + pub trait ProtoMessage: Sized { fn variant(&self) -> u8; - fn serialize(&self) -> anyhow::Result>; - fn deserialize(variant: u8, data: &[u8]) -> anyhow::Result; + fn serialize(&self) -> Result, ProtoSerializeError>; + fn deserialize(variant: u8, data: &[u8]) -> Result; } diff --git a/proto/src/reader/oneway.rs b/proto/src/reader/oneway.rs index 4cca72b..16508cb 100644 --- a/proto/src/reader/oneway.rs +++ b/proto/src/reader/oneway.rs @@ -32,6 +32,6 @@ where let mut data = vec![0u8; (length - 4) as usize]; self.inner.read_exact(&mut data).await?; - T::deserialize(variant, &data) + Ok(T::deserialize(variant, &data)?) } } From 58c69928a155b0f6fce9b6fb994392a60e6be36c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Fri, 15 Dec 2023 15:58:11 +0100 Subject: [PATCH 15/30] refactor(proto): replace anyhow with thiserror in writers --- proto/src/writer/backend.rs | 5 +++-- proto/src/writer/errors.rs | 11 +++++++++++ proto/src/writer/mod.rs | 1 + proto/src/writer/oneway.rs | 5 +++-- proto/src/writer/protowriter.rs | 5 +++-- 5 files changed, 21 insertions(+), 6 deletions(-) create mode 100644 proto/src/writer/errors.rs diff --git a/proto/src/writer/backend.rs b/proto/src/writer/backend.rs index 4cbbdc8..c4203ab 100644 --- a/proto/src/writer/backend.rs +++ b/proto/src/writer/backend.rs @@ -3,10 +3,11 @@ use crate::writer::oneway::OneWayProtoWriter; use crate::writer::protowriter::ProtoWriter; use async_trait::async_trait; use tokio::io::{AsyncWrite, AsyncWriteExt}; +use crate::writer::errors::ProtoWriteError; #[async_trait] pub trait BackendProtoWriter: OneWayProtoWriter { - async fn write_ssl_reject(&mut self) -> anyhow::Result<()>; + async fn write_ssl_reject(&mut self) -> Result<(), ProtoWriteError>; } #[async_trait] @@ -14,7 +15,7 @@ impl BackendProtoWriter for ProtoWriter where W: AsyncWrite + Unpin + Send, { - async fn write_ssl_reject(&mut self) -> anyhow::Result<()> { + async fn write_ssl_reject(&mut self) -> Result<(), ProtoWriteError> { self.inner.write_u8(b'N').await?; Ok(()) } diff --git a/proto/src/writer/errors.rs b/proto/src/writer/errors.rs new file mode 100644 index 0000000..f014a69 --- /dev/null +++ b/proto/src/writer/errors.rs @@ -0,0 +1,11 @@ +use thiserror::Error; +use tokio::io; +use crate::message::errors::ProtoSerializeError; + +#[derive(Debug, Error)] +pub enum ProtoWriteError { + #[error("writing to socket failed")] + Io(#[from] io::Error), + #[error("serialization of inner data failed")] + Serialize(#[from] ProtoSerializeError), +} diff --git a/proto/src/writer/mod.rs b/proto/src/writer/mod.rs index 3f16f05..f5cd408 100644 --- a/proto/src/writer/mod.rs +++ b/proto/src/writer/mod.rs @@ -2,3 +2,4 @@ pub mod backend; pub mod frontend; pub mod oneway; pub mod protowriter; +pub mod errors; diff --git a/proto/src/writer/oneway.rs b/proto/src/writer/oneway.rs index 1649fd8..17bb5ee 100644 --- a/proto/src/writer/oneway.rs +++ b/proto/src/writer/oneway.rs @@ -2,13 +2,14 @@ use crate::message::proto_message::ProtoMessage; use crate::writer::protowriter::ProtoWriter; use async_trait::async_trait; use tokio::io::{AsyncWrite, AsyncWriteExt}; +use crate::writer::errors::ProtoWriteError; #[async_trait] pub trait OneWayProtoWriter where T: ProtoMessage, { - async fn write_proto(&mut self, message: T) -> anyhow::Result<()>; + async fn write_proto(&mut self, message: T) -> Result<(), ProtoWriteError>; } #[async_trait] @@ -17,7 +18,7 @@ where W: AsyncWrite + Unpin + Send, T: ProtoMessage + Send + 'static, { - async fn write_proto(&mut self, message: T) -> anyhow::Result<()> { + async fn write_proto(&mut self, message: T) -> Result<(), ProtoWriteError> { let variant = message.variant(); let mut data = message.serialize()?; let length = data.len() as i32 + 4; diff --git a/proto/src/writer/protowriter.rs b/proto/src/writer/protowriter.rs index 848a727..27aa9e4 100644 --- a/proto/src/writer/protowriter.rs +++ b/proto/src/writer/protowriter.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use tokio::io; use tokio::io::{AsyncWrite, AsyncWriteExt}; pub struct ProtoWriter @@ -19,7 +20,7 @@ where #[async_trait] pub trait ProtoFlush { - async fn flush(&mut self) -> anyhow::Result<()>; + async fn flush(&mut self) -> Result<(), io::Error>; } #[async_trait] @@ -27,7 +28,7 @@ impl ProtoFlush for ProtoWriter where W: AsyncWrite + Unpin + Send, { - async fn flush(&mut self) -> anyhow::Result<()> { + async fn flush(&mut self) -> Result<(), io::Error> { self.inner.flush().await?; Ok(()) } From da6410ce051268b155b55fc17dc72fb7e9ac58f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Fri, 15 Dec 2023 16:21:51 +0100 Subject: [PATCH 16/30] refactor(proto): replace anyhow with thiserror in readers --- proto/src/reader/errors.rs | 42 ++++++++++++++++++++++++++++++++++ proto/src/reader/frontend.rs | 44 ++++++++++++++++++++++-------------- proto/src/reader/mod.rs | 1 + proto/src/reader/oneway.rs | 12 ++++++---- proto/src/reader/utils.rs | 4 ++-- 5 files changed, 80 insertions(+), 23 deletions(-) create mode 100644 proto/src/reader/errors.rs diff --git a/proto/src/reader/errors.rs b/proto/src/reader/errors.rs new file mode 100644 index 0000000..5ecd8ae --- /dev/null +++ b/proto/src/reader/errors.rs @@ -0,0 +1,42 @@ +use thiserror::Error; +use tokio::io; +use crate::message::errors::{ProtoDeserializeError}; + +#[derive(Debug, Error)] +pub enum ProtoReadError { + #[error("message has invalid length, got {0}")] + InvalidLength(i32), + #[error("message has too much data, got {actual}, limit is {limit}")] + LengthOverflow { + limit: usize, + actual: usize + }, + #[error("reading from socket failed")] + Io(#[from] io::Error), + #[error("deserialization of inner data failed")] + Deserialize(#[from] ProtoDeserializeError), +} + +#[derive(Debug, Error)] +pub enum ProtoPeekError { + #[error("message has too much data, got {actual}, limit is {limit}")] + LengthOverflow { + limit: usize, + actual: usize + }, + #[error("reading from socket failed")] + Io(#[from] io::Error), + #[error("deserialization of inner data failed")] + Deserialize(#[from] ProtoDeserializeError), +} + +#[derive(Debug, Error)] +pub enum ProtoConsumeError { + #[error("unexpected data length, expected {expected}, got {actual}")] + UnexpectedDataLength { + expected: usize, + actual: usize + }, + #[error("reading from socket failed")] + Io(#[from] io::Error), +} diff --git a/proto/src/reader/frontend.rs b/proto/src/reader/frontend.rs index fc6ceca..45cf0db 100644 --- a/proto/src/reader/frontend.rs +++ b/proto/src/reader/frontend.rs @@ -1,17 +1,18 @@ use crate::message::frontend::FrontendMessage; -use crate::message::primitive::config::pg_proto_config; -use crate::message::special::{CancelRequestData, SpecialMessage}; +use crate::message::special::{CancelRequestData, SpecialMessage, StartupMessageData}; use crate::reader::oneway::OneWayProtoReader; use crate::reader::protoreader::ProtoReader; use crate::reader::utils::AsyncPeek; -use anyhow::anyhow; use async_trait::async_trait; +use tokio::io; use tokio::io::{AsyncBufRead, AsyncBufReadExt}; +use crate::message::primitive::data::MessageData; +use crate::reader::errors::{ProtoConsumeError, ProtoPeekError}; #[async_trait] pub trait FrontendProtoReader: OneWayProtoReader { - async fn peek_special_message(&mut self) -> anyhow::Result>; - async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()>; + async fn peek_special_message(&mut self) -> Result, ProtoPeekError>; + async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError>; } #[async_trait] @@ -19,7 +20,7 @@ impl FrontendProtoReader for ProtoReader where R: AsyncBufRead + Unpin + Send, { - async fn peek_special_message(&mut self) -> anyhow::Result> { + async fn peek_special_message(&mut self) -> Result, ProtoPeekError> { if let Some(cancel) = try_get_cancel_request(&mut self).await? { return Ok(Some(cancel)); } @@ -35,7 +36,7 @@ where Ok(None) } - async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()> { + async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError> { Ok(match msg { SpecialMessage::CancelRequest(_) => consume_cancel_request(self), SpecialMessage::SSLRequest => consume_ssl_request(self), @@ -46,7 +47,7 @@ where async fn try_get_cancel_request( reader: &mut ProtoReader, -) -> anyhow::Result> +) -> Result, io::Error> where R: AsyncBufRead + AsyncPeek + Unpin + Send, { @@ -83,7 +84,7 @@ where async fn try_get_ssl_request( reader: &mut ProtoReader, -) -> anyhow::Result> +) -> Result, io::Error> where R: AsyncBufRead + AsyncPeek + Unpin + Send, { @@ -114,7 +115,7 @@ where async fn try_get_startup_message( reader: &mut ProtoReader, -) -> anyhow::Result> +) -> Result, ProtoPeekError> where R: AsyncBufRead + AsyncPeek + Unpin + Send, { @@ -128,7 +129,10 @@ where return Ok(None); } if length > reader.msg_len_limit { - return Err(anyhow!("Message length is over the limit")); + return Err(ProtoPeekError::LengthOverflow { + limit: reader.msg_len_limit as usize, + actual: length as usize, + }); } let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); @@ -142,23 +146,29 @@ where return Ok(None); } - let data = bincode::decode_from_slice(&data[4..], pg_proto_config())?.0; - + let data = StartupMessageData::deserialize(&data[4..])?; Ok(Some(SpecialMessage::StartupMessage(data))) } -async fn consume_startup_message(reader: &mut ProtoReader) -> anyhow::Result<()> +async fn consume_startup_message(reader: &mut ProtoReader) -> Result<(), ProtoConsumeError> where R: AsyncBufRead + AsyncPeek + Unpin + Send, { let mut header = [0u8; 4]; - if reader.inner.peek(&mut header).await? != 4 { - return Err(anyhow!("Invalid header peek length")); + let size = reader.inner.peek(&mut header).await?; + if size != 4 { + return Err(ProtoConsumeError::UnexpectedDataLength { + expected: 4, + actual: size + }) } let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize; if length < 8 { - return Err(anyhow!("Invalid startup message length")); + return Err(ProtoConsumeError::UnexpectedDataLength { + expected: 8, + actual: length + }) } reader.inner.consume(length); diff --git a/proto/src/reader/mod.rs b/proto/src/reader/mod.rs index 3c85a0f..6f600c3 100644 --- a/proto/src/reader/mod.rs +++ b/proto/src/reader/mod.rs @@ -3,3 +3,4 @@ pub mod frontend; pub mod oneway; pub mod protoreader; mod utils; +pub mod errors; diff --git a/proto/src/reader/oneway.rs b/proto/src/reader/oneway.rs index 16508cb..11937d7 100644 --- a/proto/src/reader/oneway.rs +++ b/proto/src/reader/oneway.rs @@ -3,13 +3,14 @@ use crate::reader::protoreader::ProtoReader; use crate::reader::utils::AsyncPeek; use async_trait::async_trait; use tokio::io::{AsyncBufRead, AsyncReadExt}; +use crate::reader::errors::ProtoReadError; #[async_trait] pub trait OneWayProtoReader where T: ProtoMessage, { - async fn read_proto(&mut self) -> anyhow::Result; + async fn read_proto(&mut self) -> Result; } #[async_trait] @@ -18,15 +19,18 @@ where R: AsyncBufRead + AsyncPeek + Unpin + Send, T: ProtoMessage, { - async fn read_proto(&mut self) -> anyhow::Result { + async fn read_proto(&mut self) -> Result { let variant = self.inner.read_u8().await?; let length = self.inner.read_i32().await?; if length < 4 { - return Err(anyhow::anyhow!("Invalid message length")); + return Err(ProtoReadError::InvalidLength(length)); } if length > self.msg_len_limit { - return Err(anyhow::anyhow!("Message length over limit")); + return Err(ProtoReadError::LengthOverflow { + limit: self.msg_len_limit as usize, + actual: length as usize, + }); } let mut data = vec![0u8; (length - 4) as usize]; diff --git a/proto/src/reader/utils.rs b/proto/src/reader/utils.rs index e4a70eb..0ca8f85 100644 --- a/proto/src/reader/utils.rs +++ b/proto/src/reader/utils.rs @@ -3,7 +3,7 @@ use tokio::io::{AsyncBufRead, AsyncBufReadExt}; #[async_trait] pub trait AsyncPeek { - async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result; + async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result; } #[async_trait] @@ -11,7 +11,7 @@ impl AsyncPeek for T where T: AsyncBufRead + Unpin + Send, { - async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result { + async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result { let filled = self.fill_buf().await?; if filled.len() >= buf.len() { buf.copy_from_slice(&filled[..buf.len()]); From 165f871324648b95bdf7a88594943edf5cf7b900 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Fri, 15 Dec 2023 16:31:10 +0100 Subject: [PATCH 17/30] refactor(proto): replace anyhow with thiserror in handshake --- proto/src/handshake/errors.rs | 21 +++++++++++++++++++++ proto/src/handshake/mod.rs | 1 + proto/src/handshake/server.rs | 5 +++-- 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 proto/src/handshake/errors.rs diff --git a/proto/src/handshake/errors.rs b/proto/src/handshake/errors.rs new file mode 100644 index 0000000..90565ac --- /dev/null +++ b/proto/src/handshake/errors.rs @@ -0,0 +1,21 @@ +use thiserror::Error; +use tokio::io; +use crate::message::errors::ProtoDeserializeError; +use crate::reader::errors::{ProtoConsumeError, ProtoPeekError}; +use crate::writer::errors::ProtoWriteError; + +#[derive(Debug, Error)] +pub enum ServerHandshakeError { + #[error("startup message not found")] + MissingStartupMessage, + #[error("reading from socket failed")] + Io(#[from] io::Error), + #[error("deserialization of inner data failed")] + Deserialize(#[from] ProtoDeserializeError), + #[error("peeking special message failed")] + Peek(#[from] ProtoPeekError), + #[error("consuming special message failed")] + Consume(#[from] ProtoConsumeError), + #[error("writing message to socket failed")] + Write(#[from] ProtoWriteError), +} diff --git a/proto/src/handshake/mod.rs b/proto/src/handshake/mod.rs index 74f47ad..24a7408 100644 --- a/proto/src/handshake/mod.rs +++ b/proto/src/handshake/mod.rs @@ -1 +1,2 @@ pub mod server; +pub mod errors; diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs index 660ea0c..b998c66 100644 --- a/proto/src/handshake/server.rs +++ b/proto/src/handshake/server.rs @@ -1,3 +1,4 @@ +use crate::handshake::errors::ServerHandshakeError; use crate::message::backend::{ AuthenticationOkData, BackendKeyDataData, BackendMessage, ParameterStatusData, ReadyForQueryData, @@ -13,7 +14,7 @@ pub async fn do_server_handshake( name: &str, process: i32, secret: i32, -) -> anyhow::Result { +) -> Result { match &reader.peek_special_message().await? { Some(msg @ SpecialMessage::SSLRequest) => { reader.consume_special_message(msg).await?; @@ -31,7 +32,7 @@ pub async fn do_server_handshake( data.clone() } _ => { - return Err(anyhow::anyhow!("Expected Startup Message")); + return Err(ServerHandshakeError::MissingStartupMessage); } }; From b97f23764fb528ccf752494d13b7975bae00902e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Fri, 15 Dec 2023 16:32:02 +0100 Subject: [PATCH 18/30] refactor(proto): remove anyhow dependency --- proto/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/proto/Cargo.toml b/proto/Cargo.toml index 9fb51ce..c8c9134 100644 --- a/proto/Cargo.toml +++ b/proto/Cargo.toml @@ -5,7 +5,6 @@ edition = "2021" [dependencies] bincode = "2.0.0-rc.3" -anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["io-util", "macros", "test-util"] } async-trait = "0.1.74" thiserror = "1.0.50" From 7185c10979aa7fa46aa2cebd378670b37c31c921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Fri, 15 Dec 2023 16:41:02 +0100 Subject: [PATCH 19/30] feat(proto): add proto crate to workspace --- .gitignore | 1 + Cargo.lock | 243 +++++++++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 1 + 3 files changed, 245 insertions(+) diff --git a/.gitignore b/.gitignore index 8cf2bff..e28629e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ +.idea /target tmp_repl.txt diff --git a/Cargo.lock b/Cargo.lock index e87788c..c064b5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,249 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "async-trait" +version = "0.1.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] +name = "bincode" +version = "2.0.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f11ea1a0346b94ef188834a65c068a03aec181c94896d481d7a0a40d85b0ce95" +dependencies = [ + "bincode_derive", + "serde", +] + +[[package]] +name = "bincode_derive" +version = "2.0.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e30759b3b99a1b802a7a3aa21c85c3ded5c28e1c83170d82d70f08bbf7f3e4c" +dependencies = [ + "virtue", +] + +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + +[[package]] +name = "libc" +version = "0.2.151" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" + +[[package]] +name = "memchr" +version = "2.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" + [[package]] name = "minisql" version = "0.1.0" + +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + +[[package]] +name = "object" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "proc-macro2" +version = "1.0.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "proto" +version = "0.1.0" +dependencies = [ + "async-trait", + "bincode", + "thiserror", + "tokio", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + +[[package]] +name = "serde" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "syn" +version = "2.0.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio" +version = "1.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" +dependencies = [ + "backtrace", + "bytes", + "pin-project-lite", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "virtue" +version = "0.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dcc60c0624df774c82a0ef104151231d37da4962957d691c011c852b2473314" diff --git a/Cargo.toml b/Cargo.toml index 3e0b7c2..18f1651 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,4 +2,5 @@ resolver = "2" members = [ "minisql", + "proto" ] From df4c4166d94ef729ebcc73935db237b063eadeaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Fri, 22 Dec 2023 23:56:52 +0100 Subject: [PATCH 20/30] feat(proto): add methods for sending startup and cancel messages --- proto/src/writer/frontend.rs | 76 +++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/proto/src/writer/frontend.rs b/proto/src/writer/frontend.rs index ba634cc..286f1ba 100644 --- a/proto/src/writer/frontend.rs +++ b/proto/src/writer/frontend.rs @@ -1,12 +1,43 @@ use crate::message::frontend::FrontendMessage; use crate::writer::oneway::OneWayProtoWriter; use async_trait::async_trait; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use crate::message::primitive::data::MessageData; +use crate::message::special::{CancelRequestData, StartupMessageData}; +use crate::writer::errors::ProtoWriteError; +use crate::writer::protowriter::ProtoWriter; #[async_trait] -pub trait FrontendProtoWriter: OneWayProtoWriter {} +pub trait FrontendProtoWriter: OneWayProtoWriter { + async fn write_startup_message(&mut self, startup_message: StartupMessageData) -> Result<(), ProtoWriteError>; + async fn write_cancel_request(&mut self, cancel_request: CancelRequestData) -> Result<(), ProtoWriteError>; +} #[async_trait] -impl FrontendProtoWriter for W where W: OneWayProtoWriter {} +impl FrontendProtoWriter for ProtoWriter +where + W: AsyncWrite + Unpin + Send +{ + async fn write_startup_message(&mut self, startup_message: StartupMessageData) -> Result<(), ProtoWriteError> { + let data = startup_message.serialize()?; + let length = data.len() + 4; + + self.inner.write_i32(length as i32).await?; + self.inner.write_all(&data).await?; + + Ok(()) + } + + async fn write_cancel_request(&mut self, cancel_request: CancelRequestData) -> Result<(), ProtoWriteError> { + let data = cancel_request.serialize()?; + let length = data.len() + 4; + + self.inner.write_i32(length as i32).await?; + self.inner.write_all(&data).await?; + + Ok(()) + } +} #[cfg(test)] mod tests { @@ -37,4 +68,45 @@ mod tests { vec![b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4] ); } + + #[tokio::test] + async fn test_startup_message() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer.write_startup_message(StartupMessageData { + version: 196608, + params: vec![ + ("user".into(), "postgres".into()), + ("database".into(), "postgres".into()), + ], + }).await.unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![ + 0, 0, 0, 40, 0, 3, 0, 0, b'u', b's', b'e', b'r', 0, b'p', b'o', + b's', b't', b'g', b'r', b'e', b's', 0, b'd', b'a', b't', b'a', b'b', b'a', b's', + b'e', 0, b'p', b'o', b's', b't', b'g', b'r', b'e', b's', 0 + ] + ); + } + + #[tokio::test] + async fn test_cancel_request() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer.write_cancel_request(CancelRequestData { + pid: 123, + secret: 234, + }).await.unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![ + 0, 0, 0, 12, 0, 0, 0, 123, 0, 0, 0, 234 + ] + ); + } } From 7b2dce4dfb296cf88cf5fa4db39dab145986b1ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sat, 23 Dec 2023 00:29:40 +0100 Subject: [PATCH 21/30] feat(proto): add client handshake implementation --- proto/src/handshake/client/error.rs | 16 +++++++++++ proto/src/handshake/client/mod.rs | 4 +++ proto/src/handshake/client/request.rs | 24 +++++++++++++++++ proto/src/handshake/client/response.rs | 37 ++++++++++++++++++++++++++ proto/src/handshake/client/shaker.rs | 34 +++++++++++++++++++++++ proto/src/handshake/mod.rs | 1 + 6 files changed, 116 insertions(+) create mode 100644 proto/src/handshake/client/error.rs create mode 100644 proto/src/handshake/client/mod.rs create mode 100644 proto/src/handshake/client/request.rs create mode 100644 proto/src/handshake/client/response.rs create mode 100644 proto/src/handshake/client/shaker.rs diff --git a/proto/src/handshake/client/error.rs b/proto/src/handshake/client/error.rs new file mode 100644 index 0000000..ab6f23c --- /dev/null +++ b/proto/src/handshake/client/error.rs @@ -0,0 +1,16 @@ +use thiserror::Error; +use crate::message::backend::BackendMessage; +use crate::reader::errors::ProtoReadError; +use crate::writer::errors::ProtoWriteError; + +#[derive(Debug, Error)] +pub enum ClientHandshakeError { + #[error("unexpected response from server")] + UnexpectedResponse, + #[error("unexpected auth response")] + UnexpectedAuthResponse(BackendMessage), + #[error("writing message to socket failed")] + Write(#[from] ProtoWriteError), + #[error("reading message from socket failed")] + Read(#[from] ProtoReadError), +} \ No newline at end of file diff --git a/proto/src/handshake/client/mod.rs b/proto/src/handshake/client/mod.rs new file mode 100644 index 0000000..060973a --- /dev/null +++ b/proto/src/handshake/client/mod.rs @@ -0,0 +1,4 @@ +pub mod request; +pub mod response; +pub mod shaker; +pub mod error; \ No newline at end of file diff --git a/proto/src/handshake/client/request.rs b/proto/src/handshake/client/request.rs new file mode 100644 index 0000000..54998b5 --- /dev/null +++ b/proto/src/handshake/client/request.rs @@ -0,0 +1,24 @@ +use crate::message::primitive::pgstring::PgString; +use crate::message::special::StartupMessageData; + +pub struct ClientHandshakeRequest { + version: i32, + parameters: Vec<(PgString, PgString)>, +} + +impl ClientHandshakeRequest { + pub fn new(version: i32) -> Self { + Self { version, parameters: Vec::new() } + } + + pub fn parameter(mut self, key: &str, value: &str) -> Self { + self.parameters.push((key.into(), value.into())); + self + } +} + +impl From for StartupMessageData { + fn from(request: ClientHandshakeRequest) -> Self { + Self { version: request.version, params: request.parameters } + } +} diff --git a/proto/src/handshake/client/response.rs b/proto/src/handshake/client/response.rs new file mode 100644 index 0000000..777cb81 --- /dev/null +++ b/proto/src/handshake/client/response.rs @@ -0,0 +1,37 @@ +use crate::handshake::client::error::ClientHandshakeError; +use crate::message::backend::BackendMessage; + +pub struct ClientHandshakeResponse { + pub version: String, + pub process_id: i32, + pub secret_key: i32, +} + +impl ClientHandshakeResponse { + pub fn from_backend_messages(message: &[BackendMessage]) -> Result { + let mut version = None; + let mut process_id = None; + let mut secret_key = None; + for message in message { + match message { + BackendMessage::ParameterStatus(data) => { + if data.name.as_str() == "server_version" { + version = Some(String::from(data.value.as_str())); + } + } + BackendMessage::BackendKeyData(data) => { + process_id = Some(data.process); + secret_key = Some(data.secret); + } + _ => {} + } + } + + match (version, process_id, secret_key) { + (Some(version), Some(process_id), Some(secret_key)) => { + Ok(Self { version, process_id, secret_key }) + } + _ => Err(ClientHandshakeError::UnexpectedResponse), + } + } +} \ No newline at end of file diff --git a/proto/src/handshake/client/shaker.rs b/proto/src/handshake/client/shaker.rs new file mode 100644 index 0000000..893f0b7 --- /dev/null +++ b/proto/src/handshake/client/shaker.rs @@ -0,0 +1,34 @@ +use crate::handshake::client::error::ClientHandshakeError; +use crate::handshake::client::request::ClientHandshakeRequest; +use crate::handshake::client::response::ClientHandshakeResponse; +use crate::message::backend::{AuthenticationOkData, BackendMessage}; +use crate::message::special::StartupMessageData; +use crate::reader::backend::BackendProtoReader; +use crate::writer::frontend::FrontendProtoWriter; +use crate::writer::protowriter::ProtoFlush; + +pub async fn do_client_handshake( + writer: &mut (impl FrontendProtoWriter + ProtoFlush), + reader: &mut impl BackendProtoReader, + request: ClientHandshakeRequest, +) -> Result { + let startup_message: StartupMessageData = request.into(); + writer.write_startup_message(startup_message).await?; + + let auth = reader.read_proto().await?; + if !matches!(auth, BackendMessage::AuthenticationOk(AuthenticationOkData { status: 0 })) { + return Err(ClientHandshakeError::UnexpectedAuthResponse(auth)); + } + + let mut messages = Vec::new(); + loop { + let msg = reader.read_proto().await?; + if matches!(msg, BackendMessage::ReadyForQuery(_)) { + break; + } + + messages.push(msg); + } + + ClientHandshakeResponse::from_backend_messages(&messages) +} \ No newline at end of file diff --git a/proto/src/handshake/mod.rs b/proto/src/handshake/mod.rs index 24a7408..0b12ab2 100644 --- a/proto/src/handshake/mod.rs +++ b/proto/src/handshake/mod.rs @@ -1,2 +1,3 @@ pub mod server; pub mod errors; +pub mod client; From c1744711d30d8b495ae4568621d35c6d90668f7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sat, 23 Dec 2023 00:52:53 +0100 Subject: [PATCH 22/30] refactor(proto): reuse code in handshakes --- .../handshake/{client/shaker.rs => client.rs} | 12 ++-- proto/src/handshake/client/error.rs | 16 ----- proto/src/handshake/client/mod.rs | 4 -- proto/src/handshake/client/response.rs | 37 ----------- proto/src/handshake/errors.rs | 19 +++++- proto/src/handshake/mod.rs | 6 +- proto/src/handshake/{client => }/request.rs | 14 ++-- proto/src/handshake/response.rs | 65 +++++++++++++++++++ proto/src/handshake/server.rs | 26 +++----- 9 files changed, 111 insertions(+), 88 deletions(-) rename proto/src/handshake/{client/shaker.rs => client.rs} (73%) delete mode 100644 proto/src/handshake/client/error.rs delete mode 100644 proto/src/handshake/client/mod.rs delete mode 100644 proto/src/handshake/client/response.rs rename proto/src/handshake/{client => }/request.rs (59%) create mode 100644 proto/src/handshake/response.rs diff --git a/proto/src/handshake/client/shaker.rs b/proto/src/handshake/client.rs similarity index 73% rename from proto/src/handshake/client/shaker.rs rename to proto/src/handshake/client.rs index 893f0b7..9226b2a 100644 --- a/proto/src/handshake/client/shaker.rs +++ b/proto/src/handshake/client.rs @@ -1,6 +1,6 @@ -use crate::handshake::client::error::ClientHandshakeError; -use crate::handshake::client::request::ClientHandshakeRequest; -use crate::handshake::client::response::ClientHandshakeResponse; +use crate::handshake::errors::ClientHandshakeError; +use crate::handshake::request::HandshakeRequest; +use crate::handshake::response::HandshakeResponse; use crate::message::backend::{AuthenticationOkData, BackendMessage}; use crate::message::special::StartupMessageData; use crate::reader::backend::BackendProtoReader; @@ -10,8 +10,8 @@ use crate::writer::protowriter::ProtoFlush; pub async fn do_client_handshake( writer: &mut (impl FrontendProtoWriter + ProtoFlush), reader: &mut impl BackendProtoReader, - request: ClientHandshakeRequest, -) -> Result { + request: HandshakeRequest, +) -> Result { let startup_message: StartupMessageData = request.into(); writer.write_startup_message(startup_message).await?; @@ -30,5 +30,5 @@ pub async fn do_client_handshake( messages.push(msg); } - ClientHandshakeResponse::from_backend_messages(&messages) + HandshakeResponse::try_from(messages.as_slice()) } \ No newline at end of file diff --git a/proto/src/handshake/client/error.rs b/proto/src/handshake/client/error.rs deleted file mode 100644 index ab6f23c..0000000 --- a/proto/src/handshake/client/error.rs +++ /dev/null @@ -1,16 +0,0 @@ -use thiserror::Error; -use crate::message::backend::BackendMessage; -use crate::reader::errors::ProtoReadError; -use crate::writer::errors::ProtoWriteError; - -#[derive(Debug, Error)] -pub enum ClientHandshakeError { - #[error("unexpected response from server")] - UnexpectedResponse, - #[error("unexpected auth response")] - UnexpectedAuthResponse(BackendMessage), - #[error("writing message to socket failed")] - Write(#[from] ProtoWriteError), - #[error("reading message from socket failed")] - Read(#[from] ProtoReadError), -} \ No newline at end of file diff --git a/proto/src/handshake/client/mod.rs b/proto/src/handshake/client/mod.rs deleted file mode 100644 index 060973a..0000000 --- a/proto/src/handshake/client/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod request; -pub mod response; -pub mod shaker; -pub mod error; \ No newline at end of file diff --git a/proto/src/handshake/client/response.rs b/proto/src/handshake/client/response.rs deleted file mode 100644 index 777cb81..0000000 --- a/proto/src/handshake/client/response.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::handshake::client::error::ClientHandshakeError; -use crate::message::backend::BackendMessage; - -pub struct ClientHandshakeResponse { - pub version: String, - pub process_id: i32, - pub secret_key: i32, -} - -impl ClientHandshakeResponse { - pub fn from_backend_messages(message: &[BackendMessage]) -> Result { - let mut version = None; - let mut process_id = None; - let mut secret_key = None; - for message in message { - match message { - BackendMessage::ParameterStatus(data) => { - if data.name.as_str() == "server_version" { - version = Some(String::from(data.value.as_str())); - } - } - BackendMessage::BackendKeyData(data) => { - process_id = Some(data.process); - secret_key = Some(data.secret); - } - _ => {} - } - } - - match (version, process_id, secret_key) { - (Some(version), Some(process_id), Some(secret_key)) => { - Ok(Self { version, process_id, secret_key }) - } - _ => Err(ClientHandshakeError::UnexpectedResponse), - } - } -} \ No newline at end of file diff --git a/proto/src/handshake/errors.rs b/proto/src/handshake/errors.rs index 90565ac..d2bc1bd 100644 --- a/proto/src/handshake/errors.rs +++ b/proto/src/handshake/errors.rs @@ -1,14 +1,27 @@ use thiserror::Error; use tokio::io; +use crate::message::backend::BackendMessage; use crate::message::errors::ProtoDeserializeError; -use crate::reader::errors::{ProtoConsumeError, ProtoPeekError}; +use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError}; use crate::writer::errors::ProtoWriteError; +#[derive(Debug, Error)] +pub enum ClientHandshakeError { + #[error("unexpected response from server")] + UnexpectedResponse, + #[error("unexpected auth response")] + UnexpectedAuthResponse(BackendMessage), + #[error("writing message to socket failed")] + Write(#[from] ProtoWriteError), + #[error("reading message from socket failed")] + Read(#[from] ProtoReadError), +} + #[derive(Debug, Error)] pub enum ServerHandshakeError { #[error("startup message not found")] MissingStartupMessage, - #[error("reading from socket failed")] + #[error("socket communication failed")] Io(#[from] io::Error), #[error("deserialization of inner data failed")] Deserialize(#[from] ProtoDeserializeError), @@ -18,4 +31,4 @@ pub enum ServerHandshakeError { Consume(#[from] ProtoConsumeError), #[error("writing message to socket failed")] Write(#[from] ProtoWriteError), -} +} \ No newline at end of file diff --git a/proto/src/handshake/mod.rs b/proto/src/handshake/mod.rs index 0b12ab2..86f92ab 100644 --- a/proto/src/handshake/mod.rs +++ b/proto/src/handshake/mod.rs @@ -1,3 +1,5 @@ -pub mod server; -pub mod errors; +pub mod response; +pub mod request; pub mod client; +pub mod server; +pub mod errors; \ No newline at end of file diff --git a/proto/src/handshake/client/request.rs b/proto/src/handshake/request.rs similarity index 59% rename from proto/src/handshake/client/request.rs rename to proto/src/handshake/request.rs index 54998b5..ec3e123 100644 --- a/proto/src/handshake/client/request.rs +++ b/proto/src/handshake/request.rs @@ -1,12 +1,12 @@ use crate::message::primitive::pgstring::PgString; use crate::message::special::StartupMessageData; -pub struct ClientHandshakeRequest { +pub struct HandshakeRequest { version: i32, parameters: Vec<(PgString, PgString)>, } -impl ClientHandshakeRequest { +impl HandshakeRequest { pub fn new(version: i32) -> Self { Self { version, parameters: Vec::new() } } @@ -17,8 +17,14 @@ impl ClientHandshakeRequest { } } -impl From for StartupMessageData { - fn from(request: ClientHandshakeRequest) -> Self { +impl From for StartupMessageData { + fn from(request: HandshakeRequest) -> Self { Self { version: request.version, params: request.parameters } } } + +impl From for HandshakeRequest { + fn from(data: StartupMessageData) -> Self { + Self { version: data.version, parameters: data.params } + } +} \ No newline at end of file diff --git a/proto/src/handshake/response.rs b/proto/src/handshake/response.rs new file mode 100644 index 0000000..84a908f --- /dev/null +++ b/proto/src/handshake/response.rs @@ -0,0 +1,65 @@ +use crate::handshake::errors::ClientHandshakeError; +use crate::message::backend::{BackendKeyDataData, BackendMessage, ParameterStatusData}; + +pub struct HandshakeResponse { + pub version: String, + pub process_id: i32, + pub secret_key: i32, +} + +impl HandshakeResponse { + pub fn new(name: &str, pid: i32, key: i32) -> Self { + Self { + version: format!("16.0 ({name})", name = name), + process_id: pid, + secret_key: key, + } + } +} + +impl TryFrom<&[BackendMessage]> for HandshakeResponse { + type Error = ClientHandshakeError; + + fn try_from(messages: &[BackendMessage]) -> Result { + let mut version = None; + let mut process_id = None; + let mut secret_key = None; + + for message in messages { + match message { + BackendMessage::ParameterStatus(data) => { + if data.name.as_str() == "server_version" { + version = Some(String::from(data.value.as_str())); + } + } + BackendMessage::BackendKeyData(data) => { + process_id = Some(data.process); + secret_key = Some(data.secret); + } + _ => {} + } + } + + match (version, process_id, secret_key) { + (Some(version), Some(process_id), Some(secret_key)) => { + Ok(Self { version, process_id, secret_key }) + } + _ => Err(ClientHandshakeError::UnexpectedResponse), + } + } +} + +impl From<&HandshakeResponse> for Vec { + fn from(response: &HandshakeResponse) -> Self { + vec![ + BackendMessage::ParameterStatus(ParameterStatusData { + name: "server_version".into(), + value: response.version.clone().into(), + }), + BackendMessage::BackendKeyData(BackendKeyDataData { + process: response.process_id, + secret: response.secret_key, + }), + ] + } +} \ No newline at end of file diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs index b998c66..127407c 100644 --- a/proto/src/handshake/server.rs +++ b/proto/src/handshake/server.rs @@ -1,6 +1,8 @@ use crate::handshake::errors::ServerHandshakeError; +use crate::handshake::request::HandshakeRequest; +use crate::handshake::response::HandshakeResponse; use crate::message::backend::{ - AuthenticationOkData, BackendKeyDataData, BackendMessage, ParameterStatusData, + AuthenticationOkData, BackendMessage, ReadyForQueryData, }; use crate::message::special::{SpecialMessage, StartupMessageData}; @@ -11,10 +13,8 @@ use crate::writer::protowriter::ProtoFlush; pub async fn do_server_handshake( writer: &mut (impl BackendProtoWriter + ProtoFlush), reader: &mut impl FrontendProtoReader, - name: &str, - process: i32, - secret: i32, -) -> Result { + response: &HandshakeResponse, +) -> Result { match &reader.peek_special_message().await? { Some(msg @ SpecialMessage::SSLRequest) => { reader.consume_special_message(msg).await?; @@ -40,21 +40,15 @@ pub async fn do_server_handshake( .write_proto(BackendMessage::from(AuthenticationOkData { status: 0 })) .await?; - writer - .write_proto(BackendMessage::from(ParameterStatusData { - name: "server_version".to_string().into(), - value: format!("16.0 ({name})").into(), - })) - .await?; - - writer - .write_proto(BackendMessage::from(BackendKeyDataData { process, secret })) - .await?; + let messages: Vec = response.into(); + for message in messages { + writer.write_proto(message).await?; + } writer .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) .await?; writer.flush().await?; - Ok(startup_message) + Ok(startup_message.into()) } From 505f59b3549a14842a95e164848b8652c98f34ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sat, 23 Dec 2023 01:27:15 +0100 Subject: [PATCH 23/30] fix(proto): flush written startup message --- proto/src/handshake/client.rs | 1 + proto/src/handshake/errors.rs | 2 ++ proto/src/handshake/request.rs | 1 + proto/src/handshake/response.rs | 17 ++++++++++------- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/proto/src/handshake/client.rs b/proto/src/handshake/client.rs index 9226b2a..aa6577b 100644 --- a/proto/src/handshake/client.rs +++ b/proto/src/handshake/client.rs @@ -14,6 +14,7 @@ pub async fn do_client_handshake( ) -> Result { let startup_message: StartupMessageData = request.into(); writer.write_startup_message(startup_message).await?; + writer.flush().await?; let auth = reader.read_proto().await?; if !matches!(auth, BackendMessage::AuthenticationOk(AuthenticationOkData { status: 0 })) { diff --git a/proto/src/handshake/errors.rs b/proto/src/handshake/errors.rs index d2bc1bd..b5a31ed 100644 --- a/proto/src/handshake/errors.rs +++ b/proto/src/handshake/errors.rs @@ -11,6 +11,8 @@ pub enum ClientHandshakeError { UnexpectedResponse, #[error("unexpected auth response")] UnexpectedAuthResponse(BackendMessage), + #[error("socket communication failed")] + Io(#[from] io::Error), #[error("writing message to socket failed")] Write(#[from] ProtoWriteError), #[error("reading message from socket failed")] diff --git a/proto/src/handshake/request.rs b/proto/src/handshake/request.rs index ec3e123..1fc097f 100644 --- a/proto/src/handshake/request.rs +++ b/proto/src/handshake/request.rs @@ -1,6 +1,7 @@ use crate::message::primitive::pgstring::PgString; use crate::message::special::StartupMessageData; +#[derive(Debug)] pub struct HandshakeRequest { version: i32, parameters: Vec<(PgString, PgString)>, diff --git a/proto/src/handshake/response.rs b/proto/src/handshake/response.rs index 84a908f..38579b6 100644 --- a/proto/src/handshake/response.rs +++ b/proto/src/handshake/response.rs @@ -1,6 +1,7 @@ use crate::handshake::errors::ClientHandshakeError; use crate::message::backend::{BackendKeyDataData, BackendMessage, ParameterStatusData}; +#[derive(Debug)] pub struct HandshakeResponse { pub version: String, pub process_id: i32, @@ -41,20 +42,22 @@ impl TryFrom<&[BackendMessage]> for HandshakeResponse { } match (version, process_id, secret_key) { - (Some(version), Some(process_id), Some(secret_key)) => { - Ok(Self { version, process_id, secret_key }) - } + (Some(version), Some(process_id), Some(secret_key)) => Ok(Self { + version, + process_id, + secret_key, + }), _ => Err(ClientHandshakeError::UnexpectedResponse), } } } -impl From<&HandshakeResponse> for Vec { - fn from(response: &HandshakeResponse) -> Self { +impl From for Vec { + fn from(response: HandshakeResponse) -> Self { vec![ BackendMessage::ParameterStatus(ParameterStatusData { name: "server_version".into(), - value: response.version.clone().into(), + value: response.version.into(), }), BackendMessage::BackendKeyData(BackendKeyDataData { process: response.process_id, @@ -62,4 +65,4 @@ impl From<&HandshakeResponse> for Vec { }), ] } -} \ No newline at end of file +} From a08376766cd5ffb898eeb9091f309313e50a54e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sat, 23 Dec 2023 01:28:30 +0100 Subject: [PATCH 24/30] chore(proto): crate formatting --- proto/src/handshake/client.rs | 7 ++- proto/src/handshake/errors.rs | 6 +-- proto/src/handshake/mod.rs | 6 +-- proto/src/handshake/request.rs | 17 ++++++-- proto/src/handshake/server.rs | 5 +-- proto/src/message/backend.rs | 7 ++- proto/src/message/frontend.rs | 9 ++-- proto/src/message/mod.rs | 2 +- proto/src/message/primitive/data.rs | 2 +- proto/src/reader/errors.rs | 17 ++------ proto/src/reader/frontend.rs | 22 ++++++---- proto/src/reader/mod.rs | 2 +- proto/src/reader/oneway.rs | 2 +- proto/src/writer/backend.rs | 2 +- proto/src/writer/errors.rs | 2 +- proto/src/writer/frontend.rs | 66 ++++++++++++++++++----------- proto/src/writer/mod.rs | 2 +- proto/src/writer/oneway.rs | 2 +- 18 files changed, 103 insertions(+), 75 deletions(-) diff --git a/proto/src/handshake/client.rs b/proto/src/handshake/client.rs index aa6577b..2deb5bc 100644 --- a/proto/src/handshake/client.rs +++ b/proto/src/handshake/client.rs @@ -17,7 +17,10 @@ pub async fn do_client_handshake( writer.flush().await?; let auth = reader.read_proto().await?; - if !matches!(auth, BackendMessage::AuthenticationOk(AuthenticationOkData { status: 0 })) { + if !matches!( + auth, + BackendMessage::AuthenticationOk(AuthenticationOkData { status: 0 }) + ) { return Err(ClientHandshakeError::UnexpectedAuthResponse(auth)); } @@ -32,4 +35,4 @@ pub async fn do_client_handshake( } HandshakeResponse::try_from(messages.as_slice()) -} \ No newline at end of file +} diff --git a/proto/src/handshake/errors.rs b/proto/src/handshake/errors.rs index b5a31ed..0811790 100644 --- a/proto/src/handshake/errors.rs +++ b/proto/src/handshake/errors.rs @@ -1,9 +1,9 @@ -use thiserror::Error; -use tokio::io; use crate::message::backend::BackendMessage; use crate::message::errors::ProtoDeserializeError; use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError}; use crate::writer::errors::ProtoWriteError; +use thiserror::Error; +use tokio::io; #[derive(Debug, Error)] pub enum ClientHandshakeError { @@ -33,4 +33,4 @@ pub enum ServerHandshakeError { Consume(#[from] ProtoConsumeError), #[error("writing message to socket failed")] Write(#[from] ProtoWriteError), -} \ No newline at end of file +} diff --git a/proto/src/handshake/mod.rs b/proto/src/handshake/mod.rs index 86f92ab..61e9c24 100644 --- a/proto/src/handshake/mod.rs +++ b/proto/src/handshake/mod.rs @@ -1,5 +1,5 @@ -pub mod response; -pub mod request; pub mod client; +pub mod errors; +pub mod request; +pub mod response; pub mod server; -pub mod errors; \ No newline at end of file diff --git a/proto/src/handshake/request.rs b/proto/src/handshake/request.rs index 1fc097f..408238f 100644 --- a/proto/src/handshake/request.rs +++ b/proto/src/handshake/request.rs @@ -9,7 +9,10 @@ pub struct HandshakeRequest { impl HandshakeRequest { pub fn new(version: i32) -> Self { - Self { version, parameters: Vec::new() } + Self { + version, + parameters: Vec::new(), + } } pub fn parameter(mut self, key: &str, value: &str) -> Self { @@ -20,12 +23,18 @@ impl HandshakeRequest { impl From for StartupMessageData { fn from(request: HandshakeRequest) -> Self { - Self { version: request.version, params: request.parameters } + Self { + version: request.version, + params: request.parameters, + } } } impl From for HandshakeRequest { fn from(data: StartupMessageData) -> Self { - Self { version: data.version, parameters: data.params } + Self { + version: data.version, + parameters: data.params, + } } -} \ No newline at end of file +} diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs index 127407c..e9e6087 100644 --- a/proto/src/handshake/server.rs +++ b/proto/src/handshake/server.rs @@ -1,10 +1,7 @@ use crate::handshake::errors::ServerHandshakeError; use crate::handshake::request::HandshakeRequest; use crate::handshake::response::HandshakeResponse; -use crate::message::backend::{ - AuthenticationOkData, BackendMessage, - ReadyForQueryData, -}; +use crate::message::backend::{AuthenticationOkData, BackendMessage, ReadyForQueryData}; use crate::message::special::{SpecialMessage, StartupMessageData}; use crate::reader::frontend::FrontendProtoReader; use crate::writer::backend::BackendProtoWriter; diff --git a/proto/src/message/backend.rs b/proto/src/message/backend.rs index 7ba0004..0aa0e20 100644 --- a/proto/src/message/backend.rs +++ b/proto/src/message/backend.rs @@ -1,9 +1,9 @@ +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; use crate::message::primitive::data::MessageData; use crate::message::primitive::pglist::PgList; use crate::message::primitive::pgstring::PgString; use crate::message::proto_message::ProtoMessage; use bincode::{Decode, Encode}; -use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; #[derive(Debug)] pub enum BackendMessage { @@ -362,6 +362,9 @@ mod tests { let data = vec![1, 2, 3]; let message = BackendMessage::deserialize(variant, &data); - assert!(matches!(message, Err(ProtoDeserializeError::InvalidVariant(0)))); + assert!(matches!( + message, + Err(ProtoDeserializeError::InvalidVariant(0)) + )); } } diff --git a/proto/src/message/frontend.rs b/proto/src/message/frontend.rs index 6d7f993..178be7b 100644 --- a/proto/src/message/frontend.rs +++ b/proto/src/message/frontend.rs @@ -1,8 +1,8 @@ +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; use crate::message::primitive::data::MessageData; use crate::message::primitive::pgstring::PgString; use crate::message::proto_message::ProtoMessage; use bincode::{Decode, Encode}; -use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; #[derive(Debug)] pub enum FrontendMessage { @@ -41,8 +41,8 @@ pub struct QueryData { #[cfg(test)] mod tests { - use crate::message::backend::BackendMessage; use super::*; + use crate::message::backend::BackendMessage; #[test] fn test_symmetric_query() { @@ -74,6 +74,9 @@ mod tests { let data = vec![1, 2, 3]; let message = BackendMessage::deserialize(variant, &data); - assert!(matches!(message, Err(ProtoDeserializeError::InvalidVariant(0)))); + assert!(matches!( + message, + Err(ProtoDeserializeError::InvalidVariant(0)) + )); } } diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs index 2c11140..0d8130c 100644 --- a/proto/src/message/mod.rs +++ b/proto/src/message/mod.rs @@ -1,6 +1,6 @@ pub mod backend; +pub mod errors; pub mod frontend; pub mod primitive; pub mod proto_message; pub mod special; -pub mod errors; diff --git a/proto/src/message/primitive/data.rs b/proto/src/message/primitive/data.rs index 37b491a..4df50ae 100644 --- a/proto/src/message/primitive/data.rs +++ b/proto/src/message/primitive/data.rs @@ -1,6 +1,6 @@ +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; use crate::message::primitive::config::pg_proto_config; use bincode::{Decode, Encode}; -use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; pub trait MessageData: Sized { fn serialize(&self) -> Result, ProtoSerializeError>; diff --git a/proto/src/reader/errors.rs b/proto/src/reader/errors.rs index 5ecd8ae..78b138d 100644 --- a/proto/src/reader/errors.rs +++ b/proto/src/reader/errors.rs @@ -1,16 +1,13 @@ +use crate::message::errors::ProtoDeserializeError; use thiserror::Error; use tokio::io; -use crate::message::errors::{ProtoDeserializeError}; #[derive(Debug, Error)] pub enum ProtoReadError { #[error("message has invalid length, got {0}")] InvalidLength(i32), #[error("message has too much data, got {actual}, limit is {limit}")] - LengthOverflow { - limit: usize, - actual: usize - }, + LengthOverflow { limit: usize, actual: usize }, #[error("reading from socket failed")] Io(#[from] io::Error), #[error("deserialization of inner data failed")] @@ -20,10 +17,7 @@ pub enum ProtoReadError { #[derive(Debug, Error)] pub enum ProtoPeekError { #[error("message has too much data, got {actual}, limit is {limit}")] - LengthOverflow { - limit: usize, - actual: usize - }, + LengthOverflow { limit: usize, actual: usize }, #[error("reading from socket failed")] Io(#[from] io::Error), #[error("deserialization of inner data failed")] @@ -33,10 +27,7 @@ pub enum ProtoPeekError { #[derive(Debug, Error)] pub enum ProtoConsumeError { #[error("unexpected data length, expected {expected}, got {actual}")] - UnexpectedDataLength { - expected: usize, - actual: usize - }, + UnexpectedDataLength { expected: usize, actual: usize }, #[error("reading from socket failed")] Io(#[from] io::Error), } diff --git a/proto/src/reader/frontend.rs b/proto/src/reader/frontend.rs index 45cf0db..bb2ffc8 100644 --- a/proto/src/reader/frontend.rs +++ b/proto/src/reader/frontend.rs @@ -1,18 +1,21 @@ use crate::message::frontend::FrontendMessage; +use crate::message::primitive::data::MessageData; use crate::message::special::{CancelRequestData, SpecialMessage, StartupMessageData}; +use crate::reader::errors::{ProtoConsumeError, ProtoPeekError}; use crate::reader::oneway::OneWayProtoReader; use crate::reader::protoreader::ProtoReader; use crate::reader::utils::AsyncPeek; use async_trait::async_trait; use tokio::io; use tokio::io::{AsyncBufRead, AsyncBufReadExt}; -use crate::message::primitive::data::MessageData; -use crate::reader::errors::{ProtoConsumeError, ProtoPeekError}; #[async_trait] pub trait FrontendProtoReader: OneWayProtoReader { async fn peek_special_message(&mut self) -> Result, ProtoPeekError>; - async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError>; + async fn consume_special_message( + &mut self, + msg: &SpecialMessage, + ) -> Result<(), ProtoConsumeError>; } #[async_trait] @@ -36,7 +39,10 @@ where Ok(None) } - async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError> { + async fn consume_special_message( + &mut self, + msg: &SpecialMessage, + ) -> Result<(), ProtoConsumeError> { Ok(match msg { SpecialMessage::CancelRequest(_) => consume_cancel_request(self), SpecialMessage::SSLRequest => consume_ssl_request(self), @@ -159,16 +165,16 @@ where if size != 4 { return Err(ProtoConsumeError::UnexpectedDataLength { expected: 4, - actual: size - }) + actual: size, + }); } let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize; if length < 8 { return Err(ProtoConsumeError::UnexpectedDataLength { expected: 8, - actual: length - }) + actual: length, + }); } reader.inner.consume(length); diff --git a/proto/src/reader/mod.rs b/proto/src/reader/mod.rs index 6f600c3..41297de 100644 --- a/proto/src/reader/mod.rs +++ b/proto/src/reader/mod.rs @@ -1,6 +1,6 @@ pub mod backend; +pub mod errors; pub mod frontend; pub mod oneway; pub mod protoreader; mod utils; -pub mod errors; diff --git a/proto/src/reader/oneway.rs b/proto/src/reader/oneway.rs index 11937d7..d1db637 100644 --- a/proto/src/reader/oneway.rs +++ b/proto/src/reader/oneway.rs @@ -1,9 +1,9 @@ use crate::message::proto_message::ProtoMessage; +use crate::reader::errors::ProtoReadError; use crate::reader::protoreader::ProtoReader; use crate::reader::utils::AsyncPeek; use async_trait::async_trait; use tokio::io::{AsyncBufRead, AsyncReadExt}; -use crate::reader::errors::ProtoReadError; #[async_trait] pub trait OneWayProtoReader diff --git a/proto/src/writer/backend.rs b/proto/src/writer/backend.rs index c4203ab..cc22e5c 100644 --- a/proto/src/writer/backend.rs +++ b/proto/src/writer/backend.rs @@ -1,9 +1,9 @@ use crate::message::backend::BackendMessage; +use crate::writer::errors::ProtoWriteError; use crate::writer::oneway::OneWayProtoWriter; use crate::writer::protowriter::ProtoWriter; use async_trait::async_trait; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use crate::writer::errors::ProtoWriteError; #[async_trait] pub trait BackendProtoWriter: OneWayProtoWriter { diff --git a/proto/src/writer/errors.rs b/proto/src/writer/errors.rs index f014a69..5cc0a7b 100644 --- a/proto/src/writer/errors.rs +++ b/proto/src/writer/errors.rs @@ -1,6 +1,6 @@ +use crate::message::errors::ProtoSerializeError; use thiserror::Error; use tokio::io; -use crate::message::errors::ProtoSerializeError; #[derive(Debug, Error)] pub enum ProtoWriteError { diff --git a/proto/src/writer/frontend.rs b/proto/src/writer/frontend.rs index 286f1ba..4ca6c0b 100644 --- a/proto/src/writer/frontend.rs +++ b/proto/src/writer/frontend.rs @@ -1,24 +1,33 @@ use crate::message::frontend::FrontendMessage; -use crate::writer::oneway::OneWayProtoWriter; -use async_trait::async_trait; -use tokio::io::{AsyncWrite, AsyncWriteExt}; use crate::message::primitive::data::MessageData; use crate::message::special::{CancelRequestData, StartupMessageData}; use crate::writer::errors::ProtoWriteError; +use crate::writer::oneway::OneWayProtoWriter; use crate::writer::protowriter::ProtoWriter; +use async_trait::async_trait; +use tokio::io::{AsyncWrite, AsyncWriteExt}; #[async_trait] pub trait FrontendProtoWriter: OneWayProtoWriter { - async fn write_startup_message(&mut self, startup_message: StartupMessageData) -> Result<(), ProtoWriteError>; - async fn write_cancel_request(&mut self, cancel_request: CancelRequestData) -> Result<(), ProtoWriteError>; + async fn write_startup_message( + &mut self, + startup_message: StartupMessageData, + ) -> Result<(), ProtoWriteError>; + async fn write_cancel_request( + &mut self, + cancel_request: CancelRequestData, + ) -> Result<(), ProtoWriteError>; } #[async_trait] impl FrontendProtoWriter for ProtoWriter where - W: AsyncWrite + Unpin + Send + W: AsyncWrite + Unpin + Send, { - async fn write_startup_message(&mut self, startup_message: StartupMessageData) -> Result<(), ProtoWriteError> { + async fn write_startup_message( + &mut self, + startup_message: StartupMessageData, + ) -> Result<(), ProtoWriteError> { let data = startup_message.serialize()?; let length = data.len() + 4; @@ -28,7 +37,10 @@ where Ok(()) } - async fn write_cancel_request(&mut self, cancel_request: CancelRequestData) -> Result<(), ProtoWriteError> { + async fn write_cancel_request( + &mut self, + cancel_request: CancelRequestData, + ) -> Result<(), ProtoWriteError> { let data = cancel_request.serialize()?; let length = data.len() + 4; @@ -74,20 +86,23 @@ mod tests { let writer = BufWriter::new(Vec::new()); let mut writer = ProtoWriter::new(writer); - writer.write_startup_message(StartupMessageData { - version: 196608, - params: vec![ - ("user".into(), "postgres".into()), - ("database".into(), "postgres".into()), - ], - }).await.unwrap(); + writer + .write_startup_message(StartupMessageData { + version: 196608, + params: vec![ + ("user".into(), "postgres".into()), + ("database".into(), "postgres".into()), + ], + }) + .await + .unwrap(); assert_eq!( writer.inner.buffer(), vec![ - 0, 0, 0, 40, 0, 3, 0, 0, b'u', b's', b'e', b'r', 0, b'p', b'o', - b's', b't', b'g', b'r', b'e', b's', 0, b'd', b'a', b't', b'a', b'b', b'a', b's', - b'e', 0, b'p', b'o', b's', b't', b'g', b'r', b'e', b's', 0 + 0, 0, 0, 40, 0, 3, 0, 0, b'u', b's', b'e', b'r', 0, b'p', b'o', b's', b't', b'g', + b'r', b'e', b's', 0, b'd', b'a', b't', b'a', b'b', b'a', b's', b'e', 0, b'p', b'o', + b's', b't', b'g', b'r', b'e', b's', 0 ] ); } @@ -97,16 +112,17 @@ mod tests { let writer = BufWriter::new(Vec::new()); let mut writer = ProtoWriter::new(writer); - writer.write_cancel_request(CancelRequestData { - pid: 123, - secret: 234, - }).await.unwrap(); + writer + .write_cancel_request(CancelRequestData { + pid: 123, + secret: 234, + }) + .await + .unwrap(); assert_eq!( writer.inner.buffer(), - vec![ - 0, 0, 0, 12, 0, 0, 0, 123, 0, 0, 0, 234 - ] + vec![0, 0, 0, 12, 0, 0, 0, 123, 0, 0, 0, 234] ); } } diff --git a/proto/src/writer/mod.rs b/proto/src/writer/mod.rs index f5cd408..651a31e 100644 --- a/proto/src/writer/mod.rs +++ b/proto/src/writer/mod.rs @@ -1,5 +1,5 @@ pub mod backend; +pub mod errors; pub mod frontend; pub mod oneway; pub mod protowriter; -pub mod errors; diff --git a/proto/src/writer/oneway.rs b/proto/src/writer/oneway.rs index 17bb5ee..30d2665 100644 --- a/proto/src/writer/oneway.rs +++ b/proto/src/writer/oneway.rs @@ -1,8 +1,8 @@ use crate::message::proto_message::ProtoMessage; +use crate::writer::errors::ProtoWriteError; use crate::writer::protowriter::ProtoWriter; use async_trait::async_trait; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use crate::writer::errors::ProtoWriteError; #[async_trait] pub trait OneWayProtoWriter From 75b067762e589d1da3ea5c39ce297f352782a0fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sat, 23 Dec 2023 01:31:49 +0100 Subject: [PATCH 25/30] fix(proto): move handshake response instead of borrowing --- proto/src/handshake/server.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs index e9e6087..4459f65 100644 --- a/proto/src/handshake/server.rs +++ b/proto/src/handshake/server.rs @@ -2,7 +2,7 @@ use crate::handshake::errors::ServerHandshakeError; use crate::handshake::request::HandshakeRequest; use crate::handshake::response::HandshakeResponse; use crate::message::backend::{AuthenticationOkData, BackendMessage, ReadyForQueryData}; -use crate::message::special::{SpecialMessage, StartupMessageData}; +use crate::message::special::SpecialMessage; use crate::reader::frontend::FrontendProtoReader; use crate::writer::backend::BackendProtoWriter; use crate::writer::protowriter::ProtoFlush; @@ -10,7 +10,7 @@ use crate::writer::protowriter::ProtoFlush; pub async fn do_server_handshake( writer: &mut (impl BackendProtoWriter + ProtoFlush), reader: &mut impl FrontendProtoReader, - response: &HandshakeResponse, + response: HandshakeResponse, ) -> Result { match &reader.peek_special_message().await? { Some(msg @ SpecialMessage::SSLRequest) => { From 031816987699eeb64dada68463ea6f6680a20493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sat, 23 Dec 2023 23:40:31 +0100 Subject: [PATCH 26/30] feat(proto): add example client --- Cargo.lock | 25 +++++++++++++++ Cargo.toml | 2 ++ client/Cargo.toml | 11 +++++++ client/src/main.rs | 80 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+) create mode 100644 client/Cargo.toml create mode 100644 client/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index c064b5d..ddb3ac8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -183,6 +183,31 @@ dependencies = [ "syn", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + +[[package]] +name = "smallvec" +version = "1.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" + +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "syn" version = "2.0.41" diff --git a/Cargo.toml b/Cargo.toml index 18f1651..69e4ac6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,4 +3,6 @@ resolver = "2" members = [ "minisql", "proto" + "proto", + "client" ] diff --git a/client/Cargo.toml b/client/Cargo.toml new file mode 100644 index 0000000..9cf09e6 --- /dev/null +++ b/client/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "client" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio = { version = "1.35.1", features = ["full"] } +anyhow = "1.0.76" +proto = { path = "../proto" } \ No newline at end of file diff --git a/client/src/main.rs b/client/src/main.rs new file mode 100644 index 0000000..ac1384c --- /dev/null +++ b/client/src/main.rs @@ -0,0 +1,80 @@ +use proto::handshake::client::do_client_handshake; +use proto::handshake::request::HandshakeRequest; +use proto::reader::protoreader::ProtoReader; +use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::TcpStream; +use proto::message::backend::{BackendMessage, DataRowData, RowDescriptionData}; +use proto::message::frontend::{FrontendMessage, QueryData}; +use proto::reader::oneway::OneWayProtoReader; +use proto::writer::oneway::OneWayProtoWriter; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let add = "127.0.0.1:5432"; + + let mut stream = TcpStream::connect(add).await?; + let (reader, writer) = stream.split(); + + let mut writer = ProtoWriter::new(BufWriter::new(writer)); + let mut reader = ProtoReader::new(BufReader::new(reader), 1024); + + let request = HandshakeRequest::new(196608) + .parameter("user", "test user") + .parameter("client_encoding", "UTF8"); + + let response = do_client_handshake(&mut writer, &mut reader, request).await?; + + println!("Handshake complete:\n{response:?}"); + + writer.write_proto(FrontendMessage::Query(QueryData { + query: "SELECT * FROM users;".to_string().into(), + })).await?; + writer.flush().await?; + + loop { + let msg: BackendMessage = reader.read_proto().await?; + match msg { + BackendMessage::RowDescription(data) => { + print_header(data); + }, + BackendMessage::DataRow(data) => { + print_row(data); + }, + BackendMessage::CommandComplete(data) => { + println!("Command complete: {:?}", data); + }, + BackendMessage::ReadyForQuery(data) => { + println!("Ready for query: {:?}", data); + break; + }, + m => { + println!("Unexpected message: {:?}", m); + } + } + } + + writer.write_proto(FrontendMessage::Terminate).await?; + writer.flush().await?; + + Ok(()) +} + +fn print_header(header: RowDescriptionData) { + print!("Header -> "); + for column in Vec::from(header.columns) { + print!("{} | ", column.name.as_str()); + } + println!(); +} + +fn print_row(row: DataRowData) { + print!("Row -> "); + for column in Vec::from(row.columns) { + let bytes = Vec::from(column); + let string = String::from_utf8(bytes).unwrap(); + + print!("{} | ", string); + } + println!(); +} From 84d9fa2d50f49d5570974055e5b59d3c0b763942 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sat, 23 Dec 2023 23:40:45 +0100 Subject: [PATCH 27/30] feat(proto): add example server --- Cargo.lock | 194 ++++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 2 +- server/Cargo.toml | 11 +++ server/src/main.rs | 184 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 388 insertions(+), 3 deletions(-) create mode 100644 server/Cargo.toml create mode 100644 server/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index ddb3ac8..5c4af8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "anyhow" +version = "1.0.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59d2a3357dde987206219e78ecfbbb6e8dad06cbb65292758d3270e6254f7355" + [[package]] name = "async-trait" version = "0.1.74" @@ -28,6 +34,12 @@ dependencies = [ "syn", ] +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + [[package]] name = "backtrace" version = "0.3.69" @@ -62,6 +74,12 @@ dependencies = [ "virtue", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bytes" version = "1.5.0" @@ -83,18 +101,43 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "client" +version = "0.1.0" +dependencies = [ + "anyhow", + "proto", + "tokio", +] + [[package]] name = "gimli" version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "hermit-abi" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" + [[package]] name = "libc" version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "memchr" version = "2.6.4" @@ -114,6 +157,27 @@ dependencies = [ "adler", ] +[[package]] +name = "mio" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.32.1" @@ -123,6 +187,29 @@ dependencies = [ "memchr", ] +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -157,12 +244,27 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags", +] + [[package]] name = "rustc-demangle" version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "serde" version = "1.0.193" @@ -183,6 +285,15 @@ dependencies = [ "syn", ] +[[package]] +name = "server" +version = "0.1.0" +dependencies = [ + "anyhow", + "proto", + "tokio", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -241,14 +352,21 @@ dependencies = [ [[package]] name = "tokio" -version = "1.35.0" +version = "1.35.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" dependencies = [ "backtrace", "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot", "pin-project-lite", + "signal-hook-registry", + "socket2", "tokio-macros", + "windows-sys", ] [[package]] @@ -273,3 +391,75 @@ name = "virtue" version = "0.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9dcc60c0624df774c82a0ef104151231d37da4962957d691c011c852b2473314" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/Cargo.toml b/Cargo.toml index 69e4ac6..714e4bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ resolver = "2" members = [ "minisql", - "proto" "proto", + "server", "client" ] diff --git a/server/Cargo.toml b/server/Cargo.toml new file mode 100644 index 0000000..bca61ec --- /dev/null +++ b/server/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "server" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio = { version = "1.35.1", features = ["full"] } +anyhow = "1.0.76" +proto = { path = "../proto" } \ No newline at end of file diff --git a/server/src/main.rs b/server/src/main.rs new file mode 100644 index 0000000..f5ea267 --- /dev/null +++ b/server/src/main.rs @@ -0,0 +1,184 @@ +use proto::handshake::response::HandshakeResponse; +use proto::handshake::server::do_server_handshake; +use proto::message::backend::{ + BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, + ReadyForQueryData, RowDescriptionData, +}; +use proto::message::frontend::FrontendMessage; +use proto::reader::oneway::OneWayProtoReader; +use proto::reader::protoreader::ProtoReader; +use proto::writer::backend::BackendProtoWriter; +use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::{TcpListener, TcpStream}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let addr = "0.0.0.0:5432"; + let listener = TcpListener::bind(&addr).await?; + println!("Server started at {addr}"); + + loop { + let (socket, _) = listener.accept().await?; + println!("New client connected: {}", socket.peer_addr()?); + tokio::spawn(async move { + let reason = handle_stream(socket).await; + println!("Client disconnected: {reason:?}"); + }); + } +} + +async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { + let (reader, writer) = stream.split(); + let mut writer = ProtoWriter::new(BufWriter::new(writer)); + let mut reader = ProtoReader::new(BufReader::new(reader), 1024); + + let response = HandshakeResponse::new("minisql", 123, 123); + let request = do_server_handshake(&mut writer, &mut reader, response).await?; + + println!("Handshake complete:\n{request:?}"); + + loop { + println!("Waiting for next message"); + let next: FrontendMessage = reader.read_proto().await?; + + match next { + FrontendMessage::Terminate => { + println!("Received Terminate"); + break; + } + FrontendMessage::Query(data) => { + println!("Received Query: {:?}", data); + if data.query.as_str().contains("car") { + println!("Sending error message"); + send_error_response(&mut writer, "Car not found").await?; + } else if data.query.as_str().to_lowercase().contains("select") { + println!("Sending table"); + send_query_repsonse(&mut writer).await?; + } else { + println!("Sending empty query"); + send_empty_query(&mut writer).await?; + } + send_ready_for_query(&mut writer).await?; + } + } + writer.flush().await?; + } + + Ok(()) +} + +async fn send_error_response( + writer: &mut impl BackendProtoWriter, + error_message: &str, +) -> anyhow::Result<()> { + writer + .write_proto( + ErrorResponseData { + code: b'M', + message: error_message.to_string().into(), + } + .into(), + ) + .await?; + + Ok(()) +} + +async fn send_ready_for_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { + writer + .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) + .await?; + + Ok(()) +} + +async fn send_empty_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { + writer + .write_proto(BackendMessage::EmptyQueryResponse) + .await?; + + Ok(()) +} + +async fn send_row_description(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { + let columns = vec![ + ColumnDescription { + name: "id".to_string().into(), + table_oid: 123, + column_index: 1, + type_oid: 23, + type_size: 4, + type_modifier: -1, + format_code: 0, + }, + ColumnDescription { + name: "argument".to_string().into(), + table_oid: 123, + column_index: 2, + type_oid: 23, + type_size: 4, + type_modifier: -1, + format_code: 0, + }, + ColumnDescription { + name: "description".to_string().into(), + table_oid: 123, + column_index: 3, + type_oid: 1043, + type_size: 32, + type_modifier: -1, + format_code: 0, + }, + ]; + + writer + .write_proto( + RowDescriptionData { + columns: columns.into(), + } + .into(), + ) + .await?; + + Ok(()) +} + +async fn send_query_repsonse(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { + send_row_description(writer).await?; + + write_row(writer, b"0", b"1337", b"auto").await?; + write_row(writer, b"1", b"69", b"bus").await?; + write_row(writer, b"2", b"420", b"kolo").await?; + + writer + .write_proto( + CommandCompleteData { + tag: "SELECT 3".to_string().into(), + } + .into(), + ) + .await?; + + Ok(()) +} + +async fn write_row( + writer: &mut impl BackendProtoWriter, + first: &[u8], + second: &[u8], + third: &[u8], +) -> anyhow::Result<()> { + let row_data = vec![ + first.to_vec().into(), + second.to_vec().into(), + third.to_vec().into(), + ] + .into(); + + writer + .write_proto(DataRowData { columns: row_data }.into()) + .await?; + + Ok(()) +} From eb8410718da934c8225fd494aba88e7f69b5b943 Mon Sep 17 00:00:00 2001 From: Yuriy Dupyn <2153100+omedusyo@users.noreply.github.com> Date: Thu, 28 Dec 2023 09:27:55 +0100 Subject: [PATCH 28/30] Fix typos --- client/src/main.rs | 4 ++-- server/src/main.rs | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/client/src/main.rs b/client/src/main.rs index ac1384c..e77ea5c 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -11,9 +11,9 @@ use proto::writer::oneway::OneWayProtoWriter; #[tokio::main] async fn main() -> anyhow::Result<()> { - let add = "127.0.0.1:5432"; + let addr = "127.0.0.1:5432"; - let mut stream = TcpStream::connect(add).await?; + let mut stream = TcpStream::connect(addr).await?; let (reader, writer) = stream.split(); let mut writer = ProtoWriter::new(BufWriter::new(writer)); diff --git a/server/src/main.rs b/server/src/main.rs index f5ea267..bda6dfd 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -34,6 +34,7 @@ async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { let mut reader = ProtoReader::new(BufReader::new(reader), 1024); let response = HandshakeResponse::new("minisql", 123, 123); + let request = do_server_handshake(&mut writer, &mut reader, response).await?; println!("Handshake complete:\n{request:?}"); @@ -54,7 +55,7 @@ async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { send_error_response(&mut writer, "Car not found").await?; } else if data.query.as_str().to_lowercase().contains("select") { println!("Sending table"); - send_query_repsonse(&mut writer).await?; + send_query_response(&mut writer).await?; } else { println!("Sending empty query"); send_empty_query(&mut writer).await?; @@ -144,7 +145,7 @@ async fn send_row_description(writer: &mut impl BackendProtoWriter) -> anyhow::R Ok(()) } -async fn send_query_repsonse(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { +async fn send_query_response(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { send_row_description(writer).await?; write_row(writer, b"0", b"1337", b"auto").await?; From c61b6021db5272e9ca1c5dab60c7648717a400f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sun, 31 Dec 2023 18:45:50 +0100 Subject: [PATCH 29/30] docs: handshake documentation --- proto/src/handshake/client.rs | 6 ++++++ proto/src/handshake/request.rs | 10 ++++++++-- proto/src/handshake/response.rs | 3 ++- proto/src/handshake/server.rs | 8 ++++++++ proto/src/lib.rs | 4 ++++ 5 files changed, 28 insertions(+), 3 deletions(-) diff --git a/proto/src/handshake/client.rs b/proto/src/handshake/client.rs index 2deb5bc..ff3aaed 100644 --- a/proto/src/handshake/client.rs +++ b/proto/src/handshake/client.rs @@ -7,15 +7,20 @@ use crate::reader::backend::BackendProtoReader; use crate::writer::frontend::FrontendProtoWriter; use crate::writer::protowriter::ProtoFlush; +/// Performs client-side handshake with the server until the `ReadyForQuery` message is received. +/// For more info visit the [`55.2.1. Start-up`](https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-START-UP) pub async fn do_client_handshake( writer: &mut (impl FrontendProtoWriter + ProtoFlush), reader: &mut impl BackendProtoReader, request: HandshakeRequest, ) -> Result { + + // Send StartupMessage without SSLRequest let startup_message: StartupMessageData = request.into(); writer.write_startup_message(startup_message).await?; writer.flush().await?; + // Wait for AuthenticationOk let auth = reader.read_proto().await?; if !matches!( auth, @@ -24,6 +29,7 @@ pub async fn do_client_handshake( return Err(ClientHandshakeError::UnexpectedAuthResponse(auth)); } + // Read server parameter messages until ReadyForQuery is received let mut messages = Vec::new(); loop { let msg = reader.read_proto().await?; diff --git a/proto/src/handshake/request.rs b/proto/src/handshake/request.rs index 408238f..9f6b9cb 100644 --- a/proto/src/handshake/request.rs +++ b/proto/src/handshake/request.rs @@ -3,11 +3,14 @@ use crate::message::special::StartupMessageData; #[derive(Debug)] pub struct HandshakeRequest { - version: i32, - parameters: Vec<(PgString, PgString)>, + pub version: i32, + pub parameters: Vec<(PgString, PgString)>, } impl HandshakeRequest { + + /// Creates a new `HandshakeRequest` with the specified version. + /// Expected `version` is 196608 (3.0). pub fn new(version: i32) -> Self { Self { version, @@ -15,6 +18,9 @@ impl HandshakeRequest { } } + /// Adds a parameter to the startup message. + /// Generally recognized names are `user`, `database`, `option` and `replication` but others can be used. + /// For more info visit [`StartupMessage`](https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE) pub fn parameter(mut self, key: &str, value: &str) -> Self { self.parameters.push((key.into(), value.into())); self diff --git a/proto/src/handshake/response.rs b/proto/src/handshake/response.rs index 38579b6..60d0b6f 100644 --- a/proto/src/handshake/response.rs +++ b/proto/src/handshake/response.rs @@ -11,7 +11,7 @@ pub struct HandshakeResponse { impl HandshakeResponse { pub fn new(name: &str, pid: i32, key: i32) -> Self { Self { - version: format!("16.0 ({name})", name = name), + version: format!("16.0 ({name})"), process_id: pid, secret_key: key, } @@ -37,6 +37,7 @@ impl TryFrom<&[BackendMessage]> for HandshakeResponse { process_id = Some(data.process); secret_key = Some(data.secret); } + // Different messages are ignored during the handshake _ => {} } } diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs index 4459f65..6c8deb2 100644 --- a/proto/src/handshake/server.rs +++ b/proto/src/handshake/server.rs @@ -7,11 +7,15 @@ use crate::reader::frontend::FrontendProtoReader; use crate::writer::backend::BackendProtoWriter; use crate::writer::protowriter::ProtoFlush; +/// Performs server-side handshake with the client until ending it with `ReadyForQuery` message. +/// For more info visit the [`55.2.1. Start-up`](https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-START-UP) pub async fn do_server_handshake( writer: &mut (impl BackendProtoWriter + ProtoFlush), reader: &mut impl FrontendProtoReader, response: HandshakeResponse, ) -> Result { + + // Check if client requested SSL match &reader.peek_special_message().await? { Some(msg @ SpecialMessage::SSLRequest) => { reader.consume_special_message(msg).await?; @@ -23,6 +27,7 @@ pub async fn do_server_handshake( } } + // Wait for mandatory StartupMessage let startup_message = match &reader.peek_special_message().await? { Some(msg @ SpecialMessage::StartupMessage(data)) => { reader.consume_special_message(msg).await?; @@ -33,15 +38,18 @@ pub async fn do_server_handshake( } }; + // Authenticate client writer .write_proto(BackendMessage::from(AuthenticationOkData { status: 0 })) .await?; + // Send server parameters let messages: Vec = response.into(); for message in messages { writer.write_proto(message).await?; } + // Finish the handshake writer .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) .await?; diff --git a/proto/src/lib.rs b/proto/src/lib.rs index 69afed6..4395510 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -1,3 +1,7 @@ +//! # PostgreSQL 16 Protocol +//! Low-level PostgreSQL protocol implementation for the version 16, protocol version 3.0. +//! Includes server and client side handshake with no password authentication. + pub mod handshake; pub mod message; pub mod reader; From df5741224f6cbd67d7b5a4d05d5267465253ff60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sun, 31 Dec 2023 19:03:10 +0100 Subject: [PATCH 30/30] docs: data messages documentation --- proto/src/handshake/request.rs | 4 ++-- proto/src/lib.rs | 4 ++-- proto/src/message/backend.rs | 2 ++ proto/src/message/frontend.rs | 2 ++ proto/src/message/primitive/config.rs | 7 ------- proto/src/message/primitive/data.rs | 8 +++++++- proto/src/message/primitive/mod.rs | 1 - proto/src/message/primitive/pglist.rs | 3 +++ proto/src/message/primitive/pgstring.rs | 1 + proto/src/message/special.rs | 5 +++++ 10 files changed, 24 insertions(+), 13 deletions(-) delete mode 100644 proto/src/message/primitive/config.rs diff --git a/proto/src/handshake/request.rs b/proto/src/handshake/request.rs index 9f6b9cb..51b6ad5 100644 --- a/proto/src/handshake/request.rs +++ b/proto/src/handshake/request.rs @@ -9,8 +9,8 @@ pub struct HandshakeRequest { impl HandshakeRequest { - /// Creates a new `HandshakeRequest` with the specified version. - /// Expected `version` is 196608 (3.0). + /// Creates a new `HandshakeRequest` with the specified protocol version. + /// Expected `version` is `196608` for the 3.0. pub fn new(version: i32) -> Self { Self { version, diff --git a/proto/src/lib.rs b/proto/src/lib.rs index 4395510..e9d155d 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -1,5 +1,5 @@ -//! # PostgreSQL 16 Protocol -//! Low-level PostgreSQL protocol implementation for the version 16, protocol version 3.0. +//! # PostgreSQL Protocol +//! Low-level PostgreSQL protocol implementation for the server version 16, protocol version 3.0. //! Includes server and client side handshake with no password authentication. pub mod handshake; diff --git a/proto/src/message/backend.rs b/proto/src/message/backend.rs index 0aa0e20..869fe14 100644 --- a/proto/src/message/backend.rs +++ b/proto/src/message/backend.rs @@ -5,6 +5,8 @@ use crate::message::primitive::pgstring::PgString; use crate::message::proto_message::ProtoMessage; use bincode::{Decode, Encode}; +/// Backend messages sent from the server to the client. +/// For more info visit the [`55.2.3. Message Formats`](https://www.postgresql.org/docs/current/protocol-message-formats.html) #[derive(Debug)] pub enum BackendMessage { AuthenticationOk(AuthenticationOkData), diff --git a/proto/src/message/frontend.rs b/proto/src/message/frontend.rs index 178be7b..648938e 100644 --- a/proto/src/message/frontend.rs +++ b/proto/src/message/frontend.rs @@ -4,6 +4,8 @@ use crate::message::primitive::pgstring::PgString; use crate::message::proto_message::ProtoMessage; use bincode::{Decode, Encode}; +/// Frontend messages sent from the client to the server. +/// For more info visit the [`55.2.3. Message Formats`](https://www.postgresql.org/docs/current/protocol-message-formats.html) #[derive(Debug)] pub enum FrontendMessage { Query(QueryData), diff --git a/proto/src/message/primitive/config.rs b/proto/src/message/primitive/config.rs deleted file mode 100644 index 5aa74f5..0000000 --- a/proto/src/message/primitive/config.rs +++ /dev/null @@ -1,7 +0,0 @@ -use bincode::config::{BigEndian, Configuration, Fixint}; - -pub fn pg_proto_config() -> Configuration { - bincode::config::standard() - .with_big_endian() - .with_fixed_int_encoding() -} diff --git a/proto/src/message/primitive/data.rs b/proto/src/message/primitive/data.rs index 4df50ae..db19ad4 100644 --- a/proto/src/message/primitive/data.rs +++ b/proto/src/message/primitive/data.rs @@ -1,6 +1,12 @@ use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; -use crate::message::primitive::config::pg_proto_config; use bincode::{Decode, Encode}; +use bincode::config::{BigEndian, Configuration, Fixint}; + +fn pg_proto_config() -> Configuration { + bincode::config::standard() + .with_big_endian() + .with_fixed_int_encoding() +} pub trait MessageData: Sized { fn serialize(&self) -> Result, ProtoSerializeError>; diff --git a/proto/src/message/primitive/mod.rs b/proto/src/message/primitive/mod.rs index e275e6e..4e84a1b 100644 --- a/proto/src/message/primitive/mod.rs +++ b/proto/src/message/primitive/mod.rs @@ -1,4 +1,3 @@ -pub(crate) mod config; pub(crate) mod data; pub mod pglist; pub mod pgstring; diff --git a/proto/src/message/primitive/pglist.rs b/proto/src/message/primitive/pglist.rs index aa95ca2..1e76db3 100644 --- a/proto/src/message/primitive/pglist.rs +++ b/proto/src/message/primitive/pglist.rs @@ -4,6 +4,9 @@ use bincode::error::{DecodeError, EncodeError}; use bincode::{BorrowDecode, Decode, Encode}; use std::marker::PhantomData; +/// Item list common in PostgreSQL messages. +/// - Generic type `T` is the type of the items in the list. +/// - Generic type `U` is the type of the list length (`i16` or `i32`). #[derive(Debug, Clone, PartialEq, BorrowDecode)] pub struct PgList(Vec, PhantomData); diff --git a/proto/src/message/primitive/pgstring.rs b/proto/src/message/primitive/pgstring.rs index 2c6cd7b..58fad78 100644 --- a/proto/src/message/primitive/pgstring.rs +++ b/proto/src/message/primitive/pgstring.rs @@ -4,6 +4,7 @@ use bincode::enc::Encoder; use bincode::error::{DecodeError, EncodeError}; use bincode::{BorrowDecode, Decode, Encode}; +/// PostgreSQL format of string encoded as a null-terminated string. #[derive(Debug, Clone, BorrowDecode)] pub struct PgString(String); diff --git a/proto/src/message/special.rs b/proto/src/message/special.rs index 166320f..6c45ab9 100644 --- a/proto/src/message/special.rs +++ b/proto/src/message/special.rs @@ -4,10 +4,15 @@ use bincode::enc::Encoder; use bincode::error::{DecodeError, EncodeError}; use bincode::{Decode, Encode}; +/// Special messages sent during handshake or to cancel request. +/// Sent in different format to preserve compatibility with older protocol versions. #[derive(Debug)] pub enum SpecialMessage { + /// Sent by client to cancel request. CancelRequest(CancelRequestData), + /// Sent by client to request upgrade to SSL connection. SSLRequest, + /// Sent by client to initiate the handshake. StartupMessage(StartupMessageData), }