use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; use crate::message::primitive::data::MessageData; use crate::message::primitive::pgstring::PgString; use crate::message::proto_message::ProtoMessage; use bincode::{Decode, Encode}; /// Frontend messages sent from the client to the server. /// For more info visit the [`55.2.3. Message Formats`](https://www.postgresql.org/docs/current/protocol-message-formats.html) #[derive(Debug)] pub enum FrontendMessage { Query(QueryData), Terminate, } impl ProtoMessage for FrontendMessage { fn variant(&self) -> u8 { match self { FrontendMessage::Query(_) => b'Q', FrontendMessage::Terminate => b'X', } } fn serialize(&self) -> Result, ProtoSerializeError> { match self { FrontendMessage::Query(data) => data.serialize(), FrontendMessage::Terminate => Ok(Vec::with_capacity(0)), } } fn deserialize(variant: u8, data: &[u8]) -> Result { match variant { b'Q' => Ok(FrontendMessage::Query(QueryData::deserialize(data)?)), b'X' => Ok(FrontendMessage::Terminate), v => Err(ProtoDeserializeError::InvalidVariant(v)), } } } #[derive(Debug, Clone, Encode, Decode)] pub struct QueryData { pub query: PgString, } #[cfg(test)] mod tests { use super::*; use crate::message::backend::BackendMessage; #[test] fn test_symmetric_query() { let frontend = FrontendMessage::Query(QueryData { query: PgString::from("SELECT * FROM foo WHERE bar = $1"), }); let raw = frontend.serialize().unwrap(); let variant = frontend.variant(); let message = FrontendMessage::deserialize(variant, &raw).unwrap(); assert!( matches!(message, FrontendMessage::Query(QueryData { query }) if query.as_str() == "SELECT * FROM foo WHERE bar = $1") ); } #[test] fn test_symmetric_terminate() { let frontend = FrontendMessage::Terminate; let raw = frontend.serialize().unwrap(); let variant = frontend.variant(); let message = FrontendMessage::deserialize(variant, &raw).unwrap(); assert!(matches!(message, FrontendMessage::Terminate)); } #[test] fn test_unknown_variant() { let variant = 0; let data = vec![1, 2, 3]; let message = BackendMessage::deserialize(variant, &data); assert!(matches!( message, Err(ProtoDeserializeError::InvalidVariant(0)) )); } }