diff --git a/Cargo.lock b/Cargo.lock index 0471c9d..a284cd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -210,6 +210,17 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "getrandom" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "gimli" version = "0.28.1" @@ -256,6 +267,7 @@ version = "0.1.0" dependencies = [ "bimap", "thiserror", + "tokio", ] [[package]] @@ -326,6 +338,12 @@ version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.70" @@ -354,6 +372,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -400,9 +448,11 @@ name = "server" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "clap", "minisql", "proto", + "rand", "tokio", ] diff --git a/minisql/Cargo.toml b/minisql/Cargo.toml index 6164b6b..22d0242 100644 --- a/minisql/Cargo.toml +++ b/minisql/Cargo.toml @@ -8,3 +8,4 @@ edition = "2021" [dependencies] bimap = "0.6.3" thiserror = "1.0.50" +tokio = { version = "1.35.1", features = ["sync"] } diff --git a/minisql/src/interpreter.rs b/minisql/src/interpreter.rs index 805923e..26f11ef 100644 --- a/minisql/src/interpreter.rs +++ b/minisql/src/interpreter.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use crate::error::Error; use crate::internals::row::ColumnPosition; use crate::schema::{TableName, TableSchema}; @@ -6,6 +7,7 @@ use crate::operation::{ColumnSelection, Condition, Operation}; use crate::result::DbResult; use crate::type_system::{DbType, IndexableValue, Value}; use bimap::BiMap; +use tokio::sync::Mutex; use crate::restricted_row::RestrictedRow; // Use `TablePosition` as index @@ -21,7 +23,7 @@ pub struct State { // #[derive(Debug)] pub enum Response<'a> { - Selected(&'a TableSchema, Box + 'a>), + Selected(&'a TableSchema, Arc + 'a + Send>>), Inserted, Deleted(usize), // how many were deleted TableCreated, @@ -49,7 +51,7 @@ impl std::fmt::Debug for Response<'_> { } impl State { - fn new() -> Self { + pub fn new() -> Self { Self { table_name_position_mapping: BiMap::new(), tables: vec![], @@ -100,7 +102,7 @@ impl State { let selected_rows = match maybe_condition { None => { let x = table.select_all_rows(selected_column_positions); - Box::new(x) as Box + 'a> + Arc::new(Mutex::new(x)) as Arc + 'a + Send>> }, Some(Condition::Eq(eq_column_name, value)) => { @@ -113,7 +115,7 @@ impl State { eq_column_position, value, )?; - Box::new(x) as Box + 'a> + Arc::new(Mutex::new(x)) as Arc + 'a + Send>> } }; diff --git a/minisql/src/operation.rs b/minisql/src/operation.rs index 3b060c9..e30a044 100644 --- a/minisql/src/operation.rs +++ b/minisql/src/operation.rs @@ -6,6 +6,7 @@ use crate::type_system::Value; // Perhaps consider factoring the table name out // and think of the operations as operating on a unique table. // TODO: `TableName` should be replaced by `TablePosition` +#[derive(Debug)] pub enum Operation { Select(TableName, ColumnSelection, Option), Insert(TableName, InsertionValues), @@ -18,11 +19,13 @@ pub enum Operation { pub type InsertionValues = Vec<(ColumnName, Value)>; +#[derive(Debug)] pub enum ColumnSelection { All, Columns(Vec), } +#[derive(Debug)] pub enum Condition { // And(Box, Box), // Or(Box, Box), diff --git a/minisql/src/schema.rs b/minisql/src/schema.rs index c9574ac..4f2bf54 100644 --- a/minisql/src/schema.rs +++ b/minisql/src/schema.rs @@ -20,7 +20,7 @@ pub type TableName = String; pub type ColumnName = String; impl TableSchema { - pub(crate) fn new(table_name: TableName, primary_key: ColumnPosition, column_name_position_map: Vec<(ColumnName, ColumnPosition)>, types: Vec) -> Self { + pub fn new(table_name: TableName, primary_key: ColumnPosition, column_name_position_map: Vec<(ColumnName, ColumnPosition)>, types: Vec) -> Self { let mut column_name_position_mapping: BiMap = BiMap::new(); for (column_name, column_position) in column_name_position_map { column_name_position_mapping.insert(column_name, column_position); diff --git a/server/Cargo.toml b/server/Cargo.toml index 0dd40ff..256b592 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,5 +11,5 @@ clap = { version = "4.4.18", features = ["derive"] } tokio = { version = "1.35.1", features = ["full"] } minisql = { path = "../minisql" } proto = { path = "../proto" } - - +async-trait = "0.1.74" +rand = "0.8.5" diff --git a/server/src/main.rs b/server/src/main.rs index f0b8084..d4a2e89 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,49 +1,121 @@ -mod config; +use std::collections::HashMap; +use std::sync::Arc; -use std::net::SocketAddr; use clap::Parser; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{Mutex, RwLock}; + +use minisql::interpreter::{Response, State}; +use proto::handshake::errors::ServerHandshakeError; +use proto::handshake::request::HandshakeRequest; use proto::handshake::response::HandshakeResponse; use proto::handshake::server::do_server_handshake; -use proto::message::backend::{ - BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, - ReadyForQueryData, RowDescriptionData, -}; use proto::message::frontend::FrontendMessage; -use proto::reader::oneway::OneWayProtoReader; +use proto::reader::frontend::FrontendProtoReader; use proto::reader::protoreader::ProtoReader; use proto::writer::backend::BackendProtoWriter; use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; -use tokio::io::{BufReader, BufWriter}; -use tokio::net::{TcpListener, TcpStream}; + +use crate::cancellation::ResetCancelToken; use crate::config::Configuration; +use crate::parser_stub::parse_query; +use crate::proto_wrapper::{CompleteStatus, ServerProto}; + +mod config; +mod proto_wrapper; +mod cancellation; +mod parser_stub; + +type TokenStore = Arc>>; +type DbState = Arc>; #[tokio::main] async fn main() -> anyhow::Result<()> { let config = Configuration::parse(); + let state = Arc::new(RwLock::new(State::new())); + let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new())); + let addr = config.get_socket_address(); let listener = TcpListener::bind(&addr).await?; println!("Server started at {addr}"); loop { + let state = state.clone(); + let tokens = tokens.clone(); + let (pid, key) = random_pid_key(); + let (socket, _) = listener.accept().await?; println!("New client connected: {}", socket.peer_addr()?); tokio::spawn(async move { - let reason = handle_stream(socket).await; + let reason = handle_stream(socket, state, tokens).await; println!("Client disconnected: {reason:?}"); }); } } -async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { +async fn handle_stream(mut stream: TcpStream, state: DbState, tokens: TokenStore) -> anyhow::Result<()> { let (reader, writer) = stream.split(); let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut reader = ProtoReader::new(BufReader::new(reader), 1024); - let response = HandshakeResponse::new("minisql", 123, 123); + // Create a token with random PID and key + let (pid, key, token) = create_token(&tokens).await?; - let request = do_server_handshake(&mut writer, &mut reader, response).await?; + // Handle handshake + let response = HandshakeResponse::new("minisql", pid, key); + let request = do_server_handshake(&mut writer, &mut reader, response).await; + let result = match request { + Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token).await, + Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await, + Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)), + }; + + // Release cancellation token + let mut tokens = tokens.lock().await; + tokens.remove(&(pid, key)); + + result +} + +fn random_pid_key() -> (i32, i32) { + let pid = rand::random::(); + let key = rand::random::(); + (pid, key) +} + +async fn create_token(tokens: &TokenStore) -> anyhow::Result<(i32, i32, ResetCancelToken)> { + let token = ResetCancelToken::new(); + let mut tokens = tokens.lock().await; + loop { + let pid_key = random_pid_key(); + if !tokens.contains_key(&pid_key) { + tokens.insert(pid_key, token.clone()); + + let (pid, key) = pid_key; + return Ok((pid, key, token)); + } + } +} + +async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow::Result<()> { + let tokens = tokens.lock().await; + let token = tokens.get(&(pid, key)); + match token { + Some(t) => t.cancel(), + None => return Err(anyhow::anyhow!("Invalid PID and Key cancel combination")), + } + + Ok(()) +} + +async fn handle_connection(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: DbState, token: ResetCancelToken) -> anyhow::Result<()> +where + R: FrontendProtoReader + Send, + W: BackendProtoWriter + ProtoFlush + Send, +{ println!("Handshake complete:\n{request:?}"); loop { @@ -57,17 +129,48 @@ async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { } FrontendMessage::Query(data) => { println!("Received Query: {:?}", data); - if data.query.as_str().contains("car") { - println!("Sending error message"); - send_error_response(&mut writer, "Car not found").await?; - } else if data.query.as_str().to_lowercase().contains("select") { - println!("Sending table"); - send_query_response(&mut writer).await?; - } else { - println!("Sending empty query"); - send_empty_query(&mut writer).await?; + let operation = parse_query(data.query.as_str()); + println!("Parsed query: {:?}", operation); + + let mut state = state.write().await; + let result = state.interpret(operation); + println!("Result: {:?}", result); + + match result { + Err(e) => { + writer.write_error_message(&format!("Error: {:?}", e)).await?; + } + Ok(res) => { + match res { + Response::Deleted(i) => writer.write_command_complete(CompleteStatus::Delete(i)).await?, + Response::Inserted => writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?, + Response::Selected(schema, rows) => { + let mut rows = rows.lock().await; + let first_row = rows.next(); + match first_row { + Some(row) => { + writer.write_table_header(&schema, &row).await?; + writer.write_table_row(&row).await?; + + let mut sent_rows = 1; + while let Some(row) = rows.next() { + writer.write_table_row(&row).await?; + sent_rows += 1; + } + + writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; + } + None => { + writer.write_command_complete(CompleteStatus::Select(0)).await?; + } + } + } + _ => {} + } + } } - send_ready_for_query(&mut writer).await?; + + writer.write_ready_for_query().await?; } } writer.flush().await?; @@ -75,118 +178,3 @@ async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { Ok(()) } - -async fn send_error_response( - writer: &mut impl BackendProtoWriter, - error_message: &str, -) -> anyhow::Result<()> { - writer - .write_proto( - ErrorResponseData { - code: b'M', - message: error_message.to_string().into(), - } - .into(), - ) - .await?; - - Ok(()) -} - -async fn send_ready_for_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { - writer - .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) - .await?; - - Ok(()) -} - -async fn send_empty_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { - writer - .write_proto(BackendMessage::EmptyQueryResponse) - .await?; - - Ok(()) -} - -async fn send_row_description(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { - let columns = vec![ - ColumnDescription { - name: "id".to_string().into(), - table_oid: 123, - column_index: 1, - type_oid: 23, - type_size: 4, - type_modifier: -1, - format_code: 0, - }, - ColumnDescription { - name: "argument".to_string().into(), - table_oid: 123, - column_index: 2, - type_oid: 23, - type_size: 4, - type_modifier: -1, - format_code: 0, - }, - ColumnDescription { - name: "description".to_string().into(), - table_oid: 123, - column_index: 3, - type_oid: 1043, - type_size: 32, - type_modifier: -1, - format_code: 0, - }, - ]; - - writer - .write_proto( - RowDescriptionData { - columns: columns.into(), - } - .into(), - ) - .await?; - - Ok(()) -} - -async fn send_query_response(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { - send_row_description(writer).await?; - - write_row(writer, b"0", b"1337", b"auto").await?; - write_row(writer, b"1", b"69", b"bus").await?; - write_row(writer, b"2", b"420", b"kolo").await?; - - writer - .write_proto( - CommandCompleteData { - tag: "SELECT 3".to_string().into(), - } - .into(), - ) - .await?; - - Ok(()) -} - -async fn write_row( - writer: &mut impl BackendProtoWriter, - first: &[u8], - second: &[u8], - third: &[u8], -) -> anyhow::Result<()> { - let row_data = vec![ - first.to_vec().into(), - second.to_vec().into(), - third.to_vec().into(), - ] - .into(); - - writer - .write_proto(DataRowData { columns: row_data }.into()) - .await?; - - Ok(()) -} diff --git a/server/src/parser_stub.rs b/server/src/parser_stub.rs new file mode 100644 index 0000000..c8f3e6a --- /dev/null +++ b/server/src/parser_stub.rs @@ -0,0 +1,63 @@ +use minisql::operation::{ColumnSelection, Operation}; +use minisql::schema::TableSchema; +use minisql::type_system::{DbType, IndexableValue, Value}; + +const TABLE_NAME: &'static str = "tablus"; + +static mut ID_COUNTER: u64 = 0; + +pub fn parse_query(query: &str) -> Operation { + if query.contains("select") { + if query.contains("*") { + Operation::Select(TABLE_NAME.to_string(), ColumnSelection::All, None) + } else { + Operation::Select(TABLE_NAME.to_string(), ColumnSelection::Columns(vec![ + "name".to_string(), + "price".to_string(), + ]), None) + } + } else if query.contains("insert") { + + let id = unsafe { + ID_COUNTER += 1; + ID_COUNTER + }; + + let rand_rak = rand::random::(); + let rand_price = rand::random::(); + + Operation::Insert(TABLE_NAME.to_string(), vec![ + ("id".to_string(), Value::Indexable(IndexableValue::Uuid(id))), + ("name".to_string(), Value::Indexable(IndexableValue::String(format!("Car {}", rand_rak)))), + ("price".to_string(), Value::Number(rand_price)), + ("mileage".to_string(), Value::Indexable(IndexableValue::Int(1234))), + ]) + } else if query.contains("delete") { + Operation::Delete(TABLE_NAME.to_string(), None) + } else if query.contains("create table") { + Operation::CreateTable(TABLE_NAME.to_string(), get_cars_schema()) + } else if query.contains("create index") { + Operation::CreateIndex(TABLE_NAME.to_string(), "price".to_string()) + } else { + panic!("Unknown query: {}", query); + } +} + +fn get_cars_schema() -> TableSchema { + TableSchema::new( + "cars".to_string(), + 0, + vec![ + ("id".to_string(), 0), + ("name".to_string(), 1), + ("price".to_string(), 2), + ("mileage".to_string(), 3), + ], + vec![ + DbType::Uuid, + DbType::String, + DbType::Number, + DbType::Int, + ] + ) +} \ No newline at end of file diff --git a/server/src/proto_wrapper.rs b/server/src/proto_wrapper.rs new file mode 100644 index 0000000..3415255 --- /dev/null +++ b/server/src/proto_wrapper.rs @@ -0,0 +1,104 @@ +use async_trait::async_trait; +use minisql::restricted_row::RestrictedRow; +use minisql::schema::TableSchema; +use minisql::type_system::{Value}; +use proto::message::backend::{BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, ReadyForQueryData, RowDescriptionData}; +use proto::message::primitive::pglist::PgList; +use proto::writer::backend::BackendProtoWriter; + +pub enum CompleteStatus { + Insert { + oid: i32, + rows: i32, + }, + Delete(usize), + Select(usize), +} + +impl CompleteStatus { + fn to_string(&self) -> String { + match self { + CompleteStatus::Insert { oid, rows } => format!("INSERT {} {}", oid, rows), + CompleteStatus::Delete(rows) => format!("DELETE {}", rows), + CompleteStatus::Select(rows) => format!("SELECT {}", rows), + } + } +} + +#[async_trait] +pub trait ServerProto { + async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()>; + async fn write_ready_for_query(&mut self) -> anyhow::Result<()>; + async fn write_empty_query(&mut self) -> anyhow::Result<()>; + async fn write_table_header(&mut self, table_schema: &TableSchema, row: &RestrictedRow) -> anyhow::Result<()>; + async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()>; + async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()>; +} + +#[async_trait] +impl ServerProto for W where W: BackendProtoWriter + Send { + async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()> { + self.write_proto(ErrorResponseData { + code: b'M', + message: format!("{error_message}\0").into(), + }.into()).await?; + + Ok(()) + } + + async fn write_ready_for_query(&mut self) -> anyhow::Result<()> { + self.write_proto(ReadyForQueryData { status: b'I' }.into()).await?; + Ok(()) + } + + async fn write_empty_query(&mut self) -> anyhow::Result<()> { + self.write_proto(BackendMessage::EmptyQueryResponse).await?; + Ok(()) + } + + async fn write_table_header(&mut self, table_schema: &TableSchema, row: &RestrictedRow) -> anyhow::Result<()> { + let columns = row.iter() + .map(|(index, value)| value_to_column_description(table_schema, value, index)) + .collect::>>()?; + + self.write_proto(RowDescriptionData { columns: columns.into() }.into()).await?; + Ok(()) + } + + async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()> { + let values = row.iter() + .map(|(_, value)| value.as_text_bytes().into()) + .collect::>>(); + + self.write_proto(BackendMessage::DataRow(DataRowData { + columns: values.into(), + })).await?; + Ok(()) + } + + async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()> { + self.write_proto(BackendMessage::CommandComplete(CommandCompleteData { + tag: status.to_string().into(), + })).await?; + Ok(()) + } +} + +fn value_to_column_description(schema: &TableSchema, value: &Value, index: &usize) -> anyhow::Result { + let name = schema.column_name_from_column_position(*index)?; + + let table_oid = schema.table_name().as_bytes().as_ptr() as i32; + let column_index = (*index).try_into()?; + let type_oid = value.type_oid(); + let type_size = value.type_size(); + + Ok(ColumnDescription { + name: name.to_string().into(), + table_oid, + column_index, + type_oid, + type_size, + type_modifier: -1, + format_code: 0, // text format + }) +}