feat(proto): add frontend message reader
This commit is contained in:
parent
413e0216e3
commit
0a6e486005
2 changed files with 280 additions and 0 deletions
279
proto/src/reader/frontend.rs
Normal file
279
proto/src/reader/frontend.rs
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
use crate::message::frontend::FrontendMessage;
|
||||
use crate::message::primitive::config::pg_proto_config;
|
||||
use crate::message::special::{CancelRequestData, SpecialMessage};
|
||||
use crate::reader::oneway::OneWayProtoReader;
|
||||
use crate::reader::protoreader::ProtoReader;
|
||||
use crate::reader::utils::AsyncPeek;
|
||||
use anyhow::anyhow;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncBufRead, AsyncBufReadExt};
|
||||
|
||||
#[async_trait]
|
||||
pub trait FrontendProtoReader: OneWayProtoReader<FrontendMessage> {
|
||||
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>>;
|
||||
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<R> FrontendProtoReader for ProtoReader<R>
|
||||
where
|
||||
R: AsyncBufRead + Unpin + Send,
|
||||
{
|
||||
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>> {
|
||||
if let Some(cancel) = try_get_cancel_request(&mut self).await? {
|
||||
return Ok(Some(cancel));
|
||||
}
|
||||
|
||||
if let Some(ssl) = try_get_ssl_request(&mut self).await? {
|
||||
return Ok(Some(ssl));
|
||||
}
|
||||
|
||||
if let Some(startup) = try_get_startup_message(&mut self).await? {
|
||||
return Ok(Some(startup));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()> {
|
||||
Ok(match msg {
|
||||
SpecialMessage::CancelRequest(_) => consume_cancel_request(self),
|
||||
SpecialMessage::SSLRequest => consume_ssl_request(self),
|
||||
SpecialMessage::StartupMessage(_) => consume_startup_message(self).await?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_get_cancel_request<R>(
|
||||
reader: &mut ProtoReader<R>,
|
||||
) -> anyhow::Result<Option<SpecialMessage>>
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
let mut header = [0u8; 16];
|
||||
if reader.inner.peek(&mut header).await? != 16 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]);
|
||||
if length != 16 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
|
||||
if code != 80877102 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let pid = i32::from_be_bytes([header[8], header[9], header[10], header[11]]);
|
||||
let secret = i32::from_be_bytes([header[12], header[13], header[14], header[15]]);
|
||||
|
||||
Ok(Some(SpecialMessage::CancelRequest(CancelRequestData {
|
||||
pid,
|
||||
secret,
|
||||
})))
|
||||
}
|
||||
|
||||
fn consume_cancel_request<R>(reader: &mut ProtoReader<R>)
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
reader.inner.consume(16);
|
||||
}
|
||||
|
||||
async fn try_get_ssl_request<R>(
|
||||
reader: &mut ProtoReader<R>,
|
||||
) -> anyhow::Result<Option<SpecialMessage>>
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
let mut header = [0u8; 8];
|
||||
if reader.inner.peek(&mut header).await? != 8 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]);
|
||||
if length != 8 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
|
||||
if code != 80877103 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(SpecialMessage::SSLRequest))
|
||||
}
|
||||
|
||||
fn consume_ssl_request<R>(reader: &mut ProtoReader<R>)
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
reader.inner.consume(8);
|
||||
}
|
||||
|
||||
async fn try_get_startup_message<R>(
|
||||
reader: &mut ProtoReader<R>,
|
||||
) -> anyhow::Result<Option<SpecialMessage>>
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
let mut header = [0u8; 8];
|
||||
if reader.inner.peek(&mut header).await? != 8 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]);
|
||||
if length < 8 {
|
||||
return Ok(None);
|
||||
}
|
||||
if length > reader.msg_len_limit {
|
||||
return Err(anyhow!("Message length is over the limit"));
|
||||
}
|
||||
|
||||
let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
|
||||
if version != 196608 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let length = length as usize;
|
||||
let mut data = vec![0u8; length];
|
||||
if reader.inner.peek(&mut data).await? != length {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let data = bincode::decode_from_slice(&data[4..], pg_proto_config())?.0;
|
||||
|
||||
Ok(Some(SpecialMessage::StartupMessage(data)))
|
||||
}
|
||||
|
||||
async fn consume_startup_message<R>(reader: &mut ProtoReader<R>) -> anyhow::Result<()>
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
let mut header = [0u8; 4];
|
||||
if reader.inner.peek(&mut header).await? != 4 {
|
||||
return Err(anyhow!("Invalid header peek length"));
|
||||
}
|
||||
|
||||
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize;
|
||||
if length < 8 {
|
||||
return Err(anyhow!("Invalid startup message length"));
|
||||
}
|
||||
|
||||
reader.inner.consume(length);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::message::frontend::QueryData;
|
||||
use crate::message::special::StartupMessageData;
|
||||
use std::io::Cursor;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_sequence() {
|
||||
let data = [
|
||||
b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4,
|
||||
];
|
||||
|
||||
let reader = BufReader::new(Cursor::new(&data));
|
||||
let mut reader = ProtoReader::new(reader, 1024);
|
||||
|
||||
let msg = reader.read_proto().await;
|
||||
assert!(
|
||||
match &msg {
|
||||
Ok(FrontendMessage::Query(QueryData { query })) => query.as_str() == "SLIME",
|
||||
_ => false,
|
||||
},
|
||||
"{msg:?}"
|
||||
);
|
||||
|
||||
let msg = reader.read_proto().await;
|
||||
assert!(matches!(msg, Ok(FrontendMessage::Terminate)), "{msg:?}");
|
||||
|
||||
let rest = reader.inner.fill_buf().await.unwrap();
|
||||
assert!(rest.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cancel_request() {
|
||||
let data = [
|
||||
0, 0, 0, 16, 0x04, 0xD2, 0x16, 0x2E, 0, 0, 0, 111, 0, 0, 0, 222,
|
||||
];
|
||||
|
||||
let reader = BufReader::new(Cursor::new(&data));
|
||||
let mut reader = ProtoReader::new(reader, 1024);
|
||||
|
||||
let peeked = reader.peek_special_message().await.unwrap();
|
||||
assert!(matches!(
|
||||
peeked,
|
||||
Some(SpecialMessage::CancelRequest(CancelRequestData {
|
||||
pid: 111,
|
||||
secret: 222
|
||||
}))
|
||||
));
|
||||
|
||||
reader
|
||||
.consume_special_message(&peeked.unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let rest = reader.inner.fill_buf().await.unwrap();
|
||||
assert!(rest.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ssl_request() {
|
||||
let data = [0, 0, 0, 8, 0x04, 0xD2, 0x16, 0x2F];
|
||||
|
||||
let reader = BufReader::new(Cursor::new(&data));
|
||||
let mut reader = ProtoReader::new(reader, 1024);
|
||||
|
||||
let peeked = reader.peek_special_message().await.unwrap();
|
||||
assert!(matches!(peeked, Some(SpecialMessage::SSLRequest)));
|
||||
|
||||
reader
|
||||
.consume_special_message(&peeked.unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let rest = reader.inner.fill_buf().await.unwrap();
|
||||
assert!(rest.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_startup_message() {
|
||||
let data = [
|
||||
0, 0, 0, 26, 0, 3, 0, 0, b'd', b'a', b't', b'a', b'b', b'a', b's', b'e', 0, b'b', b'r',
|
||||
b'a', b'n', b'i', b'k', 0, 0, 0,
|
||||
];
|
||||
|
||||
let reader = BufReader::new(Cursor::new(&data));
|
||||
let mut reader = ProtoReader::new(reader, 1024);
|
||||
|
||||
let peeked = reader.peek_special_message().await.unwrap();
|
||||
assert!(match &peeked {
|
||||
Some(SpecialMessage::StartupMessage(StartupMessageData {
|
||||
version: 196608,
|
||||
params,
|
||||
})) =>
|
||||
params.len() == 2
|
||||
&& params[0].0.as_str() == "database"
|
||||
&& params[0].1.as_str() == "branik"
|
||||
&& params[1].0.as_str() == ""
|
||||
&& params[1].1.as_str() == "",
|
||||
_ => false,
|
||||
});
|
||||
|
||||
reader
|
||||
.consume_special_message(&peeked.unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let rest = reader.inner.fill_buf().await.unwrap();
|
||||
assert!(rest.is_empty());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
pub mod frontend;
|
||||
pub mod oneway;
|
||||
pub mod protoreader;
|
||||
mod utils;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue