From ee2742ea03eb5facdeb586cfee9ea821aa7932cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:43:46 +0100 Subject: [PATCH] feat(proto): add frontend messages --- proto/src/message/frontend.rs | 68 +++++++++++++++++++++++++++++++++++ proto/src/message/mod.rs | 1 + 2 files changed, 69 insertions(+) create mode 100644 proto/src/message/frontend.rs diff --git a/proto/src/message/frontend.rs b/proto/src/message/frontend.rs new file mode 100644 index 0000000..4914fc2 --- /dev/null +++ b/proto/src/message/frontend.rs @@ -0,0 +1,68 @@ +use crate::message::primitive::data::MessageData; +use crate::message::primitive::pgstring::PgString; +use crate::message::proto_message::ProtoMessage; +use bincode::{Decode, Encode}; + +#[derive(Debug)] +pub enum FrontendMessage { + Query(QueryData), + Terminate, +} + +impl ProtoMessage for FrontendMessage { + fn variant(&self) -> u8 { + match self { + FrontendMessage::Query(_) => b'Q', + FrontendMessage::Terminate => b'X', + } + } + + fn serialize(&self) -> anyhow::Result> { + match self { + FrontendMessage::Query(data) => data.serialize(), + FrontendMessage::Terminate => Ok(Vec::with_capacity(0)), + } + } + + fn deserialize(variant: u8, data: &[u8]) -> anyhow::Result { + match variant { + b'Q' => Ok(FrontendMessage::Query(QueryData::deserialize(data)?)), + b'X' => Ok(FrontendMessage::Terminate), + v => Err(anyhow::anyhow!("Unknown frontend message variant: {v}")), + } + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct QueryData { + pub query: PgString, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symmetric_query() { + let frontend = FrontendMessage::Query(QueryData { + query: PgString::from("SELECT * FROM foo WHERE bar = $1"), + }); + let raw = frontend.serialize().unwrap(); + let variant = frontend.variant(); + + let message = FrontendMessage::deserialize(variant, &raw).unwrap(); + assert!( + matches!(message, FrontendMessage::Query(QueryData { query }) if query.as_str() == "SELECT * FROM foo WHERE bar = $1") + ); + } + + #[test] + fn test_symmetric_terminate() { + let frontend = FrontendMessage::Terminate; + let raw = frontend.serialize().unwrap(); + let variant = frontend.variant(); + + let message = FrontendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!(message, FrontendMessage::Terminate)); + } +} diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs index 5b3cd7a..30ae1a3 100644 --- a/proto/src/message/mod.rs +++ b/proto/src/message/mod.rs @@ -1,3 +1,4 @@ pub mod backend; +pub mod frontend; pub mod primitive; pub mod proto_message;