diff --git a/Cargo.lock b/Cargo.lock index 428e5cb..514f99c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,54 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "anstream" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" + +[[package]] +name = "anstyle-parse" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + [[package]] name = "anyhow" version = "1.0.76" @@ -107,6 +155,46 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clap" +version = "4.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" +dependencies = [ + "heck", + "proc-macro2 1.0.70", + "quote 1.0.33", + "syn 2.0.41", +] + +[[package]] +name = "clap_lex" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" + [[package]] name = "client" version = "0.1.0" @@ -118,12 +206,35 @@ dependencies = [ "tokio", ] +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + +[[package]] +name = "getrandom" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "gimli" version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "hermit-abi" version = "0.3.3" @@ -163,6 +274,7 @@ name = "minisql" version = "0.1.0" dependencies = [ "bimap", + "thiserror", ] [[package]] @@ -182,7 +294,7 @@ checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -256,7 +368,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -276,6 +388,12 @@ version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "0.4.30" @@ -322,6 +440,36 @@ dependencies = [ "proc-macro2 1.0.70", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -368,9 +516,12 @@ name = "server" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "clap", "minisql", "parser", "proto", + "rand", "tokio", ] @@ -396,9 +547,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "syn" version = "0.15.44" @@ -457,7 +614,7 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -483,6 +640,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc" +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + [[package]] name = "version_check" version = "0.1.5" @@ -507,7 +670,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", ] [[package]] @@ -516,13 +688,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -531,38 +718,80 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" diff --git a/minisql/Cargo.toml b/minisql/Cargo.toml index 1a108f1..6164b6b 100644 --- a/minisql/Cargo.toml +++ b/minisql/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" [dependencies] bimap = "0.6.3" +thiserror = "1.0.50" diff --git a/minisql/src/error.rs b/minisql/src/error.rs index 6660f9d..451ad7a 100644 --- a/minisql/src/error.rs +++ b/minisql/src/error.rs @@ -1,11 +1,33 @@ +use std::num::{ParseFloatError, ParseIntError}; +use std::str::Utf8Error; +use thiserror::Error; use crate::internals::row::ColumnPosition; use crate::schema::{ColumnName, TableName}; use crate::type_system::{DbType, Uuid, Value}; -#[derive(Debug)] +#[derive(Debug, Error)] pub enum Error { + #[error("column position {1} of table {0} does not exist")] ColumnPositionDoesNotExist(TableName, ColumnPosition), + #[error("column {1} of table {0} has unexpected type {2:?} and value {3:?}")] ValueDoesNotMatchExpectedType(TableName, ColumnName, DbType, Value), + #[error("table {0} already contains row with id {1}")] AttemptingToInsertAlreadyPresentId(TableName, Uuid), + #[error("table {0} cannot be indexed on column {1}")] AttemptToIndexNonIndexableColumn(TableName, ColumnName), } + +#[derive(Debug, Error)] +pub enum TypeConversionError { + #[error("failed to decode bytes to string")] + TextDecodeFailed(#[from] Utf8Error), + #[error("failed to parse float from text")] + NumberDecodeFailed(#[from] ParseFloatError), + #[error("failed to parse int from text")] + IntDecodeFailed(#[from] ParseIntError), + #[error("unknown type with oid {oid} and size {size}")] + UnknownType { + oid: i32, + size: i16 + } +} diff --git a/minisql/src/internals/row.rs b/minisql/src/internals/row.rs index 508ca40..6fa10e1 100644 --- a/minisql/src/internals/row.rs +++ b/minisql/src/internals/row.rs @@ -2,6 +2,7 @@ use crate::type_system::Value; use crate::operation::InsertionValues; use std::ops::{Index, IndexMut}; use std::slice::SliceIndex; +use crate::restricted_row::RestrictedRow; pub type ColumnPosition = usize; @@ -63,14 +64,15 @@ impl Row { self.0.get(column_position) } - pub fn restrict_columns(&self, columns: &Vec) -> Row { + pub fn restrict_columns(&self, columns: &Vec) -> RestrictedRow { // If the index from `columns` is non-existant in `row`, it will just ignore it. - let mut subrow: Row = Row::new(); + let mut subrow: Vec<(ColumnPosition, Value)> = vec![]; for column_position in columns { if let Some(value) = self.get(*column_position) { - subrow.0.push(value.clone()) + subrow.push((*column_position, value.clone())); } } - subrow + + subrow.into() } } diff --git a/minisql/src/internals/table.rs b/minisql/src/internals/table.rs index ac6ca5a..05ee66e 100644 --- a/minisql/src/internals/table.rs +++ b/minisql/src/internals/table.rs @@ -3,6 +3,7 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use crate::error::Error; use crate::internals::column_index::ColumnIndex; use crate::internals::row::{ColumnPosition, Row}; +use crate::restricted_row::RestrictedRow; use crate::schema::{ColumnName, TableSchema, TableName}; use crate::result::DbResult; use crate::type_system::{IndexableValue, Uuid, Value}; @@ -67,7 +68,7 @@ impl Table { .collect() } - pub fn select_all_rows<'a>(&'a self, selected_column_positions: Vec) -> impl Iterator + 'a { + pub fn select_all_rows<'a>(&'a self, selected_column_positions: Vec) -> impl Iterator + 'a { self.rows .values() .map(move |row| row.restrict_columns(&selected_column_positions)) @@ -78,7 +79,7 @@ impl Table { selected_column_positions: Vec, column_position: ColumnPosition, value: Value, - ) -> DbResult + 'a> { + ) -> DbResult + 'a> { let restrict_columns_of_row = move |row: Row| row.restrict_columns(&selected_column_positions); match value { Value::Indexable(value) => match self.fetch_ids_from_index(column_position, &value)? { diff --git a/minisql/src/interpreter.rs b/minisql/src/interpreter.rs index dd0e870..2caf1f0 100644 --- a/minisql/src/interpreter.rs +++ b/minisql/src/interpreter.rs @@ -1,9 +1,10 @@ -use crate::internals::row::Row; +use crate::internals::row::ColumnPosition; use crate::schema::{TableName, TableSchema}; use crate::internals::table::Table; use crate::operation::{Operation, Condition}; use crate::result::DbResult; use bimap::BiMap; +use crate::restricted_row::RestrictedRow; // Use `TablePosition` as index pub type Tables = Vec; @@ -18,7 +19,7 @@ pub struct State { // #[derive(Debug)] pub enum Response<'a> { - Selected(&'a TableSchema, Box + 'a + Send>), + Selected(&'a TableSchema, Box + 'a + Send>), Inserted, Deleted(usize), // how many were deleted TableCreated, @@ -89,8 +90,8 @@ impl State { let selected_rows = match maybe_condition { None => { - let x = table.select_all_rows(column_selection); - Box::new(x) as Box + 'a + Send> + let rows = table.select_all_rows(column_selection); + Box::new(rows) as Box + 'a + Send> }, Some(Condition::Eq(eq_column, value)) => { @@ -100,7 +101,7 @@ impl State { eq_column, value, )?; - Box::new(x) as Box + 'a + Send> + Box::new(x) as Box + 'a + Send> } }; @@ -246,9 +247,9 @@ mod tests { let row = &rows[0]; assert!(row.len() == 3); - assert!(row[0] == id); - assert!(row[1] == name); - assert!(row[2] == age); + assert!(row[0].1 == id); + assert!(row[1].1 == name); + assert!(row[2].1 == age); } #[test] @@ -305,23 +306,24 @@ mod tests { let response: Response = state.interpret(Select(users_position, users_schema.all_selection(), None)).unwrap(); assert!(matches!(response, Response::Selected(_, _))); - let Response::Selected(_schema, rows) = response else { + let Response::Selected(_, rows) = response else { panic!() }; - let rows: Vec<_> = rows.collect(); + + let rows: Vec<_> = rows.collect(); assert!(rows.len() == 2); let row0 = &rows[0]; let row1 = &rows[1]; assert!(row0.len() == 3); - assert!(row0[0] == id0); - assert!(row0[1] == name0); - assert!(row0[2] == age0); + assert!(row0[0].1 == id0); + assert!(row0[1].1 == name0); + assert!(row0[2].1 == age0); assert!(row1.len() == 3); - assert!(row1[0] == id1); - assert!(row1[1] == name1); - assert!(row1[2] == age1); + assert!(row1[0].1 == id1); + assert!(row1[1].1 == name1); + assert!(row1[2].1 == age1); } { @@ -333,7 +335,7 @@ mod tests { )) .unwrap(); assert!(matches!(response, Response::Selected(_, _))); - let Response::Selected(_schema, rows) = response else { + let Response::Selected(_, rows) = response else { panic!() }; let rows: Vec<_> = rows.collect(); @@ -341,9 +343,9 @@ mod tests { let row0 = &rows[0]; assert!(row0.len() == 3); - assert!(row0[0] == id0); - assert!(row0[1] == name0); - assert!(row0[2] == age0); + assert!(row0[0].1 == id0); + assert!(row0[1].1 == name0); + assert!(row0[2].1 == age0); } { @@ -355,7 +357,7 @@ mod tests { )) .unwrap(); assert!(matches!(response, Response::Selected(_, _))); - let Response::Selected(_schema, rows) = response else { + let Response::Selected(_, rows) = response else { panic!() }; let rows: Vec<_> = rows.collect(); @@ -363,8 +365,8 @@ mod tests { let row0 = &rows[0]; assert!(row0.len() == 2); - assert!(row0[0] == name0); - assert!(row0[1] == id0); + assert!(row0[0].1 == name0); + assert!(row0[1].1 == id0); } } @@ -430,7 +432,7 @@ mod tests { let response: Response = state.interpret(Select(users_position, users_schema.all_selection(), None)).unwrap(); assert!(matches!(response, Response::Selected(_, _))); - let Response::Selected(_schema, rows) = response else { + let Response::Selected(_, rows) = response else { panic!() }; let rows: Vec<_> = rows.collect(); @@ -438,9 +440,9 @@ mod tests { let row = &rows[0]; assert!(row.len() == 3); - assert!(row[0] == id1); - assert!(row[1] == name1); - assert!(row[2] == age1); + assert!(row[0].1 == id1); + assert!(row[1].1 == name1); + assert!(row[2].1 == age1); } #[test] @@ -516,7 +518,6 @@ mod tests { pub fn example() { use crate::type_system::{IndexableValue, Value, DbType}; - use crate::internals::row::ColumnPosition; use Condition::*; use IndexableValue::*; use Operation::*; diff --git a/minisql/src/lib.rs b/minisql/src/lib.rs index b8e95c3..f9a0b09 100644 --- a/minisql/src/lib.rs +++ b/minisql/src/lib.rs @@ -5,3 +5,4 @@ pub mod type_system; mod error; mod internals; mod result; +pub mod restricted_row; diff --git a/minisql/src/operation.rs b/minisql/src/operation.rs index 6bbf918..dae8718 100644 --- a/minisql/src/operation.rs +++ b/minisql/src/operation.rs @@ -4,6 +4,7 @@ use crate::internals::row::ColumnPosition; use crate::interpreter::TablePosition; // Validated operation. Constructed by validation crate. +#[derive(Debug)] pub enum Operation { Select(TablePosition, ColumnSelection, Option), Insert(TablePosition, InsertionValues), @@ -16,6 +17,7 @@ pub type InsertionValues = Vec; pub type ColumnSelection = Vec; +#[derive(Debug)] pub enum Condition { Eq(ColumnPosition, Value), } diff --git a/minisql/src/restricted_row.rs b/minisql/src/restricted_row.rs new file mode 100644 index 0000000..26be188 --- /dev/null +++ b/minisql/src/restricted_row.rs @@ -0,0 +1,35 @@ +use std::ops::Index; +use std::slice::SliceIndex; +use crate::internals::row::ColumnPosition; +use crate::type_system::Value; + +#[derive(Debug, Clone)] +pub struct RestrictedRow(Vec<(ColumnPosition, Value)>); + +impl Index for RestrictedRow +where + Idx: SliceIndex<[(ColumnPosition, Value)]>, +{ + type Output = Idx::Output; + + fn index(&self, index: Idx) -> &Self::Output { + &self.0[index] + } +} + +impl From> for RestrictedRow { + fn from(v: Vec<(ColumnPosition, Value)>) -> Self { + RestrictedRow(v) + } +} + +impl RestrictedRow { + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } +} + diff --git a/minisql/src/type_system.rs b/minisql/src/type_system.rs index 7bf5f60..8b95d7b 100644 --- a/minisql/src/type_system.rs +++ b/minisql/src/type_system.rs @@ -1,3 +1,5 @@ +use crate::error::TypeConversionError; + // ==============Types================ #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum DbType { @@ -39,4 +41,156 @@ impl Value { }, } } + + pub fn type_oid(&self) -> i32 { + match self { + Self::Number(_) => 701, + Self::Indexable(val) => match val { + IndexableValue::String(_) => 25, + IndexableValue::Int(_) => 23, + IndexableValue::Uuid(_) => 2950, + }, + } + } + + pub fn type_size(&self) -> i16 { + match self { + Self::Number(_) => 8, + Self::Indexable(val) => match val { + IndexableValue::String(_) => -2, // null terminated string + IndexableValue::Int(_) => 8, + IndexableValue::Uuid(_) => 16, + }, + } + } + + pub fn as_text_bytes(&self) -> Vec { + match self { + Self::Number(n) => format!("{n}").into_bytes(), + Self::Indexable(i) => match i { + IndexableValue::String(s) => format!("{s}\0").into_bytes(), + IndexableValue::Int(i) => format!("{i}").into_bytes(), + IndexableValue::Uuid(u) => format!("{u}").into_bytes(), + }, + } + } + + pub fn from_text_bytes(bytes: &[u8], type_oid: i32, type_size: i16) -> Result { + match (type_oid, type_size) { + (701, 8) => { + let s = std::str::from_utf8(bytes)?; + let n = s.parse::()?; + Ok(Value::Number(n)) + } + (25, -2) => { + let s = std::str::from_utf8(bytes)?; + let s = &s[..s.len() - 1]; // remove null terminator + Ok(Value::Indexable(IndexableValue::String(s.to_string()))) + } + (23, 8) => { + let s = std::str::from_utf8(bytes)?; + let n = s.parse::()?; + Ok(Value::Indexable(IndexableValue::Int(n))) + } + (2950, 16) => { + let s = std::str::from_utf8(bytes)?; + let n = s.parse::()?; + Ok(Value::Indexable(IndexableValue::Uuid(n))) + } + (oid, size) => Err(TypeConversionError::UnknownType { oid, size }), + } + } +} + +#[cfg(test)] +mod tests { + use crate::error::TypeConversionError::UnknownType; + use super::{Value, IndexableValue}; + + #[test] + fn test_encode_number() { + let value = Value::Number(123.456); + let oid = value.type_oid(); + let size = value.type_size(); + + let bytes = value.as_text_bytes(); + let from_bytes = Value::from_text_bytes(&bytes, oid, size).unwrap(); + + assert_eq!(value, from_bytes); + + assert_eq!(oid, 701); + assert_eq!(size, 8); + } + + #[test] + fn test_encode_string() { + let value = Value::Indexable(IndexableValue::String("hello".to_string())); + let oid = value.type_oid(); + let size = value.type_size(); + + let bytes = value.as_text_bytes(); + let from_bytes = Value::from_text_bytes(&bytes, oid, size).unwrap(); + + assert_eq!(value, from_bytes); + + assert_eq!(oid, 25); + assert_eq!(size, -2); + } + + #[test] + fn test_encode_string_utf8() { + let value = Value::Indexable(IndexableValue::String("#速度与激情9 早上好中国 现在我有冰激淋 我很喜欢冰激淋 但是《速度与激情9》比冰激淋 🍧🍦🍨".to_string())); + let oid = value.type_oid(); + let size = value.type_size(); + + let bytes = value.as_text_bytes(); + let from_bytes = Value::from_text_bytes(&bytes, oid, size).unwrap(); + + assert_eq!(value, from_bytes); + + assert_eq!(oid, 25); + assert_eq!(size, -2); + } + + #[test] + fn test_encode_int() { + let value = Value::Indexable(IndexableValue::Int(123)); + let oid = value.type_oid(); + let size = value.type_size(); + + let bytes = value.as_text_bytes(); + let from_bytes = Value::from_text_bytes(&bytes, oid, size).unwrap(); + + assert_eq!(value, from_bytes); + + assert_eq!(oid, 23); + assert_eq!(size, 8); + } + + #[test] + fn test_encode_uuid() { + let value = Value::Indexable(IndexableValue::Uuid(123)); + let oid = value.type_oid(); + let size = value.type_size(); + + let bytes = value.as_text_bytes(); + let from_bytes = Value::from_text_bytes(&bytes, oid, size).unwrap(); + + assert_eq!(value, from_bytes); + + assert_eq!(oid, 2950); + assert_eq!(size, 16); + } + + #[test] + fn test_mismatched_size() { + let value = Value::Indexable(IndexableValue::Uuid(123)); + let oid = value.type_oid(); + let size = 8; + + let bytes = value.as_text_bytes(); + let from_bytes = Value::from_text_bytes(&bytes, oid, size); + + assert!(matches!(from_bytes, Err(UnknownType { oid: 2950, size: 8 }))) + } } diff --git a/proto/src/handshake/errors.rs b/proto/src/handshake/errors.rs index 0811790..cd2a8c4 100644 --- a/proto/src/handshake/errors.rs +++ b/proto/src/handshake/errors.rs @@ -4,6 +4,7 @@ use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError}; use crate::writer::errors::ProtoWriteError; use thiserror::Error; use tokio::io; +use crate::message::special::CancelRequestData; #[derive(Debug, Error)] pub enum ClientHandshakeError { @@ -23,6 +24,8 @@ pub enum ClientHandshakeError { pub enum ServerHandshakeError { #[error("startup message not found")] MissingStartupMessage, + #[error("cancel request found instead of startup message")] + IsCancelRequest(CancelRequestData), #[error("socket communication failed")] Io(#[from] io::Error), #[error("deserialization of inner data failed")] diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs index 6c8deb2..d1a332f 100644 --- a/proto/src/handshake/server.rs +++ b/proto/src/handshake/server.rs @@ -8,6 +8,7 @@ use crate::writer::backend::BackendProtoWriter; use crate::writer::protowriter::ProtoFlush; /// Performs server-side handshake with the client until ending it with `ReadyForQuery` message. +/// Client can send `CancelRequest` message instead of `StartupMessage` to cancel the request. /// For more info visit the [`55.2.1. Start-up`](https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-START-UP) pub async fn do_server_handshake( writer: &mut (impl BackendProtoWriter + ProtoFlush), @@ -27,12 +28,16 @@ pub async fn do_server_handshake( } } - // Wait for mandatory StartupMessage + // Wait for mandatory StartupMessage or CancelRequest let startup_message = match &reader.peek_special_message().await? { Some(msg @ SpecialMessage::StartupMessage(data)) => { reader.consume_special_message(msg).await?; data.clone() } + Some(msg @ SpecialMessage::CancelRequest(data)) => { + reader.consume_special_message(msg).await?; + return Err(ServerHandshakeError::IsCancelRequest(data.clone())); + } _ => { return Err(ServerHandshakeError::MissingStartupMessage); } diff --git a/server/Cargo.toml b/server/Cargo.toml index f542cc7..6a511f6 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -8,6 +8,9 @@ edition = "2021" [dependencies] tokio = { version = "1.35.1", features = ["full"] } anyhow = "1.0.76" -proto = { path = "../proto" } +clap = { version = "4.4.18", features = ["derive"] } +async-trait = "0.1.74" +rand = "0.8.5" minisql = { path = "../minisql" } -parser = { path = "../parser" } +proto = { path = "../proto" } +parser = { path = "../parser" } \ No newline at end of file diff --git a/server/src/cancellation.rs b/server/src/cancellation.rs new file mode 100644 index 0000000..59f2cb1 --- /dev/null +++ b/server/src/cancellation.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; + +pub struct ResetCancelToken { + is_canceled: Arc, +} + +impl ResetCancelToken { + pub fn new() -> Self { + Self { + is_canceled: Arc::new(AtomicBool::new(false)), + } + } + + pub fn is_canceled(&self) -> bool { + self.is_canceled.load(Ordering::SeqCst) + } + + pub fn cancel(&self) { + self.is_canceled.store(true, Ordering::SeqCst); + } + + pub fn reset(&self) { + self.is_canceled.store(false, Ordering::SeqCst); + } +} + +impl Clone for ResetCancelToken { + fn clone(&self) -> Self { + Self { + is_canceled: self.is_canceled.clone(), + } + } +} \ No newline at end of file diff --git a/server/src/config.rs b/server/src/config.rs new file mode 100644 index 0000000..68ae54b --- /dev/null +++ b/server/src/config.rs @@ -0,0 +1,28 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::PathBuf; +use clap::Parser; + +const LOCAL_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); + +#[derive(Debug, Parser)] +#[command(author, version, about)] +pub struct Configuration { + #[arg(short, long, default_value_t = LOCAL_IPV4, help = "IP address for the server to listen on")] + address: IpAddr, + #[arg(short, long, default_value = "5432", help = "Port for the server to listen on")] + port: u16, + #[arg(short, long, help = "Path to the data file")] + file: PathBuf, +} + +impl Configuration { + #[inline] + pub fn get_socket_address(&self) -> SocketAddr { + SocketAddr::new(self.address, self.port) + } + + #[inline] + pub fn get_file_path(&self) -> &PathBuf { + &self.file + } +} \ No newline at end of file diff --git a/server/src/main.rs b/server/src/main.rs index add7b2b..2793e26 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,78 +1,139 @@ -use minisql::interpreter::State; -use parser::{parse_and_validate, Error}; +use std::collections::HashMap; +use std::sync::Arc; + +use clap::Parser; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{Mutex, RwLock}; + +use minisql::interpreter::{Response, State}; +use parser::parse_and_validate; +use proto::handshake::errors::ServerHandshakeError; +use proto::handshake::request::HandshakeRequest; use proto::handshake::response::HandshakeResponse; use proto::handshake::server::do_server_handshake; -use proto::message::backend::{ - BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, - ReadyForQueryData, RowDescriptionData, -}; use proto::message::frontend::FrontendMessage; -use proto::reader::oneway::OneWayProtoReader; +use proto::reader::frontend::FrontendProtoReader; use proto::reader::protoreader::ProtoReader; use proto::writer::backend::BackendProtoWriter; use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; -use tokio::io::{BufReader, BufWriter}; -use tokio::net::{TcpListener, TcpStream}; + +use crate::cancellation::ResetCancelToken; +use crate::config::Configuration; +use crate::proto_wrapper::{CompleteStatus, ServerProto}; + +mod config; +mod proto_wrapper; +mod cancellation; + +type TokenStore = Arc>>; +type SharedDbState = Arc>; #[tokio::main] async fn main() -> anyhow::Result<()> { - let addr = "0.0.0.0:5432"; + let config = Configuration::parse(); + + let state = Arc::new(RwLock::new(State::new())); + let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new())); + + let addr = config.get_socket_address(); let listener = TcpListener::bind(&addr).await?; println!("Server started at {addr}"); loop { + let state = state.clone(); + let tokens = tokens.clone(); + let (socket, _) = listener.accept().await?; println!("New client connected: {}", socket.peer_addr()?); tokio::spawn(async move { - let reason = handle_stream(socket).await; + let reason = handle_stream(socket, state, tokens).await; println!("Client disconnected: {reason:?}"); }); } } -async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { +async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore) -> anyhow::Result<()> { let (reader, writer) = stream.split(); let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut reader = ProtoReader::new(BufReader::new(reader), 1024); - let response = HandshakeResponse::new("minisql", 123, 123); + // Create a token with random PID and key + let (pid, key, token) = create_token(&tokens).await?; - let request = do_server_handshake(&mut writer, &mut reader, response).await?; + // Handle handshake + let response = HandshakeResponse::new("minisql", pid, key); + let request = do_server_handshake(&mut writer, &mut reader, response).await; - println!("Handshake complete:\n{request:?}"); - let mut state = State::new(); + let result = match request { + Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token).await, + Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await, + Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)), + }; + + // Release cancellation token + let mut tokens = tokens.lock().await; + tokens.remove(&(pid, key)); + + result +} + +fn random_pid_key() -> (i32, i32) { + let pid = rand::random::(); + let key = rand::random::(); + (pid, key) +} + +async fn create_token(tokens: &TokenStore) -> anyhow::Result<(i32, i32, ResetCancelToken)> { + let token = ResetCancelToken::new(); + let mut tokens = tokens.lock().await; + loop { + let pid_key = random_pid_key(); + if !tokens.contains_key(&pid_key) { + tokens.insert(pid_key, token.clone()); + + let (pid, key) = pid_key; + return Ok((pid, key, token)); + } + } +} + +async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow::Result<()> { + println!("Cancel request, PID: {}, Key: {}", pid, key); + + let tokens = tokens.lock().await; + let token = tokens.get(&(pid, key)); + match token { + Some(t) => t.cancel(), + None => return Err(anyhow::anyhow!("Invalid PID and Key cancel combination")), + } + + Ok(()) +} + +async fn handle_connection(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken) -> anyhow::Result<()> +where + R: FrontendProtoReader + Send, + W: BackendProtoWriter + ProtoFlush + Send, +{ + println!("Client connected: {:?}", request); loop { - println!("Waiting for next message"); let next: FrontendMessage = reader.read_proto().await?; match next { FrontendMessage::Terminate => { - println!("Received Terminate"); break; } FrontendMessage::Query(data) => { - println!("Received Query: {:?}", data); - let db_schema = state.db_schema(); - match parse_and_validate(data.query.as_str().to_string(), &db_schema) { - Ok(operation) => { - match state.interpret(operation) { - Ok(_) => { - send_query_response(&mut writer).await?; - } - Err(err) => { - send_error_response(&mut writer, &format!("error interpreting: {:?}", err)).await?; - } - } - }, - Err(Error::ParsingError(err)) => { - send_error_response(&mut writer, &format!("parsing error: {:?}", err)).await?; + let result = handle_query(writer, &state, data.query.into(), &token).await; + match result { + Ok(_) => {} + Err(e) => { + writer.write_error_message(&e.to_string()).await? } - Err(Error::ValidationError(v)) => { - send_error_response(&mut writer, &format!("validation error: {:?}", v)).await?; - } - }; - send_ready_for_query(&mut writer).await?; + } + writer.write_ready_for_query().await?; } } writer.flush().await?; @@ -81,117 +142,47 @@ async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { Ok(()) } -async fn send_error_response( - writer: &mut impl BackendProtoWriter, - error_message: &str, -) -> anyhow::Result<()> { - writer - .write_proto( - ErrorResponseData { - code: b'M', - message: error_message.to_string().into(), +async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken) -> anyhow::Result<()> +where + W: BackendProtoWriter + ProtoFlush + Send, +{ + let operation = { + let state = state.read().await; + let db_schema = state.db_schema(); + parse_and_validate(query, &db_schema)? + }; + + let mut state = state.write().await; + let response = state.interpret(operation)?; + + match response { + Response::Deleted(i) => writer.write_command_complete(CompleteStatus::Delete(i)).await?, + Response::Inserted => writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?, + Response::Selected(schema, mut rows) => { + match rows.next() { + Some(row) => { + writer.write_table_header(&schema, &row).await?; + writer.write_table_row(&row).await?; + + let mut sent_rows = 1; + for row in rows { + sent_rows += 1; + writer.write_table_row(&row).await?; + if token.is_canceled() { + token.reset(); + break; + } + } + + writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; + } + _ => { + writer.write_command_complete(CompleteStatus::Select(0)).await?; + } } - .into(), - ) - .await?; - - Ok(()) -} - -async fn send_ready_for_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { - writer - .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) - .await?; - - Ok(()) -} - -async fn send_empty_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { - writer - .write_proto(BackendMessage::EmptyQueryResponse) - .await?; - - Ok(()) -} - -async fn send_row_description(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { - let columns = vec![ - ColumnDescription { - name: "id".to_string().into(), - table_oid: 123, - column_index: 1, - type_oid: 23, - type_size: 4, - type_modifier: -1, - format_code: 0, - }, - ColumnDescription { - name: "argument".to_string().into(), - table_oid: 123, - column_index: 2, - type_oid: 23, - type_size: 4, - type_modifier: -1, - format_code: 0, - }, - ColumnDescription { - name: "description".to_string().into(), - table_oid: 123, - column_index: 3, - type_oid: 1043, - type_size: 32, - type_modifier: -1, - format_code: 0, - }, - ]; - - writer - .write_proto( - RowDescriptionData { - columns: columns.into(), - } - .into(), - ) - .await?; - - Ok(()) -} - -async fn send_query_response(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { - send_row_description(writer).await?; - - write_row(writer, b"0", b"1337", b"auto").await?; - write_row(writer, b"1", b"69", b"bus").await?; - write_row(writer, b"2", b"420", b"kolo").await?; - - writer - .write_proto( - CommandCompleteData { - tag: "SELECT 3".to_string().into(), - } - .into(), - ) - .await?; - - Ok(()) -} - -async fn write_row( - writer: &mut impl BackendProtoWriter, - first: &[u8], - second: &[u8], - third: &[u8], -) -> anyhow::Result<()> { - let row_data = vec![ - first.to_vec().into(), - second.to_vec().into(), - third.to_vec().into(), - ] - .into(); - - writer - .write_proto(DataRowData { columns: row_data }.into()) - .await?; + } + _ => {} + } Ok(()) } diff --git a/server/src/proto_wrapper.rs b/server/src/proto_wrapper.rs new file mode 100644 index 0000000..3415255 --- /dev/null +++ b/server/src/proto_wrapper.rs @@ -0,0 +1,104 @@ +use async_trait::async_trait; +use minisql::restricted_row::RestrictedRow; +use minisql::schema::TableSchema; +use minisql::type_system::{Value}; +use proto::message::backend::{BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, ReadyForQueryData, RowDescriptionData}; +use proto::message::primitive::pglist::PgList; +use proto::writer::backend::BackendProtoWriter; + +pub enum CompleteStatus { + Insert { + oid: i32, + rows: i32, + }, + Delete(usize), + Select(usize), +} + +impl CompleteStatus { + fn to_string(&self) -> String { + match self { + CompleteStatus::Insert { oid, rows } => format!("INSERT {} {}", oid, rows), + CompleteStatus::Delete(rows) => format!("DELETE {}", rows), + CompleteStatus::Select(rows) => format!("SELECT {}", rows), + } + } +} + +#[async_trait] +pub trait ServerProto { + async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()>; + async fn write_ready_for_query(&mut self) -> anyhow::Result<()>; + async fn write_empty_query(&mut self) -> anyhow::Result<()>; + async fn write_table_header(&mut self, table_schema: &TableSchema, row: &RestrictedRow) -> anyhow::Result<()>; + async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()>; + async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()>; +} + +#[async_trait] +impl ServerProto for W where W: BackendProtoWriter + Send { + async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()> { + self.write_proto(ErrorResponseData { + code: b'M', + message: format!("{error_message}\0").into(), + }.into()).await?; + + Ok(()) + } + + async fn write_ready_for_query(&mut self) -> anyhow::Result<()> { + self.write_proto(ReadyForQueryData { status: b'I' }.into()).await?; + Ok(()) + } + + async fn write_empty_query(&mut self) -> anyhow::Result<()> { + self.write_proto(BackendMessage::EmptyQueryResponse).await?; + Ok(()) + } + + async fn write_table_header(&mut self, table_schema: &TableSchema, row: &RestrictedRow) -> anyhow::Result<()> { + let columns = row.iter() + .map(|(index, value)| value_to_column_description(table_schema, value, index)) + .collect::>>()?; + + self.write_proto(RowDescriptionData { columns: columns.into() }.into()).await?; + Ok(()) + } + + async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()> { + let values = row.iter() + .map(|(_, value)| value.as_text_bytes().into()) + .collect::>>(); + + self.write_proto(BackendMessage::DataRow(DataRowData { + columns: values.into(), + })).await?; + Ok(()) + } + + async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()> { + self.write_proto(BackendMessage::CommandComplete(CommandCompleteData { + tag: status.to_string().into(), + })).await?; + Ok(()) + } +} + +fn value_to_column_description(schema: &TableSchema, value: &Value, index: &usize) -> anyhow::Result { + let name = schema.column_name_from_column_position(*index)?; + + let table_oid = schema.table_name().as_bytes().as_ptr() as i32; + let column_index = (*index).try_into()?; + let type_oid = value.type_oid(); + let type_size = value.type_size(); + + Ok(ColumnDescription { + name: name.to_string().into(), + table_oid, + column_index, + type_oid, + type_size, + type_modifier: -1, + format_code: 0, // text format + }) +}