minisql/proto/src/writer/frontend.rs
2023-12-23 01:28:30 +01:00

128 lines
3.5 KiB
Rust

use crate::message::frontend::FrontendMessage;
use crate::message::primitive::data::MessageData;
use crate::message::special::{CancelRequestData, StartupMessageData};
use crate::writer::errors::ProtoWriteError;
use crate::writer::oneway::OneWayProtoWriter;
use crate::writer::protowriter::ProtoWriter;
use async_trait::async_trait;
use tokio::io::{AsyncWrite, AsyncWriteExt};
#[async_trait]
pub trait FrontendProtoWriter: OneWayProtoWriter<FrontendMessage> {
async fn write_startup_message(
&mut self,
startup_message: StartupMessageData,
) -> Result<(), ProtoWriteError>;
async fn write_cancel_request(
&mut self,
cancel_request: CancelRequestData,
) -> Result<(), ProtoWriteError>;
}
#[async_trait]
impl<W> FrontendProtoWriter for ProtoWriter<W>
where
W: AsyncWrite + Unpin + Send,
{
async fn write_startup_message(
&mut self,
startup_message: StartupMessageData,
) -> Result<(), ProtoWriteError> {
let data = startup_message.serialize()?;
let length = data.len() + 4;
self.inner.write_i32(length as i32).await?;
self.inner.write_all(&data).await?;
Ok(())
}
async fn write_cancel_request(
&mut self,
cancel_request: CancelRequestData,
) -> Result<(), ProtoWriteError> {
let data = cancel_request.serialize()?;
let length = data.len() + 4;
self.inner.write_i32(length as i32).await?;
self.inner.write_all(&data).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::frontend::QueryData;
use crate::writer::protowriter::ProtoWriter;
use tokio::io::BufWriter;
#[tokio::test]
async fn test_message_sequence() {
let writer = BufWriter::new(Vec::new());
let mut writer = ProtoWriter::new(writer);
writer
.write_proto(FrontendMessage::Query(QueryData {
query: "SLIME".into(),
}))
.await
.unwrap();
writer
.write_proto(FrontendMessage::Terminate)
.await
.unwrap();
assert_eq!(
writer.inner.buffer(),
vec![b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4]
);
}
#[tokio::test]
async fn test_startup_message() {
let writer = BufWriter::new(Vec::new());
let mut writer = ProtoWriter::new(writer);
writer
.write_startup_message(StartupMessageData {
version: 196608,
params: vec![
("user".into(), "postgres".into()),
("database".into(), "postgres".into()),
],
})
.await
.unwrap();
assert_eq!(
writer.inner.buffer(),
vec![
0, 0, 0, 40, 0, 3, 0, 0, b'u', b's', b'e', b'r', 0, b'p', b'o', b's', b't', b'g',
b'r', b'e', b's', 0, b'd', b'a', b't', b'a', b'b', b'a', b's', b'e', 0, b'p', b'o',
b's', b't', b'g', b'r', b'e', b's', 0
]
);
}
#[tokio::test]
async fn test_cancel_request() {
let writer = BufWriter::new(Vec::new());
let mut writer = ProtoWriter::new(writer);
writer
.write_cancel_request(CancelRequestData {
pid: 123,
secret: 234,
})
.await
.unwrap();
assert_eq!(
writer.inner.buffer(),
vec![0, 0, 0, 12, 0, 0, 0, 123, 0, 0, 0, 234]
);
}
}