feat(proto): add frontend message reader

This commit is contained in:
Jindřich Moravec 2023-12-11 16:54:02 +01:00
parent 413e0216e3
commit 0a6e486005
2 changed files with 280 additions and 0 deletions

View file

@ -0,0 +1,279 @@
use crate::message::frontend::FrontendMessage;
use crate::message::primitive::config::pg_proto_config;
use crate::message::special::{CancelRequestData, SpecialMessage};
use crate::reader::oneway::OneWayProtoReader;
use crate::reader::protoreader::ProtoReader;
use crate::reader::utils::AsyncPeek;
use anyhow::anyhow;
use async_trait::async_trait;
use tokio::io::{AsyncBufRead, AsyncBufReadExt};
#[async_trait]
pub trait FrontendProtoReader: OneWayProtoReader<FrontendMessage> {
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>>;
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()>;
}
#[async_trait]
impl<R> FrontendProtoReader for ProtoReader<R>
where
R: AsyncBufRead + Unpin + Send,
{
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>> {
if let Some(cancel) = try_get_cancel_request(&mut self).await? {
return Ok(Some(cancel));
}
if let Some(ssl) = try_get_ssl_request(&mut self).await? {
return Ok(Some(ssl));
}
if let Some(startup) = try_get_startup_message(&mut self).await? {
return Ok(Some(startup));
}
Ok(None)
}
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()> {
Ok(match msg {
SpecialMessage::CancelRequest(_) => consume_cancel_request(self),
SpecialMessage::SSLRequest => consume_ssl_request(self),
SpecialMessage::StartupMessage(_) => consume_startup_message(self).await?,
})
}
}
async fn try_get_cancel_request<R>(
reader: &mut ProtoReader<R>,
) -> anyhow::Result<Option<SpecialMessage>>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
let mut header = [0u8; 16];
if reader.inner.peek(&mut header).await? != 16 {
return Ok(None);
}
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]);
if length != 16 {
return Ok(None);
}
let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
if code != 80877102 {
return Ok(None);
}
let pid = i32::from_be_bytes([header[8], header[9], header[10], header[11]]);
let secret = i32::from_be_bytes([header[12], header[13], header[14], header[15]]);
Ok(Some(SpecialMessage::CancelRequest(CancelRequestData {
pid,
secret,
})))
}
fn consume_cancel_request<R>(reader: &mut ProtoReader<R>)
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
reader.inner.consume(16);
}
async fn try_get_ssl_request<R>(
reader: &mut ProtoReader<R>,
) -> anyhow::Result<Option<SpecialMessage>>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
let mut header = [0u8; 8];
if reader.inner.peek(&mut header).await? != 8 {
return Ok(None);
}
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]);
if length != 8 {
return Ok(None);
}
let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
if code != 80877103 {
return Ok(None);
}
Ok(Some(SpecialMessage::SSLRequest))
}
fn consume_ssl_request<R>(reader: &mut ProtoReader<R>)
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
reader.inner.consume(8);
}
async fn try_get_startup_message<R>(
reader: &mut ProtoReader<R>,
) -> anyhow::Result<Option<SpecialMessage>>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
let mut header = [0u8; 8];
if reader.inner.peek(&mut header).await? != 8 {
return Ok(None);
}
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]);
if length < 8 {
return Ok(None);
}
if length > reader.msg_len_limit {
return Err(anyhow!("Message length is over the limit"));
}
let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
if version != 196608 {
return Ok(None);
}
let length = length as usize;
let mut data = vec![0u8; length];
if reader.inner.peek(&mut data).await? != length {
return Ok(None);
}
let data = bincode::decode_from_slice(&data[4..], pg_proto_config())?.0;
Ok(Some(SpecialMessage::StartupMessage(data)))
}
async fn consume_startup_message<R>(reader: &mut ProtoReader<R>) -> anyhow::Result<()>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
let mut header = [0u8; 4];
if reader.inner.peek(&mut header).await? != 4 {
return Err(anyhow!("Invalid header peek length"));
}
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize;
if length < 8 {
return Err(anyhow!("Invalid startup message length"));
}
reader.inner.consume(length);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::frontend::QueryData;
use crate::message::special::StartupMessageData;
use std::io::Cursor;
use tokio::io::{AsyncBufReadExt, BufReader};
#[tokio::test]
async fn test_message_sequence() {
let data = [
b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4,
];
let reader = BufReader::new(Cursor::new(&data));
let mut reader = ProtoReader::new(reader, 1024);
let msg = reader.read_proto().await;
assert!(
match &msg {
Ok(FrontendMessage::Query(QueryData { query })) => query.as_str() == "SLIME",
_ => false,
},
"{msg:?}"
);
let msg = reader.read_proto().await;
assert!(matches!(msg, Ok(FrontendMessage::Terminate)), "{msg:?}");
let rest = reader.inner.fill_buf().await.unwrap();
assert!(rest.is_empty());
}
#[tokio::test]
async fn test_cancel_request() {
let data = [
0, 0, 0, 16, 0x04, 0xD2, 0x16, 0x2E, 0, 0, 0, 111, 0, 0, 0, 222,
];
let reader = BufReader::new(Cursor::new(&data));
let mut reader = ProtoReader::new(reader, 1024);
let peeked = reader.peek_special_message().await.unwrap();
assert!(matches!(
peeked,
Some(SpecialMessage::CancelRequest(CancelRequestData {
pid: 111,
secret: 222
}))
));
reader
.consume_special_message(&peeked.unwrap())
.await
.unwrap();
let rest = reader.inner.fill_buf().await.unwrap();
assert!(rest.is_empty());
}
#[tokio::test]
async fn test_ssl_request() {
let data = [0, 0, 0, 8, 0x04, 0xD2, 0x16, 0x2F];
let reader = BufReader::new(Cursor::new(&data));
let mut reader = ProtoReader::new(reader, 1024);
let peeked = reader.peek_special_message().await.unwrap();
assert!(matches!(peeked, Some(SpecialMessage::SSLRequest)));
reader
.consume_special_message(&peeked.unwrap())
.await
.unwrap();
let rest = reader.inner.fill_buf().await.unwrap();
assert!(rest.is_empty());
}
#[tokio::test]
async fn test_startup_message() {
let data = [
0, 0, 0, 26, 0, 3, 0, 0, b'd', b'a', b't', b'a', b'b', b'a', b's', b'e', 0, b'b', b'r',
b'a', b'n', b'i', b'k', 0, 0, 0,
];
let reader = BufReader::new(Cursor::new(&data));
let mut reader = ProtoReader::new(reader, 1024);
let peeked = reader.peek_special_message().await.unwrap();
assert!(match &peeked {
Some(SpecialMessage::StartupMessage(StartupMessageData {
version: 196608,
params,
})) =>
params.len() == 2
&& params[0].0.as_str() == "database"
&& params[0].1.as_str() == "branik"
&& params[1].0.as_str() == ""
&& params[1].1.as_str() == "",
_ => false,
});
reader
.consume_special_message(&peeked.unwrap())
.await
.unwrap();
let rest = reader.inner.fill_buf().await.unwrap();
assert!(rest.is_empty());
}
}

View file

@ -1,3 +1,4 @@
pub mod frontend;
pub mod oneway; pub mod oneway;
pub mod protoreader; pub mod protoreader;
mod utils; mod utils;