refactor(proto): replace anyhow with thiserror in readers
This commit is contained in:
parent
58c69928a1
commit
da6410ce05
5 changed files with 80 additions and 23 deletions
42
proto/src/reader/errors.rs
Normal file
42
proto/src/reader/errors.rs
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
use thiserror::Error;
|
||||
use tokio::io;
|
||||
use crate::message::errors::{ProtoDeserializeError};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProtoReadError {
|
||||
#[error("message has invalid length, got {0}")]
|
||||
InvalidLength(i32),
|
||||
#[error("message has too much data, got {actual}, limit is {limit}")]
|
||||
LengthOverflow {
|
||||
limit: usize,
|
||||
actual: usize
|
||||
},
|
||||
#[error("reading from socket failed")]
|
||||
Io(#[from] io::Error),
|
||||
#[error("deserialization of inner data failed")]
|
||||
Deserialize(#[from] ProtoDeserializeError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProtoPeekError {
|
||||
#[error("message has too much data, got {actual}, limit is {limit}")]
|
||||
LengthOverflow {
|
||||
limit: usize,
|
||||
actual: usize
|
||||
},
|
||||
#[error("reading from socket failed")]
|
||||
Io(#[from] io::Error),
|
||||
#[error("deserialization of inner data failed")]
|
||||
Deserialize(#[from] ProtoDeserializeError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProtoConsumeError {
|
||||
#[error("unexpected data length, expected {expected}, got {actual}")]
|
||||
UnexpectedDataLength {
|
||||
expected: usize,
|
||||
actual: usize
|
||||
},
|
||||
#[error("reading from socket failed")]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
|
@ -1,17 +1,18 @@
|
|||
use crate::message::frontend::FrontendMessage;
|
||||
use crate::message::primitive::config::pg_proto_config;
|
||||
use crate::message::special::{CancelRequestData, SpecialMessage};
|
||||
use crate::message::special::{CancelRequestData, SpecialMessage, StartupMessageData};
|
||||
use crate::reader::oneway::OneWayProtoReader;
|
||||
use crate::reader::protoreader::ProtoReader;
|
||||
use crate::reader::utils::AsyncPeek;
|
||||
use anyhow::anyhow;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io;
|
||||
use tokio::io::{AsyncBufRead, AsyncBufReadExt};
|
||||
use crate::message::primitive::data::MessageData;
|
||||
use crate::reader::errors::{ProtoConsumeError, ProtoPeekError};
|
||||
|
||||
#[async_trait]
|
||||
pub trait FrontendProtoReader: OneWayProtoReader<FrontendMessage> {
|
||||
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>>;
|
||||
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()>;
|
||||
async fn peek_special_message(&mut self) -> Result<Option<SpecialMessage>, ProtoPeekError>;
|
||||
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -19,7 +20,7 @@ impl<R> FrontendProtoReader for ProtoReader<R>
|
|||
where
|
||||
R: AsyncBufRead + Unpin + Send,
|
||||
{
|
||||
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>> {
|
||||
async fn peek_special_message(&mut self) -> Result<Option<SpecialMessage>, ProtoPeekError> {
|
||||
if let Some(cancel) = try_get_cancel_request(&mut self).await? {
|
||||
return Ok(Some(cancel));
|
||||
}
|
||||
|
|
@ -35,7 +36,7 @@ where
|
|||
Ok(None)
|
||||
}
|
||||
|
||||
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()> {
|
||||
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError> {
|
||||
Ok(match msg {
|
||||
SpecialMessage::CancelRequest(_) => consume_cancel_request(self),
|
||||
SpecialMessage::SSLRequest => consume_ssl_request(self),
|
||||
|
|
@ -46,7 +47,7 @@ where
|
|||
|
||||
async fn try_get_cancel_request<R>(
|
||||
reader: &mut ProtoReader<R>,
|
||||
) -> anyhow::Result<Option<SpecialMessage>>
|
||||
) -> Result<Option<SpecialMessage>, io::Error>
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
|
|
@ -83,7 +84,7 @@ where
|
|||
|
||||
async fn try_get_ssl_request<R>(
|
||||
reader: &mut ProtoReader<R>,
|
||||
) -> anyhow::Result<Option<SpecialMessage>>
|
||||
) -> Result<Option<SpecialMessage>, io::Error>
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
|
|
@ -114,7 +115,7 @@ where
|
|||
|
||||
async fn try_get_startup_message<R>(
|
||||
reader: &mut ProtoReader<R>,
|
||||
) -> anyhow::Result<Option<SpecialMessage>>
|
||||
) -> Result<Option<SpecialMessage>, ProtoPeekError>
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
|
|
@ -128,7 +129,10 @@ where
|
|||
return Ok(None);
|
||||
}
|
||||
if length > reader.msg_len_limit {
|
||||
return Err(anyhow!("Message length is over the limit"));
|
||||
return Err(ProtoPeekError::LengthOverflow {
|
||||
limit: reader.msg_len_limit as usize,
|
||||
actual: length as usize,
|
||||
});
|
||||
}
|
||||
|
||||
let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
|
||||
|
|
@ -142,23 +146,29 @@ where
|
|||
return Ok(None);
|
||||
}
|
||||
|
||||
let data = bincode::decode_from_slice(&data[4..], pg_proto_config())?.0;
|
||||
|
||||
let data = StartupMessageData::deserialize(&data[4..])?;
|
||||
Ok(Some(SpecialMessage::StartupMessage(data)))
|
||||
}
|
||||
|
||||
async fn consume_startup_message<R>(reader: &mut ProtoReader<R>) -> anyhow::Result<()>
|
||||
async fn consume_startup_message<R>(reader: &mut ProtoReader<R>) -> Result<(), ProtoConsumeError>
|
||||
where
|
||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
{
|
||||
let mut header = [0u8; 4];
|
||||
if reader.inner.peek(&mut header).await? != 4 {
|
||||
return Err(anyhow!("Invalid header peek length"));
|
||||
let size = reader.inner.peek(&mut header).await?;
|
||||
if size != 4 {
|
||||
return Err(ProtoConsumeError::UnexpectedDataLength {
|
||||
expected: 4,
|
||||
actual: size
|
||||
})
|
||||
}
|
||||
|
||||
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize;
|
||||
if length < 8 {
|
||||
return Err(anyhow!("Invalid startup message length"));
|
||||
return Err(ProtoConsumeError::UnexpectedDataLength {
|
||||
expected: 8,
|
||||
actual: length
|
||||
})
|
||||
}
|
||||
|
||||
reader.inner.consume(length);
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ pub mod frontend;
|
|||
pub mod oneway;
|
||||
pub mod protoreader;
|
||||
mod utils;
|
||||
pub mod errors;
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ use crate::reader::protoreader::ProtoReader;
|
|||
use crate::reader::utils::AsyncPeek;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncBufRead, AsyncReadExt};
|
||||
use crate::reader::errors::ProtoReadError;
|
||||
|
||||
#[async_trait]
|
||||
pub trait OneWayProtoReader<T>
|
||||
where
|
||||
T: ProtoMessage,
|
||||
{
|
||||
async fn read_proto(&mut self) -> anyhow::Result<T>;
|
||||
async fn read_proto(&mut self) -> Result<T, ProtoReadError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -18,15 +19,18 @@ where
|
|||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||
T: ProtoMessage,
|
||||
{
|
||||
async fn read_proto(&mut self) -> anyhow::Result<T> {
|
||||
async fn read_proto(&mut self) -> Result<T, ProtoReadError> {
|
||||
let variant = self.inner.read_u8().await?;
|
||||
let length = self.inner.read_i32().await?;
|
||||
|
||||
if length < 4 {
|
||||
return Err(anyhow::anyhow!("Invalid message length"));
|
||||
return Err(ProtoReadError::InvalidLength(length));
|
||||
}
|
||||
if length > self.msg_len_limit {
|
||||
return Err(anyhow::anyhow!("Message length over limit"));
|
||||
return Err(ProtoReadError::LengthOverflow {
|
||||
limit: self.msg_len_limit as usize,
|
||||
actual: length as usize,
|
||||
});
|
||||
}
|
||||
|
||||
let mut data = vec![0u8; (length - 4) as usize];
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use tokio::io::{AsyncBufRead, AsyncBufReadExt};
|
|||
|
||||
#[async_trait]
|
||||
pub trait AsyncPeek {
|
||||
async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize>;
|
||||
async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -11,7 +11,7 @@ impl<T> AsyncPeek for T
|
|||
where
|
||||
T: AsyncBufRead + Unpin + Send,
|
||||
{
|
||||
async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize> {
|
||||
let filled = self.fill_buf().await?;
|
||||
if filled.len() >= buf.len() {
|
||||
buf.copy_from_slice(&filled[..buf.len()]);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue