From b5405d757529825f6cbd179fbf0aa911e79dfc7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Mon, 5 Feb 2024 16:14:33 +0100 Subject: [PATCH] refactor: create trait for writing response from interpreter --- Cargo.lock | 10 +-- minisql/Cargo.toml | 2 + minisql/src/lib.rs | 1 + minisql/src/response_writer.rs | 36 ++++++++++ server/src/config.rs | 5 +- server/src/main.rs | 16 ++--- server/src/proto_wrapper.rs | 128 ++++++++++++++++----------------- 7 files changed, 118 insertions(+), 80 deletions(-) create mode 100644 minisql/src/response_writer.rs diff --git a/Cargo.lock b/Cargo.lock index 14ae99d..5c022e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -67,15 +67,15 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.76" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59d2a3357dde987206219e78ecfbbb6e8dad06cbb65292758d3270e6254f7355" +checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" [[package]] name = "async-trait" -version = "0.1.74" +version = "0.1.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2 1.0.78", "quote 1.0.35", @@ -283,6 +283,8 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" name = "minisql" version = "0.1.0" dependencies = [ + "anyhow", + "async-trait", "bimap", "proto", "serde", diff --git a/minisql/Cargo.toml b/minisql/Cargo.toml index f44d711..8ab1ba7 100644 --- a/minisql/Cargo.toml +++ b/minisql/Cargo.toml @@ -7,6 +7,8 @@ rust-version = "1.74" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +anyhow = "1.0.79" +async-trait = "0.1.77" bimap = { version = "0.6.3", features = ["serde"] } serde = { version = "1.0.196", features = ["derive"] } thiserror = "1.0.50" diff --git a/minisql/src/lib.rs b/minisql/src/lib.rs index baa4b4b..9d27314 100644 --- a/minisql/src/lib.rs +++ b/minisql/src/lib.rs @@ -2,6 +2,7 @@ mod error; mod internals; pub mod interpreter; pub mod operation; +pub mod response_writer; pub mod restricted_row; mod result; pub mod schema; diff --git a/minisql/src/response_writer.rs b/minisql/src/response_writer.rs new file mode 100644 index 0000000..15bd788 --- /dev/null +++ b/minisql/src/response_writer.rs @@ -0,0 +1,36 @@ +use crate::operation::ColumnSelection; +use crate::restricted_row::RestrictedRow; +use crate::schema::TableSchema; +use async_trait::async_trait; +use std::fmt; + +pub enum CompleteStatus { + Insert { oid: i32, rows: i32 }, + Delete(usize), + Select(usize), + CreateTable, + CreateIndex, +} + +impl fmt::Display for CompleteStatus { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + CompleteStatus::Insert { oid, rows } => write!(f, "INSERT {} {}", oid, rows), + CompleteStatus::Delete(rows) => write!(f, "DELETE {}", rows), + CompleteStatus::Select(rows) => write!(f, "SELECT {}", rows), + CompleteStatus::CreateTable => write!(f, "CREATE TABLE"), + CompleteStatus::CreateIndex => write!(f, "CREATE INDEX"), + } + } +} + +#[async_trait] +pub trait ResponseWriter { + async fn write_table_header( + &mut self, + table_schema: &TableSchema, + columns: &ColumnSelection, + ) -> anyhow::Result<()>; + async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()>; + async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()>; +} diff --git a/server/src/config.rs b/server/src/config.rs index efdb3c4..2dedcbc 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -1,6 +1,7 @@ use clap::Parser; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::PathBuf; +use std::time::Duration; const LOCAL_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); @@ -34,7 +35,7 @@ impl Configuration { } #[inline] - pub fn get_throttle(&self) -> Option { - self.throttle + pub fn get_throttle(&self) -> Option { + self.throttle.map(|d| Duration::from_millis(d)) } } diff --git a/server/src/main.rs b/server/src/main.rs index 74d35e4..e1e6f3d 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -8,6 +8,7 @@ use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{Mutex, RwLock}; use minisql::interpreter::{Response, State}; +use minisql::response_writer::{CompleteStatus, ResponseWriter}; use parser::parse_and_validate; use proto::handshake::errors::ServerHandshakeError; use proto::handshake::request::HandshakeRequest; @@ -22,7 +23,7 @@ use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; use crate::cancellation::ResetCancelToken; use crate::config::Configuration; use crate::persistence::state_to_file; -use crate::proto_wrapper::{CompleteStatus, ServerProto}; +use crate::proto_wrapper::ServerProtoWrapper; mod cancellation; mod config; @@ -87,8 +88,11 @@ async fn handle_stream( let response = HandshakeResponse::new("minisql", pid, key); let request = do_server_handshake(&mut writer, &mut reader, response).await; + let mut wrapped_writer = ServerProtoWrapper::new(writer, config.get_throttle()); let result = match request { - Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token, config).await, + Ok(req) => { + handle_connection(&mut reader, &mut wrapped_writer, req, state, token, config).await + } Err(ServerHandshakeError::IsCancelRequest(cancel)) => { handle_cancellation(cancel.pid, cancel.secret, &tokens).await } @@ -139,7 +143,7 @@ async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow: async fn handle_connection( reader: &mut R, - writer: &mut W, + writer: &mut ServerProtoWrapper, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken, @@ -174,7 +178,7 @@ where } async fn handle_query( - writer: &mut W, + writer: &mut ServerProtoWrapper, state: &SharedDbState, query: String, token: &ResetCancelToken, @@ -223,10 +227,6 @@ where token.reset(); break; } - if let Some(throttle) = config.get_throttle() { - writer.flush().await?; - tokio::time::sleep(tokio::time::Duration::from_millis(throttle)).await; - } } writer diff --git a/server/src/proto_wrapper.rs b/server/src/proto_wrapper.rs index 1ddf59c..4d76ade 100644 --- a/server/src/proto_wrapper.rs +++ b/server/src/proto_wrapper.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; use minisql::operation::ColumnSelection; +use minisql::response_writer::{CompleteStatus, ResponseWriter}; use minisql::restricted_row::RestrictedRow; use minisql::schema::{Column, TableSchema}; use proto::message::backend::{ @@ -9,71 +10,57 @@ use proto::message::backend::{ use proto::message::primitive::pglist::PgList; use proto::message::primitive::pgoid::PgOid; use proto::writer::backend::BackendProtoWriter; -use std::fmt; +use proto::writer::protowriter::ProtoFlush; +use std::io::Error; +use std::time::Duration; -pub enum CompleteStatus { - Insert { oid: i32, rows: i32 }, - Delete(usize), - Select(usize), - CreateTable, - CreateIndex, -} +pub struct ServerProtoWrapper(W, Option); -impl fmt::Display for CompleteStatus { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - CompleteStatus::Insert { oid, rows } => write!(f, "INSERT {} {}", oid, rows), - CompleteStatus::Delete(rows) => write!(f, "DELETE {}", rows), - CompleteStatus::Select(rows) => write!(f, "SELECT {}", rows), - CompleteStatus::CreateTable => write!(f, "CREATE TABLE"), - CompleteStatus::CreateIndex => write!(f, "CREATE INDEX"), - } - } -} - -#[async_trait] -pub trait ServerProto { - async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()>; - async fn write_ready_for_query(&mut self) -> anyhow::Result<()>; - async fn write_empty_query(&mut self) -> anyhow::Result<()>; - async fn write_table_header( - &mut self, - table_schema: &TableSchema, - columns: &ColumnSelection, - ) -> anyhow::Result<()>; - async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()>; - async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()>; -} - -#[async_trait] -impl ServerProto for W +impl ServerProtoWrapper where - W: BackendProtoWriter + Send, + W: BackendProtoWriter + ProtoFlush + Send, { - async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()> { - self.write_proto( - ErrorResponseData { - code: b'M', - message: format!("{error_message}\0").into(), - } - .into(), - ) - .await?; + pub fn new(writer: W, throttle: Option) -> Self { + Self(writer, throttle) + } + + pub async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()> { + self.0 + .write_proto( + ErrorResponseData { + code: b'M', + message: format!("{error_message}\0").into(), + } + .into(), + ) + .await?; Ok(()) } - async fn write_ready_for_query(&mut self) -> anyhow::Result<()> { - self.write_proto(ReadyForQueryData { status: b'I' }.into()) + pub async fn write_ready_for_query(&mut self) -> anyhow::Result<()> { + self.0 + .write_proto(ReadyForQueryData { status: b'I' }.into()) .await?; Ok(()) } +} - async fn write_empty_query(&mut self) -> anyhow::Result<()> { - self.write_proto(BackendMessage::EmptyQueryResponse).await?; - Ok(()) +#[async_trait] +impl ProtoFlush for ServerProtoWrapper +where + W: ProtoFlush + Send, +{ + async fn flush(&mut self) -> Result<(), Error> { + self.0.flush().await } +} +#[async_trait] +impl ResponseWriter for ServerProtoWrapper +where + W: BackendProtoWriter + ProtoFlush + Send, +{ async fn write_table_header( &mut self, table_schema: &TableSchema, @@ -84,13 +71,14 @@ where .map(|column| column_to_description(table_schema, *column)) .collect::>>()?; - self.write_proto( - RowDescriptionData { - columns: columns.into(), - } - .into(), - ) - .await?; + self.0 + .write_proto( + RowDescriptionData { + columns: columns.into(), + } + .into(), + ) + .await?; Ok(()) } @@ -100,18 +88,26 @@ where .map(|(_, value)| value.as_text_bytes().into()) .collect::>>(); - self.write_proto(BackendMessage::DataRow(DataRowData { - columns: values.into(), - })) - .await?; + self.0 + .write_proto(BackendMessage::DataRow(DataRowData { + columns: values.into(), + })) + .await?; + + if let Some(throttle) = self.1 { + self.0.flush().await?; + tokio::time::sleep(throttle).await; + } + Ok(()) } async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()> { - self.write_proto(BackendMessage::CommandComplete(CommandCompleteData { - tag: status.to_string().into(), - })) - .await?; + self.0 + .write_proto(BackendMessage::CommandComplete(CommandCompleteData { + tag: status.to_string().into(), + })) + .await?; Ok(()) } }