diff --git a/minisql/src/interpreter.rs b/minisql/src/interpreter.rs index 2caf1f0..c25c57c 100644 --- a/minisql/src/interpreter.rs +++ b/minisql/src/interpreter.rs @@ -73,10 +73,10 @@ impl State { &mut self.tables[table_position] } - fn attach_table(&mut self, table_name: TableName, table: Table) { + fn attach_table(&mut self, table: Table) { let new_table_position: TablePosition = self.tables.len(); self.table_name_position_mapping - .insert(table_name, new_table_position); + .insert(table.schema().table_name().clone(), new_table_position); self.tables.push(table); } @@ -126,9 +126,9 @@ impl State { Ok(Response::Deleted(rows_affected)) } - CreateTable(table_name, table_schema) => { + CreateTable(table_schema) => { let table = Table::new(table_schema); - self.attach_table(table_name, table); + self.attach_table(table); Ok(Response::TableCreated) } @@ -150,17 +150,13 @@ mod tests { use crate::operation::Operation; fn users_schema() -> TableSchema { - let id: ColumnPosition = 0; - let name: ColumnPosition = 1; - let age: ColumnPosition = 2; - TableSchema::new( "users".to_string(), - id, + "id".to_string(), vec!( - ("id".to_string(), id), - ("name".to_string(), name), - ("age".to_string(), age), + "id".to_string(), + "name".to_string(), + "age".to_string(), ), vec![DbType::Uuid, DbType::String, DbType::Int], ) @@ -173,7 +169,7 @@ mod tests { let users = users_schema.table_name().clone(); state - .interpret(Operation::CreateTable(users.clone(), users_schema)) + .interpret(Operation::CreateTable(users_schema)) .unwrap(); assert!(state.tables.len() == 1); @@ -187,11 +183,10 @@ mod tests { fn test_select_empty() { let mut state = State::new(); let users_schema = users_schema(); - let users = users_schema.table_name().clone(); let users_position = 0; state - .interpret(Operation::CreateTable(users, users_schema.clone())) + .interpret(Operation::CreateTable(users_schema.clone())) .unwrap(); let response: Response = state .interpret(Operation::Select(users_position, users_schema.all_selection(), None)) @@ -215,7 +210,7 @@ mod tests { state - .interpret(Operation::CreateTable("users".to_string(), users_schema.clone())) + .interpret(Operation::CreateTable(users_schema.clone())) .unwrap(); let (id, name, age) = ( @@ -267,7 +262,7 @@ mod tests { let name_column: ColumnPosition = 1; state - .interpret(CreateTable(users_schema.table_name().clone(), users_schema.clone())) + .interpret(CreateTable(users_schema.clone())) .unwrap(); let (id0, name0, age0) = ( @@ -384,7 +379,7 @@ mod tests { let id_column: ColumnPosition = 0; state - .interpret(CreateTable(users_schema.table_name().clone(), users_schema.clone())) + .interpret(CreateTable(users_schema.clone())) .unwrap(); let (id0, name0, age0) = ( @@ -458,7 +453,7 @@ mod tests { let name_column: ColumnPosition = 1; state - .interpret(CreateTable(users_schema.table_name().clone(), users_schema.clone())) + .interpret(CreateTable(users_schema.clone())) .unwrap(); state @@ -525,26 +520,25 @@ pub fn example() { let id_column: ColumnPosition = 0; let name_column: ColumnPosition = 1; - let age_column: ColumnPosition = 2; + // let age_column: ColumnPosition = 2; let users_schema: TableSchema = { TableSchema::new( "users".to_string(), - id_column, + "id".to_string(), vec!( - ("id".to_string(), id_column), - ("name".to_string(), name_column), - ("age".to_string(), age_column), + "id".to_string(), // 0 + "name".to_string(), // 1 + "age".to_string(), // 2 ), vec![DbType::Uuid, DbType::String, DbType::Int], ) }; let users_position: TablePosition = 0; - let users = users_schema.table_name().clone(); let mut state = State::new(); state - .interpret(Operation::CreateTable(users, users_schema.clone())) + .interpret(Operation::CreateTable(users_schema.clone())) .unwrap(); let (id0, name0, age0) = ( diff --git a/minisql/src/operation.rs b/minisql/src/operation.rs index 5aff265..fa56b70 100644 --- a/minisql/src/operation.rs +++ b/minisql/src/operation.rs @@ -1,4 +1,4 @@ -use crate::schema::{TableName, TableSchema}; +use crate::schema::TableSchema; use crate::type_system::Value; use crate::internals::row::ColumnPosition; use crate::interpreter::TablePosition; @@ -9,7 +9,7 @@ pub enum Operation { Select(TablePosition, ColumnSelection, Option), Insert(TablePosition, InsertionValues), Delete(TablePosition, Option), - CreateTable(TableName, TableSchema), + CreateTable(TableSchema), CreateIndex(TablePosition, ColumnPosition), } diff --git a/minisql/src/schema.rs b/minisql/src/schema.rs index f2bd707..66709ce 100644 --- a/minisql/src/schema.rs +++ b/minisql/src/schema.rs @@ -18,11 +18,15 @@ pub type TableName = String; pub type ColumnName = String; impl TableSchema { - pub fn new(table_name: TableName, primary_key: ColumnPosition, column_name_position_map: Vec<(ColumnName, ColumnPosition)>, types: Vec) -> Self { + pub fn new(table_name: TableName, primary_column_name: ColumnName, columns: Vec, types: Vec) -> Self { let mut column_name_position_mapping: BiMap = BiMap::new(); - for (column_name, column) in column_name_position_map { + for (column, column_name) in columns.into_iter().enumerate() { column_name_position_mapping.insert(column_name, column); } + let primary_key: ColumnPosition = match column_name_position_mapping.get_by_left(&primary_column_name).copied() { + Some(primary_key) => primary_key, + None => unreachable!() // SAFETY: Existence of unique primary key is ensured in validation. + }; Self { table_name, primary_key, column_name_position_mapping, types } } diff --git a/parser/src/parsing/create.rs b/parser/src/parsing/create.rs index c82df31..14a2234 100644 --- a/parser/src/parsing/create.rs +++ b/parser/src/parsing/create.rs @@ -1,4 +1,3 @@ -use minisql::{schema::{ColumnName, TableSchema}, type_system::DbType}; use nom::{ bytes::complete::tag, character::complete::{char, multispace0, multispace1}, @@ -8,7 +7,7 @@ use nom::{ }; use super::common::{parse_table_name, parse_identifier, parse_db_type}; -use crate::syntax::RawQuerySyntax; +use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax}; pub fn parse_create(input: &str) -> IResult<&str, RawQuerySyntax> { let (input, _) = tag("CREATE")(input)?; @@ -20,33 +19,21 @@ pub fn parse_create(input: &str) -> IResult<&str, RawQuerySyntax> { let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, column_definitions) = parse_column_definitions(input)?; - let mut column_name_position_mapping = Vec::new(); - let mut types: Vec = Vec::new(); - let mut primary_key = None; - for (position, (column_name, db_type, pk)) in column_definitions.iter().enumerate() { - types.push(db_type.clone()); - if *pk { - primary_key = Some(position); - } - column_name_position_mapping.push((column_name.clone(), position)); - } let (input, _) = char(')')(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(';')(input)?; - let schema = TableSchema::new( - table_name.to_string(), - primary_key.unwrap_or_default(), - column_name_position_mapping, - types - ); + let schema = RawTableSchema { + table_name: table_name.to_string(), + columns: column_definitions, + }; Ok(( input, - RawQuerySyntax::CreateTable(table_name.to_string(), schema), + RawQuerySyntax::CreateTable(schema), )) } -pub fn parse_column_definitions(input: &str) -> IResult<&str, Vec<(ColumnName, DbType, bool)>> { +fn parse_column_definitions(input: &str) -> IResult<&str, Vec> { separated_list0(terminated(char(','), multispace0), parse_column_definition)(input) } @@ -58,13 +45,13 @@ fn parse_primary_key(input: &str) -> IResult<&str, &str> { Ok((input, "PRIMARY KEY")) } -pub fn parse_column_definition(input: &str) -> IResult<&str, (ColumnName, DbType, bool)> { +fn parse_column_definition(input: &str) -> IResult<&str, ColumnSchema> { let (input, identifier) = parse_identifier(input)?; let (input, _) = multispace1(input)?; let (input, db_type) = parse_db_type(input)?; let (input, pk) = opt(parse_primary_key)(input).map(|(input, pk)| (input, pk.is_some()))?; let (input, _) = multispace0(input)?; - Ok((input, (identifier.to_string(), db_type, pk))) + Ok((input, ColumnSchema { column_name: identifier.to_string(), type_: db_type, is_primary: pk })) } #[cfg(test)] @@ -95,16 +82,23 @@ mod tests { #[test] fn test_parse_create() { let (_, create) = parse_create("CREATE TABLE \"Table1\"( id UUID , column1 INT );").expect("should parse"); - assert!(matches!(create, RawQuerySyntax::CreateTable(_ ,_))); + assert!(matches!(create, RawQuerySyntax::CreateTable(_))); match create { - RawQuerySyntax::CreateTable(name, schema) => { - assert_eq!(name, "Table1"); + RawQuerySyntax::CreateTable(schema) => { + assert_eq!(schema.table_name, "Table1"); assert_eq!(schema.number_of_columns(), 2); - assert_eq!(schema.get_column_position(&"id".to_string()).unwrap(), 0); - assert_eq!(schema.get_column_position(&"column1".to_string()).unwrap(), 1); + + let result_id = schema.get_column(&"id".to_string()); + assert!(matches!(result_id, Some(_))); + let Some(id_column) = result_id else { panic!() }; + assert_eq!(id_column.column_name, "id".to_string()); + + let result_column1 = schema.get_column(&"column1".to_string()); + assert!(matches!(result_column1, Some(_))); + let Some(column1_column) = result_column1 else { panic!() }; + assert_eq!(column1_column.column_name, "column1".to_string()); } _ => {} } - } } diff --git a/parser/src/syntax.rs b/parser/src/syntax.rs index 039a06d..e1258f0 100644 --- a/parser/src/syntax.rs +++ b/parser/src/syntax.rs @@ -1,17 +1,29 @@ -use minisql::{type_system::Value, schema::{TableSchema, ColumnName, TableName}}; +use minisql::{type_system::{Value, DbType}, schema::{ColumnName, TableName}}; -// TODO: Move this out into separate file and rename to something like Syntax, SyntaxTree, -// OperationSyntax, RawOperationSyntax +// ===Table Schema=== +#[derive(Debug, Clone, PartialEq)] +pub struct RawTableSchema { + pub table_name: TableName, + pub columns: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ColumnSchema { + pub column_name: ColumnName, + pub type_: DbType, + pub is_primary: bool, +} + +// ===Query=== pub enum RawQuerySyntax { Select(TableName, ColumnSelection, Option), Insert(TableName, InsertionValues), Delete(TableName, Option), // Update(...), - CreateTable(TableName, TableSchema), + CreateTable(RawTableSchema), CreateIndex(TableName, ColumnName), // DropTable(TableName), } - pub type InsertionValues = Vec<(ColumnName, Value)>; pub enum ColumnSelection { @@ -34,3 +46,17 @@ pub enum Condition { // Prefix(ColumnName, String), // Substring(ColumnName, String), // } + +impl RawTableSchema { + pub fn number_of_columns(&self) -> usize { + self.columns.len() + } + + pub fn get_column(&self, column_name: &ColumnName) -> Option { + self.columns.iter().find(|column_schema| column_name == &column_schema.column_name).cloned() + } + + pub fn get_columns(&self) -> Vec<&ColumnName> { + self.columns.iter().map(|ColumnSchema { column_name, .. }| column_name).collect() + } +} diff --git a/parser/src/validation.rs b/parser/src/validation.rs index b4dc383..b36ce55 100644 --- a/parser/src/validation.rs +++ b/parser/src/validation.rs @@ -2,7 +2,7 @@ use std::collections::{HashSet, BTreeMap}; use thiserror::Error; use crate::syntax; -use crate::syntax::RawQuerySyntax; +use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax}; use minisql::operation; use minisql::{operation::Operation, type_system::Value, schema::{TableSchema, ColumnName, TableName}, type_system::DbType, interpreter::{TablePosition, DbSchema}}; @@ -16,6 +16,10 @@ pub enum ValidationError { ColumnsDoNotExist(Vec), #[error("duplicate column {0}")] DuplicateColumn(ColumnName), + #[error("primary key missing in table {0}")] + PrimaryKeyMissing(TableName), + #[error("multiple primary keys found in table {0}")] + MultiplePrimaryKeysFound(TableName), #[error("type mismatch at column `{column_name:?}` (expected {expected_type:?}, found {received_type:?})")] TypeMismatch { column_name: ColumnName, @@ -38,8 +42,8 @@ pub fn validate_operation(syntax: RawQuerySyntax, db_schema: &DbSchema) -> Resul RawQuerySyntax::Delete(table_name, condition) => { validate_delete(table_name, condition, db_schema) }, - RawQuerySyntax::CreateTable(table_name, schema) => { - validate_create_table(table_name, schema, db_schema) + RawQuerySyntax::CreateTable(schema) => { + validate_create_table(schema, db_schema) }, RawQuerySyntax::CreateIndex(table_name, column_name) => { validate_create_index(table_name, column_name, db_schema) @@ -53,19 +57,52 @@ fn validate_table_exists<'a>(db_schema: &DbSchema<'a>, table_name: &'a TableName .map(|(_, table_position, table_schema)| (*table_position, *table_schema)) } -fn validate_create_table(table_name: TableName, table_schema: TableSchema, db_schema: &DbSchema) -> Result { +fn validate_create_table(raw_table_schema: RawTableSchema, db_schema: &DbSchema) -> Result { + let table_name: &TableName = &raw_table_schema.table_name; if let Some(_) = get_table_schema(db_schema, &table_name) { return Err(ValidationError::TableAlreadyExists(table_name.to_string())); } - find_first_duplicate(&table_schema.get_columns()) + let table_schema: TableSchema = validate_table_schema(raw_table_schema)?; + Ok(Operation::CreateTable(table_schema)) +} + +fn validate_table_schema(raw_table_schema: RawTableSchema) -> Result { + // check for duplicate columns + find_first_duplicate(&raw_table_schema.get_columns()) .map_or_else( || Ok(()), |duplicate_column| Err(ValidationError::DuplicateColumn(duplicate_column.to_string())) )?; - // TODO: Ensure it has a primary key?? - Ok(Operation::CreateTable(table_name, table_schema)) + let mut primary_keys: Vec<(ColumnName, DbType)> = vec![]; + let mut columns: Vec = vec![]; + let mut types: Vec = vec![]; + for ColumnSchema { column_name, type_, is_primary } in raw_table_schema.columns { + if is_primary { + primary_keys.push((column_name.clone(), type_)) + } + columns.push(column_name); + types.push(type_); + } + + // Ensure it has exactly one primary key that has correct type. + if primary_keys.len() == 0 { + return Err(ValidationError::PrimaryKeyMissing(raw_table_schema.table_name.clone())) + } else if primary_keys.len() > 1 { + return Err(ValidationError::MultiplePrimaryKeysFound(raw_table_schema.table_name.clone())) + } else { + let (primary_column_name, primary_key_type) = primary_keys[0].clone(); + if primary_key_type == DbType::Uuid { + Ok(TableSchema::new(raw_table_schema.table_name, primary_column_name, columns, types)) + } else { + Err(ValidationError::TypeMismatch { + column_name: raw_table_schema.table_name.clone(), + received_type: primary_key_type, + expected_type: DbType::Uuid, + }) + } + } } fn validate_select(table_name: TableName, column_selection: syntax::ColumnSelection, condition: Option, db_schema: &DbSchema) -> Result { @@ -197,7 +234,7 @@ fn get_table_schema<'a>(db_schema: &DbSchema<'a>, table_name: &'a TableName) -> #[cfg(test)] mod tests { - use crate::syntax::{RawQuerySyntax, ColumnSelection, Condition}; + use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax, ColumnSelection, Condition}; use minisql::type_system::{Value, IndexableValue}; use minisql::operation::Operation; use minisql::operation; @@ -210,22 +247,30 @@ mod tests { use Condition::*; fn users_schema() -> TableSchema { - let id = 0; - let name = 1; - let age = 2; - TableSchema::new( "users".to_string(), - id, + "id".to_string(), vec!( - ("id".to_string(), id), - ("name".to_string(), name), - ("age".to_string(), age), + "id".to_string(), + "name".to_string(), + "age".to_string(), ), vec![DbType::Uuid, DbType::String, DbType::Int], ) } + fn raw_users_schema() -> RawTableSchema { + RawTableSchema { + table_name: "users".to_string(), + columns: vec![ + ColumnSchema { column_name: "id".to_string(), type_: DbType::Uuid, is_primary: true }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, + ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, + ], + } + } + + fn db_schema(users_schema: &TableSchema) -> DbSchema { vec![ ("users".to_string(), 0, users_schema), @@ -236,52 +281,79 @@ mod tests { vec![] } - // ====CreateTable==== #[test] fn test_create_basic() { - let users_schema: TableSchema = users_schema(); let db_schema: DbSchema = empty_db_schema(); - let syntax: RawQuerySyntax = CreateTable("users".to_string(), users_schema.clone()); + let syntax: RawQuerySyntax = CreateTable(raw_users_schema()); let result = validate_operation(syntax, &db_schema); - assert!(matches!(result, Ok(Operation::CreateTable(_, _)))); + assert!(matches!(result, Ok(Operation::CreateTable(_)))); - let Ok(Operation::CreateTable(table_name, _)) = result else { panic!() }; - assert!(table_name == "users".to_string()); + let Ok(Operation::CreateTable(schema)) = result else { panic!() }; + assert!(schema.table_name() == "users"); } - // #[test] - // fn test_create_duplicates_in_schema() { - // let id = 0; - // let name = 1; + #[test] + fn test_create_duplicates_in_schema() { + let raw_users_schema = RawTableSchema { + table_name: "users".to_string(), + columns: vec![ + ColumnSchema { column_name: "id".to_string(), type_: DbType::Uuid, is_primary: true }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::Number, is_primary: false }, + ], + }; - // let users_schema = TableSchema::new( - // "users".to_string(), - // id, - // vec!( - // ("id".to_string(), id), - // ("name".to_string(), name), - // ("name".to_string(), name + 1), - // ), - // vec![DbType::Uuid, DbType::String, DbType::Int], - // ); + let db_schema: DbSchema = empty_db_schema(); - // let db_schema: DbSchema = empty_db_schema(); + let syntax: RawQuerySyntax = CreateTable(raw_users_schema); + let result = validate_operation(syntax, &db_schema); + println!("{:?}", result); + assert!(matches!(result, Err(ValidationError::DuplicateColumn(_)))); + } - // let syntax: RawQuerySyntax = CreateTable("users".to_string(), users_schema.clone()); - // let result = validate_operation(syntax, &db_schema); - // println!("{:?}", result); - // assert!(matches!(result, Err(ValidationError::DuplicateColumn(_)))); + #[test] + fn test_create_primary_key_is_uuid() { + let raw_users_schema = RawTableSchema { + table_name: "users".to_string(), + columns: vec![ + ColumnSchema { column_name: "id".to_string(), type_: DbType::Int, is_primary: true }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, + ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, + ], + }; - // // TODO - // } + let db_schema: DbSchema = empty_db_schema(); + + let syntax: RawQuerySyntax = CreateTable(raw_users_schema); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); + } + + #[test] + fn test_create_multiple_primary_keys() { + let raw_users_schema = RawTableSchema { + table_name: "users".to_string(), + columns: vec![ + ColumnSchema { column_name: "id".to_string(), type_: DbType::Int, is_primary: true }, + ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: true }, + ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, + ], + }; + + let db_schema: DbSchema = empty_db_schema(); + + let syntax: RawQuerySyntax = CreateTable(raw_users_schema); + let result = validate_operation(syntax, &db_schema); + assert!(matches!(result, Err(ValidationError::MultiplePrimaryKeysFound(_)))); + } #[test] fn test_create_already_exists() { let users_schema: TableSchema = users_schema(); let db_schema: DbSchema = db_schema(&users_schema); - let syntax: RawQuerySyntax = CreateTable("users".to_string(), users_schema.clone()); + let syntax: RawQuerySyntax = CreateTable(raw_users_schema()); let result = validate_operation(syntax, &db_schema); assert!(matches!(result, Err(ValidationError::TableAlreadyExists(_)))); }