From 3512067c4ec5fbf77f4168b5bef9dc6aaec34c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:43:19 +0100 Subject: [PATCH] feat(proto): add backend messages --- proto/src/message/backend.rs | 357 +++++++++++++++++++++++++++++++++++ proto/src/message/mod.rs | 1 + 2 files changed, 358 insertions(+) create mode 100644 proto/src/message/backend.rs diff --git a/proto/src/message/backend.rs b/proto/src/message/backend.rs new file mode 100644 index 0000000..1d96833 --- /dev/null +++ b/proto/src/message/backend.rs @@ -0,0 +1,357 @@ +use crate::message::primitive::data::MessageData; +use crate::message::primitive::pglist::PgList; +use crate::message::primitive::pgstring::PgString; +use crate::message::proto_message::ProtoMessage; +use bincode::{Decode, Encode}; + +#[derive(Debug)] +pub enum BackendMessage { + AuthenticationOk(AuthenticationOkData), + BackendKeyData(BackendKeyDataData), + CommandComplete(CommandCompleteData), + DataRow(DataRowData), + EmptyQueryResponse, + ErrorResponse(ErrorResponseData), + NoData, + ParameterStatus(ParameterStatusData), + ReadyForQuery(ReadyForQueryData), + RowDescription(RowDescriptionData), +} + +impl ProtoMessage for BackendMessage { + fn variant(&self) -> u8 { + match self { + BackendMessage::AuthenticationOk(_) => b'R', + BackendMessage::BackendKeyData(_) => b'K', + BackendMessage::CommandComplete(_) => b'C', + BackendMessage::DataRow(_) => b'D', + BackendMessage::EmptyQueryResponse => b'I', + BackendMessage::ErrorResponse(_) => b'E', + BackendMessage::NoData => b'n', + BackendMessage::ParameterStatus(_) => b'S', + BackendMessage::ReadyForQuery(_) => b'Z', + BackendMessage::RowDescription(_) => b'T', + } + } + + fn serialize(&self) -> anyhow::Result> { + match self { + BackendMessage::AuthenticationOk(data) => data.serialize(), + BackendMessage::BackendKeyData(data) => data.serialize(), + BackendMessage::CommandComplete(data) => data.serialize(), + BackendMessage::DataRow(data) => data.serialize(), + BackendMessage::EmptyQueryResponse => Ok(Vec::with_capacity(0)), + BackendMessage::ErrorResponse(data) => data.serialize(), + BackendMessage::NoData => Ok(Vec::with_capacity(0)), + BackendMessage::ParameterStatus(data) => data.serialize(), + BackendMessage::ReadyForQuery(data) => data.serialize(), + BackendMessage::RowDescription(data) => data.serialize(), + } + } + + fn deserialize(variant: u8, data: &[u8]) -> anyhow::Result { + match variant { + b'R' => Ok(BackendMessage::AuthenticationOk( + AuthenticationOkData::deserialize(data)?, + )), + b'K' => { + let data = BackendKeyDataData::deserialize(data)?; + Ok(BackendMessage::BackendKeyData(data)) + } + b'C' => { + let data = CommandCompleteData::deserialize(data)?; + Ok(BackendMessage::CommandComplete(data)) + } + b'D' => { + let data = DataRowData::deserialize(data)?; + Ok(BackendMessage::DataRow(data)) + } + b'I' => Ok(BackendMessage::EmptyQueryResponse), + b'E' => { + let data = ErrorResponseData::deserialize(data)?; + Ok(BackendMessage::ErrorResponse(data)) + } + b'n' => Ok(BackendMessage::NoData), + b'S' => { + let data = ParameterStatusData::deserialize(data)?; + Ok(BackendMessage::ParameterStatus(data)) + } + b'Z' => { + let data = ReadyForQueryData::deserialize(data)?; + Ok(BackendMessage::ReadyForQuery(data)) + } + b'T' => { + let data = RowDescriptionData::deserialize(data)?; + Ok(BackendMessage::RowDescription(data)) + } + v => Err(anyhow::anyhow!("Unknown backend message variant: {v}")), + } + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct AuthenticationOkData { + pub status: i32, +} + +impl From for BackendMessage { + fn from(data: AuthenticationOkData) -> Self { + BackendMessage::AuthenticationOk(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct BackendKeyDataData { + pub process: i32, + pub secret: i32, +} + +impl From for BackendMessage { + fn from(data: BackendKeyDataData) -> Self { + BackendMessage::BackendKeyData(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct CommandCompleteData { + pub tag: PgString, +} + +impl From for BackendMessage { + fn from(data: CommandCompleteData) -> Self { + BackendMessage::CommandComplete(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct DataRowData { + pub columns: PgList, i16>, +} + +impl From for BackendMessage { + fn from(data: DataRowData) -> Self { + BackendMessage::DataRow(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ErrorResponseData { + pub code: u8, + pub message: PgString, +} + +impl From for BackendMessage { + fn from(data: ErrorResponseData) -> Self { + BackendMessage::ErrorResponse(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ParameterStatusData { + pub name: PgString, + pub value: PgString, +} + +impl From for BackendMessage { + fn from(data: ParameterStatusData) -> Self { + BackendMessage::ParameterStatus(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ReadyForQueryData { + pub status: u8, +} + +impl From for BackendMessage { + fn from(data: ReadyForQueryData) -> Self { + BackendMessage::ReadyForQuery(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct RowDescriptionData { + pub columns: PgList, +} + +impl From for BackendMessage { + fn from(data: RowDescriptionData) -> Self { + BackendMessage::RowDescription(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ColumnDescription { + pub name: PgString, + pub table_oid: i32, + pub column_index: i16, + pub type_oid: i32, + pub type_size: i16, + pub type_modifier: i32, + pub format_code: i16, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symmetric_authentication_ok() { + let backend = BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 }) + )); + } + + #[test] + fn test_symmetric_backend_key_data() { + let backend = BackendMessage::BackendKeyData(BackendKeyDataData { + process: 123, + secret: 456, + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::BackendKeyData(BackendKeyDataData { + process: 123, + secret: 456 + }) + )); + } + + #[test] + fn test_symmetric_command_complete() { + let backend = BackendMessage::CommandComplete(CommandCompleteData { + tag: PgString::from("SELECT 1"), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::CommandComplete(CommandCompleteData { tag }) if tag.as_str() == "SELECT 1" + )); + } + + #[test] + fn test_symmetric_data_row() { + let backend = BackendMessage::DataRow(DataRowData { + columns: PgList::from(vec![PgList::from(vec![1, 2, 3])]), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::DataRow(DataRowData { columns }) if columns == PgList::from(vec![PgList::from(vec![1, 2, 3])]) + )); + } + + #[test] + fn test_symmetric_empty_query_response() { + let backend = BackendMessage::EmptyQueryResponse; + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!(message, BackendMessage::EmptyQueryResponse)); + } + + #[test] + fn test_symmetric_error_response() { + let backend = BackendMessage::ErrorResponse(ErrorResponseData { + code: b'X', + message: PgString::from("Some error"), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::ErrorResponse(ErrorResponseData { code, message }) if code == b'X' && message.as_str() == "Some error" + )); + } + + #[test] + fn test_symmetric_no_data() { + let backend = BackendMessage::NoData; + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!(message, BackendMessage::NoData)); + } + + #[test] + fn test_symmetric_parameter_status() { + let backend = BackendMessage::ParameterStatus(ParameterStatusData { + name: PgString::from("Some name"), + value: PgString::from("Some value"), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::ParameterStatus(ParameterStatusData { name, value }) if name.as_str() == "Some name" && value.as_str() == "Some value" + )); + } + + #[test] + fn test_symmetric_ready_for_query() { + let backend = BackendMessage::ReadyForQuery(ReadyForQueryData { status: b'I' }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::ReadyForQuery(ReadyForQueryData { status }) if status == b'I' + )); + } + + #[test] + fn test_symmetric_row_description() { + let backend = BackendMessage::RowDescription(RowDescriptionData { + columns: PgList::from(vec![ColumnDescription { + name: PgString::from("Some name"), + table_oid: 123, + column_index: 456, + type_oid: 789, + type_size: 101, + type_modifier: 112, + format_code: 113, + }]), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(match message { + BackendMessage::RowDescription(RowDescriptionData { columns }) => { + let columns: Vec = columns.into(); + let column = &columns[0]; + column.name.as_str() == "Some name" + && column.table_oid == 123 + && column.column_index == 456 + && column.type_oid == 789 + && column.type_size == 101 + && column.type_modifier == 112 + && column.format_code == 113 + } + _ => false, + },) + } +} diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs index 79c2f96..5b3cd7a 100644 --- a/proto/src/message/mod.rs +++ b/proto/src/message/mod.rs @@ -1,2 +1,3 @@ +pub mod backend; pub mod primitive; pub mod proto_message;