minisql/proto/src/message/frontend.rs
2023-12-31 19:04:40 +01:00

84 lines
2.6 KiB
Rust

use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError};
use crate::message::primitive::data::MessageData;
use crate::message::primitive::pgstring::PgString;
use crate::message::proto_message::ProtoMessage;
use bincode::{Decode, Encode};
/// Frontend messages sent from the client to the server.
/// For more info visit the [`55.2.3. Message Formats`](https://www.postgresql.org/docs/current/protocol-message-formats.html)
#[derive(Debug)]
pub enum FrontendMessage {
Query(QueryData),
Terminate,
}
impl ProtoMessage for FrontendMessage {
fn variant(&self) -> u8 {
match self {
FrontendMessage::Query(_) => b'Q',
FrontendMessage::Terminate => b'X',
}
}
fn serialize(&self) -> Result<Vec<u8>, ProtoSerializeError> {
match self {
FrontendMessage::Query(data) => data.serialize(),
FrontendMessage::Terminate => Ok(Vec::with_capacity(0)),
}
}
fn deserialize(variant: u8, data: &[u8]) -> Result<Self, ProtoDeserializeError> {
match variant {
b'Q' => Ok(FrontendMessage::Query(QueryData::deserialize(data)?)),
b'X' => Ok(FrontendMessage::Terminate),
v => Err(ProtoDeserializeError::InvalidVariant(v)),
}
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct QueryData {
pub query: PgString,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::backend::BackendMessage;
#[test]
fn test_symmetric_query() {
let frontend = FrontendMessage::Query(QueryData {
query: PgString::from("SELECT * FROM foo WHERE bar = $1"),
});
let raw = frontend.serialize().unwrap();
let variant = frontend.variant();
let message = FrontendMessage::deserialize(variant, &raw).unwrap();
assert!(
matches!(message, FrontendMessage::Query(QueryData { query }) if query.as_str() == "SELECT * FROM foo WHERE bar = $1")
);
}
#[test]
fn test_symmetric_terminate() {
let frontend = FrontendMessage::Terminate;
let raw = frontend.serialize().unwrap();
let variant = frontend.variant();
let message = FrontendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(message, FrontendMessage::Terminate));
}
#[test]
fn test_unknown_variant() {
let variant = 0;
let data = vec![1, 2, 3];
let message = BackendMessage::deserialize(variant, &data);
assert!(matches!(
message,
Err(ProtoDeserializeError::InvalidVariant(0))
));
}
}