refactor(proto): replace anyhow with thiserror in readers
This commit is contained in:
parent
58c69928a1
commit
da6410ce05
5 changed files with 80 additions and 23 deletions
42
proto/src/reader/errors.rs
Normal file
42
proto/src/reader/errors.rs
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io;
|
||||||
|
use crate::message::errors::{ProtoDeserializeError};
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ProtoReadError {
|
||||||
|
#[error("message has invalid length, got {0}")]
|
||||||
|
InvalidLength(i32),
|
||||||
|
#[error("message has too much data, got {actual}, limit is {limit}")]
|
||||||
|
LengthOverflow {
|
||||||
|
limit: usize,
|
||||||
|
actual: usize
|
||||||
|
},
|
||||||
|
#[error("reading from socket failed")]
|
||||||
|
Io(#[from] io::Error),
|
||||||
|
#[error("deserialization of inner data failed")]
|
||||||
|
Deserialize(#[from] ProtoDeserializeError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ProtoPeekError {
|
||||||
|
#[error("message has too much data, got {actual}, limit is {limit}")]
|
||||||
|
LengthOverflow {
|
||||||
|
limit: usize,
|
||||||
|
actual: usize
|
||||||
|
},
|
||||||
|
#[error("reading from socket failed")]
|
||||||
|
Io(#[from] io::Error),
|
||||||
|
#[error("deserialization of inner data failed")]
|
||||||
|
Deserialize(#[from] ProtoDeserializeError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ProtoConsumeError {
|
||||||
|
#[error("unexpected data length, expected {expected}, got {actual}")]
|
||||||
|
UnexpectedDataLength {
|
||||||
|
expected: usize,
|
||||||
|
actual: usize
|
||||||
|
},
|
||||||
|
#[error("reading from socket failed")]
|
||||||
|
Io(#[from] io::Error),
|
||||||
|
}
|
||||||
|
|
@ -1,17 +1,18 @@
|
||||||
use crate::message::frontend::FrontendMessage;
|
use crate::message::frontend::FrontendMessage;
|
||||||
use crate::message::primitive::config::pg_proto_config;
|
use crate::message::special::{CancelRequestData, SpecialMessage, StartupMessageData};
|
||||||
use crate::message::special::{CancelRequestData, SpecialMessage};
|
|
||||||
use crate::reader::oneway::OneWayProtoReader;
|
use crate::reader::oneway::OneWayProtoReader;
|
||||||
use crate::reader::protoreader::ProtoReader;
|
use crate::reader::protoreader::ProtoReader;
|
||||||
use crate::reader::utils::AsyncPeek;
|
use crate::reader::utils::AsyncPeek;
|
||||||
use anyhow::anyhow;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use tokio::io;
|
||||||
use tokio::io::{AsyncBufRead, AsyncBufReadExt};
|
use tokio::io::{AsyncBufRead, AsyncBufReadExt};
|
||||||
|
use crate::message::primitive::data::MessageData;
|
||||||
|
use crate::reader::errors::{ProtoConsumeError, ProtoPeekError};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait FrontendProtoReader: OneWayProtoReader<FrontendMessage> {
|
pub trait FrontendProtoReader: OneWayProtoReader<FrontendMessage> {
|
||||||
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>>;
|
async fn peek_special_message(&mut self) -> Result<Option<SpecialMessage>, ProtoPeekError>;
|
||||||
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()>;
|
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -19,7 +20,7 @@ impl<R> FrontendProtoReader for ProtoReader<R>
|
||||||
where
|
where
|
||||||
R: AsyncBufRead + Unpin + Send,
|
R: AsyncBufRead + Unpin + Send,
|
||||||
{
|
{
|
||||||
async fn peek_special_message(&mut self) -> anyhow::Result<Option<SpecialMessage>> {
|
async fn peek_special_message(&mut self) -> Result<Option<SpecialMessage>, ProtoPeekError> {
|
||||||
if let Some(cancel) = try_get_cancel_request(&mut self).await? {
|
if let Some(cancel) = try_get_cancel_request(&mut self).await? {
|
||||||
return Ok(Some(cancel));
|
return Ok(Some(cancel));
|
||||||
}
|
}
|
||||||
|
|
@ -35,7 +36,7 @@ where
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> anyhow::Result<()> {
|
async fn consume_special_message(&mut self, msg: &SpecialMessage) -> Result<(), ProtoConsumeError> {
|
||||||
Ok(match msg {
|
Ok(match msg {
|
||||||
SpecialMessage::CancelRequest(_) => consume_cancel_request(self),
|
SpecialMessage::CancelRequest(_) => consume_cancel_request(self),
|
||||||
SpecialMessage::SSLRequest => consume_ssl_request(self),
|
SpecialMessage::SSLRequest => consume_ssl_request(self),
|
||||||
|
|
@ -46,7 +47,7 @@ where
|
||||||
|
|
||||||
async fn try_get_cancel_request<R>(
|
async fn try_get_cancel_request<R>(
|
||||||
reader: &mut ProtoReader<R>,
|
reader: &mut ProtoReader<R>,
|
||||||
) -> anyhow::Result<Option<SpecialMessage>>
|
) -> Result<Option<SpecialMessage>, io::Error>
|
||||||
where
|
where
|
||||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||||
{
|
{
|
||||||
|
|
@ -83,7 +84,7 @@ where
|
||||||
|
|
||||||
async fn try_get_ssl_request<R>(
|
async fn try_get_ssl_request<R>(
|
||||||
reader: &mut ProtoReader<R>,
|
reader: &mut ProtoReader<R>,
|
||||||
) -> anyhow::Result<Option<SpecialMessage>>
|
) -> Result<Option<SpecialMessage>, io::Error>
|
||||||
where
|
where
|
||||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||||
{
|
{
|
||||||
|
|
@ -114,7 +115,7 @@ where
|
||||||
|
|
||||||
async fn try_get_startup_message<R>(
|
async fn try_get_startup_message<R>(
|
||||||
reader: &mut ProtoReader<R>,
|
reader: &mut ProtoReader<R>,
|
||||||
) -> anyhow::Result<Option<SpecialMessage>>
|
) -> Result<Option<SpecialMessage>, ProtoPeekError>
|
||||||
where
|
where
|
||||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||||
{
|
{
|
||||||
|
|
@ -128,7 +129,10 @@ where
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
if length > reader.msg_len_limit {
|
if length > reader.msg_len_limit {
|
||||||
return Err(anyhow!("Message length is over the limit"));
|
return Err(ProtoPeekError::LengthOverflow {
|
||||||
|
limit: reader.msg_len_limit as usize,
|
||||||
|
actual: length as usize,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
|
let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]);
|
||||||
|
|
@ -142,23 +146,29 @@ where
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
let data = bincode::decode_from_slice(&data[4..], pg_proto_config())?.0;
|
let data = StartupMessageData::deserialize(&data[4..])?;
|
||||||
|
|
||||||
Ok(Some(SpecialMessage::StartupMessage(data)))
|
Ok(Some(SpecialMessage::StartupMessage(data)))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn consume_startup_message<R>(reader: &mut ProtoReader<R>) -> anyhow::Result<()>
|
async fn consume_startup_message<R>(reader: &mut ProtoReader<R>) -> Result<(), ProtoConsumeError>
|
||||||
where
|
where
|
||||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||||
{
|
{
|
||||||
let mut header = [0u8; 4];
|
let mut header = [0u8; 4];
|
||||||
if reader.inner.peek(&mut header).await? != 4 {
|
let size = reader.inner.peek(&mut header).await?;
|
||||||
return Err(anyhow!("Invalid header peek length"));
|
if size != 4 {
|
||||||
|
return Err(ProtoConsumeError::UnexpectedDataLength {
|
||||||
|
expected: 4,
|
||||||
|
actual: size
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize;
|
let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize;
|
||||||
if length < 8 {
|
if length < 8 {
|
||||||
return Err(anyhow!("Invalid startup message length"));
|
return Err(ProtoConsumeError::UnexpectedDataLength {
|
||||||
|
expected: 8,
|
||||||
|
actual: length
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
reader.inner.consume(length);
|
reader.inner.consume(length);
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,4 @@ pub mod frontend;
|
||||||
pub mod oneway;
|
pub mod oneway;
|
||||||
pub mod protoreader;
|
pub mod protoreader;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
pub mod errors;
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,14 @@ use crate::reader::protoreader::ProtoReader;
|
||||||
use crate::reader::utils::AsyncPeek;
|
use crate::reader::utils::AsyncPeek;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tokio::io::{AsyncBufRead, AsyncReadExt};
|
use tokio::io::{AsyncBufRead, AsyncReadExt};
|
||||||
|
use crate::reader::errors::ProtoReadError;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait OneWayProtoReader<T>
|
pub trait OneWayProtoReader<T>
|
||||||
where
|
where
|
||||||
T: ProtoMessage,
|
T: ProtoMessage,
|
||||||
{
|
{
|
||||||
async fn read_proto(&mut self) -> anyhow::Result<T>;
|
async fn read_proto(&mut self) -> Result<T, ProtoReadError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -18,15 +19,18 @@ where
|
||||||
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
R: AsyncBufRead + AsyncPeek + Unpin + Send,
|
||||||
T: ProtoMessage,
|
T: ProtoMessage,
|
||||||
{
|
{
|
||||||
async fn read_proto(&mut self) -> anyhow::Result<T> {
|
async fn read_proto(&mut self) -> Result<T, ProtoReadError> {
|
||||||
let variant = self.inner.read_u8().await?;
|
let variant = self.inner.read_u8().await?;
|
||||||
let length = self.inner.read_i32().await?;
|
let length = self.inner.read_i32().await?;
|
||||||
|
|
||||||
if length < 4 {
|
if length < 4 {
|
||||||
return Err(anyhow::anyhow!("Invalid message length"));
|
return Err(ProtoReadError::InvalidLength(length));
|
||||||
}
|
}
|
||||||
if length > self.msg_len_limit {
|
if length > self.msg_len_limit {
|
||||||
return Err(anyhow::anyhow!("Message length over limit"));
|
return Err(ProtoReadError::LengthOverflow {
|
||||||
|
limit: self.msg_len_limit as usize,
|
||||||
|
actual: length as usize,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut data = vec![0u8; (length - 4) as usize];
|
let mut data = vec![0u8; (length - 4) as usize];
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ use tokio::io::{AsyncBufRead, AsyncBufReadExt};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait AsyncPeek {
|
pub trait AsyncPeek {
|
||||||
async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize>;
|
async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -11,7 +11,7 @@ impl<T> AsyncPeek for T
|
||||||
where
|
where
|
||||||
T: AsyncBufRead + Unpin + Send,
|
T: AsyncBufRead + Unpin + Send,
|
||||||
{
|
{
|
||||||
async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize> {
|
||||||
let filled = self.fill_buf().await?;
|
let filled = self.fill_buf().await?;
|
||||||
if filled.len() >= buf.len() {
|
if filled.len() >= buf.len() {
|
||||||
buf.copy_from_slice(&filled[..buf.len()]);
|
buf.copy_from_slice(&filled[..buf.len()]);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue