diff --git a/proto/src/writer/frontend.rs b/proto/src/writer/frontend.rs index ba634cc..286f1ba 100644 --- a/proto/src/writer/frontend.rs +++ b/proto/src/writer/frontend.rs @@ -1,12 +1,43 @@ use crate::message::frontend::FrontendMessage; use crate::writer::oneway::OneWayProtoWriter; use async_trait::async_trait; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use crate::message::primitive::data::MessageData; +use crate::message::special::{CancelRequestData, StartupMessageData}; +use crate::writer::errors::ProtoWriteError; +use crate::writer::protowriter::ProtoWriter; #[async_trait] -pub trait FrontendProtoWriter: OneWayProtoWriter {} +pub trait FrontendProtoWriter: OneWayProtoWriter { + async fn write_startup_message(&mut self, startup_message: StartupMessageData) -> Result<(), ProtoWriteError>; + async fn write_cancel_request(&mut self, cancel_request: CancelRequestData) -> Result<(), ProtoWriteError>; +} #[async_trait] -impl FrontendProtoWriter for W where W: OneWayProtoWriter {} +impl FrontendProtoWriter for ProtoWriter +where + W: AsyncWrite + Unpin + Send +{ + async fn write_startup_message(&mut self, startup_message: StartupMessageData) -> Result<(), ProtoWriteError> { + let data = startup_message.serialize()?; + let length = data.len() + 4; + + self.inner.write_i32(length as i32).await?; + self.inner.write_all(&data).await?; + + Ok(()) + } + + async fn write_cancel_request(&mut self, cancel_request: CancelRequestData) -> Result<(), ProtoWriteError> { + let data = cancel_request.serialize()?; + let length = data.len() + 4; + + self.inner.write_i32(length as i32).await?; + self.inner.write_all(&data).await?; + + Ok(()) + } +} #[cfg(test)] mod tests { @@ -37,4 +68,45 @@ mod tests { vec![b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4] ); } + + #[tokio::test] + async fn test_startup_message() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer.write_startup_message(StartupMessageData { + version: 196608, + params: vec![ + ("user".into(), "postgres".into()), + ("database".into(), "postgres".into()), + ], + }).await.unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![ + 0, 0, 0, 40, 0, 3, 0, 0, b'u', b's', b'e', b'r', 0, b'p', b'o', + b's', b't', b'g', b'r', b'e', b's', 0, b'd', b'a', b't', b'a', b'b', b'a', b's', + b'e', 0, b'p', b'o', b's', b't', b'g', b'r', b'e', b's', 0 + ] + ); + } + + #[tokio::test] + async fn test_cancel_request() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer.write_cancel_request(CancelRequestData { + pid: 123, + secret: 234, + }).await.unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![ + 0, 0, 0, 12, 0, 0, 0, 123, 0, 0, 0, 234 + ] + ); + } }