From 413e0216e367b47741523e95a16bbab3537680c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 11 Dec 2023 16:53:21 +0100 Subject: [PATCH] feat(proto): add generic proto reader --- proto/src/lib.rs | 1 + proto/src/reader/mod.rs | 3 +++ proto/src/reader/oneway.rs | 37 +++++++++++++++++++++++++++++++++ proto/src/reader/protoreader.rs | 22 ++++++++++++++++++++ proto/src/reader/utils.rs | 24 +++++++++++++++++++++ 5 files changed, 87 insertions(+) create mode 100644 proto/src/reader/mod.rs create mode 100644 proto/src/reader/oneway.rs create mode 100644 proto/src/reader/protoreader.rs create mode 100644 proto/src/reader/utils.rs diff --git a/proto/src/lib.rs b/proto/src/lib.rs index c65ece0..c964d21 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -1,2 +1,3 @@ pub mod message; +pub mod reader; pub mod writer; diff --git a/proto/src/reader/mod.rs b/proto/src/reader/mod.rs new file mode 100644 index 0000000..450e7d4 --- /dev/null +++ b/proto/src/reader/mod.rs @@ -0,0 +1,3 @@ +pub mod oneway; +pub mod protoreader; +mod utils; diff --git a/proto/src/reader/oneway.rs b/proto/src/reader/oneway.rs new file mode 100644 index 0000000..4cca72b --- /dev/null +++ b/proto/src/reader/oneway.rs @@ -0,0 +1,37 @@ +use crate::message::proto_message::ProtoMessage; +use crate::reader::protoreader::ProtoReader; +use crate::reader::utils::AsyncPeek; +use async_trait::async_trait; +use tokio::io::{AsyncBufRead, AsyncReadExt}; + +#[async_trait] +pub trait OneWayProtoReader +where + T: ProtoMessage, +{ + async fn read_proto(&mut self) -> anyhow::Result; +} + +#[async_trait] +impl OneWayProtoReader for ProtoReader +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, + T: ProtoMessage, +{ + async fn read_proto(&mut self) -> anyhow::Result { + let variant = self.inner.read_u8().await?; + let length = self.inner.read_i32().await?; + + if length < 4 { + return Err(anyhow::anyhow!("Invalid message length")); + } + if length > self.msg_len_limit { + return Err(anyhow::anyhow!("Message length over limit")); + } + + let mut data = vec![0u8; (length - 4) as usize]; + self.inner.read_exact(&mut data).await?; + + T::deserialize(variant, &data) + } +} diff --git a/proto/src/reader/protoreader.rs b/proto/src/reader/protoreader.rs new file mode 100644 index 0000000..5e3f572 --- /dev/null +++ b/proto/src/reader/protoreader.rs @@ -0,0 +1,22 @@ +use crate::reader::utils::AsyncPeek; +use tokio::io::AsyncBufRead; + +pub struct ProtoReader +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + pub(super) inner: R, + pub(super) msg_len_limit: i32, +} + +impl ProtoReader +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + pub fn new(reader: R, msg_len_limit: i32) -> ProtoReader { + ProtoReader { + inner: reader, + msg_len_limit, + } + } +} diff --git a/proto/src/reader/utils.rs b/proto/src/reader/utils.rs new file mode 100644 index 0000000..e4a70eb --- /dev/null +++ b/proto/src/reader/utils.rs @@ -0,0 +1,24 @@ +use async_trait::async_trait; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +#[async_trait] +pub trait AsyncPeek { + async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result; +} + +#[async_trait] +impl AsyncPeek for T +where + T: AsyncBufRead + Unpin + Send, +{ + async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result { + let filled = self.fill_buf().await?; + if filled.len() >= buf.len() { + buf.copy_from_slice(&filled[..buf.len()]); + Ok(buf.len()) + } else { + buf[..filled.len()].copy_from_slice(filled); + Ok(filled.len()) + } + } +}