refactor(proto): replace anyhow with thiserror in readers

This commit is contained in:
Jindřich Moravec 2023-12-15 16:21:51 +01:00
parent 58c69928a1
commit da6410ce05
5 changed files with 80 additions and 23 deletions

View file

@ -0,0 +1,42 @@
use thiserror::Error;
use tokio::io;
use crate::message::errors::{ProtoDeserializeError};
#[derive(Debug, Error)]
pub enum ProtoReadError {
#[error("message has invalid length, got {0}")]
InvalidLength(i32),
#[error("message has too much data, got {actual}, limit is {limit}")]
LengthOverflow {
limit: usize,
actual: usize
},
#[error("reading from socket failed")]
Io(#[from] io::Error),
#[error("deserialization of inner data failed")]
Deserialize(#[from] ProtoDeserializeError),
}
#[derive(Debug, Error)]
pub enum ProtoPeekError {
#[error("message has too much data, got {actual}, limit is {limit}")]
LengthOverflow {
limit: usize,
actual: usize
},
#[error("reading from socket failed")]
Io(#[from] io::Error),
#[error("deserialization of inner data failed")]
Deserialize(#[from] ProtoDeserializeError),
}
#[derive(Debug, Error)]
pub enum ProtoConsumeError {
#[error("unexpected data length, expected {expected}, got {actual}")]
UnexpectedDataLength {
expected: usize,
actual: usize
},
#[error("reading from socket failed")]
Io(#[from] io::Error),
}

View file

@ -1,17 +1,18 @@
use crate::message::frontend::FrontendMessage;
use crate::message::primitive::config::pg_proto_config;
use crate::message::special::{CancelRequestData, SpecialMessage};
use crate::message::special::{CancelRequestData, SpecialMessage, StartupMessageData};
use crate::reader::oneway::OneWayProtoReader;
use crate::reader::protoreader::ProtoReader;
use crate::reader::utils::AsyncPeek;
use anyhow::anyhow;
use async_trait::async_trait;
use tokio::io;
use tokio::io::{AsyncBufRead, AsyncBufReadExt};
use crate::message::primitive::data::MessageData;
use crate::reader::errors::{ProtoConsumeError, ProtoPeekError};
#[async_trait]
pub trait FrontendProtoReader: OneWayProtoReader<FrontendMessage> {
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>>;
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()>;
async fn peek_special_message(&mut self) -> Result<Option<SpecialMessage>, ProtoPeekError>;
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError>;
}
#[async_trait]
@ -19,7 +20,7 @@ impl<R> FrontendProtoReader for ProtoReader<R>
where
R: AsyncBufRead + Unpin + Send,
{
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>> {
async fn peek_special_message(&mut self) -> Result<Option<SpecialMessage>, ProtoPeekError> {
if let Some(cancel) = try_get_cancel_request(&mut self).await? {
return Ok(Some(cancel));
}
@ -35,7 +36,7 @@ where
Ok(None)
}
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()> {
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError> {
Ok(match msg {
SpecialMessage::CancelRequest(_) => consume_cancel_request(self),
SpecialMessage::SSLRequest => consume_ssl_request(self),
@ -46,7 +47,7 @@ where
async fn try_get_cancel_request<R>(
reader: &mut ProtoReader<R>,
) -> anyhow::Result<Option<SpecialMessage>>
) -> Result<Option<SpecialMessage>, io::Error>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
@ -83,7 +84,7 @@ where
async fn try_get_ssl_request<R>(
reader: &mut ProtoReader<R>,
) -> anyhow::Result<Option<SpecialMessage>>
) -> Result<Option<SpecialMessage>, io::Error>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
@ -114,7 +115,7 @@ where
async fn try_get_startup_message<R>(
reader: &mut ProtoReader<R>,
) -> anyhow::Result<Option<SpecialMessage>>
) -> Result<Option<SpecialMessage>, ProtoPeekError>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
@ -128,7 +129,10 @@ where
return Ok(None);
}
if length > reader.msg_len_limit {
return Err(anyhow!("Message length is over the limit"));
return Err(ProtoPeekError::LengthOverflow {
limit: reader.msg_len_limit as usize,
actual: length as usize,
});
}
let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
@ -142,23 +146,29 @@ where
return Ok(None);
}
let data = bincode::decode_from_slice(&data[4..], pg_proto_config())?.0;
let data = StartupMessageData::deserialize(&data[4..])?;
Ok(Some(SpecialMessage::StartupMessage(data)))
}
async fn consume_startup_message<R>(reader: &mut ProtoReader<R>) -> anyhow::Result<()>
async fn consume_startup_message<R>(reader: &mut ProtoReader<R>) -> Result<(), ProtoConsumeError>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
let mut header = [0u8; 4];
if reader.inner.peek(&mut header).await? != 4 {
return Err(anyhow!("Invalid header peek length"));
let size = reader.inner.peek(&mut header).await?;
if size != 4 {
return Err(ProtoConsumeError::UnexpectedDataLength {
expected: 4,
actual: size
})
}
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize;
if length < 8 {
return Err(anyhow!("Invalid startup message length"));
return Err(ProtoConsumeError::UnexpectedDataLength {
expected: 8,
actual: length
})
}
reader.inner.consume(length);

View file

@ -3,3 +3,4 @@ pub mod frontend;
pub mod oneway;
pub mod protoreader;
mod utils;
pub mod errors;

View file

@ -3,13 +3,14 @@ use crate::reader::protoreader::ProtoReader;
use crate::reader::utils::AsyncPeek;
use async_trait::async_trait;
use tokio::io::{AsyncBufRead, AsyncReadExt};
use crate::reader::errors::ProtoReadError;
#[async_trait]
pub trait OneWayProtoReader<T>
where
T: ProtoMessage,
{
async fn read_proto(&mut self) -> anyhow::Result<T>;
async fn read_proto(&mut self) -> Result<T, ProtoReadError>;
}
#[async_trait]
@ -18,15 +19,18 @@ where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
T: ProtoMessage,
{
async fn read_proto(&mut self) -> anyhow::Result<T> {
async fn read_proto(&mut self) -> Result<T, ProtoReadError> {
let variant = self.inner.read_u8().await?;
let length = self.inner.read_i32().await?;
if length < 4 {
return Err(anyhow::anyhow!("Invalid message length"));
return Err(ProtoReadError::InvalidLength(length));
}
if length > self.msg_len_limit {
return Err(anyhow::anyhow!("Message length over limit"));
return Err(ProtoReadError::LengthOverflow {
limit: self.msg_len_limit as usize,
actual: length as usize,
});
}
let mut data = vec![0u8; (length - 4) as usize];

View file

@ -3,7 +3,7 @@ use tokio::io::{AsyncBufRead, AsyncBufReadExt};
#[async_trait]
pub trait AsyncPeek {
async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize>;
async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize>;
}
#[async_trait]
@ -11,7 +11,7 @@ impl<T> AsyncPeek for T
where
T: AsyncBufRead + Unpin + Send,
{
async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize> {
let filled = self.fill_buf().await?;
if filled.len() >= buf.len() {
buf.copy_from_slice(&filled[..buf.len()]);