diff --git a/proto/src/reader/errors.rs b/proto/src/reader/errors.rs new file mode 100644 index 0000000..5ecd8ae --- /dev/null +++ b/proto/src/reader/errors.rs @@ -0,0 +1,42 @@ +use thiserror::Error; +use tokio::io; +use crate::message::errors::{ProtoDeserializeError}; + +#[derive(Debug, Error)] +pub enum ProtoReadError { + #[error("message has invalid length, got {0}")] + InvalidLength(i32), + #[error("message has too much data, got {actual}, limit is {limit}")] + LengthOverflow { + limit: usize, + actual: usize + }, + #[error("reading from socket failed")] + Io(#[from] io::Error), + #[error("deserialization of inner data failed")] + Deserialize(#[from] ProtoDeserializeError), +} + +#[derive(Debug, Error)] +pub enum ProtoPeekError { + #[error("message has too much data, got {actual}, limit is {limit}")] + LengthOverflow { + limit: usize, + actual: usize + }, + #[error("reading from socket failed")] + Io(#[from] io::Error), + #[error("deserialization of inner data failed")] + Deserialize(#[from] ProtoDeserializeError), +} + +#[derive(Debug, Error)] +pub enum ProtoConsumeError { + #[error("unexpected data length, expected {expected}, got {actual}")] + UnexpectedDataLength { + expected: usize, + actual: usize + }, + #[error("reading from socket failed")] + Io(#[from] io::Error), +} diff --git a/proto/src/reader/frontend.rs b/proto/src/reader/frontend.rs index fc6ceca..45cf0db 100644 --- a/proto/src/reader/frontend.rs +++ b/proto/src/reader/frontend.rs @@ -1,17 +1,18 @@ use crate::message::frontend::FrontendMessage; -use crate::message::primitive::config::pg_proto_config; -use crate::message::special::{CancelRequestData, SpecialMessage}; +use crate::message::special::{CancelRequestData, SpecialMessage, StartupMessageData}; use crate::reader::oneway::OneWayProtoReader; use crate::reader::protoreader::ProtoReader; use crate::reader::utils::AsyncPeek; -use anyhow::anyhow; use async_trait::async_trait; +use tokio::io; use tokio::io::{AsyncBufRead, AsyncBufReadExt}; +use crate::message::primitive::data::MessageData; +use crate::reader::errors::{ProtoConsumeError, ProtoPeekError}; #[async_trait] pub trait FrontendProtoReader: OneWayProtoReader { - async fn peek_special_message(&mut self) -> anyhow::Result>; - async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()>; + async fn peek_special_message(&mut self) -> Result, ProtoPeekError>; + async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError>; } #[async_trait] @@ -19,7 +20,7 @@ impl FrontendProtoReader for ProtoReader where R: AsyncBufRead + Unpin + Send, { - async fn peek_special_message(&mut self) -> anyhow::Result> { + async fn peek_special_message(&mut self) -> Result, ProtoPeekError> { if let Some(cancel) = try_get_cancel_request(&mut self).await? { return Ok(Some(cancel)); } @@ -35,7 +36,7 @@ where Ok(None) } - async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()> { + async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError> { Ok(match msg { SpecialMessage::CancelRequest(_) => consume_cancel_request(self), SpecialMessage::SSLRequest => consume_ssl_request(self), @@ -46,7 +47,7 @@ where async fn try_get_cancel_request( reader: &mut ProtoReader, -) -> anyhow::Result> +) -> Result, io::Error> where R: AsyncBufRead + AsyncPeek + Unpin + Send, { @@ -83,7 +84,7 @@ where async fn try_get_ssl_request( reader: &mut ProtoReader, -) -> anyhow::Result> +) -> Result, io::Error> where R: AsyncBufRead + AsyncPeek + Unpin + Send, { @@ -114,7 +115,7 @@ where async fn try_get_startup_message( reader: &mut ProtoReader, -) -> anyhow::Result> +) -> Result, ProtoPeekError> where R: AsyncBufRead + AsyncPeek + Unpin + Send, { @@ -128,7 +129,10 @@ where return Ok(None); } if length > reader.msg_len_limit { - return Err(anyhow!("Message length is over the limit")); + return Err(ProtoPeekError::LengthOverflow { + limit: reader.msg_len_limit as usize, + actual: length as usize, + }); } let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); @@ -142,23 +146,29 @@ where return Ok(None); } - let data = bincode::decode_from_slice(&data[4..], pg_proto_config())?.0; - + let data = StartupMessageData::deserialize(&data[4..])?; Ok(Some(SpecialMessage::StartupMessage(data))) } -async fn consume_startup_message(reader: &mut ProtoReader) -> anyhow::Result<()> +async fn consume_startup_message(reader: &mut ProtoReader) -> Result<(), ProtoConsumeError> where R: AsyncBufRead + AsyncPeek + Unpin + Send, { let mut header = [0u8; 4]; - if reader.inner.peek(&mut header).await? != 4 { - return Err(anyhow!("Invalid header peek length")); + let size = reader.inner.peek(&mut header).await?; + if size != 4 { + return Err(ProtoConsumeError::UnexpectedDataLength { + expected: 4, + actual: size + }) } let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize; if length < 8 { - return Err(anyhow!("Invalid startup message length")); + return Err(ProtoConsumeError::UnexpectedDataLength { + expected: 8, + actual: length + }) } reader.inner.consume(length); diff --git a/proto/src/reader/mod.rs b/proto/src/reader/mod.rs index 3c85a0f..6f600c3 100644 --- a/proto/src/reader/mod.rs +++ b/proto/src/reader/mod.rs @@ -3,3 +3,4 @@ pub mod frontend; pub mod oneway; pub mod protoreader; mod utils; +pub mod errors; diff --git a/proto/src/reader/oneway.rs b/proto/src/reader/oneway.rs index 16508cb..11937d7 100644 --- a/proto/src/reader/oneway.rs +++ b/proto/src/reader/oneway.rs @@ -3,13 +3,14 @@ use crate::reader::protoreader::ProtoReader; use crate::reader::utils::AsyncPeek; use async_trait::async_trait; use tokio::io::{AsyncBufRead, AsyncReadExt}; +use crate::reader::errors::ProtoReadError; #[async_trait] pub trait OneWayProtoReader where T: ProtoMessage, { - async fn read_proto(&mut self) -> anyhow::Result; + async fn read_proto(&mut self) -> Result; } #[async_trait] @@ -18,15 +19,18 @@ where R: AsyncBufRead + AsyncPeek + Unpin + Send, T: ProtoMessage, { - async fn read_proto(&mut self) -> anyhow::Result { + async fn read_proto(&mut self) -> Result { let variant = self.inner.read_u8().await?; let length = self.inner.read_i32().await?; if length < 4 { - return Err(anyhow::anyhow!("Invalid message length")); + return Err(ProtoReadError::InvalidLength(length)); } if length > self.msg_len_limit { - return Err(anyhow::anyhow!("Message length over limit")); + return Err(ProtoReadError::LengthOverflow { + limit: self.msg_len_limit as usize, + actual: length as usize, + }); } let mut data = vec![0u8; (length - 4) as usize]; diff --git a/proto/src/reader/utils.rs b/proto/src/reader/utils.rs index e4a70eb..0ca8f85 100644 --- a/proto/src/reader/utils.rs +++ b/proto/src/reader/utils.rs @@ -3,7 +3,7 @@ use tokio::io::{AsyncBufRead, AsyncBufReadExt}; #[async_trait] pub trait AsyncPeek { - async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result; + async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result; } #[async_trait] @@ -11,7 +11,7 @@ impl AsyncPeek for T where T: AsyncBufRead + Unpin + Send, { - async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result { + async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result { let filled = self.fill_buf().await?; if filled.len() >= buf.len() { buf.copy_from_slice(&filled[..buf.len()]);