use crate::message::frontend::FrontendMessage; use crate::message::primitive::data::MessageData; use crate::message::special::{CancelRequestData, StartupMessageData}; use crate::writer::errors::ProtoWriteError; use crate::writer::oneway::OneWayProtoWriter; use crate::writer::protowriter::ProtoWriter; use async_trait::async_trait; use tokio::io::{AsyncWrite, AsyncWriteExt}; #[async_trait] pub trait FrontendProtoWriter: OneWayProtoWriter { async fn write_startup_message( &mut self, startup_message: StartupMessageData, ) -> Result<(), ProtoWriteError>; async fn write_cancel_request( &mut self, cancel_request: CancelRequestData, ) -> Result<(), ProtoWriteError>; } #[async_trait] impl FrontendProtoWriter for ProtoWriter where W: AsyncWrite + Unpin + Send, { async fn write_startup_message( &mut self, startup_message: StartupMessageData, ) -> Result<(), ProtoWriteError> { let data = startup_message.serialize()?; let length = data.len() + 4; self.inner.write_i32(length as i32).await?; self.inner.write_all(&data).await?; Ok(()) } async fn write_cancel_request( &mut self, cancel_request: CancelRequestData, ) -> Result<(), ProtoWriteError> { let data = cancel_request.serialize()?; let length = data.len() + 4; self.inner.write_i32(length as i32).await?; self.inner.write_all(&data).await?; Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::message::frontend::QueryData; use crate::writer::protowriter::ProtoWriter; use tokio::io::BufWriter; #[tokio::test] async fn test_message_sequence() { let writer = BufWriter::new(Vec::new()); let mut writer = ProtoWriter::new(writer); writer .write_proto(FrontendMessage::Query(QueryData { query: "SLIME".into(), })) .await .unwrap(); writer .write_proto(FrontendMessage::Terminate) .await .unwrap(); assert_eq!( writer.inner.buffer(), vec![b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4] ); } #[tokio::test] async fn test_startup_message() { let writer = BufWriter::new(Vec::new()); let mut writer = ProtoWriter::new(writer); writer .write_startup_message(StartupMessageData { version: 196608, params: vec![ ("user".into(), "postgres".into()), ("database".into(), "postgres".into()), ], }) .await .unwrap(); assert_eq!( writer.inner.buffer(), vec![ 0, 0, 0, 40, 0, 3, 0, 0, b'u', b's', b'e', b'r', 0, b'p', b'o', b's', b't', b'g', b'r', b'e', b's', 0, b'd', b'a', b't', b'a', b'b', b'a', b's', b'e', 0, b'p', b'o', b's', b't', b'g', b'r', b'e', b's', 0 ] ); } #[tokio::test] async fn test_cancel_request() { let writer = BufWriter::new(Vec::new()); let mut writer = ProtoWriter::new(writer); writer .write_cancel_request(CancelRequestData { pid: 123, secret: 234, }) .await .unwrap(); assert_eq!( writer.inner.buffer(), vec![0, 0, 0, 12, 0, 0, 0, 123, 0, 0, 0, 234] ); } }