diff --git a/proto/src/reader/frontend.rs b/proto/src/reader/frontend.rs new file mode 100644 index 0000000..fc6ceca --- /dev/null +++ b/proto/src/reader/frontend.rs @@ -0,0 +1,279 @@ +use crate::message::frontend::FrontendMessage; +use crate::message::primitive::config::pg_proto_config; +use crate::message::special::{CancelRequestData, SpecialMessage}; +use crate::reader::oneway::OneWayProtoReader; +use crate::reader::protoreader::ProtoReader; +use crate::reader::utils::AsyncPeek; +use anyhow::anyhow; +use async_trait::async_trait; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +#[async_trait] +pub trait FrontendProtoReader: OneWayProtoReader { + async fn peek_special_message(&mut self) -> anyhow::Result>; + async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()>; +} + +#[async_trait] +impl FrontendProtoReader for ProtoReader +where + R: AsyncBufRead + Unpin + Send, +{ + async fn peek_special_message(&mut self) -> anyhow::Result> { + if let Some(cancel) = try_get_cancel_request(&mut self).await? { + return Ok(Some(cancel)); + } + + if let Some(ssl) = try_get_ssl_request(&mut self).await? { + return Ok(Some(ssl)); + } + + if let Some(startup) = try_get_startup_message(&mut self).await? { + return Ok(Some(startup)); + } + + Ok(None) + } + + async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()> { + Ok(match msg { + SpecialMessage::CancelRequest(_) => consume_cancel_request(self), + SpecialMessage::SSLRequest => consume_ssl_request(self), + SpecialMessage::StartupMessage(_) => consume_startup_message(self).await?, + }) + } +} + +async fn try_get_cancel_request( + reader: &mut ProtoReader, +) -> anyhow::Result> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 16]; + if reader.inner.peek(&mut header).await? != 16 { + return Ok(None); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]); + if length != 16 { + return Ok(None); + } + + let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); + if code != 80877102 { + return Ok(None); + } + + let pid = i32::from_be_bytes([header[8], header[9], header[10], header[11]]); + let secret = i32::from_be_bytes([header[12], header[13], header[14], header[15]]); + + Ok(Some(SpecialMessage::CancelRequest(CancelRequestData { + pid, + secret, + }))) +} + +fn consume_cancel_request(reader: &mut ProtoReader) +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + reader.inner.consume(16); +} + +async fn try_get_ssl_request( + reader: &mut ProtoReader, +) -> anyhow::Result> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 8]; + if reader.inner.peek(&mut header).await? != 8 { + return Ok(None); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]); + if length != 8 { + return Ok(None); + } + + let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); + if code != 80877103 { + return Ok(None); + } + + Ok(Some(SpecialMessage::SSLRequest)) +} + +fn consume_ssl_request(reader: &mut ProtoReader) +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + reader.inner.consume(8); +} + +async fn try_get_startup_message( + reader: &mut ProtoReader, +) -> anyhow::Result> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 8]; + if reader.inner.peek(&mut header).await? != 8 { + return Ok(None); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]); + if length < 8 { + return Ok(None); + } + if length > reader.msg_len_limit { + return Err(anyhow!("Message length is over the limit")); + } + + let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); + if version != 196608 { + return Ok(None); + } + + let length = length as usize; + let mut data = vec![0u8; length]; + if reader.inner.peek(&mut data).await? != length { + return Ok(None); + } + + let data = bincode::decode_from_slice(&data[4..], pg_proto_config())?.0; + + Ok(Some(SpecialMessage::StartupMessage(data))) +} + +async fn consume_startup_message(reader: &mut ProtoReader) -> anyhow::Result<()> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 4]; + if reader.inner.peek(&mut header).await? != 4 { + return Err(anyhow!("Invalid header peek length")); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize; + if length < 8 { + return Err(anyhow!("Invalid startup message length")); + } + + reader.inner.consume(length); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::frontend::QueryData; + use crate::message::special::StartupMessageData; + use std::io::Cursor; + use tokio::io::{AsyncBufReadExt, BufReader}; + + #[tokio::test] + async fn test_message_sequence() { + let data = [ + b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let msg = reader.read_proto().await; + assert!( + match &msg { + Ok(FrontendMessage::Query(QueryData { query })) => query.as_str() == "SLIME", + _ => false, + }, + "{msg:?}" + ); + + let msg = reader.read_proto().await; + assert!(matches!(msg, Ok(FrontendMessage::Terminate)), "{msg:?}"); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } + + #[tokio::test] + async fn test_cancel_request() { + let data = [ + 0, 0, 0, 16, 0x04, 0xD2, 0x16, 0x2E, 0, 0, 0, 111, 0, 0, 0, 222, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let peeked = reader.peek_special_message().await.unwrap(); + assert!(matches!( + peeked, + Some(SpecialMessage::CancelRequest(CancelRequestData { + pid: 111, + secret: 222 + })) + )); + + reader + .consume_special_message(&peeked.unwrap()) + .await + .unwrap(); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } + + #[tokio::test] + async fn test_ssl_request() { + let data = [0, 0, 0, 8, 0x04, 0xD2, 0x16, 0x2F]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let peeked = reader.peek_special_message().await.unwrap(); + assert!(matches!(peeked, Some(SpecialMessage::SSLRequest))); + + reader + .consume_special_message(&peeked.unwrap()) + .await + .unwrap(); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } + + #[tokio::test] + async fn test_startup_message() { + let data = [ + 0, 0, 0, 26, 0, 3, 0, 0, b'd', b'a', b't', b'a', b'b', b'a', b's', b'e', 0, b'b', b'r', + b'a', b'n', b'i', b'k', 0, 0, 0, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let peeked = reader.peek_special_message().await.unwrap(); + assert!(match &peeked { + Some(SpecialMessage::StartupMessage(StartupMessageData { + version: 196608, + params, + })) => + params.len() == 2 + && params[0].0.as_str() == "database" + && params[0].1.as_str() == "branik" + && params[1].0.as_str() == "" + && params[1].1.as_str() == "", + _ => false, + }); + + reader + .consume_special_message(&peeked.unwrap()) + .await + .unwrap(); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } +} diff --git a/proto/src/reader/mod.rs b/proto/src/reader/mod.rs index 450e7d4..53bf75c 100644 --- a/proto/src/reader/mod.rs +++ b/proto/src/reader/mod.rs @@ -1,3 +1,4 @@ +pub mod frontend; pub mod oneway; pub mod protoreader; mod utils;