diff --git a/minisql/src/error.rs b/minisql/src/error.rs index 451ad7a..71e832b 100644 --- a/minisql/src/error.rs +++ b/minisql/src/error.rs @@ -1,20 +1,17 @@ use std::num::{ParseFloatError, ParseIntError}; use std::str::Utf8Error; use thiserror::Error; -use crate::internals::row::ColumnPosition; use crate::schema::{ColumnName, TableName}; -use crate::type_system::{DbType, Uuid, Value}; +use crate::type_system::Uuid; #[derive(Debug, Error)] -pub enum Error { - #[error("column position {1} of table {0} does not exist")] - ColumnPositionDoesNotExist(TableName, ColumnPosition), - #[error("column {1} of table {0} has unexpected type {2:?} and value {3:?}")] - ValueDoesNotMatchExpectedType(TableName, ColumnName, DbType, Value), +pub enum RuntimeError { #[error("table {0} already contains row with id {1}")] AttemptingToInsertAlreadyPresentId(TableName, Uuid), #[error("table {0} cannot be indexed on column {1}")] AttemptToIndexNonIndexableColumn(TableName, ColumnName), + #[error("table {0} already indexes column {1}")] + AttemptToIndexAlreadyIndexedColumn(TableName, ColumnName), } #[derive(Debug, Error)] diff --git a/minisql/src/internals/row.rs b/minisql/src/internals/row.rs index 23590d2..0e48867 100644 --- a/minisql/src/internals/row.rs +++ b/minisql/src/internals/row.rs @@ -61,16 +61,16 @@ impl Row { self.0.len() } - pub fn get(&self, column_position: ColumnPosition) -> Option<&Value> { - self.0.get(column_position) + pub fn get(&self, column: ColumnPosition) -> Option<&Value> { + self.0.get(column) } pub fn restrict_columns(&self, columns: &Vec) -> RestrictedRow { // If the index from `columns` is non-existant in `row`, it will just ignore it. let mut subrow: Vec<(ColumnPosition, Value)> = vec![]; - for column_position in columns { - if let Some(value) = self.get(*column_position) { - subrow.push((*column_position, value.clone())); + for column in columns { + if let Some(value) = self.get(*column) { + subrow.push((*column, value.clone())); } } diff --git a/minisql/src/internals/table.rs b/minisql/src/internals/table.rs index 18c98a9..dc83639 100644 --- a/minisql/src/internals/table.rs +++ b/minisql/src/internals/table.rs @@ -1,7 +1,7 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use serde::{Deserialize, Serialize}; -use crate::error::Error; +use crate::error::RuntimeError; use crate::internals::column_index::ColumnIndex; use crate::internals::row::{ColumnPosition, Row}; use crate::restricted_row::RestrictedRow; @@ -55,12 +55,12 @@ impl Table { .collect() } - fn get_rows_by_value(&self, column_position: ColumnPosition, value: &Value) -> Vec { + fn get_rows_by_value(&self, column: ColumnPosition, value: &Value) -> Vec { // brute-force search self.rows .values() .filter_map(|row| { - if row.get(column_position) == Some(value) { + if row.get(column) == Some(value) { Some(row.clone()) } else { None @@ -69,21 +69,21 @@ impl Table { .collect() } - pub fn select_all_rows<'a>(&'a self, selected_column_positions: Vec) -> impl Iterator + 'a { + pub fn select_all_rows<'a>(&'a self, selected_columns: Vec) -> impl Iterator + 'a { self.rows .values() - .map(move |row| row.restrict_columns(&selected_column_positions)) + .map(move |row| row.restrict_columns(&selected_columns)) } pub fn select_rows_where_eq<'a>( &'a self, - selected_column_positions: Vec, - column_position: ColumnPosition, + selected_columns: Vec, + column: ColumnPosition, value: Value, ) -> DbResult + 'a> { - let restrict_columns_of_row = move |row: Row| row.restrict_columns(&selected_column_positions); + let restrict_columns_of_row = move |row: Row| row.restrict_columns(&selected_columns); match value { - Value::Indexable(value) => match self.fetch_ids_from_index(column_position, &value)? { + Value::Indexable(value) => match self.fetch_ids_from_index(column, &value)? { Some(ids) => Ok(self .get_rows_by_ids(ids) @@ -92,14 +92,14 @@ impl Table { ), None => Ok(self - .get_rows_by_value(column_position, &Value::Indexable(value)) + .get_rows_by_value(column, &Value::Indexable(value)) .into_iter() .map(restrict_columns_of_row) ), }, _ => Ok(self - .get_rows_by_value(column_position, &value) + .get_rows_by_value(column, &value) .into_iter() .map(restrict_columns_of_row) ), @@ -109,22 +109,16 @@ impl Table { // ======Insertion====== pub fn insert_row_at(&mut self, id: Uuid, row: Row) -> DbResult<()> { if self.rows.get(&id).is_some() { - return Err(Error::AttemptingToInsertAlreadyPresentId( + return Err(RuntimeError::AttemptingToInsertAlreadyPresentId( self.table_name().clone(), id, )); } - for (column_position, column_index) in &mut self.indexes { - match row.get(*column_position) { - Some(Value::Indexable(val)) => column_index.add(val.clone(), id), - Some(_) => {} - None => { - return Err(Error::ColumnPositionDoesNotExist( - self.schema.table_name().clone(), // Note that I can't simply use self.table_name() here because of rust borrowing rules. - *column_position, - )) - } + for (column, column_index) in &mut self.indexes { + match &row[*column] { + Value::Indexable(val) => column_index.add(val.clone(), id), + _ => {}, } } @@ -136,8 +130,8 @@ impl Table { fn delete_row_by_id(&mut self, id: Uuid) -> usize { match self.rows.remove(&id) { Some(row) => { - for (column_position, column_index) in &mut self.indexes { - if let Value::Indexable(value) = &row[*column_position] { + for (column, column_index) in &mut self.indexes { + if let Value::Indexable(value) = &row[*column] { let _ = column_index.remove(value, id); }; } @@ -155,12 +149,12 @@ impl Table { total_count } - fn delete_rows_by_value(&mut self, column_position: ColumnPosition, value: &Value) -> usize { + fn delete_rows_by_value(&mut self, column: ColumnPosition, value: &Value) -> usize { let matched_ids: HashSet = self .rows .iter() .filter_map(|(id, row)| { - if row.get(column_position) == Some(value) { + if row.get(column) == Some(value) { Some(*id) } else { None @@ -179,50 +173,43 @@ impl Table { pub fn delete_rows_where_eq( &mut self, - column_position: ColumnPosition, + column: ColumnPosition, value: Value, ) -> DbResult { match value { - Value::Indexable(value) => match self.fetch_ids_from_index(column_position, &value)? { + Value::Indexable(value) => match self.fetch_ids_from_index(column, &value)? { Some(ids) => Ok(self.delete_rows_by_ids(ids)), - None => Ok(self.delete_rows_by_value(column_position, &Value::Indexable(value))), + None => Ok(self.delete_rows_by_value(column, &Value::Indexable(value))), }, - _ => Ok(self.delete_rows_by_value(column_position, &value)), + _ => Ok(self.delete_rows_by_value(column, &value)), } } // ======Indexing====== - pub fn attach_index(&mut self, column_position: ColumnPosition) -> DbResult<()> { + pub fn attach_index(&mut self, column: ColumnPosition) -> DbResult<()> { + if self.indexes.get(&column).is_some() { + let column_name = self.schema.column_name_from_column(column).clone(); + let table_name = self.schema.table_name().clone(); + return Err(RuntimeError::AttemptToIndexAlreadyIndexedColumn(table_name, column_name)) + } let mut column_index: ColumnIndex = ColumnIndex::new(); - update_index_from_table(&mut column_index, self, column_position)?; - self.indexes.insert(column_position, column_index); + update_index_from_table(&mut column_index, self, column)?; + self.indexes.insert(column, column_index); Ok(()) } fn fetch_ids_from_index( &self, - column_position: ColumnPosition, + column: ColumnPosition, value: &IndexableValue, ) -> DbResult>> { - if self.schema.is_primary(column_position) { + if self.schema.is_primary(column) { match value { IndexableValue::Uuid(id) => Ok(Some(HashSet::from([*id]))), - _ => { - // TODO: This validation step is not really necessary. - let column_name: ColumnName = self - .schema - .column_name_from_column_position(column_position)?; - let type_ = self.schema.column_type(column_position); - Err(Error::ValueDoesNotMatchExpectedType( - self.table_name().clone(), - column_name, - type_, - Value::Indexable(value.clone()), - )) - } + _ => unreachable!() // SAFETY: Validation guarantees primary column has correct Uuid type. } } else { - match self.indexes.get(&column_position) { + match self.indexes.get(&column) { Some(index) => { // Note that we are cloning the ids here! This can be very wasteful in some cases. // Theoretically it would be possible to return a reference, @@ -241,26 +228,21 @@ impl Table { fn update_index_from_table( column_index: &mut ColumnIndex, table: &Table, - column_position: ColumnPosition, + column: ColumnPosition, ) -> DbResult<()> { for (id, row) in &table.rows { - let value = match row.get(column_position) { - Some(Value::Indexable(value)) => value.clone(), - Some(_) => { + let value = match &row[column] { + Value::Indexable(value) => value.clone(), + _ => { let column_name: ColumnName = table .schema - .column_name_from_column_position(column_position)?; - return Err(Error::AttemptToIndexNonIndexableColumn( + .column_name_from_column(column); + // TODO: Perhaps this should be handled in validation? + return Err(RuntimeError::AttemptToIndexNonIndexableColumn( table.table_name().to_string(), column_name, )); } - None => { - return Err(Error::ColumnPositionDoesNotExist( - table.table_name().to_string(), - column_position, - )) - } }; column_index.add(value, *id) } diff --git a/minisql/src/interpreter.rs b/minisql/src/interpreter.rs index 85266b9..08602af 100644 --- a/minisql/src/interpreter.rs +++ b/minisql/src/interpreter.rs @@ -74,10 +74,10 @@ impl State { &mut self.tables[table_position] } - fn attach_table(&mut self, table_name: TableName, table: Table) { + fn attach_table(&mut self, table: Table) { let new_table_position: TablePosition = self.tables.len(); self.table_name_position_mapping - .insert(table_name, new_table_position); + .insert(table.schema().table_name().clone(), new_table_position); self.tables.push(table); } @@ -127,9 +127,9 @@ impl State { Ok(Response::Deleted(rows_affected)) } - CreateTable(table_name, table_schema) => { + CreateTable(table_schema) => { let table = Table::new(table_schema); - self.attach_table(table_name, table); + self.attach_table(table); Ok(Response::TableCreated) } @@ -151,17 +151,13 @@ mod tests { use crate::operation::Operation; fn users_schema() -> TableSchema { - let id: ColumnPosition = 0; - let name: ColumnPosition = 1; - let age: ColumnPosition = 2; - TableSchema::new( "users".to_string(), - id, + "id".to_string(), vec!( - ("id".to_string(), id), - ("name".to_string(), name), - ("age".to_string(), age), + "id".to_string(), + "name".to_string(), + "age".to_string(), ), vec![DbType::Uuid, DbType::String, DbType::Int], ) @@ -174,7 +170,7 @@ mod tests { let users = users_schema.table_name().clone(); state - .interpret(Operation::CreateTable(users.clone(), users_schema)) + .interpret(Operation::CreateTable(users_schema)) .unwrap(); assert!(state.tables.len() == 1); @@ -188,11 +184,10 @@ mod tests { fn test_select_empty() { let mut state = State::new(); let users_schema = users_schema(); - let users = users_schema.table_name().clone(); let users_position = 0; state - .interpret(Operation::CreateTable(users, users_schema.clone())) + .interpret(Operation::CreateTable(users_schema.clone())) .unwrap(); let response: Response = state .interpret(Operation::Select(users_position, users_schema.all_selection(), None)) @@ -216,7 +211,7 @@ mod tests { state - .interpret(Operation::CreateTable("users".to_string(), users_schema.clone())) + .interpret(Operation::CreateTable(users_schema.clone())) .unwrap(); let (id, name, age) = ( @@ -268,7 +263,7 @@ mod tests { let name_column: ColumnPosition = 1; state - .interpret(CreateTable(users_schema.table_name().clone(), users_schema.clone())) + .interpret(CreateTable(users_schema.clone())) .unwrap(); let (id0, name0, age0) = ( @@ -385,7 +380,7 @@ mod tests { let id_column: ColumnPosition = 0; state - .interpret(CreateTable(users_schema.table_name().clone(), users_schema.clone())) + .interpret(CreateTable(users_schema.clone())) .unwrap(); let (id0, name0, age0) = ( @@ -459,7 +454,7 @@ mod tests { let name_column: ColumnPosition = 1; state - .interpret(CreateTable(users_schema.table_name().clone(), users_schema.clone())) + .interpret(CreateTable(users_schema.clone())) .unwrap(); state @@ -526,26 +521,25 @@ pub fn example() { let id_column: ColumnPosition = 0; let name_column: ColumnPosition = 1; - let age_column: ColumnPosition = 2; + // let age_column: ColumnPosition = 2; let users_schema: TableSchema = { TableSchema::new( "users".to_string(), - id_column, + "id".to_string(), vec!( - ("id".to_string(), id_column), - ("name".to_string(), name_column), - ("age".to_string(), age_column), + "id".to_string(), // 0 + "name".to_string(), // 1 + "age".to_string(), // 2 ), vec![DbType::Uuid, DbType::String, DbType::Int], ) }; let users_position: TablePosition = 0; - let users = users_schema.table_name().clone(); let mut state = State::new(); state - .interpret(Operation::CreateTable(users, users_schema.clone())) + .interpret(Operation::CreateTable(users_schema.clone())) .unwrap(); let (id0, name0, age0) = ( diff --git a/minisql/src/operation.rs b/minisql/src/operation.rs index dae8718..fa56b70 100644 --- a/minisql/src/operation.rs +++ b/minisql/src/operation.rs @@ -1,23 +1,24 @@ -use crate::schema::{TableName, TableSchema}; +use crate::schema::TableSchema; use crate::type_system::Value; use crate::internals::row::ColumnPosition; use crate::interpreter::TablePosition; // Validated operation. Constructed by validation crate. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Operation { Select(TablePosition, ColumnSelection, Option), Insert(TablePosition, InsertionValues), Delete(TablePosition, Option), - CreateTable(TableName, TableSchema), + CreateTable(TableSchema), CreateIndex(TablePosition, ColumnPosition), } +// Assumes that these are sorted by column position. pub type InsertionValues = Vec; pub type ColumnSelection = Vec; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Condition { Eq(ColumnPosition, Value), } diff --git a/minisql/src/result.rs b/minisql/src/result.rs index fcad8b5..ace48fe 100644 --- a/minisql/src/result.rs +++ b/minisql/src/result.rs @@ -1,3 +1,3 @@ -use crate::error::Error; +use crate::error::RuntimeError; -pub type DbResult = Result; +pub type DbResult = Result; diff --git a/minisql/src/schema.rs b/minisql/src/schema.rs index c61b7c2..16a41bf 100644 --- a/minisql/src/schema.rs +++ b/minisql/src/schema.rs @@ -1,4 +1,3 @@ -use crate::error::Error; use crate::internals::row::{ColumnPosition, Row}; use crate::operation::{InsertionValues, ColumnSelection}; use crate::result::DbResult; @@ -8,7 +7,7 @@ use serde::{Deserialize, Serialize}; // Note that it is nice to split metadata from the data because // then you can give the metadata to the parser without giving it the data. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct TableSchema { table_name: TableName, // used for descriptive errors primary_key: ColumnPosition, @@ -20,11 +19,15 @@ pub type TableName = String; pub type ColumnName = String; impl TableSchema { - pub fn new(table_name: TableName, primary_key: ColumnPosition, column_name_position_map: Vec<(ColumnName, ColumnPosition)>, types: Vec) -> Self { + pub fn new(table_name: TableName, primary_column_name: ColumnName, columns: Vec, types: Vec) -> Self { let mut column_name_position_mapping: BiMap = BiMap::new(); - for (column_name, column_position) in column_name_position_map { - column_name_position_mapping.insert(column_name, column_position); + for (column, column_name) in columns.into_iter().enumerate() { + column_name_position_mapping.insert(column_name, column); } + let primary_key: ColumnPosition = match column_name_position_mapping.get_by_left(&primary_column_name).copied() { + Some(primary_key) => primary_key, + None => unreachable!() // SAFETY: Existence of unique primary key is ensured in validation. + }; Self { table_name, primary_key, column_name_position_mapping, types } } @@ -32,8 +35,8 @@ impl TableSchema { &self.table_name } - pub fn column_type(&self, column_position: ColumnPosition) -> DbType { - self.types[column_position] + pub fn column_type(&self, column: ColumnPosition) -> DbType { + self.types[column] } pub fn get_columns(&self) -> Vec<&ColumnName> { @@ -64,23 +67,19 @@ impl TableSchema { self.types.get(position).copied() } - pub fn is_primary(&self, column_position: ColumnPosition) -> bool { - self.primary_key == column_position + pub fn is_primary(&self, column: ColumnPosition) -> bool { + self.primary_key == column } - pub fn column_name_from_column_position( - &self, - column_position: ColumnPosition, - ) -> DbResult { + // Assumes `column` comes from a validated source. + pub fn column_name_from_column(&self, column: ColumnPosition) -> ColumnName { match self .column_name_position_mapping - .get_by_right(&column_position) + .get_by_right(&column) { - Some(column_name) => Ok(column_name.clone()), - None => Err(Error::ColumnPositionDoesNotExist( - self.table_name.clone(), - column_position, - )), + Some(column_name) => column_name.clone(), + None => unreachable!() // SAFETY: The only way this function can get a column is from + // validation, which guarantees there is such a colun. } } @@ -96,8 +95,8 @@ impl TableSchema { let id: Uuid = match row.get(self.primary_key) { Some(Value::Indexable(IndexableValue::Uuid(id))) => *id, - Some(_) => unreachable!(), // SAFETY: Should be guaranteed by validation - None => unreachable!(), // SAFETY: Should be guaranteed by validation + Some(_) => unreachable!(), // SAFETY: Should be guaranteed by validation (type-safety) + None => unreachable!(), // SAFETY: Should be guaranteed by validation (missing columns) }; Ok((id, row)) diff --git a/minisql/src/type_system.rs b/minisql/src/type_system.rs index 3d4c837..f098ed5 100644 --- a/minisql/src/type_system.rs +++ b/minisql/src/type_system.rs @@ -31,6 +31,18 @@ pub enum IndexableValue { // TODO: what about null? } +impl DbType { + pub fn is_indexable(&self) -> bool { + match self { + Self::String => true, + Self::Int => true, + Self::Number => false, + Self::Uuid => true, + } + } + +} + impl Value { pub fn to_type(&self) -> DbType { match self { diff --git a/parser/src/core.rs b/parser/src/core.rs index 2cd2432..ec0c140 100644 --- a/parser/src/core.rs +++ b/parser/src/core.rs @@ -29,8 +29,8 @@ pub fn parse_statements<'a>(input: &'a str) -> IResult<&str, Vec many0(parse_statement)(input) } -pub fn parse_and_validate(query: String, db_schema: &DbSchema) -> Result { - let (_, op) = parse_statement(query.as_str()) +pub fn parse_and_validate(str_query: String, db_schema: &DbSchema) -> Result { + let (_, op) = parse_statement(str_query.as_str()) .map_err(|err| { Error::ParsingError(err.to_string()) })?; diff --git a/parser/src/parsing/create.rs b/parser/src/parsing/create.rs index c82df31..14a2234 100644 --- a/parser/src/parsing/create.rs +++ b/parser/src/parsing/create.rs @@ -1,4 +1,3 @@ -use minisql::{schema::{ColumnName, TableSchema}, type_system::DbType}; use nom::{ bytes::complete::tag, character::complete::{char, multispace0, multispace1}, @@ -8,7 +7,7 @@ use nom::{ }; use super::common::{parse_table_name, parse_identifier, parse_db_type}; -use crate::syntax::RawQuerySyntax; +use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax}; pub fn parse_create(input: &str) -> IResult<&str, RawQuerySyntax> { let (input, _) = tag("CREATE")(input)?; @@ -20,33 +19,21 @@ pub fn parse_create(input: &str) -> IResult<&str, RawQuerySyntax> { let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, column_definitions) = parse_column_definitions(input)?; - let mut column_name_position_mapping = Vec::new(); - let mut types: Vec = Vec::new(); - let mut primary_key = None; - for (position, (column_name, db_type, pk)) in column_definitions.iter().enumerate() { - types.push(db_type.clone()); - if *pk { - primary_key = Some(position); - } - column_name_position_mapping.push((column_name.clone(), position)); - } let (input, _) = char(')')(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(';')(input)?; - let schema = TableSchema::new( - table_name.to_string(), - primary_key.unwrap_or_default(), - column_name_position_mapping, - types - ); + let schema = RawTableSchema { + table_name: table_name.to_string(), + columns: column_definitions, + }; Ok(( input, - RawQuerySyntax::CreateTable(table_name.to_string(), schema), + RawQuerySyntax::CreateTable(schema), )) } -pub fn parse_column_definitions(input: &str) -> IResult<&str, Vec<(ColumnName, DbType, bool)>> { +fn parse_column_definitions(input: &str) -> IResult<&str, Vec> { separated_list0(terminated(char(','), multispace0), parse_column_definition)(input) } @@ -58,13 +45,13 @@ fn parse_primary_key(input: &str) -> IResult<&str, &str> { Ok((input, "PRIMARY KEY")) } -pub fn parse_column_definition(input: &str) -> IResult<&str, (ColumnName, DbType, bool)> { +fn parse_column_definition(input: &str) -> IResult<&str, ColumnSchema> { let (input, identifier) = parse_identifier(input)?; let (input, _) = multispace1(input)?; let (input, db_type) = parse_db_type(input)?; let (input, pk) = opt(parse_primary_key)(input).map(|(input, pk)| (input, pk.is_some()))?; let (input, _) = multispace0(input)?; - Ok((input, (identifier.to_string(), db_type, pk))) + Ok((input, ColumnSchema { column_name: identifier.to_string(), type_: db_type, is_primary: pk })) } #[cfg(test)] @@ -95,16 +82,23 @@ mod tests { #[test] fn test_parse_create() { let (_, create) = parse_create("CREATE TABLE \"Table1\"( id UUID , column1 INT );").expect("should parse"); - assert!(matches!(create, RawQuerySyntax::CreateTable(_ ,_))); + assert!(matches!(create, RawQuerySyntax::CreateTable(_))); match create { - RawQuerySyntax::CreateTable(name, schema) => { - assert_eq!(name, "Table1"); + RawQuerySyntax::CreateTable(schema) => { + assert_eq!(schema.table_name, "Table1"); assert_eq!(schema.number_of_columns(), 2); - assert_eq!(schema.get_column_position(&"id".to_string()).unwrap(), 0); - assert_eq!(schema.get_column_position(&"column1".to_string()).unwrap(), 1); + + let result_id = schema.get_column(&"id".to_string()); + assert!(matches!(result_id, Some(_))); + let Some(id_column) = result_id else { panic!() }; + assert_eq!(id_column.column_name, "id".to_string()); + + let result_column1 = schema.get_column(&"column1".to_string()); + assert!(matches!(result_column1, Some(_))); + let Some(column1_column) = result_column1 else { panic!() }; + assert_eq!(column1_column.column_name, "column1".to_string()); } _ => {} } - } } diff --git a/parser/src/syntax.rs b/parser/src/syntax.rs index 039a06d..e1258f0 100644 --- a/parser/src/syntax.rs +++ b/parser/src/syntax.rs @@ -1,17 +1,29 @@ -use minisql::{type_system::Value, schema::{TableSchema, ColumnName, TableName}}; +use minisql::{type_system::{Value, DbType}, schema::{ColumnName, TableName}}; -// TODO: Move this out into separate file and rename to something like Syntax, SyntaxTree, -// OperationSyntax, RawOperationSyntax +// ===Table Schema=== +#[derive(Debug, Clone, PartialEq)] +pub struct RawTableSchema { + pub table_name: TableName, + pub columns: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ColumnSchema { + pub column_name: ColumnName, + pub type_: DbType, + pub is_primary: bool, +} + +// ===Query=== pub enum RawQuerySyntax { Select(TableName, ColumnSelection, Option), Insert(TableName, InsertionValues), Delete(TableName, Option), // Update(...), - CreateTable(TableName, TableSchema), + CreateTable(RawTableSchema), CreateIndex(TableName, ColumnName), // DropTable(TableName), } - pub type InsertionValues = Vec<(ColumnName, Value)>; pub enum ColumnSelection { @@ -34,3 +46,17 @@ pub enum Condition { // Prefix(ColumnName, String), // Substring(ColumnName, String), // } + +impl RawTableSchema { + pub fn number_of_columns(&self) -> usize { + self.columns.len() + } + + pub fn get_column(&self, column_name: &ColumnName) -> Option { + self.columns.iter().find(|column_schema| column_name == &column_schema.column_name).cloned() + } + + pub fn get_columns(&self) -> Vec<&ColumnName> { + self.columns.iter().map(|ColumnSchema { column_name, .. }| column_name).collect() + } +} diff --git a/parser/src/validation.rs b/parser/src/validation.rs index 0f60f00..b818448 100644 --- a/parser/src/validation.rs +++ b/parser/src/validation.rs @@ -1,9 +1,8 @@ -use std::collections::HashSet; -use std::collections::HashMap; +use std::collections::{HashSet, BTreeMap}; use thiserror::Error; use crate::syntax; -use crate::syntax::RawQuerySyntax; +use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax}; use minisql::operation; use minisql::{operation::Operation, type_system::Value, schema::{TableSchema, ColumnName, TableName}, type_system::DbType, interpreter::{TablePosition, DbSchema}}; @@ -17,6 +16,12 @@ pub enum ValidationError { ColumnsDoNotExist(Vec), #[error("duplicate column {0}")] DuplicateColumn(ColumnName), + #[error("primary key missing in table {0}")] + PrimaryKeyMissing(TableName), + #[error("multiple primary keys found in table {0}")] + MultiplePrimaryKeysFound(TableName), + #[error("attempt to index non-indexable column {1} in table {0}")] + AttemptToIndexNonIndexableColumn(TableName, ColumnName), #[error("type mismatch at column `{column_name:?}` (expected {expected_type:?}, found {received_type:?})")] TypeMismatch { column_name: ColumnName, @@ -28,8 +33,8 @@ pub enum ValidationError { } /// Validates and converts the raw syntax into a proper interpreter operation based on db schema. -pub fn validate_operation(query: RawQuerySyntax, db_schema: &DbSchema) -> Result { - match query { +pub fn validate_operation(syntax: RawQuerySyntax, db_schema: &DbSchema) -> Result { + match syntax { RawQuerySyntax::Select(table_name, column_selection, condition) => { validate_select(table_name, column_selection, condition, db_schema) }, @@ -39,8 +44,8 @@ pub fn validate_operation(query: RawQuerySyntax, db_schema: &DbSchema) -> Result RawQuerySyntax::Delete(table_name, condition) => { validate_delete(table_name, condition, db_schema) }, - RawQuerySyntax::CreateTable(table_name, schema) => { - validate_create(table_name, schema, db_schema) + RawQuerySyntax::CreateTable(schema) => { + validate_create_table(schema, db_schema) }, RawQuerySyntax::CreateIndex(table_name, column_name) => { validate_create_index(table_name, column_name, db_schema) @@ -54,31 +59,64 @@ fn validate_table_exists<'a>(db_schema: &DbSchema<'a>, table_name: &'a TableName .map(|(_, table_position, table_schema)| (*table_position, *table_schema)) } -pub fn validate_create(table_name: TableName, table_schema: TableSchema, db_schema: &DbSchema) -> Result { +fn validate_create_table(raw_table_schema: RawTableSchema, db_schema: &DbSchema) -> Result { + let table_name: &TableName = &raw_table_schema.table_name; if let Some(_) = get_table_schema(db_schema, &table_name) { return Err(ValidationError::TableAlreadyExists(table_name.to_string())); } - find_first_duplicate(&table_schema.get_columns()) + let table_schema: TableSchema = validate_table_schema(raw_table_schema)?; + Ok(Operation::CreateTable(table_schema)) +} + +fn validate_table_schema(raw_table_schema: RawTableSchema) -> Result { + // check for duplicate columns + find_first_duplicate(&raw_table_schema.get_columns()) .map_or_else( || Ok(()), |duplicate_column| Err(ValidationError::DuplicateColumn(duplicate_column.to_string())) )?; - // TODO: Ensure it has a primary key?? - Ok(Operation::CreateTable(table_name, table_schema)) + let mut primary_keys: Vec<(ColumnName, DbType)> = vec![]; + let mut columns: Vec = vec![]; + let mut types: Vec = vec![]; + for ColumnSchema { column_name, type_, is_primary } in raw_table_schema.columns { + if is_primary { + primary_keys.push((column_name.clone(), type_)) + } + columns.push(column_name); + types.push(type_); + } + + // Ensure it has exactly one primary key that has correct type. + if primary_keys.len() == 0 { + return Err(ValidationError::PrimaryKeyMissing(raw_table_schema.table_name.clone())) + } else if primary_keys.len() > 1 { + return Err(ValidationError::MultiplePrimaryKeysFound(raw_table_schema.table_name.clone())) + } else { + let (primary_column_name, primary_key_type) = primary_keys[0].clone(); + if primary_key_type == DbType::Uuid { + Ok(TableSchema::new(raw_table_schema.table_name, primary_column_name, columns, types)) + } else { + Err(ValidationError::TypeMismatch { + column_name: raw_table_schema.table_name.clone(), + received_type: primary_key_type, + expected_type: DbType::Uuid, + }) + } + } } -pub fn validate_select(table_name: TableName, column_selection: syntax::ColumnSelection, condition: Option, db_schema: &DbSchema) -> Result { +fn validate_select(table_name: TableName, column_selection: syntax::ColumnSelection, condition: Option, db_schema: &DbSchema) -> Result { let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; match column_selection { syntax::ColumnSelection::Columns(columns) => { let non_existant_columns: Vec = columns.iter().filter_map(|column| if schema.does_column_exist(&column) { - Some(column.clone()) - } else { None + } else { + Some(column.clone()) }).collect(); if non_existant_columns.len() > 0 { Err(ValidationError::ColumnsDoNotExist(non_existant_columns)) @@ -96,7 +134,7 @@ pub fn validate_select(table_name: TableName, column_selection: syntax::ColumnSe } } -pub fn validate_insert(table_name: TableName, insertion_values: syntax::InsertionValues, db_schema: &DbSchema) -> Result { +fn validate_insert(table_name: TableName, insertion_values: syntax::InsertionValues, db_schema: &DbSchema) -> Result { let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; // Check for duplicate columns in insertion_values. @@ -120,7 +158,10 @@ pub fn validate_insert(table_name: TableName, insertion_values: syntax::Insertio } // Check types and prepare for creation of InsertionValues for the interpreter - let mut values_map: HashMap<_, Value> = HashMap::new(); + let mut values_map: BTreeMap<_, Value> = BTreeMap::new(); // The reason for using BTreeMap + // instead of HashMap is that we need + // to get the values in a vector + // sorted by the key. for (column_name, value) in insertion_values { let (column, expected_type) = schema.get_column(&column_name).ok_or(ValidationError::ColumnsDoNotExist(vec![column_name.to_string()]))?; // By the previous validation steps this is never gonna trigger an error. let value_type = value.to_type(); @@ -130,13 +171,14 @@ pub fn validate_insert(table_name: TableName, insertion_values: syntax::Insertio values_map.insert(column, value); } - // These are values ordered by the column position - let values: operation::InsertionValues = values_map.into_values().collect(); + // WARNING: If you use `values_map: HashMap<_,_>`, this is not gonna sort values by key. + let values: operation::InsertionValues = values_map.into_values().collect(); + // Note that one of the values is id. Ok(Operation::Insert(table_position, values)) } -pub fn validate_delete(table_name: TableName, condition: Option, db_schema: &DbSchema) -> Result { +fn validate_delete(table_name: TableName, condition: Option, db_schema: &DbSchema) -> Result { let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; let validated_condition = validate_condition(condition, schema)?; Ok(Operation::Delete(table_position, validated_condition)) @@ -162,13 +204,18 @@ fn validate_condition(condition: Option, schema: &TableSchema } fn validate_create_index(table_name: TableName, column_name: ColumnName, db_schema: &DbSchema) -> Result { - // TODO: You should disallow indexing of Number columns. let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; schema - .get_column_position(&column_name) + .get_column(&column_name) .map_or_else( || Err(ValidationError::ColumnsDoNotExist(vec![column_name.to_string()])), - |column| Ok(Operation::CreateIndex(table_position, column)) + |(column, type_)| { + if type_.is_indexable() { + Ok(Operation::CreateIndex(table_position, column)) + } else { + Err(ValidationError::AttemptToIndexNonIndexableColumn(column_name.clone(), table_name)) + } + } ) } @@ -191,3 +238,359 @@ fn get_table_schema<'a>(db_schema: &DbSchema<'a>, table_name: &'a TableName) -> let (_, _, table_schema) = db_schema.iter().find(|(tname, _, _)| table_name.eq(tname))?; Some(table_schema) } + +#[cfg(test)] +mod tests { + use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax, ColumnSelection, Condition}; + use minisql::type_system::{Value, IndexableValue}; + use minisql::operation::Operation; + use minisql::operation; + use minisql::schema::TableSchema; + use super::*; + + use RawQuerySyntax::*; + use Value::*; + use IndexableValue::*; + use Condition::*; + + fn users_schema() -> TableSchema { + TableSchema::new( + "users".to_string(), + "id".to_string(), + vec!( + "id".to_string(), + "name".to_string(), + "age".to_string(), + ), + vec![DbType::Uuid, DbType::String, DbType::Int], + ) + } + + fn raw_users_schema() -> RawTableSchema { + RawTableSchema { + table_name: "users".to_string(), + columns: vec![ + ColumnSchema { column_name: "id".to_string(), type_: DbType::Uuid, is_primary: true }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, + ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, + ], + } + } + + + fn db_schema(users_schema: &TableSchema) -> DbSchema { + vec![ + ("users".to_string(), 0, users_schema), + ] + } + + fn empty_db_schema() -> DbSchema<'static> { + vec![] + } + + #[test] + fn test_create_basic() { + let db_schema: DbSchema = empty_db_schema(); + + let syntax: RawQuerySyntax = CreateTable(raw_users_schema()); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Ok(Operation::CreateTable(_)))); + + let Ok(Operation::CreateTable(schema)) = result else { panic!() }; + assert!(schema.table_name() == "users"); + } + + #[test] + fn test_create_duplicates_in_schema() { + let raw_users_schema = RawTableSchema { + table_name: "users".to_string(), + columns: vec![ + ColumnSchema { column_name: "id".to_string(), type_: DbType::Uuid, is_primary: true }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::Number, is_primary: false }, + ], + }; + + let db_schema: DbSchema = empty_db_schema(); + + let syntax: RawQuerySyntax = CreateTable(raw_users_schema); + let result = validate_operation(syntax, &db_schema); + println!("{:?}", result); + assert!(matches!(result, Err(ValidationError::DuplicateColumn(_)))); + } + + #[test] + fn test_create_primary_key_is_uuid() { + let raw_users_schema = RawTableSchema { + table_name: "users".to_string(), + columns: vec![ + ColumnSchema { column_name: "id".to_string(), type_: DbType::Int, is_primary: true }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, + ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, + ], + }; + + let db_schema: DbSchema = empty_db_schema(); + + let syntax: RawQuerySyntax = CreateTable(raw_users_schema); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); + } + + #[test] + fn test_create_multiple_primary_keys() { + let raw_users_schema = RawTableSchema { + table_name: "users".to_string(), + columns: vec![ + ColumnSchema { column_name: "id".to_string(), type_: DbType::Int, is_primary: true }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: true }, + ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, + ], + }; + + let db_schema: DbSchema = empty_db_schema(); + + let syntax: RawQuerySyntax = CreateTable(raw_users_schema); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::MultiplePrimaryKeysFound(_)))); + } + + #[test] + fn test_create_already_exists() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let syntax: RawQuerySyntax = CreateTable(raw_users_schema()); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::TableAlreadyExists(_)))); + } + + // ====Select==== + #[test] + fn test_select_basic() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + let users_position = 0; + let id = 0; + let name = 1; + let age = 2; + + let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, None); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Ok(Operation::Select(_, _, _)))); + + let Ok(Operation::Select(table_position, column_selection, condition)) = result else { panic!() }; + + assert!(table_position == users_position); + assert!(condition == None); + assert!(column_selection == vec![id, name, age]); + } + + #[test] + fn test_select_non_existent_table() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let syntax: RawQuerySyntax = Select("does_not_exist".to_string(), ColumnSelection::All, None); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::TableDoesNotExist(_)))); + } + + #[test] + fn test_select_eq() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let users_position = 0; + let id = 0; + let name = 1; + let age = 2; + + let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, Some(Eq("age".to_string(), Indexable(Int(25))))); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Ok(Operation::Select(_, _, _)))); + + let Ok(Operation::Select(table_position, column_selection, condition)) = result else { panic!() }; + + assert!(table_position == users_position); + assert!(column_selection == vec![id, name, age]); + + assert!(condition == Some(operation::Condition::Eq(age, Indexable(Int(25))))); + } + + #[test] + fn test_select_eq_columns_selection() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let users_position = 0; + let name = 1; + let age = 2; + + let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::Columns(vec!["age".to_string(), "name".to_string(), "age".to_string()]), None); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Ok(Operation::Select(_, _, _)))); + + let Ok(Operation::Select(table_position, column_selection, condition)) = result else { panic!() }; + + assert!(table_position == users_position); + assert!(column_selection == vec![age, name, age]); + assert!(condition == None); + } + + #[test] + fn test_select_eq_columns_selection_nonexistent_column_selected() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::Columns(vec!["age".to_string(), "does_not_exist".to_string()]), None); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); + } + + #[test] + fn test_select_eq_non_existent_column() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, Some(Eq("does_not_exist".to_string(), Indexable(Int(25))))); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); + } + + #[test] + fn test_select_eq_type_error() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, Some(Eq("age".to_string(), Indexable(String("25".to_string()))))); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); + } + + // ====Insert==== + #[test] + fn test_insert() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let users_position = 0; + + let syntax: RawQuerySyntax = Insert( + "users".to_string(), + vec![ + ("name".to_string(), Indexable(String("Alice".to_string()))), + ("id".to_string(), Indexable(Uuid(0))), + ("age".to_string(), Indexable(Int(25))), + ]); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Ok(Operation::Insert(_, _)))); + + let Ok(Operation::Insert(table_position, values)) = result else { panic!() }; + + assert!(table_position == users_position); + // Recall the order is + // let id = 0; + // let name = 1; + // let age = 2; + assert!(values == vec![Indexable(Uuid(0)), Indexable(String("Alice".to_string())), Indexable(Int(25))]); + } + + #[test] + fn test_insert_non_existent_column() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let syntax: RawQuerySyntax = Insert( + "users".to_string(), + vec![ + ("name".to_string(), Indexable(String("Alice".to_string()))), + ("id".to_string(), Indexable(Uuid(0))), + ("age".to_string(), Indexable(Int(25))), + ("does_not_exist".to_string(), Indexable(Int(25))), + ]); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); + } + + #[test] + fn test_insert_ill_typed_column() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let syntax: RawQuerySyntax = Insert( + "users".to_string(), + vec![ + ("name".to_string(), Indexable(String("Alice".to_string()))), + ("id".to_string(), Indexable(Uuid(0))), + ("age".to_string(), Number(25.0)), + ]); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); + } + + // ====Delete==== + #[test] + fn test_delete_all() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let users_position = 0; + + let syntax: RawQuerySyntax = Delete("users".to_string(), None); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Ok(Operation::Delete(_, None)))); + + let Ok(Operation::Delete(table_position, _)) = result else { panic!() }; + + assert!(table_position == users_position); + } + + #[test] + fn test_delete_eq() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let users_position = 0; + let age = 2; + + let syntax: RawQuerySyntax = Delete("users".to_string(), Some(Eq("age".to_string(), Indexable(Int(25))))); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Ok(Operation::Delete(_, Some(operation::Condition::Eq(_, _)))))); + + let Ok(Operation::Delete(table_position, Some(operation::Condition::Eq(column, value)))) = result else { panic!() }; + + assert!(table_position == users_position); + assert!(column == age); + assert!(value == Indexable(Int(25))); + } + + // ====CreateIndex==== + #[test] + fn test_create_index() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let users_position = 0; + let age = 2; + + let syntax: RawQuerySyntax = CreateIndex("users".to_string(), "age".to_string()); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Ok(Operation::CreateIndex(_, _)))); + + let Ok(Operation::CreateIndex(table_position, column)) = result else { panic!() }; + + assert!(table_position == users_position); + assert!(column == age); + } + + #[test] + fn test_create_index_nonexistent_column() { + let users_schema: TableSchema = users_schema(); + let db_schema: DbSchema = db_schema(&users_schema); + + let syntax: RawQuerySyntax = CreateIndex("users".to_string(), "does_not_exist".to_string()); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); + } +} diff --git a/server/src/proto_wrapper.rs b/server/src/proto_wrapper.rs index 3415255..bd74ba6 100644 --- a/server/src/proto_wrapper.rs +++ b/server/src/proto_wrapper.rs @@ -58,7 +58,7 @@ impl ServerProto for W where W: BackendProtoWriter + Send { async fn write_table_header(&mut self, table_schema: &TableSchema, row: &RestrictedRow) -> anyhow::Result<()> { let columns = row.iter() - .map(|(index, value)| value_to_column_description(table_schema, value, index)) + .map(|(index, value)| value_to_column_description(table_schema, value, *index)) .collect::>>()?; self.write_proto(RowDescriptionData { columns: columns.into() }.into()).await?; @@ -84,11 +84,11 @@ impl ServerProto for W where W: BackendProtoWriter + Send { } } -fn value_to_column_description(schema: &TableSchema, value: &Value, index: &usize) -> anyhow::Result { - let name = schema.column_name_from_column_position(*index)?; +fn value_to_column_description(schema: &TableSchema, value: &Value, index: usize) -> anyhow::Result { + let name = schema.column_name_from_column(index); let table_oid = schema.table_name().as_bytes().as_ptr() as i32; - let column_index = (*index).try_into()?; + let column_index = index.try_into()?; let type_oid = value.type_oid(); let type_size = value.type_size();