diff --git a/client/src/main.rs b/client/src/main.rs index e0fd46b..a481d96 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -1,24 +1,35 @@ -use std::io::Write; use clap::Parser; use proto::handshake::client::do_client_handshake; use proto::handshake::request::HandshakeRequest; -use proto::reader::protoreader::ProtoReader; -use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; -use tokio::io::{BufReader, BufWriter}; -use tokio::net::TcpStream; -use proto::message::backend::{BackendMessage, CommandCompleteData, DataRowData, ErrorResponseData, RowDescriptionData}; +use proto::message::backend::{ + BackendMessage, CommandCompleteData, DataRowData, ErrorResponseData, RowDescriptionData, +}; use proto::message::frontend::{FrontendMessage, QueryData}; use proto::reader::oneway::OneWayProtoReader; +use proto::reader::protoreader::ProtoReader; use proto::writer::oneway::OneWayProtoWriter; +use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; +use std::io::Write; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::TcpStream; #[derive(Parser)] struct Cli { /// Port number of the server. - #[arg(short, long, default_value_t = 5432, help = "Port number of the server")] + #[arg( + short, + long, + default_value_t = 5432, + help = "Port number of the server" + )] port: u16, /// Host name or IP address of the server. - #[arg(long, default_value = "127.0.0.1", help = "Host name or IP address of the server")] + #[arg( + long, + default_value = "127.0.0.1", + help = "Host name or IP address of the server" + )] host: String, /// User name sent to the server. @@ -48,9 +59,9 @@ async fn main() -> anyhow::Result<()> { let mut exit = false; let command = prompt()?; if let Some(cmd) = command { - writer.write_proto(FrontendMessage::Query(QueryData { - query: cmd.into(), - })).await?; + writer + .write_proto(FrontendMessage::Query(QueryData { query: cmd.into() })) + .await?; writer.flush().await?; } else { exit = true; @@ -61,35 +72,35 @@ async fn main() -> anyhow::Result<()> { match msg { BackendMessage::RowDescription(data) => { print_row_description(data); - }, + } BackendMessage::DataRow(data) => { print_row_data(data); - }, + } BackendMessage::CommandComplete(data) => { print_command_complete(data); - }, + } BackendMessage::ErrorResponse(data) => { print_error_response(data); - }, + } BackendMessage::EmptyQueryResponse => { println!("Empty query response"); - }, + } BackendMessage::NoData => { println!("No data"); - }, + } BackendMessage::ReadyForQuery(data) => { println!("Ready for next query ({})", data.status); let command = prompt()?; if let Some(cmd) = command { - writer.write_proto(FrontendMessage::Query(QueryData { - query: cmd.into(), - })).await?; + writer + .write_proto(FrontendMessage::Query(QueryData { query: cmd.into() })) + .await?; writer.flush().await?; } else { exit = true; } - }, + } m => { println!("Unexpected message: {:?}", m); } diff --git a/minisql/src/error.rs b/minisql/src/error.rs index 71e832b..adc2933 100644 --- a/minisql/src/error.rs +++ b/minisql/src/error.rs @@ -1,8 +1,8 @@ +use crate::schema::{ColumnName, TableName}; +use crate::type_system::Uuid; use std::num::{ParseFloatError, ParseIntError}; use std::str::Utf8Error; use thiserror::Error; -use crate::schema::{ColumnName, TableName}; -use crate::type_system::Uuid; #[derive(Debug, Error)] pub enum RuntimeError { @@ -23,8 +23,5 @@ pub enum TypeConversionError { #[error("failed to parse int from text")] IntDecodeFailed(#[from] ParseIntError), #[error("unknown type with oid {oid} and size {size}")] - UnknownType { - oid: i32, - size: i16 - } + UnknownType { oid: i32, size: i16 }, } diff --git a/minisql/src/internals/column_index.rs b/minisql/src/internals/column_index.rs index 6d7e2e6..6f0619c 100644 --- a/minisql/src/internals/column_index.rs +++ b/minisql/src/internals/column_index.rs @@ -1,6 +1,6 @@ use crate::type_system::{IndexableValue, Uuid}; -use std::collections::{BTreeMap, HashSet}; use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, HashSet}; #[derive(Debug, Serialize, Deserialize)] pub struct ColumnIndex { diff --git a/minisql/src/internals/row.rs b/minisql/src/internals/row.rs index 201ace6..248f2f7 100644 --- a/minisql/src/internals/row.rs +++ b/minisql/src/internals/row.rs @@ -1,10 +1,10 @@ -use crate::type_system::Value; use crate::operation::InsertionValues; -use std::ops::{Index, IndexMut}; -use std::slice::SliceIndex; -use serde::{Deserialize, Serialize}; use crate::restricted_row::RestrictedRow; use crate::schema::Column; +use crate::type_system::Value; +use serde::{Deserialize, Serialize}; +use std::ops::{Index, IndexMut}; +use std::slice::SliceIndex; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Row(Vec); diff --git a/minisql/src/internals/table.rs b/minisql/src/internals/table.rs index aff4b37..0f9bb2c 100644 --- a/minisql/src/internals/table.rs +++ b/minisql/src/internals/table.rs @@ -1,12 +1,12 @@ -use std::collections::{BTreeMap, HashMap, HashSet}; use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, HashMap, HashSet}; use crate::error::RuntimeError; use crate::internals::column_index::ColumnIndex; use crate::internals::row::Row; use crate::restricted_row::RestrictedRow; -use crate::schema::{Column, ColumnName, TableSchema, TableName}; use crate::result::DbResult; +use crate::schema::{Column, ColumnName, TableName, TableSchema}; use crate::type_system::{IndexableValue, Uuid, Value}; #[derive(Debug, Serialize, Deserialize)] @@ -69,7 +69,10 @@ impl Table { .collect() } - pub fn select_all_rows(&self, selected_columns: Vec) -> impl Iterator + '_ { + pub fn select_all_rows( + &self, + selected_columns: Vec, + ) -> impl Iterator + '_ { self.rows .values() .map(move |row| row.restrict_columns(&selected_columns)) @@ -80,29 +83,23 @@ impl Table { selected_columns: Vec, column: Column, value: Value, - ) -> DbResult + '_> { + ) -> DbResult + '_> { let restrict_columns_of_row = move |row: Row| row.restrict_columns(&selected_columns); match value { Value::Indexable(value) => match self.fetch_ids_from_index(column, &value)? { - Some(ids) => - Ok(self - .get_rows_by_ids(ids) - .into_iter() - .map(restrict_columns_of_row) - ), - None => - Ok(self - .get_rows_by_value(column, &Value::Indexable(value)) - .into_iter() - .map(restrict_columns_of_row) - ), - }, - _ => - Ok(self - .get_rows_by_value(column, &value) + Some(ids) => Ok(self + .get_rows_by_ids(ids) .into_iter() - .map(restrict_columns_of_row) - ), + .map(restrict_columns_of_row)), + None => Ok(self + .get_rows_by_value(column, &Value::Indexable(value)) + .into_iter() + .map(restrict_columns_of_row)), + }, + _ => Ok(self + .get_rows_by_value(column, &value) + .into_iter() + .map(restrict_columns_of_row)), } } @@ -116,7 +113,9 @@ impl Table { } for (column, column_index) in &mut self.indexes { - if let Value::Indexable(val) = &row[*column] { column_index.add(val.clone(), id) } + if let Value::Indexable(val) = &row[*column] { + column_index.add(val.clone(), id) + } } let _ = self.rows.insert(id, row); @@ -168,11 +167,7 @@ impl Table { number_of_rows } - pub fn delete_rows_where_eq( - &mut self, - column: Column, - value: Value, - ) -> DbResult { + pub fn delete_rows_where_eq(&mut self, column: Column, value: Value) -> DbResult { match value { Value::Indexable(value) => match self.fetch_ids_from_index(column, &value)? { Some(ids) => Ok(self.delete_rows_by_ids(ids)), @@ -187,7 +182,10 @@ impl Table { if self.indexes.get(&column).is_some() { let column_name = self.schema.column_name_from_column(column).clone(); let table_name = self.schema.table_name().clone(); - return Err(RuntimeError::AttemptToIndexAlreadyIndexedColumn(table_name, column_name)) + return Err(RuntimeError::AttemptToIndexAlreadyIndexedColumn( + table_name, + column_name, + )); } let mut column_index: ColumnIndex = ColumnIndex::new(); update_index_from_table(&mut column_index, self, column)?; @@ -203,7 +201,7 @@ impl Table { if self.schema.is_primary(column) { match value { IndexableValue::Uuid(id) => Ok(Some(HashSet::from([*id]))), - _ => unreachable!() // SAFETY: Validation guarantees primary column has correct Uuid type. + _ => unreachable!(), // SAFETY: Validation guarantees primary column has correct Uuid type. } } else { match self.indexes.get(&column) { @@ -231,9 +229,7 @@ fn update_index_from_table( let value = match &row[column] { Value::Indexable(value) => value.clone(), _ => { - let column_name: ColumnName = table - .schema - .column_name_from_column(column); + let column_name: ColumnName = table.schema.column_name_from_column(column); return Err(RuntimeError::AttemptToIndexNonIndexableColumn( table.table_name().to_string(), column_name, diff --git a/minisql/src/interpreter.rs b/minisql/src/interpreter.rs index 93045fd..773eac1 100644 --- a/minisql/src/interpreter.rs +++ b/minisql/src/interpreter.rs @@ -1,10 +1,10 @@ -use crate::schema::{Column, TableName, TablePosition, TableSchema}; use crate::internals::table::Table; -use crate::operation::{Operation, Condition, ColumnSelection}; +use crate::operation::{ColumnSelection, Condition, Operation}; +use crate::restricted_row::RestrictedRow; use crate::result::DbResult; +use crate::schema::{Column, TableName, TablePosition, TableSchema}; use bimap::BiMap; use serde::{Deserialize, Serialize}; -use crate::restricted_row::RestrictedRow; // Use `TablePosition` as index pub type Tables = Vec; @@ -18,7 +18,11 @@ pub struct State { // #[derive(Debug)] pub enum Response<'a> { - Selected(&'a TableSchema, ColumnSelection, Box + 'a + Send>), + Selected( + &'a TableSchema, + ColumnSelection, + Box + 'a + Send>, + ), Inserted, Deleted(usize), // how many were deleted TableCreated, @@ -32,17 +36,15 @@ impl std::fmt::Debug for Response<'_> { use Response::*; match self { Selected(_schema, _columns, _rows) => - // TODO: How can we iterate through the rows without having to take ownership of - // them? - f.write_str("Some rows... trust me"), - Inserted => - f.write_str("Inserted"), - Deleted(usize) => - f.write_fmt(format_args!("Deleted({})", usize)), - TableCreated => - f.write_str("TableCreated"), - IndexCreated => - f.write_str("IndexCreated"), + // TODO: How can we iterate through the rows without having to take ownership of + // them? + { + f.write_str("Some rows... trust me") + } + Inserted => f.write_str("Inserted"), + Deleted(usize) => f.write_fmt(format_args!("Deleted({})", usize)), + TableCreated => f.write_str("TableCreated"), + IndexCreated => f.write_str("IndexCreated"), } } } @@ -97,22 +99,25 @@ impl State { let selected_rows = match maybe_condition { None => { let rows = table.select_all_rows(column_selection.clone()); - Box::new(rows) as Box + 'a + Send> - }, + Box::new(rows) as Box + 'a + Send> + } Some(Condition::Eq(eq_column, value)) => { - let x = - table.select_rows_where_eq( - column_selection.clone(), - eq_column, - value, - )?; - Box::new(x) as Box + 'a + Send> + let rows = table.select_rows_where_eq( + column_selection.clone(), + eq_column, + value, + )?; + Box::new(rows) as Box + 'a + Send> } }; - Ok(Response::Selected(table.schema(), column_selection, selected_rows)) - }, + Ok(Response::Selected( + table.schema(), + column_selection, + selected_rows, + )) + } Insert(table_position, values) => { let table: &mut Table = self.table_at_mut(table_position); @@ -150,20 +155,16 @@ impl State { #[cfg(test)] mod tests { use super::*; - use crate::schema::Column; - use std::collections::HashSet; - use crate::type_system::{DbType, IndexableValue, Value}; use crate::operation::Operation; + use crate::schema::Column; + use crate::type_system::{DbType, IndexableValue, Value}; + use std::collections::HashSet; fn users_schema() -> TableSchema { TableSchema::new( "users".to_string(), "id".to_string(), - vec!( - "id".to_string(), - "name".to_string(), - "age".to_string(), - ), + vec!["id".to_string(), "name".to_string(), "age".to_string()], vec![DbType::Uuid, DbType::String, DbType::Int], ) } @@ -195,7 +196,11 @@ mod tests { .interpret(Operation::CreateTable(users_schema.clone())) .unwrap(); let response: Response = state - .interpret(Operation::Select(users_position, users_schema.all_selection(), None)) + .interpret(Operation::Select( + users_position, + users_schema.all_selection(), + None, + )) .unwrap(); assert!(matches!(response, Response::Selected(_, _, _))); let Response::Selected(_, _, rows) = response else { @@ -214,7 +219,6 @@ mod tests { let users_schema = users_schema(); let users = 0; - state .interpret(Operation::CreateTable(users_schema.clone())) .unwrap(); @@ -227,11 +231,7 @@ mod tests { state .interpret(Operation::Insert( users, - vec![ - id.clone(), - name.clone(), - age.clone(), - ], + vec![id.clone(), name.clone(), age.clone()], )) .unwrap(); @@ -246,7 +246,7 @@ mod tests { let rows: Vec<_> = rows.collect(); assert!(rows.len() == 1); let row = &rows[0]; - + assert!(row.len() == 3); assert!(row[0].1 == id); assert!(row[1].1 == name); @@ -267,9 +267,7 @@ mod tests { let id_column: Column = 0; let name_column: Column = 1; - state - .interpret(CreateTable(users_schema.clone())) - .unwrap(); + state.interpret(CreateTable(users_schema.clone())).unwrap(); let (id0, name0, age0) = ( Indexable(Uuid(0)), @@ -279,11 +277,7 @@ mod tests { state .interpret(Insert( users_position, - vec![ - id0.clone(), - name0.clone(), - age0.clone(), - ], + vec![id0.clone(), name0.clone(), age0.clone()], )) .unwrap(); @@ -294,17 +288,15 @@ mod tests { ); state .interpret(Insert( - users_position, - vec![ - id1.clone(), - name1.clone(), - age1.clone(), - ], + users_position, + vec![id1.clone(), name1.clone(), age1.clone()], )) .unwrap(); { - let response: Response = state.interpret(Select(users_position, users_schema.all_selection(), None)).unwrap(); + let response: Response = state + .interpret(Select(users_position, users_schema.all_selection(), None)) + .unwrap(); assert!(matches!(response, Response::Selected(_, _, _))); let Response::Selected(_, _, rows) = response else { @@ -384,9 +376,7 @@ mod tests { let id_column: Column = 0; - state - .interpret(CreateTable(users_schema.clone())) - .unwrap(); + state.interpret(CreateTable(users_schema.clone())).unwrap(); let (id0, name0, age0) = ( Indexable(Uuid(0)), @@ -396,11 +386,7 @@ mod tests { state .interpret(Insert( users_position, - vec![ - id0.clone(), - name0.clone(), - age0.clone(), - ], + vec![id0.clone(), name0.clone(), age0.clone()], )) .unwrap(); @@ -412,25 +398,20 @@ mod tests { state .interpret(Insert( users_position, - vec![ - id1.clone(), - name1.clone(), - age1.clone(), - ], + vec![id1.clone(), name1.clone(), age1.clone()], )) .unwrap(); { let delete_response: Response = state - .interpret(Delete( - users_position, - Some(Eq(id_column, id0.clone())), - )) + .interpret(Delete(users_position, Some(Eq(id_column, id0.clone())))) .unwrap(); assert!(matches!(delete_response, Response::Deleted(1))); } - let response: Response = state.interpret(Select(users_position, users_schema.all_selection(), None)).unwrap(); + let response: Response = state + .interpret(Select(users_position, users_schema.all_selection(), None)) + .unwrap(); assert!(matches!(response, Response::Selected(_, _, _))); let Response::Selected(_, _, rows) = response else { @@ -458,9 +439,7 @@ mod tests { let name_column: Column = 1; - state - .interpret(CreateTable(users_schema.clone())) - .unwrap(); + state.interpret(CreateTable(users_schema.clone())).unwrap(); state .interpret(CreateIndex(users_position, name_column)) @@ -474,11 +453,7 @@ mod tests { state .interpret(Insert( users_position, - vec![ - id0.clone(), - name0.clone(), - age0.clone(), - ], + vec![id0.clone(), name0.clone(), age0.clone()], )) .unwrap(); @@ -490,11 +465,7 @@ mod tests { state .interpret(Insert( users_position, - vec![ - id1.clone(), - name1.clone(), - age1.clone(), - ], + vec![id1.clone(), name1.clone(), age1.clone()], )) .unwrap(); @@ -510,7 +481,10 @@ mod tests { let plato_id = 0; let aristotle_id = 1; - let plato_ids = index.get(&String("Plato".to_string())).cloned().unwrap_or(HashSet::new()); + let plato_ids = index + .get(&String("Plato".to_string())) + .cloned() + .unwrap_or(HashSet::new()); assert!(plato_ids.contains(&plato_id)); assert!(!plato_ids.contains(&aristotle_id)); assert!(plato_ids.len() == 1); @@ -518,7 +492,7 @@ mod tests { } pub fn example() { - use crate::type_system::{IndexableValue, Value, DbType}; + use crate::type_system::{DbType, IndexableValue, Value}; use Condition::*; use IndexableValue::*; use Operation::*; @@ -532,11 +506,11 @@ pub fn example() { TableSchema::new( "users".to_string(), "id".to_string(), - vec!( - "id".to_string(), // 0 + vec![ + "id".to_string(), // 0 "name".to_string(), // 1 - "age".to_string(), // 2 - ), + "age".to_string(), // 2 + ], vec![DbType::Uuid, DbType::String, DbType::Int], ) }; @@ -556,11 +530,7 @@ pub fn example() { state .interpret(Insert( users_position, - vec![ - id0.clone(), - name0.clone(), - age0.clone(), - ], + vec![id0.clone(), name0.clone(), age0.clone()], )) .unwrap(); @@ -573,18 +543,18 @@ pub fn example() { state .interpret(Insert( users_position, - vec![ - id1.clone(), - name1.clone(), - age1.clone(), - ], + vec![id1.clone(), name1.clone(), age1.clone()], )) .unwrap(); println!(); { let response: Response = state - .interpret(Operation::Select(users_position, users_schema.all_selection(), None)) + .interpret(Operation::Select( + users_position, + users_schema.all_selection(), + None, + )) .unwrap(); println!("==SELECT ALL=="); println!("{:?}", response); @@ -608,19 +578,12 @@ pub fn example() { // TODO: Why do I have to write these braces explicitely? Why doesn't Rust compiler // "infer" them? let _delete_response: Response = state - .interpret(Delete( - users_position, - Some(Eq(id_column, id0.clone())), - )) - .unwrap(); + .interpret(Delete(users_position, Some(Eq(id_column, id0.clone())))) + .unwrap(); println!("==DELETE Plato=="); } let response: Response = state - .interpret(Select( - users_position, - vec![name_column, id_column], - None, - )) + .interpret(Select(users_position, vec![name_column, id_column], None)) .unwrap(); println!("==SELECT All=="); println!("{:?}", response); diff --git a/minisql/src/lib.rs b/minisql/src/lib.rs index f9a0b09..baa4b4b 100644 --- a/minisql/src/lib.rs +++ b/minisql/src/lib.rs @@ -1,8 +1,8 @@ -pub mod schema; -pub mod interpreter; -pub mod operation; -pub mod type_system; mod error; mod internals; -mod result; +pub mod interpreter; +pub mod operation; pub mod restricted_row; +mod result; +pub mod schema; +pub mod type_system; diff --git a/minisql/src/restricted_row.rs b/minisql/src/restricted_row.rs index 77cafcb..15aebba 100644 --- a/minisql/src/restricted_row.rs +++ b/minisql/src/restricted_row.rs @@ -1,7 +1,7 @@ -use std::ops::Index; -use std::slice::SliceIndex; use crate::schema::Column; use crate::type_system::Value; +use std::ops::Index; +use std::slice::SliceIndex; #[derive(Debug, Clone)] pub struct RestrictedRow(Vec<(Column, Value)>); @@ -32,8 +32,7 @@ impl RestrictedRow { self.0.is_empty() } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.0.iter() } } - diff --git a/minisql/src/schema.rs b/minisql/src/schema.rs index 314eae9..ed5ddce 100644 --- a/minisql/src/schema.rs +++ b/minisql/src/schema.rs @@ -1,5 +1,5 @@ use crate::internals::row::Row; -use crate::operation::{InsertionValues, ColumnSelection}; +use crate::operation::{ColumnSelection, InsertionValues}; use crate::result::DbResult; use crate::type_system::{DbType, IndexableValue, Uuid, Value}; use bimap::BiMap; @@ -20,18 +20,30 @@ pub type TablePosition = usize; pub type ColumnName = String; pub type Column = usize; - impl TableSchema { - pub fn new(table_name: TableName, primary_column_name: ColumnName, columns: Vec, types: Vec) -> Self { + pub fn new( + table_name: TableName, + primary_column_name: ColumnName, + columns: Vec, + types: Vec, + ) -> Self { let mut column_name_position_mapping: BiMap = BiMap::new(); for (column, column_name) in columns.into_iter().enumerate() { column_name_position_mapping.insert(column_name, column); } - let primary_key: Column = match column_name_position_mapping.get_by_left(&primary_column_name).copied() { + let primary_key: Column = match column_name_position_mapping + .get_by_left(&primary_column_name) + .copied() + { Some(primary_key) => primary_key, - None => unreachable!() // SAFETY: Existence of unique primary key is ensured in validation. + None => unreachable!(), // SAFETY: Existence of unique primary key is ensured in validation. }; - Self { table_name, primary_key, column_name_position_mapping, types } + Self { + table_name, + primary_key, + column_name_position_mapping, + types, + } } pub fn table_name(&self) -> &TableName { @@ -43,7 +55,10 @@ impl TableSchema { } pub fn get_columns(&self) -> Vec<&ColumnName> { - self.column_name_position_mapping.iter().map(|(name, _)| name).collect() + self.column_name_position_mapping + .iter() + .map(|(name, _)| name) + .collect() } pub fn does_column_exist(&self, column_name: &ColumnName) -> bool { @@ -51,11 +66,17 @@ impl TableSchema { } pub fn get_column(&self, column_name: &ColumnName) -> Option { - self.column_name_position_mapping.get_by_left(column_name).copied() + self.column_name_position_mapping + .get_by_left(column_name) + .copied() } pub fn all_selection(&self) -> ColumnSelection { - let mut selection: ColumnSelection = self.column_name_position_mapping.iter().map(|(_, column)| *column).collect(); + let mut selection: ColumnSelection = self + .column_name_position_mapping + .iter() + .map(|(_, column)| *column) + .collect(); selection.sort(); selection } @@ -76,13 +97,10 @@ impl TableSchema { // Assumes `column` comes from a validated source. pub fn column_name_from_column(&self, column: Column) -> ColumnName { - match self - .column_name_position_mapping - .get_by_right(&column) - { + match self.column_name_position_mapping.get_by_right(&column) { Some(column_name) => column_name.clone(), - None => unreachable!() // SAFETY: The only way this function can get a column is from - // validation, which guarantees there is such a colun. + None => unreachable!(), // SAFETY: The only way this function can get a column is from + // validation, which guarantees there is such a colun. } } diff --git a/minisql/src/type_system.rs b/minisql/src/type_system.rs index d9248fc..abae701 100644 --- a/minisql/src/type_system.rs +++ b/minisql/src/type_system.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use crate::error::TypeConversionError; +use serde::{Deserialize, Serialize}; // ==============Types================ #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -82,7 +82,11 @@ impl Value { } } - pub fn from_text_bytes(bytes: &[u8], type_oid: i32, type_size: i16) -> Result { + pub fn from_text_bytes( + bytes: &[u8], + type_oid: i32, + type_size: i16, + ) -> Result { match (type_oid, type_size) { (701, 8) => { let s = std::str::from_utf8(bytes)?; @@ -91,7 +95,7 @@ impl Value { } (25, -2) => { let s = std::str::from_utf8(bytes)?; - let s = &s[..s.len() - 1]; // remove null terminator + let s = &s[..s.len() - 1]; // remove null terminator Ok(Value::Indexable(IndexableValue::String(s.to_string()))) } (23, 8) => { @@ -111,8 +115,8 @@ impl Value { #[cfg(test)] mod tests { + use super::{IndexableValue, Value}; use crate::error::TypeConversionError::UnknownType; - use super::{Value, IndexableValue}; #[test] fn test_encode_number() { @@ -204,6 +208,9 @@ mod tests { let bytes = value.as_text_bytes(); let from_bytes = Value::from_text_bytes(&bytes, oid, size); - assert!(matches!(from_bytes, Err(UnknownType { oid: 2950, size: 8 }))) + assert!(matches!( + from_bytes, + Err(UnknownType { oid: 2950, size: 8 }) + )) } } diff --git a/parser/src/core.rs b/parser/src/core.rs index d063a07..75d2256 100644 --- a/parser/src/core.rs +++ b/parser/src/core.rs @@ -1,16 +1,22 @@ -use minisql::{operation::Operation, interpreter::DbSchema}; use crate::syntax::RawQuerySyntax; +use minisql::{interpreter::DbSchema, operation::Operation}; use nom::{branch::alt, IResult}; use thiserror::Error; -use crate::{parsing::{create::parse_create, delete::parse_delete, index::parse_create_index, insert::parse_insert, select::parse_select}, validation::{validate_operation, ValidationError}}; +use crate::{ + parsing::{ + create::parse_create, delete::parse_delete, index::parse_create_index, + insert::parse_insert, select::parse_select, + }, + validation::{validate_operation, ValidationError}, +}; #[derive(Debug, Error)] pub enum Error { #[error("parsing error: {0}")] ParsingError(String), #[error("validation error: {0}")] - ValidationError(#[from] ValidationError) + ValidationError(#[from] ValidationError), } fn parse_statement(input: &str) -> IResult<&str, RawQuerySyntax> { @@ -21,15 +27,13 @@ fn parse_statement(input: &str) -> IResult<&str, RawQuerySyntax> { //parse_drop, parse_select, // parse_update, - parse_create_index + parse_create_index, ))(input) } pub fn parse_and_validate(str_query: String, db_schema: &DbSchema) -> Result { - let (_, op) = parse_statement(str_query.as_str()) - .map_err(|err| { - Error::ParsingError(err.to_string()) - })?; + let (_, op) = + parse_statement(str_query.as_str()).map_err(|err| Error::ParsingError(err.to_string()))?; Ok(validate_operation(op, db_schema)?) } diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 18eab0c..f043fb7 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -1,8 +1,7 @@ - -mod parsing; -mod validation; mod core; +mod parsing; mod syntax; +mod validation; pub use core::parse_and_validate; pub use core::Error; diff --git a/parser/src/parsing/common.rs b/parser/src/parsing/common.rs index 6787105..508289e 100644 --- a/parser/src/parsing/common.rs +++ b/parser/src/parsing/common.rs @@ -1,20 +1,21 @@ +use minisql::type_system::DbType; use nom::{ - character::complete::{alphanumeric1, char, multispace0, anychar, multispace1}, + branch::alt, + bytes::complete::tag, + character::complete::{alphanumeric1, anychar, char, multispace0, multispace1}, combinator::peek, error::make_error, sequence::{delimited, terminated}, - bytes::complete::tag, - IResult, branch::alt, + IResult, }; -use minisql::type_system::DbType; -use crate::syntax::Condition; use super::literal::parse_db_value; +use crate::syntax::Condition; pub fn parse_table_name(input: &str) -> IResult<&str, &str> { alt(( delimited(char('"'), alphanumeric1, char('"')), - parse_identifier + parse_identifier, ))(input) } @@ -24,7 +25,10 @@ pub fn parse_identifier(input: &str) -> IResult<&str, &str> { if first.is_alphabetic() { alphanumeric1(input) } else { - Err(nom::Err::Error(make_error(input, nom::error::ErrorKind::Alpha))) + Err(nom::Err::Error(make_error( + input, + nom::error::ErrorKind::Alpha, + ))) } } @@ -39,7 +43,12 @@ pub fn parse_db_type(input: &str) -> IResult<&str, DbType> { "INT" => DbType::Int, "UUID" => DbType::Uuid, "NUMBER" => DbType::Number, - _ => return Err(nom::Err::Failure(make_error(input, nom::error::ErrorKind::IsNot))) + _ => { + return Err(nom::Err::Failure(make_error( + input, + nom::error::ErrorKind::IsNot, + ))) + } }; Ok((input, db_type)) } @@ -51,9 +60,7 @@ pub fn parse_condition(input: &str) -> IResult<&str, Option> { let (input, condition) = parse_equality(input)?; Ok((input, Some(condition))) } - Err(_) => { - Ok((input, None)) - } + Err(_) => Ok((input, None)), } } @@ -70,9 +77,9 @@ fn parse_equality(input: &str) -> IResult<&str, Condition> { mod tests { use minisql::type_system::DbType; - use crate::syntax::Condition; use crate::parsing::common::{parse_db_type, parse_equality}; - + use crate::syntax::Condition; + #[test] fn test_parse_equality() { use minisql::type_system::{IndexableValue, Value}; @@ -89,10 +96,22 @@ mod tests { #[test] fn test_parse_db_type() { - assert!(matches!(parse_db_type("INT").expect("should parse").1, DbType::Int)); - assert!(matches!(parse_db_type("STRING").expect("should parse").1, DbType::String)); - assert!(matches!(parse_db_type("UUID").expect("should parse").1, DbType::Uuid)); - assert!(matches!(parse_db_type("NUMBER").expect("should parse").1, DbType::Number)); + assert!(matches!( + parse_db_type("INT").expect("should parse").1, + DbType::Int + )); + assert!(matches!( + parse_db_type("STRING").expect("should parse").1, + DbType::String + )); + assert!(matches!( + parse_db_type("UUID").expect("should parse").1, + DbType::Uuid + )); + assert!(matches!( + parse_db_type("NUMBER").expect("should parse").1, + DbType::Number + )); assert!(matches!(parse_db_type("Unknown"), Err(_))); } } diff --git a/parser/src/parsing/create.rs b/parser/src/parsing/create.rs index 14a2234..f13e2bf 100644 --- a/parser/src/parsing/create.rs +++ b/parser/src/parsing/create.rs @@ -1,13 +1,14 @@ use nom::{ bytes::complete::tag, character::complete::{char, multispace0, multispace1}, + combinator::opt, multi::separated_list0, sequence::terminated, - IResult, combinator::opt, + IResult, }; -use super::common::{parse_table_name, parse_identifier, parse_db_type}; -use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax}; +use super::common::{parse_db_type, parse_identifier, parse_table_name}; +use crate::syntax::{ColumnSchema, RawQuerySyntax, RawTableSchema}; pub fn parse_create(input: &str) -> IResult<&str, RawQuerySyntax> { let (input, _) = tag("CREATE")(input)?; @@ -27,10 +28,7 @@ pub fn parse_create(input: &str) -> IResult<&str, RawQuerySyntax> { table_name: table_name.to_string(), columns: column_definitions, }; - Ok(( - input, - RawQuerySyntax::CreateTable(schema), - )) + Ok((input, RawQuerySyntax::CreateTable(schema))) } fn parse_column_definitions(input: &str) -> IResult<&str, Vec> { @@ -51,7 +49,14 @@ fn parse_column_definition(input: &str) -> IResult<&str, ColumnSchema> { let (input, db_type) = parse_db_type(input)?; let (input, pk) = opt(parse_primary_key)(input).map(|(input, pk)| (input, pk.is_some()))?; let (input, _) = multispace0(input)?; - Ok((input, ColumnSchema { column_name: identifier.to_string(), type_: db_type, is_primary: pk })) + Ok(( + input, + ColumnSchema { + column_name: identifier.to_string(), + type_: db_type, + is_primary: pk, + }, + )) } #[cfg(test)] @@ -66,22 +71,28 @@ mod tests { #[test] fn test_parse_create_primary_key() { - parse_create("CREATE TABLE \"Table1\"(id UUID PRIMARY KEY,column1 INT);").expect("should parse"); + parse_create("CREATE TABLE \"Table1\"(id UUID PRIMARY KEY,column1 INT);") + .expect("should parse"); } #[test] fn test_parse_create_no_quotes_table_name() { - parse_create("CREATE TABLE Table1(id UUID PRIMARY KEY,column1 INT);").expect("should parse"); + parse_create("CREATE TABLE Table1(id UUID PRIMARY KEY,column1 INT);") + .expect("should parse"); } #[test] fn test_parse_create_primary_key_with_spaces() { - parse_create("CREATE TABLE \"Table1\" ( id UUID PRIMARY KEY , column1 INT ) ;").expect("should parse"); + parse_create( + "CREATE TABLE \"Table1\" ( id UUID PRIMARY KEY , column1 INT ) ;", + ) + .expect("should parse"); } #[test] fn test_parse_create() { - let (_, create) = parse_create("CREATE TABLE \"Table1\"( id UUID , column1 INT );").expect("should parse"); + let (_, create) = parse_create("CREATE TABLE \"Table1\"( id UUID , column1 INT );") + .expect("should parse"); assert!(matches!(create, RawQuerySyntax::CreateTable(_))); match create { RawQuerySyntax::CreateTable(schema) => { @@ -95,7 +106,9 @@ mod tests { let result_column1 = schema.get_column(&"column1".to_string()); assert!(matches!(result_column1, Some(_))); - let Some(column1_column) = result_column1 else { panic!() }; + let Some(column1_column) = result_column1 else { + panic!() + }; assert_eq!(column1_column.column_name, "column1".to_string()); } _ => {} diff --git a/parser/src/parsing/delete.rs b/parser/src/parsing/delete.rs index 4dc07d2..2cbe88f 100644 --- a/parser/src/parsing/delete.rs +++ b/parser/src/parsing/delete.rs @@ -4,8 +4,8 @@ use nom::{ IResult, }; +use super::common::{parse_condition, parse_table_name}; use crate::syntax::RawQuerySyntax; -use super::common::{parse_table_name, parse_condition}; pub fn parse_delete(input: &str) -> IResult<&str, RawQuerySyntax> { let (input, _) = tag("DELETE")(input)?; @@ -25,14 +25,15 @@ pub fn parse_delete(input: &str) -> IResult<&str, RawQuerySyntax> { #[cfg(test)] mod tests { - use crate::syntax::RawQuerySyntax; use crate::parsing::delete::parse_delete; + use crate::syntax::RawQuerySyntax; #[test] fn test_parse_delete() { - let (_, operation) = parse_delete("DELETE FROM \"T1\" WHERE id = 1 ;").expect("should parse"); + let (_, operation) = + parse_delete("DELETE FROM \"T1\" WHERE id = 1 ;").expect("should parse"); assert!(matches!(operation, RawQuerySyntax::Delete(_, _))) } -// TODO: add test with condition + // TODO: add test with condition } diff --git a/parser/src/parsing/index.rs b/parser/src/parsing/index.rs index 68233ba..4a38b32 100644 --- a/parser/src/parsing/index.rs +++ b/parser/src/parsing/index.rs @@ -2,7 +2,8 @@ use crate::syntax::RawQuerySyntax; use nom::{ bytes::complete::tag, character::complete::{char, multispace0, multispace1}, - IResult, combinator::opt, + combinator::opt, + IResult, }; use super::common::{parse_identifier, parse_table_name}; @@ -35,16 +36,16 @@ pub fn parse_create_index(input: &str) -> IResult<&str, RawQuerySyntax> { Ok((input, operation)) } - #[cfg(test)] mod tests { - use crate::syntax::RawQuerySyntax; use crate::parsing::index::parse_create_index; + use crate::syntax::RawQuerySyntax; - #[test] fn test_create_index() { - let (_, syntax) = parse_create_index("CREATE UNIQUE INDEX idxcontactsemail ON \"contacts\" (email);").expect("should parse"); + let (_, syntax) = + parse_create_index("CREATE UNIQUE INDEX idxcontactsemail ON \"contacts\" (email);") + .expect("should parse"); assert!(matches!(syntax, RawQuerySyntax::CreateIndex(_, _))); match syntax { RawQuerySyntax::CreateIndex(table_name, column_name) => { @@ -57,7 +58,10 @@ mod tests { #[test] fn test_create_index_with_spaces() { - let (_, syntax) = parse_create_index("CREATE UNIQUE INDEX idxcontactsemail ON \"contacts\" ( email ) ;").expect("should parse"); + let (_, syntax) = parse_create_index( + "CREATE UNIQUE INDEX idxcontactsemail ON \"contacts\" ( email ) ;", + ) + .expect("should parse"); assert!(matches!(syntax, RawQuerySyntax::CreateIndex(_, _))); match syntax { RawQuerySyntax::CreateIndex(table_name, column_name) => { diff --git a/parser/src/parsing/insert.rs b/parser/src/parsing/insert.rs index 5f1f64a..0b62574 100644 --- a/parser/src/parsing/insert.rs +++ b/parser/src/parsing/insert.rs @@ -1,9 +1,12 @@ -use super::{literal::parse_db_value, common::{parse_table_name, parse_identifier}}; +use super::{ + common::{parse_identifier, parse_table_name}, + literal::parse_db_value, +}; use crate::syntax::RawQuerySyntax; use minisql::type_system::Value; use nom::{ bytes::complete::tag, - character::complete::{multispace0, multispace1, char}, + character::complete::{char, multispace0, multispace1}, combinator::map, multi::separated_list0, sequence::terminated, @@ -14,7 +17,7 @@ pub fn parse_insert(input: &str) -> IResult<&str, RawQuerySyntax> { let (input, _) = tag("INSERT")(input)?; let (input, _) = multispace1(input)?; let (input, _) = tag("INTO")(input)?; - let (input, _) = multispace1(input)?; + let (input, _) = multispace1(input)?; let (input, table_name) = parse_table_name(input)?; let (input, _) = multispace1(input)?; let (input, _) = char('(')(input)?; @@ -34,27 +37,31 @@ pub fn parse_insert(input: &str) -> IResult<&str, RawQuerySyntax> { let (input, _) = char(';')(input)?; Ok(( input, - RawQuerySyntax::Insert(table_name.to_string(), column_names.into_iter().zip(values).collect()), + RawQuerySyntax::Insert( + table_name.to_string(), + column_names.into_iter().zip(values).collect(), + ), )) } pub fn parse_columns(input: &str) -> IResult<&str, Vec> { - separated_list0(terminated(char(','), multispace0), map(parse_identifier, |name|name.to_string()))(input) + separated_list0( + terminated(char(','), multispace0), + map(parse_identifier, |name| name.to_string()), + )(input) } pub fn parse_values(input: &str) -> IResult<&str, Vec> { separated_list0(terminated(char(','), multispace0), parse_db_value)(input) } - #[cfg(test)] mod tests { use minisql::type_system::{IndexableValue, Value}; - use crate::syntax::RawQuerySyntax; use super::parse_insert; + use crate::syntax::RawQuerySyntax; - #[test] fn test_parse_insert() { let sql = "INSERT INTO \"MyTable\" (id, data) VALUES(1, \"Text\");"; @@ -63,11 +70,15 @@ mod tests { ("", RawQuerySyntax::Insert(table_name, insertion_values)) => { assert_eq!(table_name, "MyTable"); assert_eq!( - insertion_values, + insertion_values, vec![ ("id".to_string(), Value::Indexable(IndexableValue::Int(1))), - ("data".to_string(), Value::Indexable(IndexableValue::String("Text".to_string()))) - ]); + ( + "data".to_string(), + Value::Indexable(IndexableValue::String("Text".to_string())) + ) + ] + ); } _ => { unreachable!() @@ -77,16 +88,22 @@ mod tests { #[test] fn test_parse_insert_with_spaces() { - let sql = "INSERT INTO \"MyTable\" ( id, data ) VALUES ( 1, \"Text\" ) ;"; + let sql = + "INSERT INTO \"MyTable\" ( id, data ) VALUES ( 1, \"Text\" ) ;"; let operation = parse_insert(sql).expect("should parse"); match operation { ("", RawQuerySyntax::Insert(table_name, insertion_values)) => { assert_eq!(table_name, "MyTable"); - assert_eq!(insertion_values, + assert_eq!( + insertion_values, vec![ ("id".to_string(), Value::Indexable(IndexableValue::Int(1))), - ("data".to_string(), Value::Indexable(IndexableValue::String("Text".to_string()))) - ]); + ( + "data".to_string(), + Value::Indexable(IndexableValue::String("Text".to_string())) + ) + ] + ); } _ => { unreachable!() diff --git a/parser/src/parsing/literal.rs b/parser/src/parsing/literal.rs index 921a8bf..c37f963 100644 --- a/parser/src/parsing/literal.rs +++ b/parser/src/parsing/literal.rs @@ -1,20 +1,16 @@ use minisql::type_system::{IndexableValue, Value}; use nom::{ branch::alt, - character::complete::{u64, char, digit1, none_of}, + character::complete::{char, digit1, none_of, u64}, combinator::opt, + error::make_error, multi::many0, sequence::{delimited, pair, preceded}, - IResult, error::make_error + IResult, }; pub fn parse_db_value(input: &str) -> IResult<&str, Value> { - alt(( - parse_string, - parse_number, - parse_int, - parse_uuid, - ))(input) + alt((parse_string, parse_number, parse_int, parse_uuid))(input) } pub fn parse_number(input: &str) -> IResult<&str, Value> { @@ -27,56 +23,47 @@ pub fn parse_number(input: &str) -> IResult<&str, Value> { match frac_part { Some((_fsign, fdigits)) => { // Combine integer and fractional parts - let combined_parts = format!( - "{}{}.{}", - sign.unwrap_or('+'), - digits, - fdigits - ); + let combined_parts = format!("{}{}.{}", sign.unwrap_or('+'), digits, fdigits); // Parse the combined parts as a floating-point number - let value = combined_parts.parse::() - .map_err(|_| { - nom::Err::Failure(make_error(input, nom::error::ErrorKind::Fail)) - })?; + let value = combined_parts + .parse::() + .map_err(|_| nom::Err::Failure(make_error(input, nom::error::ErrorKind::Fail)))?; Ok((input, Value::Number(value))) } None => { - let value = format!("{}{}", sign.unwrap_or('+'), digits).parse::() - .map_err(|_| { - nom::Err::Failure(make_error(input, nom::error::ErrorKind::Fail)) - })?; + let value = format!("{}{}", sign.unwrap_or('+'), digits) + .parse::() + .map_err(|_| nom::Err::Failure(make_error(input, nom::error::ErrorKind::Fail)))?; Ok((input, Value::Indexable(IndexableValue::Int(value)))) } } } pub fn parse_int(input: &str) -> IResult<&str, Value> { - u64(input).map(|(input, v)| { - (input, Value::Indexable(IndexableValue::Int(v))) - }) + u64(input).map(|(input, v)| (input, Value::Indexable(IndexableValue::Int(v)))) } -fn escape_tab(input:&str) -> IResult<&str, char> { +fn escape_tab(input: &str) -> IResult<&str, char> { let (input, _) = preceded(char('\\'), char('t'))(input)?; Ok((input, '\t')) } -fn escape_backslash(input:&str) -> IResult<&str, char> { +fn escape_backslash(input: &str) -> IResult<&str, char> { let (input, _) = preceded(char('\\'), char('\\'))(input)?; Ok((input, '\\')) } -fn escape_newline(input:&str) -> IResult<&str, char> { +fn escape_newline(input: &str) -> IResult<&str, char> { let (input, _) = preceded(char('\\'), char('n'))(input)?; Ok((input, '\n')) } -fn escape_carriegereturn(input:&str) -> IResult<&str, char> { +fn escape_carriegereturn(input: &str) -> IResult<&str, char> { let (input, _) = preceded(char('\\'), char('r'))(input)?; Ok((input, '\r')) } -fn escape_doublequote(input:&str) -> IResult<&str, char> { +fn escape_doublequote(input: &str) -> IResult<&str, char> { preceded(char('\\'), char('"'))(input) } @@ -90,7 +77,7 @@ pub fn parse_string(input: &str) -> IResult<&str, Value> { escape_newline, escape_doublequote, escape_tab, - none_of(r#"\""#) + none_of(r#"\""#), ))), char('"'), )(input)?; @@ -102,23 +89,39 @@ pub fn parse_string(input: &str) -> IResult<&str, Value> { } pub fn parse_uuid(input: &str) -> IResult<&str, Value> { - let (input, value) = pair(char('u'), u64)(input).map(|(input, (_, v))| { - (input, Value::Indexable(IndexableValue::Uuid(v))) - })?; + let (input, value) = pair(char('u'), u64)(input) + .map(|(input, (_, v))| (input, Value::Indexable(IndexableValue::Uuid(v))))?; Ok((input, value)) } #[cfg(test)] mod tests { - use minisql::type_system::{IndexableValue, Value}; use crate::parsing::literal::{parse_db_value, parse_string, parse_uuid}; - + use minisql::type_system::{IndexableValue, Value}; #[test] fn test_string_parser() { - assert_eq!(parse_string(r#""simple""#), Ok(("", Value::Indexable(IndexableValue::String(String::from("simple")))))); - assert_eq!(parse_string(r#""\"\t\r\n\\""#), Ok(("", Value::Indexable(IndexableValue::String(String::from("\"\t\r\n\\")))))); - assert_eq!(parse_string(r#""name is \"John\".""#), Ok(("", Value::Indexable(IndexableValue::String(String::from("name is \"John\".")))))); + assert_eq!( + parse_string(r#""simple""#), + Ok(( + "", + Value::Indexable(IndexableValue::String(String::from("simple"))) + )) + ); + assert_eq!( + parse_string(r#""\"\t\r\n\\""#), + Ok(( + "", + Value::Indexable(IndexableValue::String(String::from("\"\t\r\n\\"))) + )) + ); + assert_eq!( + parse_string(r#""name is \"John\".""#), + Ok(( + "", + Value::Indexable(IndexableValue::String(String::from("name is \"John\"."))) + )) + ); } #[test] @@ -132,39 +135,63 @@ mod tests { assert_eq!(value, Value::Number(5.5)); let (_, _) = parse_db_value("\"STRING\"").expect("should parse"); - let (input, value) = parse_db_value("\"abcdefghkjklmnopqrstuvwxyz!@#$%^&*()_+ \"").expect("should parse"); + let (input, value) = + parse_db_value("\"abcdefghkjklmnopqrstuvwxyz!@#$%^&*()_+ \"").expect("should parse"); assert_eq!(input, ""); - assert_eq!(value, Value::Indexable(IndexableValue::String("abcdefghkjklmnopqrstuvwxyz!@#$%^&*()_+ ".to_string()))); - + assert_eq!( + value, + Value::Indexable(IndexableValue::String( + "abcdefghkjklmnopqrstuvwxyz!@#$%^&*()_+ ".to_string() + )) + ); } - #[test] fn test_parse_positive_float() { - assert_eq!(parse_db_value("23.213313"), Ok(("", Value::Number(23.213313)))); - assert_eq!(parse_db_value("2241.9734"), Ok(("", Value::Number(2241.9734)))); + assert_eq!( + parse_db_value("23.213313"), + Ok(("", Value::Number(23.213313))) + ); + assert_eq!( + parse_db_value("2241.9734"), + Ok(("", Value::Number(2241.9734))) + ); } #[test] fn test_parse_negative_float() { - assert_eq!(parse_db_value("-9241.873654"), Ok(("", Value::Number(-9241.873654)))); - assert_eq!(parse_db_value("-62625.0"), Ok(("", Value::Number(-62625.0)))); + assert_eq!( + parse_db_value("-9241.873654"), + Ok(("", Value::Number(-9241.873654))) + ); + assert_eq!( + parse_db_value("-62625.0"), + Ok(("", Value::Number(-62625.0))) + ); } #[test] fn test_parse_float_between_0_and_1() { - assert_eq!(parse_db_value("0.873654"), Ok(("", Value::Number(0.873654)))); + assert_eq!( + parse_db_value("0.873654"), + Ok(("", Value::Number(0.873654))) + ); assert_eq!(parse_db_value("0.62625"), Ok(("", Value::Number(0.62625)))); } - #[test] fn test_parse_int() { - assert_eq!(parse_db_value("5134616"), Ok(("", Value::Indexable(IndexableValue::Int(5134616))))); + assert_eq!( + parse_db_value("5134616"), + Ok(("", Value::Indexable(IndexableValue::Int(5134616)))) + ); } #[test] fn test_parse_uuid() { - assert_eq!(parse_uuid("u131515"), Ok(("", Value::Indexable(IndexableValue::Uuid(131515))))) + assert_eq!( + parse_uuid("u131515"), + Ok(("", Value::Indexable(IndexableValue::Uuid(131515)))) + ) } } diff --git a/parser/src/parsing/mod.rs b/parser/src/parsing/mod.rs index 482deb4..9d7903e 100644 --- a/parser/src/parsing/mod.rs +++ b/parser/src/parsing/mod.rs @@ -1,7 +1,7 @@ -pub(crate) mod literal; -pub(crate) mod select; pub(crate) mod common; pub(crate) mod create; -pub(crate) mod insert; pub(crate) mod delete; pub(crate) mod index; +pub(crate) mod insert; +pub(crate) mod literal; +pub(crate) mod select; diff --git a/parser/src/parsing/select.rs b/parser/src/parsing/select.rs index e45aa3c..3d14bd6 100644 --- a/parser/src/parsing/select.rs +++ b/parser/src/parsing/select.rs @@ -1,9 +1,9 @@ -use super::common::{parse_table_name, parse_column_name, parse_condition}; +use super::common::{parse_column_name, parse_condition, parse_table_name}; use crate::syntax::{ColumnSelection, RawQuerySyntax}; use nom::{ branch::alt, bytes::complete::tag, - character::complete::{multispace0, multispace1, char}, + character::complete::{char, multispace0, multispace1}, combinator::map, error::Error, multi::separated_list0, @@ -44,10 +44,12 @@ pub fn try_parse_column_selection(input: &str) -> IResult<&str, ColumnSelection> #[cfg(test)] mod tests { + use crate::parsing::{ + common::{parse_column_name, parse_table_name}, + select::parse_select, + }; use crate::syntax::{ColumnSelection, RawQuerySyntax}; - use crate::parsing::{common::{parse_column_name, parse_table_name}, select::parse_select}; - #[test] fn test_parse_select_all() { let sql = "SELECT * FROM \"MyTable\";"; diff --git a/parser/src/syntax.rs b/parser/src/syntax.rs index e1258f0..27d3306 100644 --- a/parser/src/syntax.rs +++ b/parser/src/syntax.rs @@ -1,4 +1,7 @@ -use minisql::{type_system::{Value, DbType}, schema::{ColumnName, TableName}}; +use minisql::{ + schema::{ColumnName, TableName}, + type_system::{DbType, Value}, +}; // ===Table Schema=== #[derive(Debug, Clone, PartialEq)] @@ -53,10 +56,16 @@ impl RawTableSchema { } pub fn get_column(&self, column_name: &ColumnName) -> Option { - self.columns.iter().find(|column_schema| column_name == &column_schema.column_name).cloned() + self.columns + .iter() + .find(|column_schema| column_name == &column_schema.column_name) + .cloned() } pub fn get_columns(&self) -> Vec<&ColumnName> { - self.columns.iter().map(|ColumnSchema { column_name, .. }| column_name).collect() + self.columns + .iter() + .map(|ColumnSchema { column_name, .. }| column_name) + .collect() } } diff --git a/parser/src/validation.rs b/parser/src/validation.rs index 31cae4b..bd948a7 100644 --- a/parser/src/validation.rs +++ b/parser/src/validation.rs @@ -1,10 +1,16 @@ -use std::collections::{HashSet, BTreeMap}; +use std::collections::{BTreeMap, HashSet}; use thiserror::Error; use crate::syntax; -use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax}; +use crate::syntax::{ColumnSchema, RawQuerySyntax, RawTableSchema}; use minisql::operation; -use minisql::{operation::Operation, type_system::Value, schema::{TableSchema, ColumnName, Column, TableName, TablePosition}, type_system::DbType, interpreter::DbSchema}; +use minisql::{ + interpreter::DbSchema, + operation::Operation, + schema::{Column, ColumnName, TableName, TablePosition, TableSchema}, + type_system::DbType, + type_system::Value, +}; #[derive(Debug, Error)] pub enum ValidationError { @@ -29,37 +35,46 @@ pub enum ValidationError { expected_type: DbType, }, #[error("values for required columns {0:?} are missing")] - RequiredColumnsAreMissing(Vec) + RequiredColumnsAreMissing(Vec), } /// Validates and converts the raw syntax into a proper interpreter operation based on db schema. -pub fn validate_operation(syntax: RawQuerySyntax, db_schema: &DbSchema) -> Result { +pub fn validate_operation( + syntax: RawQuerySyntax, + db_schema: &DbSchema, +) -> Result { match syntax { RawQuerySyntax::Select(table_name, column_selection, condition) => { validate_select(table_name, column_selection, condition, db_schema) - }, + } RawQuerySyntax::Insert(table_name, insertion_values) => { validate_insert(table_name, insertion_values, db_schema) - }, + } RawQuerySyntax::Delete(table_name, condition) => { validate_delete(table_name, condition, db_schema) - }, - RawQuerySyntax::CreateTable(schema) => { - validate_create_table(schema, db_schema) - }, + } + RawQuerySyntax::CreateTable(schema) => validate_create_table(schema, db_schema), RawQuerySyntax::CreateIndex(table_name, column_name) => { validate_create_index(table_name, column_name, db_schema) - }, + } } } -fn validate_table_exists<'a>(db_schema: &DbSchema<'a>, table_name: &'a TableName) -> Result<(TablePosition, &'a TableSchema), ValidationError> { - db_schema.iter().find(|(tname, _, _)| table_name.eq(tname)) +fn validate_table_exists<'a>( + db_schema: &DbSchema<'a>, + table_name: &'a TableName, +) -> Result<(TablePosition, &'a TableSchema), ValidationError> { + db_schema + .iter() + .find(|(tname, _, _)| table_name.eq(tname)) .ok_or(ValidationError::TableDoesNotExist(table_name.to_string())) .map(|(_, table_position, table_schema)| (*table_position, *table_schema)) } -fn validate_create_table(raw_table_schema: RawTableSchema, db_schema: &DbSchema) -> Result { +fn validate_create_table( + raw_table_schema: RawTableSchema, + db_schema: &DbSchema, +) -> Result { let table_name: &TableName = &raw_table_schema.table_name; if get_table_schema(db_schema, table_name).is_some() { return Err(ValidationError::TableAlreadyExists(table_name.to_string())); @@ -71,16 +86,24 @@ fn validate_create_table(raw_table_schema: RawTableSchema, db_schema: &DbSchema) fn validate_table_schema(raw_table_schema: RawTableSchema) -> Result { // check for duplicate columns - find_first_duplicate(&raw_table_schema.get_columns()) - .map_or_else( - || Ok(()), - |duplicate_column| Err(ValidationError::DuplicateColumn(duplicate_column.to_string())) - )?; + find_first_duplicate(&raw_table_schema.get_columns()).map_or_else( + || Ok(()), + |duplicate_column| { + Err(ValidationError::DuplicateColumn( + duplicate_column.to_string(), + )) + }, + )?; let mut primary_keys: Vec<(ColumnName, DbType)> = vec![]; let mut columns: Vec = vec![]; let mut types: Vec = vec![]; - for ColumnSchema { column_name, type_, is_primary } in raw_table_schema.columns { + for ColumnSchema { + column_name, + type_, + is_primary, + } in raw_table_schema.columns + { if is_primary { primary_keys.push((column_name.clone(), type_)) } @@ -91,13 +114,22 @@ fn validate_table_schema(raw_table_schema: RawTableSchema) -> Result 1 { - Err(ValidationError::MultiplePrimaryKeysFound(raw_table_schema.table_name.clone())) + Err(ValidationError::MultiplePrimaryKeysFound( + raw_table_schema.table_name.clone(), + )) } else { let (primary_column_name, primary_key_type) = primary_keys[0].clone(); if primary_key_type == DbType::Uuid { - Ok(TableSchema::new(raw_table_schema.table_name, primary_column_name, columns, types)) + Ok(TableSchema::new( + raw_table_schema.table_name, + primary_column_name, + columns, + types, + )) } else { Err(ValidationError::TypeMismatch { column_name: raw_table_schema.table_name.clone(), @@ -108,121 +140,189 @@ fn validate_table_schema(raw_table_schema: RawTableSchema) -> Result, db_schema: &DbSchema) -> Result { +fn validate_select( + table_name: TableName, + column_selection: syntax::ColumnSelection, + condition: Option, + db_schema: &DbSchema, +) -> Result { let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; match column_selection { syntax::ColumnSelection::Columns(columns) => { - let non_existant_columns: Vec = - columns.iter().filter_map(|column| + let non_existant_columns: Vec = columns + .iter() + .filter_map(|column| { if schema.does_column_exist(column) { None } else { Some(column.clone()) - }).collect(); + } + }) + .collect(); if non_existant_columns.is_empty() { - let selection: operation::ColumnSelection = - columns.iter().filter_map(|column_name| schema.get_column(column_name)).collect(); + let selection: operation::ColumnSelection = columns + .iter() + .filter_map(|column_name| schema.get_column(column_name)) + .collect(); let validated_condition = validate_condition(condition, schema)?; - Ok(Operation::Select(table_position, selection, validated_condition)) + Ok(Operation::Select( + table_position, + selection, + validated_condition, + )) } else { Err(ValidationError::ColumnsDoNotExist(non_existant_columns)) } } syntax::ColumnSelection::All => { let validated_condition = validate_condition(condition, schema)?; - Ok(Operation::Select(table_position, schema.all_selection(), validated_condition)) + Ok(Operation::Select( + table_position, + schema.all_selection(), + validated_condition, + )) } } } -fn validate_insert(table_name: TableName, insertion_values: syntax::InsertionValues, db_schema: &DbSchema) -> Result { +fn validate_insert( + table_name: TableName, + insertion_values: syntax::InsertionValues, + db_schema: &DbSchema, +) -> Result { let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; // Check for duplicate columns in insertion_values. - let columns_in_query_vec: Vec<&ColumnName> = insertion_values.iter().map(|(column_name, _)| column_name).collect(); - find_first_duplicate(&columns_in_query_vec) - .map_or_else( - || Ok(()), - |duplicate_column| Err(ValidationError::DuplicateColumn(duplicate_column.to_string())) - )?; + let columns_in_query_vec: Vec<&ColumnName> = insertion_values + .iter() + .map(|(column_name, _)| column_name) + .collect(); + find_first_duplicate(&columns_in_query_vec).map_or_else( + || Ok(()), + |duplicate_column| { + Err(ValidationError::DuplicateColumn( + duplicate_column.to_string(), + )) + }, + )?; // Check that the set of columns in the insertion_values is the same as the set of required columns of the table. let columns_in_query: HashSet<&ColumnName> = HashSet::from_iter(columns_in_query_vec); let columns_in_schema: HashSet<&ColumnName> = HashSet::from_iter(schema.get_columns()); let non_existant_columns = Vec::from_iter(columns_in_query.difference(&columns_in_schema)); if !non_existant_columns.is_empty() { - return Err(ValidationError::ColumnsDoNotExist(non_existant_columns.iter().map(|column_name| column_name.to_string()).collect())); + return Err(ValidationError::ColumnsDoNotExist( + non_existant_columns + .iter() + .map(|column_name| column_name.to_string()) + .collect(), + )); } let missing_required_columns = Vec::from_iter(columns_in_schema.difference(&columns_in_query)); if !missing_required_columns.is_empty() { - return Err(ValidationError::RequiredColumnsAreMissing(missing_required_columns.iter().map(|str| str.to_string()).collect())); + return Err(ValidationError::RequiredColumnsAreMissing( + missing_required_columns + .iter() + .map(|str| str.to_string()) + .collect(), + )); } // Check types and prepare for creation of InsertionValues for the interpreter let mut values_map: BTreeMap = BTreeMap::new(); // The reason for using BTreeMap - // instead of HashMap is that we need - // to get the values in a vector - // sorted by the key. + // instead of HashMap is that we need + // to get the values in a vector + // sorted by the key. for (column_name, value) in insertion_values { - let (column, expected_type) = schema.get_typed_column(&column_name).ok_or(ValidationError::ColumnsDoNotExist(vec![column_name.to_string()]))?; // By the previous validation steps this is never gonna trigger an error. + let (column, expected_type) = + schema + .get_typed_column(&column_name) + .ok_or(ValidationError::ColumnsDoNotExist(vec![ + column_name.to_string() + ]))?; // By the previous validation steps this is never gonna trigger an error. let value_type = value.to_type(); if value_type != expected_type { - return Err(ValidationError::TypeMismatch { column_name: column_name.to_string(), received_type: value_type, expected_type }); + return Err(ValidationError::TypeMismatch { + column_name: column_name.to_string(), + received_type: value_type, + expected_type, + }); } values_map.insert(column, value); } // WARNING: If you use `values_map: HashMap<_,_>`, this is not gonna sort values by key. - let values: operation::InsertionValues = values_map.into_values().collect(); + let values: operation::InsertionValues = values_map.into_values().collect(); // Note that one of the values is id. Ok(Operation::Insert(table_position, values)) } -fn validate_delete(table_name: TableName, condition: Option, db_schema: &DbSchema) -> Result { +fn validate_delete( + table_name: TableName, + condition: Option, + db_schema: &DbSchema, +) -> Result { let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; let validated_condition = validate_condition(condition, schema)?; Ok(Operation::Delete(table_position, validated_condition)) } -fn validate_condition(condition: Option, schema: &TableSchema) -> Result, ValidationError> { +fn validate_condition( + condition: Option, + schema: &TableSchema, +) -> Result, ValidationError> { match condition { - Some(condition) => { - match condition { - syntax::Condition::Eq(column_name, value) => { - let (column, expected_type) = schema.get_typed_column(&column_name).ok_or(ValidationError::ColumnsDoNotExist(vec![column_name.to_string()]))?; - let value_type: DbType = value.to_type(); - if expected_type.eq(&value_type) { - Ok(Some(operation::Condition::Eq(column, value))) - } else { - Err(ValidationError::TypeMismatch { column_name: column_name.to_string(), received_type: value_type, expected_type }) - } + Some(condition) => match condition { + syntax::Condition::Eq(column_name, value) => { + let (column, expected_type) = schema.get_typed_column(&column_name).ok_or( + ValidationError::ColumnsDoNotExist(vec![column_name.to_string()]), + )?; + let value_type: DbType = value.to_type(); + if expected_type.eq(&value_type) { + Ok(Some(operation::Condition::Eq(column, value))) + } else { + Err(ValidationError::TypeMismatch { + column_name: column_name.to_string(), + received_type: value_type, + expected_type, + }) } } - } - None => Ok(None) + }, + None => Ok(None), } } -fn validate_create_index(table_name: TableName, column_name: ColumnName, db_schema: &DbSchema) -> Result { +fn validate_create_index( + table_name: TableName, + column_name: ColumnName, + db_schema: &DbSchema, +) -> Result { let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; - schema - .get_typed_column(&column_name) - .map_or_else( - || Err(ValidationError::ColumnsDoNotExist(vec![column_name.to_string()])), - |(column, type_)| { - if type_.is_indexable() { - Ok(Operation::CreateIndex(table_position, column)) - } else { - Err(ValidationError::AttemptToIndexNonIndexableColumn(column_name.clone(), table_name)) - } + schema.get_typed_column(&column_name).map_or_else( + || { + Err(ValidationError::ColumnsDoNotExist(vec![ + column_name.to_string() + ])) + }, + |(column, type_)| { + if type_.is_indexable() { + Ok(Operation::CreateIndex(table_position, column)) + } else { + Err(ValidationError::AttemptToIndexNonIndexableColumn( + column_name.clone(), + table_name, + )) } - ) + }, + ) } // ===Helpers=== -fn find_first_duplicate(ts: &[T]) -> Option<&T> -where T: Eq + std::hash::Hash +fn find_first_duplicate(ts: &[T]) -> Option<&T> +where + T: Eq + std::hash::Hash, { let mut already_seen_elements: HashSet<&T> = HashSet::new(); for t in ts { @@ -235,34 +335,35 @@ where T: Eq + std::hash::Hash None } -fn get_table_schema<'a>(db_schema: &DbSchema<'a>, table_name: &'a TableName) -> Option<&'a TableSchema> { - let (_, _, table_schema) = db_schema.iter().find(|(tname, _, _)| table_name.eq(tname))?; +fn get_table_schema<'a>( + db_schema: &DbSchema<'a>, + table_name: &'a TableName, +) -> Option<&'a TableSchema> { + let (_, _, table_schema) = db_schema + .iter() + .find(|(tname, _, _)| table_name.eq(tname))?; Some(table_schema) } #[cfg(test)] mod tests { - use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax, ColumnSelection, Condition}; - use minisql::type_system::{Value, IndexableValue}; - use minisql::operation::Operation; - use minisql::operation; - use minisql::schema::TableSchema; use super::*; + use crate::syntax::{ColumnSchema, ColumnSelection, Condition, RawQuerySyntax, RawTableSchema}; + use minisql::operation; + use minisql::operation::Operation; + use minisql::schema::TableSchema; + use minisql::type_system::{IndexableValue, Value}; + use Condition::*; + use IndexableValue::*; use RawQuerySyntax::*; use Value::*; - use IndexableValue::*; - use Condition::*; fn users_schema() -> TableSchema { TableSchema::new( "users".to_string(), "id".to_string(), - vec!( - "id".to_string(), - "name".to_string(), - "age".to_string(), - ), + vec!["id".to_string(), "name".to_string(), "age".to_string()], vec![DbType::Uuid, DbType::String, DbType::Int], ) } @@ -271,18 +372,27 @@ mod tests { RawTableSchema { table_name: "users".to_string(), columns: vec![ - ColumnSchema { column_name: "id".to_string(), type_: DbType::Uuid, is_primary: true }, - ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, - ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, + ColumnSchema { + column_name: "id".to_string(), + type_: DbType::Uuid, + is_primary: true, + }, + ColumnSchema { + column_name: "name".to_string(), + type_: DbType::String, + is_primary: false, + }, + ColumnSchema { + column_name: "age".to_string(), + type_: DbType::Int, + is_primary: false, + }, ], } } - fn db_schema(users_schema: &TableSchema) -> DbSchema { - vec![ - ("users".to_string(), 0, users_schema), - ] + vec![("users".to_string(), 0, users_schema)] } fn empty_db_schema() -> DbSchema<'static> { @@ -297,7 +407,9 @@ mod tests { let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Ok(Operation::CreateTable(_)))); - let Ok(Operation::CreateTable(schema)) = result else { panic!() }; + let Ok(Operation::CreateTable(schema)) = result else { + panic!() + }; assert!(schema.table_name() == "users"); } @@ -306,9 +418,21 @@ mod tests { let raw_users_schema = RawTableSchema { table_name: "users".to_string(), columns: vec![ - ColumnSchema { column_name: "id".to_string(), type_: DbType::Uuid, is_primary: true }, - ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, - ColumnSchema { column_name: "name".to_string(), type_: DbType::Number, is_primary: false }, + ColumnSchema { + column_name: "id".to_string(), + type_: DbType::Uuid, + is_primary: true, + }, + ColumnSchema { + column_name: "name".to_string(), + type_: DbType::String, + is_primary: false, + }, + ColumnSchema { + column_name: "name".to_string(), + type_: DbType::Number, + is_primary: false, + }, ], }; @@ -325,9 +449,21 @@ mod tests { let raw_users_schema = RawTableSchema { table_name: "users".to_string(), columns: vec![ - ColumnSchema { column_name: "id".to_string(), type_: DbType::Int, is_primary: true }, - ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, - ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, + ColumnSchema { + column_name: "id".to_string(), + type_: DbType::Int, + is_primary: true, + }, + ColumnSchema { + column_name: "name".to_string(), + type_: DbType::String, + is_primary: false, + }, + ColumnSchema { + column_name: "age".to_string(), + type_: DbType::Int, + is_primary: false, + }, ], }; @@ -343,9 +479,21 @@ mod tests { let raw_users_schema = RawTableSchema { table_name: "users".to_string(), columns: vec![ - ColumnSchema { column_name: "id".to_string(), type_: DbType::Int, is_primary: true }, - ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: true }, - ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, + ColumnSchema { + column_name: "id".to_string(), + type_: DbType::Int, + is_primary: true, + }, + ColumnSchema { + column_name: "name".to_string(), + type_: DbType::String, + is_primary: true, + }, + ColumnSchema { + column_name: "age".to_string(), + type_: DbType::Int, + is_primary: false, + }, ], }; @@ -353,7 +501,10 @@ mod tests { let syntax: RawQuerySyntax = CreateTable(raw_users_schema); let result = validate_operation(syntax, &db_schema); - assert!(matches!(result, Err(ValidationError::MultiplePrimaryKeysFound(_)))); + assert!(matches!( + result, + Err(ValidationError::MultiplePrimaryKeysFound(_)) + )); } #[test] @@ -363,7 +514,10 @@ mod tests { let syntax: RawQuerySyntax = CreateTable(raw_users_schema()); let result = validate_operation(syntax, &db_schema); - assert!(matches!(result, Err(ValidationError::TableAlreadyExists(_)))); + assert!(matches!( + result, + Err(ValidationError::TableAlreadyExists(_)) + )); } // ====Select==== @@ -380,7 +534,9 @@ mod tests { let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Ok(Operation::Select(_, _, _)))); - let Ok(Operation::Select(table_position, column_selection, condition)) = result else { panic!() }; + let Ok(Operation::Select(table_position, column_selection, condition)) = result else { + panic!() + }; assert!(table_position == users_position); assert!(condition == None); @@ -392,7 +548,8 @@ mod tests { let users_schema: TableSchema = users_schema(); let db_schema: DbSchema = db_schema(&users_schema); - let syntax: RawQuerySyntax = Select("does_not_exist".to_string(), ColumnSelection::All, None); + let syntax: RawQuerySyntax = + Select("does_not_exist".to_string(), ColumnSelection::All, None); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Err(ValidationError::TableDoesNotExist(_)))); } @@ -407,11 +564,17 @@ mod tests { let name = 1; let age = 2; - let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, Some(Eq("age".to_string(), Indexable(Int(25))))); + let syntax: RawQuerySyntax = Select( + "users".to_string(), + ColumnSelection::All, + Some(Eq("age".to_string(), Indexable(Int(25)))), + ); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Ok(Operation::Select(_, _, _)))); - let Ok(Operation::Select(table_position, column_selection, condition)) = result else { panic!() }; + let Ok(Operation::Select(table_position, column_selection, condition)) = result else { + panic!() + }; assert!(table_position == users_position); assert!(column_selection == vec![id, name, age]); @@ -428,11 +591,21 @@ mod tests { let name = 1; let age = 2; - let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::Columns(vec!["age".to_string(), "name".to_string(), "age".to_string()]), None); + let syntax: RawQuerySyntax = Select( + "users".to_string(), + ColumnSelection::Columns(vec![ + "age".to_string(), + "name".to_string(), + "age".to_string(), + ]), + None, + ); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Ok(Operation::Select(_, _, _)))); - let Ok(Operation::Select(table_position, column_selection, condition)) = result else { panic!() }; + let Ok(Operation::Select(table_position, column_selection, condition)) = result else { + panic!() + }; assert!(table_position == users_position); assert!(column_selection == vec![age, name, age]); @@ -444,7 +617,11 @@ mod tests { let users_schema: TableSchema = users_schema(); let db_schema: DbSchema = db_schema(&users_schema); - let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::Columns(vec!["age".to_string(), "does_not_exist".to_string()]), None); + let syntax: RawQuerySyntax = Select( + "users".to_string(), + ColumnSelection::Columns(vec!["age".to_string(), "does_not_exist".to_string()]), + None, + ); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); } @@ -454,7 +631,11 @@ mod tests { let users_schema: TableSchema = users_schema(); let db_schema: DbSchema = db_schema(&users_schema); - let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, Some(Eq("does_not_exist".to_string(), Indexable(Int(25))))); + let syntax: RawQuerySyntax = Select( + "users".to_string(), + ColumnSelection::All, + Some(Eq("does_not_exist".to_string(), Indexable(Int(25)))), + ); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); } @@ -464,7 +645,11 @@ mod tests { let users_schema: TableSchema = users_schema(); let db_schema: DbSchema = db_schema(&users_schema); - let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, Some(Eq("age".to_string(), Indexable(String("25".to_string()))))); + let syntax: RawQuerySyntax = Select( + "users".to_string(), + ColumnSelection::All, + Some(Eq("age".to_string(), Indexable(String("25".to_string())))), + ); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); } @@ -483,18 +668,28 @@ mod tests { ("name".to_string(), Indexable(String("Alice".to_string()))), ("id".to_string(), Indexable(Uuid(0))), ("age".to_string(), Indexable(Int(25))), - ]); + ], + ); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Ok(Operation::Insert(_, _)))); - let Ok(Operation::Insert(table_position, values)) = result else { panic!() }; + let Ok(Operation::Insert(table_position, values)) = result else { + panic!() + }; assert!(table_position == users_position); // Recall the order is // let id = 0; // let name = 1; // let age = 2; - assert!(values == vec![Indexable(Uuid(0)), Indexable(String("Alice".to_string())), Indexable(Int(25))]); + assert!( + values + == vec![ + Indexable(Uuid(0)), + Indexable(String("Alice".to_string())), + Indexable(Int(25)) + ] + ); } #[test] @@ -509,7 +704,8 @@ mod tests { ("id".to_string(), Indexable(Uuid(0))), ("age".to_string(), Indexable(Int(25))), ("does_not_exist".to_string(), Indexable(Int(25))), - ]); + ], + ); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); } @@ -525,7 +721,8 @@ mod tests { ("name".to_string(), Indexable(String("Alice".to_string()))), ("id".to_string(), Indexable(Uuid(0))), ("age".to_string(), Number(25.0)), - ]); + ], + ); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); } @@ -542,7 +739,9 @@ mod tests { let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Ok(Operation::Delete(_, None)))); - let Ok(Operation::Delete(table_position, _)) = result else { panic!() }; + let Ok(Operation::Delete(table_position, _)) = result else { + panic!() + }; assert!(table_position == users_position); } @@ -555,11 +754,21 @@ mod tests { let users_position = 0; let age = 2; - let syntax: RawQuerySyntax = Delete("users".to_string(), Some(Eq("age".to_string(), Indexable(Int(25))))); + let syntax: RawQuerySyntax = Delete( + "users".to_string(), + Some(Eq("age".to_string(), Indexable(Int(25)))), + ); let result = validate_operation(syntax, &db_schema); - assert!(matches!(result, Ok(Operation::Delete(_, Some(operation::Condition::Eq(_, _)))))); + assert!(matches!( + result, + Ok(Operation::Delete(_, Some(operation::Condition::Eq(_, _)))) + )); - let Ok(Operation::Delete(table_position, Some(operation::Condition::Eq(column, value)))) = result else { panic!() }; + let Ok(Operation::Delete(table_position, Some(operation::Condition::Eq(column, value)))) = + result + else { + panic!() + }; assert!(table_position == users_position); assert!(column == age); @@ -579,7 +788,9 @@ mod tests { let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Ok(Operation::CreateIndex(_, _)))); - let Ok(Operation::CreateIndex(table_position, column)) = result else { panic!() }; + let Ok(Operation::CreateIndex(table_position, column)) = result else { + panic!() + }; assert!(table_position == users_position); assert!(column == age); diff --git a/proto/src/handshake/client.rs b/proto/src/handshake/client.rs index ff3aaed..160f72c 100644 --- a/proto/src/handshake/client.rs +++ b/proto/src/handshake/client.rs @@ -14,7 +14,6 @@ pub async fn do_client_handshake( reader: &mut impl BackendProtoReader, request: HandshakeRequest, ) -> Result { - // Send StartupMessage without SSLRequest let startup_message: StartupMessageData = request.into(); writer.write_startup_message(startup_message).await?; diff --git a/proto/src/handshake/errors.rs b/proto/src/handshake/errors.rs index cd2a8c4..ca7934b 100644 --- a/proto/src/handshake/errors.rs +++ b/proto/src/handshake/errors.rs @@ -1,10 +1,10 @@ use crate::message::backend::BackendMessage; use crate::message::errors::ProtoDeserializeError; +use crate::message::special::CancelRequestData; use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError}; use crate::writer::errors::ProtoWriteError; use thiserror::Error; use tokio::io; -use crate::message::special::CancelRequestData; #[derive(Debug, Error)] pub enum ClientHandshakeError { diff --git a/proto/src/handshake/request.rs b/proto/src/handshake/request.rs index 51b6ad5..5d334c6 100644 --- a/proto/src/handshake/request.rs +++ b/proto/src/handshake/request.rs @@ -8,7 +8,6 @@ pub struct HandshakeRequest { } impl HandshakeRequest { - /// Creates a new `HandshakeRequest` with the specified protocol version. /// Expected `version` is `196608` for the 3.0. pub fn new(version: i32) -> Self { diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs index d1a332f..7077d42 100644 --- a/proto/src/handshake/server.rs +++ b/proto/src/handshake/server.rs @@ -15,7 +15,6 @@ pub async fn do_server_handshake( reader: &mut impl FrontendProtoReader, response: HandshakeResponse, ) -> Result { - // Check if client requested SSL match &reader.peek_special_message().await? { Some(msg @ SpecialMessage::SSLRequest) => { diff --git a/proto/src/message/primitive/data.rs b/proto/src/message/primitive/data.rs index db19ad4..684444c 100644 --- a/proto/src/message/primitive/data.rs +++ b/proto/src/message/primitive/data.rs @@ -1,6 +1,6 @@ use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; -use bincode::{Decode, Encode}; use bincode::config::{BigEndian, Configuration, Fixint}; +use bincode::{Decode, Encode}; fn pg_proto_config() -> Configuration { bincode::config::standard() diff --git a/proto/src/message/primitive/pgstring.rs b/proto/src/message/primitive/pgstring.rs index c768528..278b2e0 100644 --- a/proto/src/message/primitive/pgstring.rs +++ b/proto/src/message/primitive/pgstring.rs @@ -50,16 +50,17 @@ impl Decode for PgString { bytes.push(byte); } - let string = String::from_utf8(bytes) - .map_err(|e| DecodeError::Utf8 { inner: e.utf8_error() })?; + let string = String::from_utf8(bytes).map_err(|e| DecodeError::Utf8 { + inner: e.utf8_error(), + })?; Ok(PgString(string)) } } #[cfg(test)] mod tests { - use crate::message::primitive::data::MessageData; use super::*; + use crate::message::primitive::data::MessageData; #[test] fn test_encode_decode_utf8() { diff --git a/server/src/cancellation.rs b/server/src/cancellation.rs index 59f2cb1..4609f48 100644 --- a/server/src/cancellation.rs +++ b/server/src/cancellation.rs @@ -1,5 +1,5 @@ -use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; pub struct ResetCancelToken { is_canceled: Arc, @@ -31,4 +31,4 @@ impl Clone for ResetCancelToken { is_canceled: self.is_canceled.clone(), } } -} \ No newline at end of file +} diff --git a/server/src/config.rs b/server/src/config.rs index 68ae54b..2f15935 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -1,6 +1,6 @@ +use clap::Parser; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::PathBuf; -use clap::Parser; const LOCAL_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); @@ -9,7 +9,12 @@ const LOCAL_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); pub struct Configuration { #[arg(short, long, default_value_t = LOCAL_IPV4, help = "IP address for the server to listen on")] address: IpAddr, - #[arg(short, long, default_value = "5432", help = "Port for the server to listen on")] + #[arg( + short, + long, + default_value = "5432", + help = "Port for the server to listen on" + )] port: u16, #[arg(short, long, help = "Path to the data file")] file: PathBuf, @@ -25,4 +30,4 @@ impl Configuration { pub fn get_file_path(&self) -> &PathBuf { &self.file } -} \ No newline at end of file +} diff --git a/server/src/main.rs b/server/src/main.rs index fbaaafb..e54677e 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -24,10 +24,10 @@ use crate::config::Configuration; use crate::persistence::state_to_file; use crate::proto_wrapper::{CompleteStatus, ServerProto}; -mod config; -mod proto_wrapper; mod cancellation; +mod config; mod persistence; +mod proto_wrapper; type TokenStore = Arc>>; type SharedDbState = Arc>; @@ -65,16 +65,17 @@ async fn get_state(config: &Configuration) -> anyhow::Result { println!("WARNING: No DB state file found, creating new one"); Ok(State::new()) } - Err(e) => { - Err(e)? - } - Ok(state) => { - Ok(state) - } + Err(e) => Err(e)?, + Ok(state) => Ok(state), } } -async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore, config: Arc) -> anyhow::Result<()> { +async fn handle_stream( + mut stream: TcpStream, + state: SharedDbState, + tokens: TokenStore, + config: Arc, +) -> anyhow::Result<()> { let (reader, writer) = stream.split(); let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut reader = ProtoReader::new(BufReader::new(reader), 1024); @@ -88,7 +89,9 @@ async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: Toke let result = match request { Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token, config).await, - Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await, + Err(ServerHandshakeError::IsCancelRequest(cancel)) => { + handle_cancellation(cancel.pid, cancel.secret, &tokens).await + } Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)), }; @@ -134,10 +137,17 @@ async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow: Ok(()) } -async fn handle_connection(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken, config: Arc) -> anyhow::Result<()> - where - R: FrontendProtoReader + Send, - W: BackendProtoWriter + ProtoFlush + Send, +async fn handle_connection( + reader: &mut R, + writer: &mut W, + request: HandshakeRequest, + state: SharedDbState, + token: ResetCancelToken, + config: Arc, +) -> anyhow::Result<()> +where + R: FrontendProtoReader + Send, + W: BackendProtoWriter + ProtoFlush + Send, { println!("Client connected: {:?}", request); @@ -152,9 +162,7 @@ async fn handle_connection(reader: &mut R, writer: &mut W, request: Handsh let result = handle_query(writer, &state, data.query.into(), &token, &config).await; match result { Ok(_) => {} - Err(e) => { - writer.write_error_message(&e.to_string()).await? - } + Err(e) => writer.write_error_message(&e.to_string()).await?, } writer.write_ready_for_query().await?; } @@ -165,9 +173,15 @@ async fn handle_connection(reader: &mut R, writer: &mut W, request: Handsh Ok(()) } -async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken, config: &Arc) -> anyhow::Result<()> - where - W: BackendProtoWriter + ProtoFlush + Send, +async fn handle_query( + writer: &mut W, + state: &SharedDbState, + query: String, + token: &ResetCancelToken, + config: &Arc, +) -> anyhow::Result<()> +where + W: BackendProtoWriter + ProtoFlush + Send, { // Make sure token is reset before next query token.reset(); @@ -184,11 +198,15 @@ async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, t match response { Response::Deleted(i) => { - writer.write_command_complete(CompleteStatus::Delete(i)).await?; + writer + .write_command_complete(CompleteStatus::Delete(i)) + .await?; true } Response::Inserted => { - writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?; + writer + .write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }) + .await?; true } Response::Selected(schema, columns, mut rows) => { @@ -207,22 +225,30 @@ async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, t } } - writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; + writer + .write_command_complete(CompleteStatus::Select(sent_rows)) + .await?; } _ => { - writer.write_command_complete(CompleteStatus::Select(0)).await?; + writer + .write_command_complete(CompleteStatus::Select(0)) + .await?; } } false - }, + } Response::TableCreated => { - writer.write_command_complete(CompleteStatus::CreateTable).await?; + writer + .write_command_complete(CompleteStatus::CreateTable) + .await?; true - }, + } Response::IndexCreated => { - writer.write_command_complete(CompleteStatus::CreateIndex).await?; + writer + .write_command_complete(CompleteStatus::CreateIndex) + .await?; true - }, + } } }; diff --git a/server/src/persistence.rs b/server/src/persistence.rs index 980945a..12479c2 100644 --- a/server/src/persistence.rs +++ b/server/src/persistence.rs @@ -1,6 +1,6 @@ +use minisql::interpreter::State; use std::path::PathBuf; use tokio::{fs, io}; -use minisql::interpreter::State; pub async fn state_from_file(path: &PathBuf) -> io::Result { let content = fs::read_to_string(path).await?; diff --git a/server/src/proto_wrapper.rs b/server/src/proto_wrapper.rs index 6462847..960522e 100644 --- a/server/src/proto_wrapper.rs +++ b/server/src/proto_wrapper.rs @@ -1,20 +1,20 @@ use async_trait::async_trait; +use minisql::operation::ColumnSelection; +use minisql::restricted_row::RestrictedRow; +use minisql::schema::{Column, TableSchema}; +use proto::message::backend::{ + BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, + ReadyForQueryData, RowDescriptionData, +}; +use proto::message::primitive::pglist::PgList; +use proto::writer::backend::BackendProtoWriter; use rand::Rng; use rand_pcg::Pcg64; use rand_seeder::Seeder; use std::fmt; -use minisql::operation::ColumnSelection; -use minisql::restricted_row::RestrictedRow; -use minisql::schema::{Column, TableSchema}; -use proto::message::backend::{BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, ReadyForQueryData, RowDescriptionData}; -use proto::message::primitive::pglist::PgList; -use proto::writer::backend::BackendProtoWriter; pub enum CompleteStatus { - Insert { - oid: i32, - rows: i32, - }, + Insert { oid: i32, rows: i32 }, Delete(usize), Select(usize), CreateTable, @@ -38,24 +38,36 @@ pub trait ServerProto { async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()>; async fn write_ready_for_query(&mut self) -> anyhow::Result<()>; async fn write_empty_query(&mut self) -> anyhow::Result<()>; - async fn write_table_header(&mut self, table_schema: &TableSchema, columns: &ColumnSelection) -> anyhow::Result<()>; + async fn write_table_header( + &mut self, + table_schema: &TableSchema, + columns: &ColumnSelection, + ) -> anyhow::Result<()>; async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()>; async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()>; } #[async_trait] -impl ServerProto for W where W: BackendProtoWriter + Send { +impl ServerProto for W +where + W: BackendProtoWriter + Send, +{ async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()> { - self.write_proto(ErrorResponseData { - code: b'M', - message: format!("{error_message}\0").into(), - }.into()).await?; + self.write_proto( + ErrorResponseData { + code: b'M', + message: format!("{error_message}\0").into(), + } + .into(), + ) + .await?; Ok(()) } async fn write_ready_for_query(&mut self) -> anyhow::Result<()> { - self.write_proto(ReadyForQueryData { status: b'I' }.into()).await?; + self.write_proto(ReadyForQueryData { status: b'I' }.into()) + .await?; Ok(()) } @@ -64,35 +76,52 @@ impl ServerProto for W where W: BackendProtoWriter + Send { Ok(()) } - async fn write_table_header(&mut self, table_schema: &TableSchema, columns: &ColumnSelection) -> anyhow::Result<()> { - let columns = columns.iter() + async fn write_table_header( + &mut self, + table_schema: &TableSchema, + columns: &ColumnSelection, + ) -> anyhow::Result<()> { + let columns = columns + .iter() .map(|column| column_to_description(table_schema, *column)) .collect::>>()?; - self.write_proto(RowDescriptionData { columns: columns.into() }.into()).await?; + self.write_proto( + RowDescriptionData { + columns: columns.into(), + } + .into(), + ) + .await?; Ok(()) } async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()> { - let values = row.iter() + let values = row + .iter() .map(|(_, value)| value.as_text_bytes().into()) .collect::>>(); self.write_proto(BackendMessage::DataRow(DataRowData { columns: values.into(), - })).await?; + })) + .await?; Ok(()) } async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()> { self.write_proto(BackendMessage::CommandComplete(CommandCompleteData { tag: status.to_string().into(), - })).await?; + })) + .await?; Ok(()) } } -fn column_to_description(schema: &TableSchema, column: Column) -> anyhow::Result { +fn column_to_description( + schema: &TableSchema, + column: Column, +) -> anyhow::Result { let table_name = schema.table_name(); let table_oid = table_name_to_oid(table_name);