feat(proto): add generic proto reader

This commit is contained in:
Jindřich Moravec 2023-12-11 16:53:21 +01:00
parent 67af05ea42
commit 413e0216e3
5 changed files with 87 additions and 0 deletions

View file

@ -1,2 +1,3 @@
pub mod message;
pub mod reader;
pub mod writer;

3
proto/src/reader/mod.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod oneway;
pub mod protoreader;
mod utils;

View file

@ -0,0 +1,37 @@
use crate::message::proto_message::ProtoMessage;
use crate::reader::protoreader::ProtoReader;
use crate::reader::utils::AsyncPeek;
use async_trait::async_trait;
use tokio::io::{AsyncBufRead, AsyncReadExt};
#[async_trait]
pub trait OneWayProtoReader<T>
where
T: ProtoMessage,
{
async fn read_proto(&mut self) -> anyhow::Result<T>;
}
#[async_trait]
impl<R, T> OneWayProtoReader<T> for ProtoReader<R>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
T: ProtoMessage,
{
async fn read_proto(&mut self) -> anyhow::Result<T> {
let variant = self.inner.read_u8().await?;
let length = self.inner.read_i32().await?;
if length < 4 {
return Err(anyhow::anyhow!("Invalid message length"));
}
if length > self.msg_len_limit {
return Err(anyhow::anyhow!("Message length over limit"));
}
let mut data = vec![0u8; (length - 4) as usize];
self.inner.read_exact(&mut data).await?;
T::deserialize(variant, &data)
}
}

View file

@ -0,0 +1,22 @@
use crate::reader::utils::AsyncPeek;
use tokio::io::AsyncBufRead;
pub struct ProtoReader<R>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
pub(super) inner: R,
pub(super) msg_len_limit: i32,
}
impl<R> ProtoReader<R>
where
R: AsyncBufRead + AsyncPeek + Unpin + Send,
{
pub fn new(reader: R, msg_len_limit: i32) -> ProtoReader<R> {
ProtoReader {
inner: reader,
msg_len_limit,
}
}
}

24
proto/src/reader/utils.rs Normal file
View file

@ -0,0 +1,24 @@
use async_trait::async_trait;
use tokio::io::{AsyncBufRead, AsyncBufReadExt};
#[async_trait]
pub trait AsyncPeek {
async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize>;
}
#[async_trait]
impl<T> AsyncPeek for T
where
T: AsyncBufRead + Unpin + Send,
{
async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let filled = self.fill_buf().await?;
if filled.len() >= buf.len() {
buf.copy_from_slice(&filled[..buf.len()]);
Ok(buf.len())
} else {
buf[..filled.len()].copy_from_slice(filled);
Ok(filled.len())
}
}
}