Merge branch 'cargo-format' into 'main'

Apply cargo format

See merge request x433485/minisql!22
This commit is contained in:
Yuriy Dupyn 2024-01-28 22:43:06 +01:00
commit 53c5d3f3f7
33 changed files with 885 additions and 530 deletions

View file

@ -1,24 +1,35 @@
use std::io::Write;
use clap::Parser; use clap::Parser;
use proto::handshake::client::do_client_handshake; use proto::handshake::client::do_client_handshake;
use proto::handshake::request::HandshakeRequest; use proto::handshake::request::HandshakeRequest;
use proto::reader::protoreader::ProtoReader; use proto::message::backend::{
use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; BackendMessage, CommandCompleteData, DataRowData, ErrorResponseData, RowDescriptionData,
use tokio::io::{BufReader, BufWriter}; };
use tokio::net::TcpStream;
use proto::message::backend::{BackendMessage, CommandCompleteData, DataRowData, ErrorResponseData, RowDescriptionData};
use proto::message::frontend::{FrontendMessage, QueryData}; use proto::message::frontend::{FrontendMessage, QueryData};
use proto::reader::oneway::OneWayProtoReader; use proto::reader::oneway::OneWayProtoReader;
use proto::reader::protoreader::ProtoReader;
use proto::writer::oneway::OneWayProtoWriter; use proto::writer::oneway::OneWayProtoWriter;
use proto::writer::protowriter::{ProtoFlush, ProtoWriter};
use std::io::Write;
use tokio::io::{BufReader, BufWriter};
use tokio::net::TcpStream;
#[derive(Parser)] #[derive(Parser)]
struct Cli { struct Cli {
/// Port number of the server. /// Port number of the server.
#[arg(short, long, default_value_t = 5432, help = "Port number of the server")] #[arg(
short,
long,
default_value_t = 5432,
help = "Port number of the server"
)]
port: u16, port: u16,
/// Host name or IP address of the server. /// Host name or IP address of the server.
#[arg(long, default_value = "127.0.0.1", help = "Host name or IP address of the server")] #[arg(
long,
default_value = "127.0.0.1",
help = "Host name or IP address of the server"
)]
host: String, host: String,
/// User name sent to the server. /// User name sent to the server.
@ -48,9 +59,9 @@ async fn main() -> anyhow::Result<()> {
let mut exit = false; let mut exit = false;
let command = prompt()?; let command = prompt()?;
if let Some(cmd) = command { if let Some(cmd) = command {
writer.write_proto(FrontendMessage::Query(QueryData { writer
query: cmd.into(), .write_proto(FrontendMessage::Query(QueryData { query: cmd.into() }))
})).await?; .await?;
writer.flush().await?; writer.flush().await?;
} else { } else {
exit = true; exit = true;
@ -61,35 +72,35 @@ async fn main() -> anyhow::Result<()> {
match msg { match msg {
BackendMessage::RowDescription(data) => { BackendMessage::RowDescription(data) => {
print_row_description(data); print_row_description(data);
}, }
BackendMessage::DataRow(data) => { BackendMessage::DataRow(data) => {
print_row_data(data); print_row_data(data);
}, }
BackendMessage::CommandComplete(data) => { BackendMessage::CommandComplete(data) => {
print_command_complete(data); print_command_complete(data);
}, }
BackendMessage::ErrorResponse(data) => { BackendMessage::ErrorResponse(data) => {
print_error_response(data); print_error_response(data);
}, }
BackendMessage::EmptyQueryResponse => { BackendMessage::EmptyQueryResponse => {
println!("Empty query response"); println!("Empty query response");
}, }
BackendMessage::NoData => { BackendMessage::NoData => {
println!("No data"); println!("No data");
}, }
BackendMessage::ReadyForQuery(data) => { BackendMessage::ReadyForQuery(data) => {
println!("Ready for next query ({})", data.status); println!("Ready for next query ({})", data.status);
let command = prompt()?; let command = prompt()?;
if let Some(cmd) = command { if let Some(cmd) = command {
writer.write_proto(FrontendMessage::Query(QueryData { writer
query: cmd.into(), .write_proto(FrontendMessage::Query(QueryData { query: cmd.into() }))
})).await?; .await?;
writer.flush().await?; writer.flush().await?;
} else { } else {
exit = true; exit = true;
} }
}, }
m => { m => {
println!("Unexpected message: {:?}", m); println!("Unexpected message: {:?}", m);
} }

View file

@ -1,8 +1,8 @@
use crate::schema::{ColumnName, TableName};
use crate::type_system::Uuid;
use std::num::{ParseFloatError, ParseIntError}; use std::num::{ParseFloatError, ParseIntError};
use std::str::Utf8Error; use std::str::Utf8Error;
use thiserror::Error; use thiserror::Error;
use crate::schema::{ColumnName, TableName};
use crate::type_system::Uuid;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum RuntimeError { pub enum RuntimeError {
@ -23,8 +23,5 @@ pub enum TypeConversionError {
#[error("failed to parse int from text")] #[error("failed to parse int from text")]
IntDecodeFailed(#[from] ParseIntError), IntDecodeFailed(#[from] ParseIntError),
#[error("unknown type with oid {oid} and size {size}")] #[error("unknown type with oid {oid} and size {size}")]
UnknownType { UnknownType { oid: i32, size: i16 },
oid: i32,
size: i16
}
} }

View file

@ -1,6 +1,6 @@
use crate::type_system::{IndexableValue, Uuid}; use crate::type_system::{IndexableValue, Uuid};
use std::collections::{BTreeMap, HashSet};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashSet};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct ColumnIndex { pub struct ColumnIndex {

View file

@ -1,10 +1,10 @@
use crate::type_system::Value;
use crate::operation::InsertionValues; use crate::operation::InsertionValues;
use std::ops::{Index, IndexMut};
use std::slice::SliceIndex;
use serde::{Deserialize, Serialize};
use crate::restricted_row::RestrictedRow; use crate::restricted_row::RestrictedRow;
use crate::schema::Column; use crate::schema::Column;
use crate::type_system::Value;
use serde::{Deserialize, Serialize};
use std::ops::{Index, IndexMut};
use std::slice::SliceIndex;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Row(Vec<Value>); pub struct Row(Vec<Value>);

View file

@ -1,12 +1,12 @@
use std::collections::{BTreeMap, HashMap, HashSet};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap, HashSet};
use crate::error::RuntimeError; use crate::error::RuntimeError;
use crate::internals::column_index::ColumnIndex; use crate::internals::column_index::ColumnIndex;
use crate::internals::row::Row; use crate::internals::row::Row;
use crate::restricted_row::RestrictedRow; use crate::restricted_row::RestrictedRow;
use crate::schema::{Column, ColumnName, TableSchema, TableName};
use crate::result::DbResult; use crate::result::DbResult;
use crate::schema::{Column, ColumnName, TableName, TableSchema};
use crate::type_system::{IndexableValue, Uuid, Value}; use crate::type_system::{IndexableValue, Uuid, Value};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -69,7 +69,10 @@ impl Table {
.collect() .collect()
} }
pub fn select_all_rows(&self, selected_columns: Vec<Column>) -> impl Iterator<Item=RestrictedRow> + '_ { pub fn select_all_rows(
&self,
selected_columns: Vec<Column>,
) -> impl Iterator<Item = RestrictedRow> + '_ {
self.rows self.rows
.values() .values()
.map(move |row| row.restrict_columns(&selected_columns)) .map(move |row| row.restrict_columns(&selected_columns))
@ -80,29 +83,23 @@ impl Table {
selected_columns: Vec<Column>, selected_columns: Vec<Column>,
column: Column, column: Column,
value: Value, value: Value,
) -> DbResult<impl Iterator<Item=RestrictedRow> + '_> { ) -> DbResult<impl Iterator<Item = RestrictedRow> + '_> {
let restrict_columns_of_row = move |row: Row| row.restrict_columns(&selected_columns); let restrict_columns_of_row = move |row: Row| row.restrict_columns(&selected_columns);
match value { match value {
Value::Indexable(value) => match self.fetch_ids_from_index(column, &value)? { Value::Indexable(value) => match self.fetch_ids_from_index(column, &value)? {
Some(ids) => Some(ids) => Ok(self
Ok(self .get_rows_by_ids(ids)
.get_rows_by_ids(ids)
.into_iter()
.map(restrict_columns_of_row)
),
None =>
Ok(self
.get_rows_by_value(column, &Value::Indexable(value))
.into_iter()
.map(restrict_columns_of_row)
),
},
_ =>
Ok(self
.get_rows_by_value(column, &value)
.into_iter() .into_iter()
.map(restrict_columns_of_row) .map(restrict_columns_of_row)),
), None => Ok(self
.get_rows_by_value(column, &Value::Indexable(value))
.into_iter()
.map(restrict_columns_of_row)),
},
_ => Ok(self
.get_rows_by_value(column, &value)
.into_iter()
.map(restrict_columns_of_row)),
} }
} }
@ -116,7 +113,9 @@ impl Table {
} }
for (column, column_index) in &mut self.indexes { for (column, column_index) in &mut self.indexes {
if let Value::Indexable(val) = &row[*column] { column_index.add(val.clone(), id) } if let Value::Indexable(val) = &row[*column] {
column_index.add(val.clone(), id)
}
} }
let _ = self.rows.insert(id, row); let _ = self.rows.insert(id, row);
@ -168,11 +167,7 @@ impl Table {
number_of_rows number_of_rows
} }
pub fn delete_rows_where_eq( pub fn delete_rows_where_eq(&mut self, column: Column, value: Value) -> DbResult<usize> {
&mut self,
column: Column,
value: Value,
) -> DbResult<usize> {
match value { match value {
Value::Indexable(value) => match self.fetch_ids_from_index(column, &value)? { Value::Indexable(value) => match self.fetch_ids_from_index(column, &value)? {
Some(ids) => Ok(self.delete_rows_by_ids(ids)), Some(ids) => Ok(self.delete_rows_by_ids(ids)),
@ -187,7 +182,10 @@ impl Table {
if self.indexes.get(&column).is_some() { if self.indexes.get(&column).is_some() {
let column_name = self.schema.column_name_from_column(column).clone(); let column_name = self.schema.column_name_from_column(column).clone();
let table_name = self.schema.table_name().clone(); let table_name = self.schema.table_name().clone();
return Err(RuntimeError::AttemptToIndexAlreadyIndexedColumn(table_name, column_name)) return Err(RuntimeError::AttemptToIndexAlreadyIndexedColumn(
table_name,
column_name,
));
} }
let mut column_index: ColumnIndex = ColumnIndex::new(); let mut column_index: ColumnIndex = ColumnIndex::new();
update_index_from_table(&mut column_index, self, column)?; update_index_from_table(&mut column_index, self, column)?;
@ -203,7 +201,7 @@ impl Table {
if self.schema.is_primary(column) { if self.schema.is_primary(column) {
match value { match value {
IndexableValue::Uuid(id) => Ok(Some(HashSet::from([*id]))), IndexableValue::Uuid(id) => Ok(Some(HashSet::from([*id]))),
_ => unreachable!() // SAFETY: Validation guarantees primary column has correct Uuid type. _ => unreachable!(), // SAFETY: Validation guarantees primary column has correct Uuid type.
} }
} else { } else {
match self.indexes.get(&column) { match self.indexes.get(&column) {
@ -231,9 +229,7 @@ fn update_index_from_table(
let value = match &row[column] { let value = match &row[column] {
Value::Indexable(value) => value.clone(), Value::Indexable(value) => value.clone(),
_ => { _ => {
let column_name: ColumnName = table let column_name: ColumnName = table.schema.column_name_from_column(column);
.schema
.column_name_from_column(column);
return Err(RuntimeError::AttemptToIndexNonIndexableColumn( return Err(RuntimeError::AttemptToIndexNonIndexableColumn(
table.table_name().to_string(), table.table_name().to_string(),
column_name, column_name,

View file

@ -1,10 +1,10 @@
use crate::schema::{Column, TableName, TablePosition, TableSchema};
use crate::internals::table::Table; use crate::internals::table::Table;
use crate::operation::{Operation, Condition, ColumnSelection}; use crate::operation::{ColumnSelection, Condition, Operation};
use crate::restricted_row::RestrictedRow;
use crate::result::DbResult; use crate::result::DbResult;
use crate::schema::{Column, TableName, TablePosition, TableSchema};
use bimap::BiMap; use bimap::BiMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::restricted_row::RestrictedRow;
// Use `TablePosition` as index // Use `TablePosition` as index
pub type Tables = Vec<Table>; pub type Tables = Vec<Table>;
@ -18,7 +18,11 @@ pub struct State {
// #[derive(Debug)] // #[derive(Debug)]
pub enum Response<'a> { pub enum Response<'a> {
Selected(&'a TableSchema, ColumnSelection, Box<dyn Iterator<Item=RestrictedRow> + 'a + Send>), Selected(
&'a TableSchema,
ColumnSelection,
Box<dyn Iterator<Item = RestrictedRow> + 'a + Send>,
),
Inserted, Inserted,
Deleted(usize), // how many were deleted Deleted(usize), // how many were deleted
TableCreated, TableCreated,
@ -32,17 +36,15 @@ impl std::fmt::Debug for Response<'_> {
use Response::*; use Response::*;
match self { match self {
Selected(_schema, _columns, _rows) => Selected(_schema, _columns, _rows) =>
// TODO: How can we iterate through the rows without having to take ownership of // TODO: How can we iterate through the rows without having to take ownership of
// them? // them?
f.write_str("Some rows... trust me"), {
Inserted => f.write_str("Some rows... trust me")
f.write_str("Inserted"), }
Deleted(usize) => Inserted => f.write_str("Inserted"),
f.write_fmt(format_args!("Deleted({})", usize)), Deleted(usize) => f.write_fmt(format_args!("Deleted({})", usize)),
TableCreated => TableCreated => f.write_str("TableCreated"),
f.write_str("TableCreated"), IndexCreated => f.write_str("IndexCreated"),
IndexCreated =>
f.write_str("IndexCreated"),
} }
} }
} }
@ -97,22 +99,25 @@ impl State {
let selected_rows = match maybe_condition { let selected_rows = match maybe_condition {
None => { None => {
let rows = table.select_all_rows(column_selection.clone()); let rows = table.select_all_rows(column_selection.clone());
Box::new(rows) as Box<dyn Iterator<Item=RestrictedRow> + 'a + Send> Box::new(rows) as Box<dyn Iterator<Item = RestrictedRow> + 'a + Send>
}, }
Some(Condition::Eq(eq_column, value)) => { Some(Condition::Eq(eq_column, value)) => {
let x = let rows = table.select_rows_where_eq(
table.select_rows_where_eq( column_selection.clone(),
column_selection.clone(), eq_column,
eq_column, value,
value, )?;
)?; Box::new(rows) as Box<dyn Iterator<Item = RestrictedRow> + 'a + Send>
Box::new(x) as Box<dyn Iterator<Item=RestrictedRow> + 'a + Send>
} }
}; };
Ok(Response::Selected(table.schema(), column_selection, selected_rows)) Ok(Response::Selected(
}, table.schema(),
column_selection,
selected_rows,
))
}
Insert(table_position, values) => { Insert(table_position, values) => {
let table: &mut Table = self.table_at_mut(table_position); let table: &mut Table = self.table_at_mut(table_position);
@ -150,20 +155,16 @@ impl State {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::schema::Column;
use std::collections::HashSet;
use crate::type_system::{DbType, IndexableValue, Value};
use crate::operation::Operation; use crate::operation::Operation;
use crate::schema::Column;
use crate::type_system::{DbType, IndexableValue, Value};
use std::collections::HashSet;
fn users_schema() -> TableSchema { fn users_schema() -> TableSchema {
TableSchema::new( TableSchema::new(
"users".to_string(), "users".to_string(),
"id".to_string(), "id".to_string(),
vec!( vec!["id".to_string(), "name".to_string(), "age".to_string()],
"id".to_string(),
"name".to_string(),
"age".to_string(),
),
vec![DbType::Uuid, DbType::String, DbType::Int], vec![DbType::Uuid, DbType::String, DbType::Int],
) )
} }
@ -195,7 +196,11 @@ mod tests {
.interpret(Operation::CreateTable(users_schema.clone())) .interpret(Operation::CreateTable(users_schema.clone()))
.unwrap(); .unwrap();
let response: Response = state let response: Response = state
.interpret(Operation::Select(users_position, users_schema.all_selection(), None)) .interpret(Operation::Select(
users_position,
users_schema.all_selection(),
None,
))
.unwrap(); .unwrap();
assert!(matches!(response, Response::Selected(_, _, _))); assert!(matches!(response, Response::Selected(_, _, _)));
let Response::Selected(_, _, rows) = response else { let Response::Selected(_, _, rows) = response else {
@ -214,7 +219,6 @@ mod tests {
let users_schema = users_schema(); let users_schema = users_schema();
let users = 0; let users = 0;
state state
.interpret(Operation::CreateTable(users_schema.clone())) .interpret(Operation::CreateTable(users_schema.clone()))
.unwrap(); .unwrap();
@ -227,11 +231,7 @@ mod tests {
state state
.interpret(Operation::Insert( .interpret(Operation::Insert(
users, users,
vec![ vec![id.clone(), name.clone(), age.clone()],
id.clone(),
name.clone(),
age.clone(),
],
)) ))
.unwrap(); .unwrap();
@ -246,7 +246,7 @@ mod tests {
let rows: Vec<_> = rows.collect(); let rows: Vec<_> = rows.collect();
assert!(rows.len() == 1); assert!(rows.len() == 1);
let row = &rows[0]; let row = &rows[0];
assert!(row.len() == 3); assert!(row.len() == 3);
assert!(row[0].1 == id); assert!(row[0].1 == id);
assert!(row[1].1 == name); assert!(row[1].1 == name);
@ -267,9 +267,7 @@ mod tests {
let id_column: Column = 0; let id_column: Column = 0;
let name_column: Column = 1; let name_column: Column = 1;
state state.interpret(CreateTable(users_schema.clone())).unwrap();
.interpret(CreateTable(users_schema.clone()))
.unwrap();
let (id0, name0, age0) = ( let (id0, name0, age0) = (
Indexable(Uuid(0)), Indexable(Uuid(0)),
@ -279,11 +277,7 @@ mod tests {
state state
.interpret(Insert( .interpret(Insert(
users_position, users_position,
vec![ vec![id0.clone(), name0.clone(), age0.clone()],
id0.clone(),
name0.clone(),
age0.clone(),
],
)) ))
.unwrap(); .unwrap();
@ -294,17 +288,15 @@ mod tests {
); );
state state
.interpret(Insert( .interpret(Insert(
users_position, users_position,
vec![ vec![id1.clone(), name1.clone(), age1.clone()],
id1.clone(),
name1.clone(),
age1.clone(),
],
)) ))
.unwrap(); .unwrap();
{ {
let response: Response = state.interpret(Select(users_position, users_schema.all_selection(), None)).unwrap(); let response: Response = state
.interpret(Select(users_position, users_schema.all_selection(), None))
.unwrap();
assert!(matches!(response, Response::Selected(_, _, _))); assert!(matches!(response, Response::Selected(_, _, _)));
let Response::Selected(_, _, rows) = response else { let Response::Selected(_, _, rows) = response else {
@ -384,9 +376,7 @@ mod tests {
let id_column: Column = 0; let id_column: Column = 0;
state state.interpret(CreateTable(users_schema.clone())).unwrap();
.interpret(CreateTable(users_schema.clone()))
.unwrap();
let (id0, name0, age0) = ( let (id0, name0, age0) = (
Indexable(Uuid(0)), Indexable(Uuid(0)),
@ -396,11 +386,7 @@ mod tests {
state state
.interpret(Insert( .interpret(Insert(
users_position, users_position,
vec![ vec![id0.clone(), name0.clone(), age0.clone()],
id0.clone(),
name0.clone(),
age0.clone(),
],
)) ))
.unwrap(); .unwrap();
@ -412,25 +398,20 @@ mod tests {
state state
.interpret(Insert( .interpret(Insert(
users_position, users_position,
vec![ vec![id1.clone(), name1.clone(), age1.clone()],
id1.clone(),
name1.clone(),
age1.clone(),
],
)) ))
.unwrap(); .unwrap();
{ {
let delete_response: Response = state let delete_response: Response = state
.interpret(Delete( .interpret(Delete(users_position, Some(Eq(id_column, id0.clone()))))
users_position,
Some(Eq(id_column, id0.clone())),
))
.unwrap(); .unwrap();
assert!(matches!(delete_response, Response::Deleted(1))); assert!(matches!(delete_response, Response::Deleted(1)));
} }
let response: Response = state.interpret(Select(users_position, users_schema.all_selection(), None)).unwrap(); let response: Response = state
.interpret(Select(users_position, users_schema.all_selection(), None))
.unwrap();
assert!(matches!(response, Response::Selected(_, _, _))); assert!(matches!(response, Response::Selected(_, _, _)));
let Response::Selected(_, _, rows) = response else { let Response::Selected(_, _, rows) = response else {
@ -458,9 +439,7 @@ mod tests {
let name_column: Column = 1; let name_column: Column = 1;
state state.interpret(CreateTable(users_schema.clone())).unwrap();
.interpret(CreateTable(users_schema.clone()))
.unwrap();
state state
.interpret(CreateIndex(users_position, name_column)) .interpret(CreateIndex(users_position, name_column))
@ -474,11 +453,7 @@ mod tests {
state state
.interpret(Insert( .interpret(Insert(
users_position, users_position,
vec![ vec![id0.clone(), name0.clone(), age0.clone()],
id0.clone(),
name0.clone(),
age0.clone(),
],
)) ))
.unwrap(); .unwrap();
@ -490,11 +465,7 @@ mod tests {
state state
.interpret(Insert( .interpret(Insert(
users_position, users_position,
vec![ vec![id1.clone(), name1.clone(), age1.clone()],
id1.clone(),
name1.clone(),
age1.clone(),
],
)) ))
.unwrap(); .unwrap();
@ -510,7 +481,10 @@ mod tests {
let plato_id = 0; let plato_id = 0;
let aristotle_id = 1; let aristotle_id = 1;
let plato_ids = index.get(&String("Plato".to_string())).cloned().unwrap_or(HashSet::new()); let plato_ids = index
.get(&String("Plato".to_string()))
.cloned()
.unwrap_or(HashSet::new());
assert!(plato_ids.contains(&plato_id)); assert!(plato_ids.contains(&plato_id));
assert!(!plato_ids.contains(&aristotle_id)); assert!(!plato_ids.contains(&aristotle_id));
assert!(plato_ids.len() == 1); assert!(plato_ids.len() == 1);
@ -518,7 +492,7 @@ mod tests {
} }
pub fn example() { pub fn example() {
use crate::type_system::{IndexableValue, Value, DbType}; use crate::type_system::{DbType, IndexableValue, Value};
use Condition::*; use Condition::*;
use IndexableValue::*; use IndexableValue::*;
use Operation::*; use Operation::*;
@ -532,11 +506,11 @@ pub fn example() {
TableSchema::new( TableSchema::new(
"users".to_string(), "users".to_string(),
"id".to_string(), "id".to_string(),
vec!( vec![
"id".to_string(), // 0 "id".to_string(), // 0
"name".to_string(), // 1 "name".to_string(), // 1
"age".to_string(), // 2 "age".to_string(), // 2
), ],
vec![DbType::Uuid, DbType::String, DbType::Int], vec![DbType::Uuid, DbType::String, DbType::Int],
) )
}; };
@ -556,11 +530,7 @@ pub fn example() {
state state
.interpret(Insert( .interpret(Insert(
users_position, users_position,
vec![ vec![id0.clone(), name0.clone(), age0.clone()],
id0.clone(),
name0.clone(),
age0.clone(),
],
)) ))
.unwrap(); .unwrap();
@ -573,18 +543,18 @@ pub fn example() {
state state
.interpret(Insert( .interpret(Insert(
users_position, users_position,
vec![ vec![id1.clone(), name1.clone(), age1.clone()],
id1.clone(),
name1.clone(),
age1.clone(),
],
)) ))
.unwrap(); .unwrap();
println!(); println!();
{ {
let response: Response = state let response: Response = state
.interpret(Operation::Select(users_position, users_schema.all_selection(), None)) .interpret(Operation::Select(
users_position,
users_schema.all_selection(),
None,
))
.unwrap(); .unwrap();
println!("==SELECT ALL=="); println!("==SELECT ALL==");
println!("{:?}", response); println!("{:?}", response);
@ -608,19 +578,12 @@ pub fn example() {
// TODO: Why do I have to write these braces explicitely? Why doesn't Rust compiler // TODO: Why do I have to write these braces explicitely? Why doesn't Rust compiler
// "infer" them? // "infer" them?
let _delete_response: Response = state let _delete_response: Response = state
.interpret(Delete( .interpret(Delete(users_position, Some(Eq(id_column, id0.clone()))))
users_position, .unwrap();
Some(Eq(id_column, id0.clone())),
))
.unwrap();
println!("==DELETE Plato=="); println!("==DELETE Plato==");
} }
let response: Response = state let response: Response = state
.interpret(Select( .interpret(Select(users_position, vec![name_column, id_column], None))
users_position,
vec![name_column, id_column],
None,
))
.unwrap(); .unwrap();
println!("==SELECT All=="); println!("==SELECT All==");
println!("{:?}", response); println!("{:?}", response);

View file

@ -1,8 +1,8 @@
pub mod schema;
pub mod interpreter;
pub mod operation;
pub mod type_system;
mod error; mod error;
mod internals; mod internals;
mod result; pub mod interpreter;
pub mod operation;
pub mod restricted_row; pub mod restricted_row;
mod result;
pub mod schema;
pub mod type_system;

View file

@ -1,7 +1,7 @@
use std::ops::Index;
use std::slice::SliceIndex;
use crate::schema::Column; use crate::schema::Column;
use crate::type_system::Value; use crate::type_system::Value;
use std::ops::Index;
use std::slice::SliceIndex;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RestrictedRow(Vec<(Column, Value)>); pub struct RestrictedRow(Vec<(Column, Value)>);
@ -32,8 +32,7 @@ impl RestrictedRow {
self.0.is_empty() self.0.is_empty()
} }
pub fn iter(&self) -> impl Iterator<Item=&(Column, Value)> { pub fn iter(&self) -> impl Iterator<Item = &(Column, Value)> {
self.0.iter() self.0.iter()
} }
} }

View file

@ -1,5 +1,5 @@
use crate::internals::row::Row; use crate::internals::row::Row;
use crate::operation::{InsertionValues, ColumnSelection}; use crate::operation::{ColumnSelection, InsertionValues};
use crate::result::DbResult; use crate::result::DbResult;
use crate::type_system::{DbType, IndexableValue, Uuid, Value}; use crate::type_system::{DbType, IndexableValue, Uuid, Value};
use bimap::BiMap; use bimap::BiMap;
@ -20,18 +20,30 @@ pub type TablePosition = usize;
pub type ColumnName = String; pub type ColumnName = String;
pub type Column = usize; pub type Column = usize;
impl TableSchema { impl TableSchema {
pub fn new(table_name: TableName, primary_column_name: ColumnName, columns: Vec<ColumnName>, types: Vec<DbType>) -> Self { pub fn new(
table_name: TableName,
primary_column_name: ColumnName,
columns: Vec<ColumnName>,
types: Vec<DbType>,
) -> Self {
let mut column_name_position_mapping: BiMap<ColumnName, Column> = BiMap::new(); let mut column_name_position_mapping: BiMap<ColumnName, Column> = BiMap::new();
for (column, column_name) in columns.into_iter().enumerate() { for (column, column_name) in columns.into_iter().enumerate() {
column_name_position_mapping.insert(column_name, column); column_name_position_mapping.insert(column_name, column);
} }
let primary_key: Column = match column_name_position_mapping.get_by_left(&primary_column_name).copied() { let primary_key: Column = match column_name_position_mapping
.get_by_left(&primary_column_name)
.copied()
{
Some(primary_key) => primary_key, Some(primary_key) => primary_key,
None => unreachable!() // SAFETY: Existence of unique primary key is ensured in validation. None => unreachable!(), // SAFETY: Existence of unique primary key is ensured in validation.
}; };
Self { table_name, primary_key, column_name_position_mapping, types } Self {
table_name,
primary_key,
column_name_position_mapping,
types,
}
} }
pub fn table_name(&self) -> &TableName { pub fn table_name(&self) -> &TableName {
@ -43,7 +55,10 @@ impl TableSchema {
} }
pub fn get_columns(&self) -> Vec<&ColumnName> { pub fn get_columns(&self) -> Vec<&ColumnName> {
self.column_name_position_mapping.iter().map(|(name, _)| name).collect() self.column_name_position_mapping
.iter()
.map(|(name, _)| name)
.collect()
} }
pub fn does_column_exist(&self, column_name: &ColumnName) -> bool { pub fn does_column_exist(&self, column_name: &ColumnName) -> bool {
@ -51,11 +66,17 @@ impl TableSchema {
} }
pub fn get_column(&self, column_name: &ColumnName) -> Option<Column> { pub fn get_column(&self, column_name: &ColumnName) -> Option<Column> {
self.column_name_position_mapping.get_by_left(column_name).copied() self.column_name_position_mapping
.get_by_left(column_name)
.copied()
} }
pub fn all_selection(&self) -> ColumnSelection { pub fn all_selection(&self) -> ColumnSelection {
let mut selection: ColumnSelection = self.column_name_position_mapping.iter().map(|(_, column)| *column).collect(); let mut selection: ColumnSelection = self
.column_name_position_mapping
.iter()
.map(|(_, column)| *column)
.collect();
selection.sort(); selection.sort();
selection selection
} }
@ -76,13 +97,10 @@ impl TableSchema {
// Assumes `column` comes from a validated source. // Assumes `column` comes from a validated source.
pub fn column_name_from_column(&self, column: Column) -> ColumnName { pub fn column_name_from_column(&self, column: Column) -> ColumnName {
match self match self.column_name_position_mapping.get_by_right(&column) {
.column_name_position_mapping
.get_by_right(&column)
{
Some(column_name) => column_name.clone(), Some(column_name) => column_name.clone(),
None => unreachable!() // SAFETY: The only way this function can get a column is from None => unreachable!(), // SAFETY: The only way this function can get a column is from
// validation, which guarantees there is such a colun. // validation, which guarantees there is such a colun.
} }
} }

View file

@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
use crate::error::TypeConversionError; use crate::error::TypeConversionError;
use serde::{Deserialize, Serialize};
// ==============Types================ // ==============Types================
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
@ -82,7 +82,11 @@ impl Value {
} }
} }
pub fn from_text_bytes(bytes: &[u8], type_oid: i32, type_size: i16) -> Result<Value, TypeConversionError> { pub fn from_text_bytes(
bytes: &[u8],
type_oid: i32,
type_size: i16,
) -> Result<Value, TypeConversionError> {
match (type_oid, type_size) { match (type_oid, type_size) {
(701, 8) => { (701, 8) => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(bytes)?;
@ -91,7 +95,7 @@ impl Value {
} }
(25, -2) => { (25, -2) => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(bytes)?;
let s = &s[..s.len() - 1]; // remove null terminator let s = &s[..s.len() - 1]; // remove null terminator
Ok(Value::Indexable(IndexableValue::String(s.to_string()))) Ok(Value::Indexable(IndexableValue::String(s.to_string())))
} }
(23, 8) => { (23, 8) => {
@ -111,8 +115,8 @@ impl Value {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{IndexableValue, Value};
use crate::error::TypeConversionError::UnknownType; use crate::error::TypeConversionError::UnknownType;
use super::{Value, IndexableValue};
#[test] #[test]
fn test_encode_number() { fn test_encode_number() {
@ -204,6 +208,9 @@ mod tests {
let bytes = value.as_text_bytes(); let bytes = value.as_text_bytes();
let from_bytes = Value::from_text_bytes(&bytes, oid, size); let from_bytes = Value::from_text_bytes(&bytes, oid, size);
assert!(matches!(from_bytes, Err(UnknownType { oid: 2950, size: 8 }))) assert!(matches!(
from_bytes,
Err(UnknownType { oid: 2950, size: 8 })
))
} }
} }

View file

@ -1,16 +1,22 @@
use minisql::{operation::Operation, interpreter::DbSchema};
use crate::syntax::RawQuerySyntax; use crate::syntax::RawQuerySyntax;
use minisql::{interpreter::DbSchema, operation::Operation};
use nom::{branch::alt, IResult}; use nom::{branch::alt, IResult};
use thiserror::Error; use thiserror::Error;
use crate::{parsing::{create::parse_create, delete::parse_delete, index::parse_create_index, insert::parse_insert, select::parse_select}, validation::{validate_operation, ValidationError}}; use crate::{
parsing::{
create::parse_create, delete::parse_delete, index::parse_create_index,
insert::parse_insert, select::parse_select,
},
validation::{validate_operation, ValidationError},
};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum Error { pub enum Error {
#[error("parsing error: {0}")] #[error("parsing error: {0}")]
ParsingError(String), ParsingError(String),
#[error("validation error: {0}")] #[error("validation error: {0}")]
ValidationError(#[from] ValidationError) ValidationError(#[from] ValidationError),
} }
fn parse_statement(input: &str) -> IResult<&str, RawQuerySyntax> { fn parse_statement(input: &str) -> IResult<&str, RawQuerySyntax> {
@ -21,15 +27,13 @@ fn parse_statement(input: &str) -> IResult<&str, RawQuerySyntax> {
//parse_drop, //parse_drop,
parse_select, parse_select,
// parse_update, // parse_update,
parse_create_index parse_create_index,
))(input) ))(input)
} }
pub fn parse_and_validate(str_query: String, db_schema: &DbSchema) -> Result<Operation, Error> { pub fn parse_and_validate(str_query: String, db_schema: &DbSchema) -> Result<Operation, Error> {
let (_, op) = parse_statement(str_query.as_str()) let (_, op) =
.map_err(|err| { parse_statement(str_query.as_str()).map_err(|err| Error::ParsingError(err.to_string()))?;
Error::ParsingError(err.to_string())
})?;
Ok(validate_operation(op, db_schema)?) Ok(validate_operation(op, db_schema)?)
} }

View file

@ -1,8 +1,7 @@
mod parsing;
mod validation;
mod core; mod core;
mod parsing;
mod syntax; mod syntax;
mod validation;
pub use core::parse_and_validate; pub use core::parse_and_validate;
pub use core::Error; pub use core::Error;

View file

@ -1,20 +1,21 @@
use minisql::type_system::DbType;
use nom::{ use nom::{
character::complete::{alphanumeric1, char, multispace0, anychar, multispace1}, branch::alt,
bytes::complete::tag,
character::complete::{alphanumeric1, anychar, char, multispace0, multispace1},
combinator::peek, combinator::peek,
error::make_error, error::make_error,
sequence::{delimited, terminated}, sequence::{delimited, terminated},
bytes::complete::tag, IResult,
IResult, branch::alt,
}; };
use minisql::type_system::DbType;
use crate::syntax::Condition;
use super::literal::parse_db_value; use super::literal::parse_db_value;
use crate::syntax::Condition;
pub fn parse_table_name(input: &str) -> IResult<&str, &str> { pub fn parse_table_name(input: &str) -> IResult<&str, &str> {
alt(( alt((
delimited(char('"'), alphanumeric1, char('"')), delimited(char('"'), alphanumeric1, char('"')),
parse_identifier parse_identifier,
))(input) ))(input)
} }
@ -24,7 +25,10 @@ pub fn parse_identifier(input: &str) -> IResult<&str, &str> {
if first.is_alphabetic() { if first.is_alphabetic() {
alphanumeric1(input) alphanumeric1(input)
} else { } else {
Err(nom::Err::Error(make_error(input, nom::error::ErrorKind::Alpha))) Err(nom::Err::Error(make_error(
input,
nom::error::ErrorKind::Alpha,
)))
} }
} }
@ -39,7 +43,12 @@ pub fn parse_db_type(input: &str) -> IResult<&str, DbType> {
"INT" => DbType::Int, "INT" => DbType::Int,
"UUID" => DbType::Uuid, "UUID" => DbType::Uuid,
"NUMBER" => DbType::Number, "NUMBER" => DbType::Number,
_ => return Err(nom::Err::Failure(make_error(input, nom::error::ErrorKind::IsNot))) _ => {
return Err(nom::Err::Failure(make_error(
input,
nom::error::ErrorKind::IsNot,
)))
}
}; };
Ok((input, db_type)) Ok((input, db_type))
} }
@ -51,9 +60,7 @@ pub fn parse_condition(input: &str) -> IResult<&str, Option<Condition>> {
let (input, condition) = parse_equality(input)?; let (input, condition) = parse_equality(input)?;
Ok((input, Some(condition))) Ok((input, Some(condition)))
} }
Err(_) => { Err(_) => Ok((input, None)),
Ok((input, None))
}
} }
} }
@ -70,9 +77,9 @@ fn parse_equality(input: &str) -> IResult<&str, Condition> {
mod tests { mod tests {
use minisql::type_system::DbType; use minisql::type_system::DbType;
use crate::syntax::Condition;
use crate::parsing::common::{parse_db_type, parse_equality}; use crate::parsing::common::{parse_db_type, parse_equality};
use crate::syntax::Condition;
#[test] #[test]
fn test_parse_equality() { fn test_parse_equality() {
use minisql::type_system::{IndexableValue, Value}; use minisql::type_system::{IndexableValue, Value};
@ -89,10 +96,22 @@ mod tests {
#[test] #[test]
fn test_parse_db_type() { fn test_parse_db_type() {
assert!(matches!(parse_db_type("INT").expect("should parse").1, DbType::Int)); assert!(matches!(
assert!(matches!(parse_db_type("STRING").expect("should parse").1, DbType::String)); parse_db_type("INT").expect("should parse").1,
assert!(matches!(parse_db_type("UUID").expect("should parse").1, DbType::Uuid)); DbType::Int
assert!(matches!(parse_db_type("NUMBER").expect("should parse").1, DbType::Number)); ));
assert!(matches!(
parse_db_type("STRING").expect("should parse").1,
DbType::String
));
assert!(matches!(
parse_db_type("UUID").expect("should parse").1,
DbType::Uuid
));
assert!(matches!(
parse_db_type("NUMBER").expect("should parse").1,
DbType::Number
));
assert!(matches!(parse_db_type("Unknown"), Err(_))); assert!(matches!(parse_db_type("Unknown"), Err(_)));
} }
} }

View file

@ -1,13 +1,14 @@
use nom::{ use nom::{
bytes::complete::tag, bytes::complete::tag,
character::complete::{char, multispace0, multispace1}, character::complete::{char, multispace0, multispace1},
combinator::opt,
multi::separated_list0, multi::separated_list0,
sequence::terminated, sequence::terminated,
IResult, combinator::opt, IResult,
}; };
use super::common::{parse_table_name, parse_identifier, parse_db_type}; use super::common::{parse_db_type, parse_identifier, parse_table_name};
use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax}; use crate::syntax::{ColumnSchema, RawQuerySyntax, RawTableSchema};
pub fn parse_create(input: &str) -> IResult<&str, RawQuerySyntax> { pub fn parse_create(input: &str) -> IResult<&str, RawQuerySyntax> {
let (input, _) = tag("CREATE")(input)?; let (input, _) = tag("CREATE")(input)?;
@ -27,10 +28,7 @@ pub fn parse_create(input: &str) -> IResult<&str, RawQuerySyntax> {
table_name: table_name.to_string(), table_name: table_name.to_string(),
columns: column_definitions, columns: column_definitions,
}; };
Ok(( Ok((input, RawQuerySyntax::CreateTable(schema)))
input,
RawQuerySyntax::CreateTable(schema),
))
} }
fn parse_column_definitions(input: &str) -> IResult<&str, Vec<ColumnSchema>> { fn parse_column_definitions(input: &str) -> IResult<&str, Vec<ColumnSchema>> {
@ -51,7 +49,14 @@ fn parse_column_definition(input: &str) -> IResult<&str, ColumnSchema> {
let (input, db_type) = parse_db_type(input)?; let (input, db_type) = parse_db_type(input)?;
let (input, pk) = opt(parse_primary_key)(input).map(|(input, pk)| (input, pk.is_some()))?; let (input, pk) = opt(parse_primary_key)(input).map(|(input, pk)| (input, pk.is_some()))?;
let (input, _) = multispace0(input)?; let (input, _) = multispace0(input)?;
Ok((input, ColumnSchema { column_name: identifier.to_string(), type_: db_type, is_primary: pk })) Ok((
input,
ColumnSchema {
column_name: identifier.to_string(),
type_: db_type,
is_primary: pk,
},
))
} }
#[cfg(test)] #[cfg(test)]
@ -66,22 +71,28 @@ mod tests {
#[test] #[test]
fn test_parse_create_primary_key() { fn test_parse_create_primary_key() {
parse_create("CREATE TABLE \"Table1\"(id UUID PRIMARY KEY,column1 INT);").expect("should parse"); parse_create("CREATE TABLE \"Table1\"(id UUID PRIMARY KEY,column1 INT);")
.expect("should parse");
} }
#[test] #[test]
fn test_parse_create_no_quotes_table_name() { fn test_parse_create_no_quotes_table_name() {
parse_create("CREATE TABLE Table1(id UUID PRIMARY KEY,column1 INT);").expect("should parse"); parse_create("CREATE TABLE Table1(id UUID PRIMARY KEY,column1 INT);")
.expect("should parse");
} }
#[test] #[test]
fn test_parse_create_primary_key_with_spaces() { fn test_parse_create_primary_key_with_spaces() {
parse_create("CREATE TABLE \"Table1\" ( id UUID PRIMARY KEY , column1 INT ) ;").expect("should parse"); parse_create(
"CREATE TABLE \"Table1\" ( id UUID PRIMARY KEY , column1 INT ) ;",
)
.expect("should parse");
} }
#[test] #[test]
fn test_parse_create() { fn test_parse_create() {
let (_, create) = parse_create("CREATE TABLE \"Table1\"( id UUID , column1 INT );").expect("should parse"); let (_, create) = parse_create("CREATE TABLE \"Table1\"( id UUID , column1 INT );")
.expect("should parse");
assert!(matches!(create, RawQuerySyntax::CreateTable(_))); assert!(matches!(create, RawQuerySyntax::CreateTable(_)));
match create { match create {
RawQuerySyntax::CreateTable(schema) => { RawQuerySyntax::CreateTable(schema) => {
@ -95,7 +106,9 @@ mod tests {
let result_column1 = schema.get_column(&"column1".to_string()); let result_column1 = schema.get_column(&"column1".to_string());
assert!(matches!(result_column1, Some(_))); assert!(matches!(result_column1, Some(_)));
let Some(column1_column) = result_column1 else { panic!() }; let Some(column1_column) = result_column1 else {
panic!()
};
assert_eq!(column1_column.column_name, "column1".to_string()); assert_eq!(column1_column.column_name, "column1".to_string());
} }
_ => {} _ => {}

View file

@ -4,8 +4,8 @@ use nom::{
IResult, IResult,
}; };
use super::common::{parse_condition, parse_table_name};
use crate::syntax::RawQuerySyntax; use crate::syntax::RawQuerySyntax;
use super::common::{parse_table_name, parse_condition};
pub fn parse_delete(input: &str) -> IResult<&str, RawQuerySyntax> { pub fn parse_delete(input: &str) -> IResult<&str, RawQuerySyntax> {
let (input, _) = tag("DELETE")(input)?; let (input, _) = tag("DELETE")(input)?;
@ -25,14 +25,15 @@ pub fn parse_delete(input: &str) -> IResult<&str, RawQuerySyntax> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::syntax::RawQuerySyntax;
use crate::parsing::delete::parse_delete; use crate::parsing::delete::parse_delete;
use crate::syntax::RawQuerySyntax;
#[test] #[test]
fn test_parse_delete() { fn test_parse_delete() {
let (_, operation) = parse_delete("DELETE FROM \"T1\" WHERE id = 1 ;").expect("should parse"); let (_, operation) =
parse_delete("DELETE FROM \"T1\" WHERE id = 1 ;").expect("should parse");
assert!(matches!(operation, RawQuerySyntax::Delete(_, _))) assert!(matches!(operation, RawQuerySyntax::Delete(_, _)))
} }
// TODO: add test with condition // TODO: add test with condition
} }

View file

@ -2,7 +2,8 @@ use crate::syntax::RawQuerySyntax;
use nom::{ use nom::{
bytes::complete::tag, bytes::complete::tag,
character::complete::{char, multispace0, multispace1}, character::complete::{char, multispace0, multispace1},
IResult, combinator::opt, combinator::opt,
IResult,
}; };
use super::common::{parse_identifier, parse_table_name}; use super::common::{parse_identifier, parse_table_name};
@ -35,16 +36,16 @@ pub fn parse_create_index(input: &str) -> IResult<&str, RawQuerySyntax> {
Ok((input, operation)) Ok((input, operation))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::syntax::RawQuerySyntax;
use crate::parsing::index::parse_create_index; use crate::parsing::index::parse_create_index;
use crate::syntax::RawQuerySyntax;
#[test] #[test]
fn test_create_index() { fn test_create_index() {
let (_, syntax) = parse_create_index("CREATE UNIQUE INDEX idxcontactsemail ON \"contacts\" (email);").expect("should parse"); let (_, syntax) =
parse_create_index("CREATE UNIQUE INDEX idxcontactsemail ON \"contacts\" (email);")
.expect("should parse");
assert!(matches!(syntax, RawQuerySyntax::CreateIndex(_, _))); assert!(matches!(syntax, RawQuerySyntax::CreateIndex(_, _)));
match syntax { match syntax {
RawQuerySyntax::CreateIndex(table_name, column_name) => { RawQuerySyntax::CreateIndex(table_name, column_name) => {
@ -57,7 +58,10 @@ mod tests {
#[test] #[test]
fn test_create_index_with_spaces() { fn test_create_index_with_spaces() {
let (_, syntax) = parse_create_index("CREATE UNIQUE INDEX idxcontactsemail ON \"contacts\" ( email ) ;").expect("should parse"); let (_, syntax) = parse_create_index(
"CREATE UNIQUE INDEX idxcontactsemail ON \"contacts\" ( email ) ;",
)
.expect("should parse");
assert!(matches!(syntax, RawQuerySyntax::CreateIndex(_, _))); assert!(matches!(syntax, RawQuerySyntax::CreateIndex(_, _)));
match syntax { match syntax {
RawQuerySyntax::CreateIndex(table_name, column_name) => { RawQuerySyntax::CreateIndex(table_name, column_name) => {

View file

@ -1,9 +1,12 @@
use super::{literal::parse_db_value, common::{parse_table_name, parse_identifier}}; use super::{
common::{parse_identifier, parse_table_name},
literal::parse_db_value,
};
use crate::syntax::RawQuerySyntax; use crate::syntax::RawQuerySyntax;
use minisql::type_system::Value; use minisql::type_system::Value;
use nom::{ use nom::{
bytes::complete::tag, bytes::complete::tag,
character::complete::{multispace0, multispace1, char}, character::complete::{char, multispace0, multispace1},
combinator::map, combinator::map,
multi::separated_list0, multi::separated_list0,
sequence::terminated, sequence::terminated,
@ -14,7 +17,7 @@ pub fn parse_insert(input: &str) -> IResult<&str, RawQuerySyntax> {
let (input, _) = tag("INSERT")(input)?; let (input, _) = tag("INSERT")(input)?;
let (input, _) = multispace1(input)?; let (input, _) = multispace1(input)?;
let (input, _) = tag("INTO")(input)?; let (input, _) = tag("INTO")(input)?;
let (input, _) = multispace1(input)?; let (input, _) = multispace1(input)?;
let (input, table_name) = parse_table_name(input)?; let (input, table_name) = parse_table_name(input)?;
let (input, _) = multispace1(input)?; let (input, _) = multispace1(input)?;
let (input, _) = char('(')(input)?; let (input, _) = char('(')(input)?;
@ -34,27 +37,31 @@ pub fn parse_insert(input: &str) -> IResult<&str, RawQuerySyntax> {
let (input, _) = char(';')(input)?; let (input, _) = char(';')(input)?;
Ok(( Ok((
input, input,
RawQuerySyntax::Insert(table_name.to_string(), column_names.into_iter().zip(values).collect()), RawQuerySyntax::Insert(
table_name.to_string(),
column_names.into_iter().zip(values).collect(),
),
)) ))
} }
pub fn parse_columns(input: &str) -> IResult<&str, Vec<String>> { pub fn parse_columns(input: &str) -> IResult<&str, Vec<String>> {
separated_list0(terminated(char(','), multispace0), map(parse_identifier, |name|name.to_string()))(input) separated_list0(
terminated(char(','), multispace0),
map(parse_identifier, |name| name.to_string()),
)(input)
} }
pub fn parse_values(input: &str) -> IResult<&str, Vec<Value>> { pub fn parse_values(input: &str) -> IResult<&str, Vec<Value>> {
separated_list0(terminated(char(','), multispace0), parse_db_value)(input) separated_list0(terminated(char(','), multispace0), parse_db_value)(input)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use minisql::type_system::{IndexableValue, Value}; use minisql::type_system::{IndexableValue, Value};
use crate::syntax::RawQuerySyntax;
use super::parse_insert; use super::parse_insert;
use crate::syntax::RawQuerySyntax;
#[test] #[test]
fn test_parse_insert() { fn test_parse_insert() {
let sql = "INSERT INTO \"MyTable\" (id, data) VALUES(1, \"Text\");"; let sql = "INSERT INTO \"MyTable\" (id, data) VALUES(1, \"Text\");";
@ -63,11 +70,15 @@ mod tests {
("", RawQuerySyntax::Insert(table_name, insertion_values)) => { ("", RawQuerySyntax::Insert(table_name, insertion_values)) => {
assert_eq!(table_name, "MyTable"); assert_eq!(table_name, "MyTable");
assert_eq!( assert_eq!(
insertion_values, insertion_values,
vec![ vec![
("id".to_string(), Value::Indexable(IndexableValue::Int(1))), ("id".to_string(), Value::Indexable(IndexableValue::Int(1))),
("data".to_string(), Value::Indexable(IndexableValue::String("Text".to_string()))) (
]); "data".to_string(),
Value::Indexable(IndexableValue::String("Text".to_string()))
)
]
);
} }
_ => { _ => {
unreachable!() unreachable!()
@ -77,16 +88,22 @@ mod tests {
#[test] #[test]
fn test_parse_insert_with_spaces() { fn test_parse_insert_with_spaces() {
let sql = "INSERT INTO \"MyTable\" ( id, data ) VALUES ( 1, \"Text\" ) ;"; let sql =
"INSERT INTO \"MyTable\" ( id, data ) VALUES ( 1, \"Text\" ) ;";
let operation = parse_insert(sql).expect("should parse"); let operation = parse_insert(sql).expect("should parse");
match operation { match operation {
("", RawQuerySyntax::Insert(table_name, insertion_values)) => { ("", RawQuerySyntax::Insert(table_name, insertion_values)) => {
assert_eq!(table_name, "MyTable"); assert_eq!(table_name, "MyTable");
assert_eq!(insertion_values, assert_eq!(
insertion_values,
vec![ vec![
("id".to_string(), Value::Indexable(IndexableValue::Int(1))), ("id".to_string(), Value::Indexable(IndexableValue::Int(1))),
("data".to_string(), Value::Indexable(IndexableValue::String("Text".to_string()))) (
]); "data".to_string(),
Value::Indexable(IndexableValue::String("Text".to_string()))
)
]
);
} }
_ => { _ => {
unreachable!() unreachable!()

View file

@ -1,20 +1,16 @@
use minisql::type_system::{IndexableValue, Value}; use minisql::type_system::{IndexableValue, Value};
use nom::{ use nom::{
branch::alt, branch::alt,
character::complete::{u64, char, digit1, none_of}, character::complete::{char, digit1, none_of, u64},
combinator::opt, combinator::opt,
error::make_error,
multi::many0, multi::many0,
sequence::{delimited, pair, preceded}, sequence::{delimited, pair, preceded},
IResult, error::make_error IResult,
}; };
pub fn parse_db_value(input: &str) -> IResult<&str, Value> { pub fn parse_db_value(input: &str) -> IResult<&str, Value> {
alt(( alt((parse_string, parse_number, parse_int, parse_uuid))(input)
parse_string,
parse_number,
parse_int,
parse_uuid,
))(input)
} }
pub fn parse_number(input: &str) -> IResult<&str, Value> { pub fn parse_number(input: &str) -> IResult<&str, Value> {
@ -27,56 +23,47 @@ pub fn parse_number(input: &str) -> IResult<&str, Value> {
match frac_part { match frac_part {
Some((_fsign, fdigits)) => { Some((_fsign, fdigits)) => {
// Combine integer and fractional parts // Combine integer and fractional parts
let combined_parts = format!( let combined_parts = format!("{}{}.{}", sign.unwrap_or('+'), digits, fdigits);
"{}{}.{}",
sign.unwrap_or('+'),
digits,
fdigits
);
// Parse the combined parts as a floating-point number // Parse the combined parts as a floating-point number
let value = combined_parts.parse::<f64>() let value = combined_parts
.map_err(|_| { .parse::<f64>()
nom::Err::Failure(make_error(input, nom::error::ErrorKind::Fail)) .map_err(|_| nom::Err::Failure(make_error(input, nom::error::ErrorKind::Fail)))?;
})?;
Ok((input, Value::Number(value))) Ok((input, Value::Number(value)))
} }
None => { None => {
let value = format!("{}{}", sign.unwrap_or('+'), digits).parse::<u64>() let value = format!("{}{}", sign.unwrap_or('+'), digits)
.map_err(|_| { .parse::<u64>()
nom::Err::Failure(make_error(input, nom::error::ErrorKind::Fail)) .map_err(|_| nom::Err::Failure(make_error(input, nom::error::ErrorKind::Fail)))?;
})?;
Ok((input, Value::Indexable(IndexableValue::Int(value)))) Ok((input, Value::Indexable(IndexableValue::Int(value))))
} }
} }
} }
pub fn parse_int(input: &str) -> IResult<&str, Value> { pub fn parse_int(input: &str) -> IResult<&str, Value> {
u64(input).map(|(input, v)| { u64(input).map(|(input, v)| (input, Value::Indexable(IndexableValue::Int(v))))
(input, Value::Indexable(IndexableValue::Int(v)))
})
} }
fn escape_tab(input:&str) -> IResult<&str, char> { fn escape_tab(input: &str) -> IResult<&str, char> {
let (input, _) = preceded(char('\\'), char('t'))(input)?; let (input, _) = preceded(char('\\'), char('t'))(input)?;
Ok((input, '\t')) Ok((input, '\t'))
} }
fn escape_backslash(input:&str) -> IResult<&str, char> { fn escape_backslash(input: &str) -> IResult<&str, char> {
let (input, _) = preceded(char('\\'), char('\\'))(input)?; let (input, _) = preceded(char('\\'), char('\\'))(input)?;
Ok((input, '\\')) Ok((input, '\\'))
} }
fn escape_newline(input:&str) -> IResult<&str, char> { fn escape_newline(input: &str) -> IResult<&str, char> {
let (input, _) = preceded(char('\\'), char('n'))(input)?; let (input, _) = preceded(char('\\'), char('n'))(input)?;
Ok((input, '\n')) Ok((input, '\n'))
} }
fn escape_carriegereturn(input:&str) -> IResult<&str, char> { fn escape_carriegereturn(input: &str) -> IResult<&str, char> {
let (input, _) = preceded(char('\\'), char('r'))(input)?; let (input, _) = preceded(char('\\'), char('r'))(input)?;
Ok((input, '\r')) Ok((input, '\r'))
} }
fn escape_doublequote(input:&str) -> IResult<&str, char> { fn escape_doublequote(input: &str) -> IResult<&str, char> {
preceded(char('\\'), char('"'))(input) preceded(char('\\'), char('"'))(input)
} }
@ -90,7 +77,7 @@ pub fn parse_string(input: &str) -> IResult<&str, Value> {
escape_newline, escape_newline,
escape_doublequote, escape_doublequote,
escape_tab, escape_tab,
none_of(r#"\""#) none_of(r#"\""#),
))), ))),
char('"'), char('"'),
)(input)?; )(input)?;
@ -102,23 +89,39 @@ pub fn parse_string(input: &str) -> IResult<&str, Value> {
} }
pub fn parse_uuid(input: &str) -> IResult<&str, Value> { pub fn parse_uuid(input: &str) -> IResult<&str, Value> {
let (input, value) = pair(char('u'), u64)(input).map(|(input, (_, v))| { let (input, value) = pair(char('u'), u64)(input)
(input, Value::Indexable(IndexableValue::Uuid(v))) .map(|(input, (_, v))| (input, Value::Indexable(IndexableValue::Uuid(v))))?;
})?;
Ok((input, value)) Ok((input, value))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use minisql::type_system::{IndexableValue, Value};
use crate::parsing::literal::{parse_db_value, parse_string, parse_uuid}; use crate::parsing::literal::{parse_db_value, parse_string, parse_uuid};
use minisql::type_system::{IndexableValue, Value};
#[test] #[test]
fn test_string_parser() { fn test_string_parser() {
assert_eq!(parse_string(r#""simple""#), Ok(("", Value::Indexable(IndexableValue::String(String::from("simple")))))); assert_eq!(
assert_eq!(parse_string(r#""\"\t\r\n\\""#), Ok(("", Value::Indexable(IndexableValue::String(String::from("\"\t\r\n\\")))))); parse_string(r#""simple""#),
assert_eq!(parse_string(r#""name is \"John\".""#), Ok(("", Value::Indexable(IndexableValue::String(String::from("name is \"John\".")))))); Ok((
"",
Value::Indexable(IndexableValue::String(String::from("simple")))
))
);
assert_eq!(
parse_string(r#""\"\t\r\n\\""#),
Ok((
"",
Value::Indexable(IndexableValue::String(String::from("\"\t\r\n\\")))
))
);
assert_eq!(
parse_string(r#""name is \"John\".""#),
Ok((
"",
Value::Indexable(IndexableValue::String(String::from("name is \"John\".")))
))
);
} }
#[test] #[test]
@ -132,39 +135,63 @@ mod tests {
assert_eq!(value, Value::Number(5.5)); assert_eq!(value, Value::Number(5.5));
let (_, _) = parse_db_value("\"STRING\"").expect("should parse"); let (_, _) = parse_db_value("\"STRING\"").expect("should parse");
let (input, value) = parse_db_value("\"abcdefghkjklmnopqrstuvwxyz!@#$%^&*()_+ \"").expect("should parse"); let (input, value) =
parse_db_value("\"abcdefghkjklmnopqrstuvwxyz!@#$%^&*()_+ \"").expect("should parse");
assert_eq!(input, ""); assert_eq!(input, "");
assert_eq!(value, Value::Indexable(IndexableValue::String("abcdefghkjklmnopqrstuvwxyz!@#$%^&*()_+ ".to_string()))); assert_eq!(
value,
Value::Indexable(IndexableValue::String(
"abcdefghkjklmnopqrstuvwxyz!@#$%^&*()_+ ".to_string()
))
);
} }
#[test] #[test]
fn test_parse_positive_float() { fn test_parse_positive_float() {
assert_eq!(parse_db_value("23.213313"), Ok(("", Value::Number(23.213313)))); assert_eq!(
assert_eq!(parse_db_value("2241.9734"), Ok(("", Value::Number(2241.9734)))); parse_db_value("23.213313"),
Ok(("", Value::Number(23.213313)))
);
assert_eq!(
parse_db_value("2241.9734"),
Ok(("", Value::Number(2241.9734)))
);
} }
#[test] #[test]
fn test_parse_negative_float() { fn test_parse_negative_float() {
assert_eq!(parse_db_value("-9241.873654"), Ok(("", Value::Number(-9241.873654)))); assert_eq!(
assert_eq!(parse_db_value("-62625.0"), Ok(("", Value::Number(-62625.0)))); parse_db_value("-9241.873654"),
Ok(("", Value::Number(-9241.873654)))
);
assert_eq!(
parse_db_value("-62625.0"),
Ok(("", Value::Number(-62625.0)))
);
} }
#[test] #[test]
fn test_parse_float_between_0_and_1() { fn test_parse_float_between_0_and_1() {
assert_eq!(parse_db_value("0.873654"), Ok(("", Value::Number(0.873654)))); assert_eq!(
parse_db_value("0.873654"),
Ok(("", Value::Number(0.873654)))
);
assert_eq!(parse_db_value("0.62625"), Ok(("", Value::Number(0.62625)))); assert_eq!(parse_db_value("0.62625"), Ok(("", Value::Number(0.62625))));
} }
#[test] #[test]
fn test_parse_int() { fn test_parse_int() {
assert_eq!(parse_db_value("5134616"), Ok(("", Value::Indexable(IndexableValue::Int(5134616))))); assert_eq!(
parse_db_value("5134616"),
Ok(("", Value::Indexable(IndexableValue::Int(5134616))))
);
} }
#[test] #[test]
fn test_parse_uuid() { fn test_parse_uuid() {
assert_eq!(parse_uuid("u131515"), Ok(("", Value::Indexable(IndexableValue::Uuid(131515))))) assert_eq!(
parse_uuid("u131515"),
Ok(("", Value::Indexable(IndexableValue::Uuid(131515))))
)
} }
} }

View file

@ -1,7 +1,7 @@
pub(crate) mod literal;
pub(crate) mod select;
pub(crate) mod common; pub(crate) mod common;
pub(crate) mod create; pub(crate) mod create;
pub(crate) mod insert;
pub(crate) mod delete; pub(crate) mod delete;
pub(crate) mod index; pub(crate) mod index;
pub(crate) mod insert;
pub(crate) mod literal;
pub(crate) mod select;

View file

@ -1,9 +1,9 @@
use super::common::{parse_table_name, parse_column_name, parse_condition}; use super::common::{parse_column_name, parse_condition, parse_table_name};
use crate::syntax::{ColumnSelection, RawQuerySyntax}; use crate::syntax::{ColumnSelection, RawQuerySyntax};
use nom::{ use nom::{
branch::alt, branch::alt,
bytes::complete::tag, bytes::complete::tag,
character::complete::{multispace0, multispace1, char}, character::complete::{char, multispace0, multispace1},
combinator::map, combinator::map,
error::Error, error::Error,
multi::separated_list0, multi::separated_list0,
@ -44,10 +44,12 @@ pub fn try_parse_column_selection(input: &str) -> IResult<&str, ColumnSelection>
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::parsing::{
common::{parse_column_name, parse_table_name},
select::parse_select,
};
use crate::syntax::{ColumnSelection, RawQuerySyntax}; use crate::syntax::{ColumnSelection, RawQuerySyntax};
use crate::parsing::{common::{parse_column_name, parse_table_name}, select::parse_select};
#[test] #[test]
fn test_parse_select_all() { fn test_parse_select_all() {
let sql = "SELECT * FROM \"MyTable\";"; let sql = "SELECT * FROM \"MyTable\";";

View file

@ -1,4 +1,7 @@
use minisql::{type_system::{Value, DbType}, schema::{ColumnName, TableName}}; use minisql::{
schema::{ColumnName, TableName},
type_system::{DbType, Value},
};
// ===Table Schema=== // ===Table Schema===
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -53,10 +56,16 @@ impl RawTableSchema {
} }
pub fn get_column(&self, column_name: &ColumnName) -> Option<ColumnSchema> { pub fn get_column(&self, column_name: &ColumnName) -> Option<ColumnSchema> {
self.columns.iter().find(|column_schema| column_name == &column_schema.column_name).cloned() self.columns
.iter()
.find(|column_schema| column_name == &column_schema.column_name)
.cloned()
} }
pub fn get_columns(&self) -> Vec<&ColumnName> { pub fn get_columns(&self) -> Vec<&ColumnName> {
self.columns.iter().map(|ColumnSchema { column_name, .. }| column_name).collect() self.columns
.iter()
.map(|ColumnSchema { column_name, .. }| column_name)
.collect()
} }
} }

View file

@ -1,10 +1,16 @@
use std::collections::{HashSet, BTreeMap}; use std::collections::{BTreeMap, HashSet};
use thiserror::Error; use thiserror::Error;
use crate::syntax; use crate::syntax;
use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax}; use crate::syntax::{ColumnSchema, RawQuerySyntax, RawTableSchema};
use minisql::operation; use minisql::operation;
use minisql::{operation::Operation, type_system::Value, schema::{TableSchema, ColumnName, Column, TableName, TablePosition}, type_system::DbType, interpreter::DbSchema}; use minisql::{
interpreter::DbSchema,
operation::Operation,
schema::{Column, ColumnName, TableName, TablePosition, TableSchema},
type_system::DbType,
type_system::Value,
};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ValidationError { pub enum ValidationError {
@ -29,37 +35,46 @@ pub enum ValidationError {
expected_type: DbType, expected_type: DbType,
}, },
#[error("values for required columns {0:?} are missing")] #[error("values for required columns {0:?} are missing")]
RequiredColumnsAreMissing(Vec<ColumnName>) RequiredColumnsAreMissing(Vec<ColumnName>),
} }
/// Validates and converts the raw syntax into a proper interpreter operation based on db schema. /// Validates and converts the raw syntax into a proper interpreter operation based on db schema.
pub fn validate_operation(syntax: RawQuerySyntax, db_schema: &DbSchema) -> Result<Operation, ValidationError> { pub fn validate_operation(
syntax: RawQuerySyntax,
db_schema: &DbSchema,
) -> Result<Operation, ValidationError> {
match syntax { match syntax {
RawQuerySyntax::Select(table_name, column_selection, condition) => { RawQuerySyntax::Select(table_name, column_selection, condition) => {
validate_select(table_name, column_selection, condition, db_schema) validate_select(table_name, column_selection, condition, db_schema)
}, }
RawQuerySyntax::Insert(table_name, insertion_values) => { RawQuerySyntax::Insert(table_name, insertion_values) => {
validate_insert(table_name, insertion_values, db_schema) validate_insert(table_name, insertion_values, db_schema)
}, }
RawQuerySyntax::Delete(table_name, condition) => { RawQuerySyntax::Delete(table_name, condition) => {
validate_delete(table_name, condition, db_schema) validate_delete(table_name, condition, db_schema)
}, }
RawQuerySyntax::CreateTable(schema) => { RawQuerySyntax::CreateTable(schema) => validate_create_table(schema, db_schema),
validate_create_table(schema, db_schema)
},
RawQuerySyntax::CreateIndex(table_name, column_name) => { RawQuerySyntax::CreateIndex(table_name, column_name) => {
validate_create_index(table_name, column_name, db_schema) validate_create_index(table_name, column_name, db_schema)
}, }
} }
} }
fn validate_table_exists<'a>(db_schema: &DbSchema<'a>, table_name: &'a TableName) -> Result<(TablePosition, &'a TableSchema), ValidationError> { fn validate_table_exists<'a>(
db_schema.iter().find(|(tname, _, _)| table_name.eq(tname)) db_schema: &DbSchema<'a>,
table_name: &'a TableName,
) -> Result<(TablePosition, &'a TableSchema), ValidationError> {
db_schema
.iter()
.find(|(tname, _, _)| table_name.eq(tname))
.ok_or(ValidationError::TableDoesNotExist(table_name.to_string())) .ok_or(ValidationError::TableDoesNotExist(table_name.to_string()))
.map(|(_, table_position, table_schema)| (*table_position, *table_schema)) .map(|(_, table_position, table_schema)| (*table_position, *table_schema))
} }
fn validate_create_table(raw_table_schema: RawTableSchema, db_schema: &DbSchema) -> Result<Operation, ValidationError> { fn validate_create_table(
raw_table_schema: RawTableSchema,
db_schema: &DbSchema,
) -> Result<Operation, ValidationError> {
let table_name: &TableName = &raw_table_schema.table_name; let table_name: &TableName = &raw_table_schema.table_name;
if get_table_schema(db_schema, table_name).is_some() { if get_table_schema(db_schema, table_name).is_some() {
return Err(ValidationError::TableAlreadyExists(table_name.to_string())); return Err(ValidationError::TableAlreadyExists(table_name.to_string()));
@ -71,16 +86,24 @@ fn validate_create_table(raw_table_schema: RawTableSchema, db_schema: &DbSchema)
fn validate_table_schema(raw_table_schema: RawTableSchema) -> Result<TableSchema, ValidationError> { fn validate_table_schema(raw_table_schema: RawTableSchema) -> Result<TableSchema, ValidationError> {
// check for duplicate columns // check for duplicate columns
find_first_duplicate(&raw_table_schema.get_columns()) find_first_duplicate(&raw_table_schema.get_columns()).map_or_else(
.map_or_else( || Ok(()),
|| Ok(()), |duplicate_column| {
|duplicate_column| Err(ValidationError::DuplicateColumn(duplicate_column.to_string())) Err(ValidationError::DuplicateColumn(
)?; duplicate_column.to_string(),
))
},
)?;
let mut primary_keys: Vec<(ColumnName, DbType)> = vec![]; let mut primary_keys: Vec<(ColumnName, DbType)> = vec![];
let mut columns: Vec<ColumnName> = vec![]; let mut columns: Vec<ColumnName> = vec![];
let mut types: Vec<DbType> = vec![]; let mut types: Vec<DbType> = vec![];
for ColumnSchema { column_name, type_, is_primary } in raw_table_schema.columns { for ColumnSchema {
column_name,
type_,
is_primary,
} in raw_table_schema.columns
{
if is_primary { if is_primary {
primary_keys.push((column_name.clone(), type_)) primary_keys.push((column_name.clone(), type_))
} }
@ -91,13 +114,22 @@ fn validate_table_schema(raw_table_schema: RawTableSchema) -> Result<TableSchema
// Ensure it has exactly one primary key that has correct type. // Ensure it has exactly one primary key that has correct type.
let number_of_primary_keys = primary_keys.len(); let number_of_primary_keys = primary_keys.len();
if number_of_primary_keys == 0 { if number_of_primary_keys == 0 {
Err(ValidationError::PrimaryKeyMissing(raw_table_schema.table_name.clone())) Err(ValidationError::PrimaryKeyMissing(
raw_table_schema.table_name.clone(),
))
} else if number_of_primary_keys > 1 { } else if number_of_primary_keys > 1 {
Err(ValidationError::MultiplePrimaryKeysFound(raw_table_schema.table_name.clone())) Err(ValidationError::MultiplePrimaryKeysFound(
raw_table_schema.table_name.clone(),
))
} else { } else {
let (primary_column_name, primary_key_type) = primary_keys[0].clone(); let (primary_column_name, primary_key_type) = primary_keys[0].clone();
if primary_key_type == DbType::Uuid { if primary_key_type == DbType::Uuid {
Ok(TableSchema::new(raw_table_schema.table_name, primary_column_name, columns, types)) Ok(TableSchema::new(
raw_table_schema.table_name,
primary_column_name,
columns,
types,
))
} else { } else {
Err(ValidationError::TypeMismatch { Err(ValidationError::TypeMismatch {
column_name: raw_table_schema.table_name.clone(), column_name: raw_table_schema.table_name.clone(),
@ -108,121 +140,189 @@ fn validate_table_schema(raw_table_schema: RawTableSchema) -> Result<TableSchema
} }
} }
fn validate_select(table_name: TableName, column_selection: syntax::ColumnSelection, condition: Option<syntax::Condition>, db_schema: &DbSchema) -> Result<Operation, ValidationError> { fn validate_select(
table_name: TableName,
column_selection: syntax::ColumnSelection,
condition: Option<syntax::Condition>,
db_schema: &DbSchema,
) -> Result<Operation, ValidationError> {
let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; let (table_position, schema) = validate_table_exists(db_schema, &table_name)?;
match column_selection { match column_selection {
syntax::ColumnSelection::Columns(columns) => { syntax::ColumnSelection::Columns(columns) => {
let non_existant_columns: Vec<ColumnName> = let non_existant_columns: Vec<ColumnName> = columns
columns.iter().filter_map(|column| .iter()
.filter_map(|column| {
if schema.does_column_exist(column) { if schema.does_column_exist(column) {
None None
} else { } else {
Some(column.clone()) Some(column.clone())
}).collect(); }
})
.collect();
if non_existant_columns.is_empty() { if non_existant_columns.is_empty() {
let selection: operation::ColumnSelection = let selection: operation::ColumnSelection = columns
columns.iter().filter_map(|column_name| schema.get_column(column_name)).collect(); .iter()
.filter_map(|column_name| schema.get_column(column_name))
.collect();
let validated_condition = validate_condition(condition, schema)?; let validated_condition = validate_condition(condition, schema)?;
Ok(Operation::Select(table_position, selection, validated_condition)) Ok(Operation::Select(
table_position,
selection,
validated_condition,
))
} else { } else {
Err(ValidationError::ColumnsDoNotExist(non_existant_columns)) Err(ValidationError::ColumnsDoNotExist(non_existant_columns))
} }
} }
syntax::ColumnSelection::All => { syntax::ColumnSelection::All => {
let validated_condition = validate_condition(condition, schema)?; let validated_condition = validate_condition(condition, schema)?;
Ok(Operation::Select(table_position, schema.all_selection(), validated_condition)) Ok(Operation::Select(
table_position,
schema.all_selection(),
validated_condition,
))
} }
} }
} }
fn validate_insert(table_name: TableName, insertion_values: syntax::InsertionValues, db_schema: &DbSchema) -> Result<Operation, ValidationError> { fn validate_insert(
table_name: TableName,
insertion_values: syntax::InsertionValues,
db_schema: &DbSchema,
) -> Result<Operation, ValidationError> {
let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; let (table_position, schema) = validate_table_exists(db_schema, &table_name)?;
// Check for duplicate columns in insertion_values. // Check for duplicate columns in insertion_values.
let columns_in_query_vec: Vec<&ColumnName> = insertion_values.iter().map(|(column_name, _)| column_name).collect(); let columns_in_query_vec: Vec<&ColumnName> = insertion_values
find_first_duplicate(&columns_in_query_vec) .iter()
.map_or_else( .map(|(column_name, _)| column_name)
|| Ok(()), .collect();
|duplicate_column| Err(ValidationError::DuplicateColumn(duplicate_column.to_string())) find_first_duplicate(&columns_in_query_vec).map_or_else(
)?; || Ok(()),
|duplicate_column| {
Err(ValidationError::DuplicateColumn(
duplicate_column.to_string(),
))
},
)?;
// Check that the set of columns in the insertion_values is the same as the set of required columns of the table. // Check that the set of columns in the insertion_values is the same as the set of required columns of the table.
let columns_in_query: HashSet<&ColumnName> = HashSet::from_iter(columns_in_query_vec); let columns_in_query: HashSet<&ColumnName> = HashSet::from_iter(columns_in_query_vec);
let columns_in_schema: HashSet<&ColumnName> = HashSet::from_iter(schema.get_columns()); let columns_in_schema: HashSet<&ColumnName> = HashSet::from_iter(schema.get_columns());
let non_existant_columns = Vec::from_iter(columns_in_query.difference(&columns_in_schema)); let non_existant_columns = Vec::from_iter(columns_in_query.difference(&columns_in_schema));
if !non_existant_columns.is_empty() { if !non_existant_columns.is_empty() {
return Err(ValidationError::ColumnsDoNotExist(non_existant_columns.iter().map(|column_name| column_name.to_string()).collect())); return Err(ValidationError::ColumnsDoNotExist(
non_existant_columns
.iter()
.map(|column_name| column_name.to_string())
.collect(),
));
} }
let missing_required_columns = Vec::from_iter(columns_in_schema.difference(&columns_in_query)); let missing_required_columns = Vec::from_iter(columns_in_schema.difference(&columns_in_query));
if !missing_required_columns.is_empty() { if !missing_required_columns.is_empty() {
return Err(ValidationError::RequiredColumnsAreMissing(missing_required_columns.iter().map(|str| str.to_string()).collect())); return Err(ValidationError::RequiredColumnsAreMissing(
missing_required_columns
.iter()
.map(|str| str.to_string())
.collect(),
));
} }
// Check types and prepare for creation of InsertionValues for the interpreter // Check types and prepare for creation of InsertionValues for the interpreter
let mut values_map: BTreeMap<Column, Value> = BTreeMap::new(); // The reason for using BTreeMap let mut values_map: BTreeMap<Column, Value> = BTreeMap::new(); // The reason for using BTreeMap
// instead of HashMap is that we need // instead of HashMap is that we need
// to get the values in a vector // to get the values in a vector
// sorted by the key. // sorted by the key.
for (column_name, value) in insertion_values { for (column_name, value) in insertion_values {
let (column, expected_type) = schema.get_typed_column(&column_name).ok_or(ValidationError::ColumnsDoNotExist(vec![column_name.to_string()]))?; // By the previous validation steps this is never gonna trigger an error. let (column, expected_type) =
schema
.get_typed_column(&column_name)
.ok_or(ValidationError::ColumnsDoNotExist(vec![
column_name.to_string()
]))?; // By the previous validation steps this is never gonna trigger an error.
let value_type = value.to_type(); let value_type = value.to_type();
if value_type != expected_type { if value_type != expected_type {
return Err(ValidationError::TypeMismatch { column_name: column_name.to_string(), received_type: value_type, expected_type }); return Err(ValidationError::TypeMismatch {
column_name: column_name.to_string(),
received_type: value_type,
expected_type,
});
} }
values_map.insert(column, value); values_map.insert(column, value);
} }
// WARNING: If you use `values_map: HashMap<_,_>`, this is not gonna sort values by key. // WARNING: If you use `values_map: HashMap<_,_>`, this is not gonna sort values by key.
let values: operation::InsertionValues = values_map.into_values().collect(); let values: operation::InsertionValues = values_map.into_values().collect();
// Note that one of the values is id. // Note that one of the values is id.
Ok(Operation::Insert(table_position, values)) Ok(Operation::Insert(table_position, values))
} }
fn validate_delete(table_name: TableName, condition: Option<syntax::Condition>, db_schema: &DbSchema) -> Result<Operation, ValidationError> { fn validate_delete(
table_name: TableName,
condition: Option<syntax::Condition>,
db_schema: &DbSchema,
) -> Result<Operation, ValidationError> {
let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; let (table_position, schema) = validate_table_exists(db_schema, &table_name)?;
let validated_condition = validate_condition(condition, schema)?; let validated_condition = validate_condition(condition, schema)?;
Ok(Operation::Delete(table_position, validated_condition)) Ok(Operation::Delete(table_position, validated_condition))
} }
fn validate_condition(condition: Option<syntax::Condition>, schema: &TableSchema) -> Result<Option<operation::Condition>, ValidationError> { fn validate_condition(
condition: Option<syntax::Condition>,
schema: &TableSchema,
) -> Result<Option<operation::Condition>, ValidationError> {
match condition { match condition {
Some(condition) => { Some(condition) => match condition {
match condition { syntax::Condition::Eq(column_name, value) => {
syntax::Condition::Eq(column_name, value) => { let (column, expected_type) = schema.get_typed_column(&column_name).ok_or(
let (column, expected_type) = schema.get_typed_column(&column_name).ok_or(ValidationError::ColumnsDoNotExist(vec![column_name.to_string()]))?; ValidationError::ColumnsDoNotExist(vec![column_name.to_string()]),
let value_type: DbType = value.to_type(); )?;
if expected_type.eq(&value_type) { let value_type: DbType = value.to_type();
Ok(Some(operation::Condition::Eq(column, value))) if expected_type.eq(&value_type) {
} else { Ok(Some(operation::Condition::Eq(column, value)))
Err(ValidationError::TypeMismatch { column_name: column_name.to_string(), received_type: value_type, expected_type }) } else {
} Err(ValidationError::TypeMismatch {
column_name: column_name.to_string(),
received_type: value_type,
expected_type,
})
} }
} }
} },
None => Ok(None) None => Ok(None),
} }
} }
fn validate_create_index(table_name: TableName, column_name: ColumnName, db_schema: &DbSchema) -> Result<Operation, ValidationError> { fn validate_create_index(
table_name: TableName,
column_name: ColumnName,
db_schema: &DbSchema,
) -> Result<Operation, ValidationError> {
let (table_position, schema) = validate_table_exists(db_schema, &table_name)?; let (table_position, schema) = validate_table_exists(db_schema, &table_name)?;
schema schema.get_typed_column(&column_name).map_or_else(
.get_typed_column(&column_name) || {
.map_or_else( Err(ValidationError::ColumnsDoNotExist(vec![
|| Err(ValidationError::ColumnsDoNotExist(vec![column_name.to_string()])), column_name.to_string()
|(column, type_)| { ]))
if type_.is_indexable() { },
Ok(Operation::CreateIndex(table_position, column)) |(column, type_)| {
} else { if type_.is_indexable() {
Err(ValidationError::AttemptToIndexNonIndexableColumn(column_name.clone(), table_name)) Ok(Operation::CreateIndex(table_position, column))
} } else {
Err(ValidationError::AttemptToIndexNonIndexableColumn(
column_name.clone(),
table_name,
))
} }
) },
)
} }
// ===Helpers=== // ===Helpers===
fn find_first_duplicate<T>(ts: &[T]) -> Option<&T> fn find_first_duplicate<T>(ts: &[T]) -> Option<&T>
where T: Eq + std::hash::Hash where
T: Eq + std::hash::Hash,
{ {
let mut already_seen_elements: HashSet<&T> = HashSet::new(); let mut already_seen_elements: HashSet<&T> = HashSet::new();
for t in ts { for t in ts {
@ -235,34 +335,35 @@ where T: Eq + std::hash::Hash
None None
} }
fn get_table_schema<'a>(db_schema: &DbSchema<'a>, table_name: &'a TableName) -> Option<&'a TableSchema> { fn get_table_schema<'a>(
let (_, _, table_schema) = db_schema.iter().find(|(tname, _, _)| table_name.eq(tname))?; db_schema: &DbSchema<'a>,
table_name: &'a TableName,
) -> Option<&'a TableSchema> {
let (_, _, table_schema) = db_schema
.iter()
.find(|(tname, _, _)| table_name.eq(tname))?;
Some(table_schema) Some(table_schema)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::syntax::{RawTableSchema, ColumnSchema, RawQuerySyntax, ColumnSelection, Condition};
use minisql::type_system::{Value, IndexableValue};
use minisql::operation::Operation;
use minisql::operation;
use minisql::schema::TableSchema;
use super::*; use super::*;
use crate::syntax::{ColumnSchema, ColumnSelection, Condition, RawQuerySyntax, RawTableSchema};
use minisql::operation;
use minisql::operation::Operation;
use minisql::schema::TableSchema;
use minisql::type_system::{IndexableValue, Value};
use Condition::*;
use IndexableValue::*;
use RawQuerySyntax::*; use RawQuerySyntax::*;
use Value::*; use Value::*;
use IndexableValue::*;
use Condition::*;
fn users_schema() -> TableSchema { fn users_schema() -> TableSchema {
TableSchema::new( TableSchema::new(
"users".to_string(), "users".to_string(),
"id".to_string(), "id".to_string(),
vec!( vec!["id".to_string(), "name".to_string(), "age".to_string()],
"id".to_string(),
"name".to_string(),
"age".to_string(),
),
vec![DbType::Uuid, DbType::String, DbType::Int], vec![DbType::Uuid, DbType::String, DbType::Int],
) )
} }
@ -271,18 +372,27 @@ mod tests {
RawTableSchema { RawTableSchema {
table_name: "users".to_string(), table_name: "users".to_string(),
columns: vec![ columns: vec![
ColumnSchema { column_name: "id".to_string(), type_: DbType::Uuid, is_primary: true }, ColumnSchema {
ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, column_name: "id".to_string(),
ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, type_: DbType::Uuid,
is_primary: true,
},
ColumnSchema {
column_name: "name".to_string(),
type_: DbType::String,
is_primary: false,
},
ColumnSchema {
column_name: "age".to_string(),
type_: DbType::Int,
is_primary: false,
},
], ],
} }
} }
fn db_schema(users_schema: &TableSchema) -> DbSchema { fn db_schema(users_schema: &TableSchema) -> DbSchema {
vec![ vec![("users".to_string(), 0, users_schema)]
("users".to_string(), 0, users_schema),
]
} }
fn empty_db_schema() -> DbSchema<'static> { fn empty_db_schema() -> DbSchema<'static> {
@ -297,7 +407,9 @@ mod tests {
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Ok(Operation::CreateTable(_)))); assert!(matches!(result, Ok(Operation::CreateTable(_))));
let Ok(Operation::CreateTable(schema)) = result else { panic!() }; let Ok(Operation::CreateTable(schema)) = result else {
panic!()
};
assert!(schema.table_name() == "users"); assert!(schema.table_name() == "users");
} }
@ -306,9 +418,21 @@ mod tests {
let raw_users_schema = RawTableSchema { let raw_users_schema = RawTableSchema {
table_name: "users".to_string(), table_name: "users".to_string(),
columns: vec![ columns: vec![
ColumnSchema { column_name: "id".to_string(), type_: DbType::Uuid, is_primary: true }, ColumnSchema {
ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, column_name: "id".to_string(),
ColumnSchema { column_name: "name".to_string(), type_: DbType::Number, is_primary: false }, type_: DbType::Uuid,
is_primary: true,
},
ColumnSchema {
column_name: "name".to_string(),
type_: DbType::String,
is_primary: false,
},
ColumnSchema {
column_name: "name".to_string(),
type_: DbType::Number,
is_primary: false,
},
], ],
}; };
@ -325,9 +449,21 @@ mod tests {
let raw_users_schema = RawTableSchema { let raw_users_schema = RawTableSchema {
table_name: "users".to_string(), table_name: "users".to_string(),
columns: vec![ columns: vec![
ColumnSchema { column_name: "id".to_string(), type_: DbType::Int, is_primary: true }, ColumnSchema {
ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: false }, column_name: "id".to_string(),
ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, type_: DbType::Int,
is_primary: true,
},
ColumnSchema {
column_name: "name".to_string(),
type_: DbType::String,
is_primary: false,
},
ColumnSchema {
column_name: "age".to_string(),
type_: DbType::Int,
is_primary: false,
},
], ],
}; };
@ -343,9 +479,21 @@ mod tests {
let raw_users_schema = RawTableSchema { let raw_users_schema = RawTableSchema {
table_name: "users".to_string(), table_name: "users".to_string(),
columns: vec![ columns: vec![
ColumnSchema { column_name: "id".to_string(), type_: DbType::Int, is_primary: true }, ColumnSchema {
ColumnSchema { column_name: "name".to_string(), type_: DbType::String, is_primary: true }, column_name: "id".to_string(),
ColumnSchema { column_name: "age".to_string(), type_: DbType::Int, is_primary: false }, type_: DbType::Int,
is_primary: true,
},
ColumnSchema {
column_name: "name".to_string(),
type_: DbType::String,
is_primary: true,
},
ColumnSchema {
column_name: "age".to_string(),
type_: DbType::Int,
is_primary: false,
},
], ],
}; };
@ -353,7 +501,10 @@ mod tests {
let syntax: RawQuerySyntax = CreateTable(raw_users_schema); let syntax: RawQuerySyntax = CreateTable(raw_users_schema);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Err(ValidationError::MultiplePrimaryKeysFound(_)))); assert!(matches!(
result,
Err(ValidationError::MultiplePrimaryKeysFound(_))
));
} }
#[test] #[test]
@ -363,7 +514,10 @@ mod tests {
let syntax: RawQuerySyntax = CreateTable(raw_users_schema()); let syntax: RawQuerySyntax = CreateTable(raw_users_schema());
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Err(ValidationError::TableAlreadyExists(_)))); assert!(matches!(
result,
Err(ValidationError::TableAlreadyExists(_))
));
} }
// ====Select==== // ====Select====
@ -380,7 +534,9 @@ mod tests {
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Ok(Operation::Select(_, _, _)))); assert!(matches!(result, Ok(Operation::Select(_, _, _))));
let Ok(Operation::Select(table_position, column_selection, condition)) = result else { panic!() }; let Ok(Operation::Select(table_position, column_selection, condition)) = result else {
panic!()
};
assert!(table_position == users_position); assert!(table_position == users_position);
assert!(condition == None); assert!(condition == None);
@ -392,7 +548,8 @@ mod tests {
let users_schema: TableSchema = users_schema(); let users_schema: TableSchema = users_schema();
let db_schema: DbSchema = db_schema(&users_schema); let db_schema: DbSchema = db_schema(&users_schema);
let syntax: RawQuerySyntax = Select("does_not_exist".to_string(), ColumnSelection::All, None); let syntax: RawQuerySyntax =
Select("does_not_exist".to_string(), ColumnSelection::All, None);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Err(ValidationError::TableDoesNotExist(_)))); assert!(matches!(result, Err(ValidationError::TableDoesNotExist(_))));
} }
@ -407,11 +564,17 @@ mod tests {
let name = 1; let name = 1;
let age = 2; let age = 2;
let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, Some(Eq("age".to_string(), Indexable(Int(25))))); let syntax: RawQuerySyntax = Select(
"users".to_string(),
ColumnSelection::All,
Some(Eq("age".to_string(), Indexable(Int(25)))),
);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Ok(Operation::Select(_, _, _)))); assert!(matches!(result, Ok(Operation::Select(_, _, _))));
let Ok(Operation::Select(table_position, column_selection, condition)) = result else { panic!() }; let Ok(Operation::Select(table_position, column_selection, condition)) = result else {
panic!()
};
assert!(table_position == users_position); assert!(table_position == users_position);
assert!(column_selection == vec![id, name, age]); assert!(column_selection == vec![id, name, age]);
@ -428,11 +591,21 @@ mod tests {
let name = 1; let name = 1;
let age = 2; let age = 2;
let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::Columns(vec!["age".to_string(), "name".to_string(), "age".to_string()]), None); let syntax: RawQuerySyntax = Select(
"users".to_string(),
ColumnSelection::Columns(vec![
"age".to_string(),
"name".to_string(),
"age".to_string(),
]),
None,
);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Ok(Operation::Select(_, _, _)))); assert!(matches!(result, Ok(Operation::Select(_, _, _))));
let Ok(Operation::Select(table_position, column_selection, condition)) = result else { panic!() }; let Ok(Operation::Select(table_position, column_selection, condition)) = result else {
panic!()
};
assert!(table_position == users_position); assert!(table_position == users_position);
assert!(column_selection == vec![age, name, age]); assert!(column_selection == vec![age, name, age]);
@ -444,7 +617,11 @@ mod tests {
let users_schema: TableSchema = users_schema(); let users_schema: TableSchema = users_schema();
let db_schema: DbSchema = db_schema(&users_schema); let db_schema: DbSchema = db_schema(&users_schema);
let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::Columns(vec!["age".to_string(), "does_not_exist".to_string()]), None); let syntax: RawQuerySyntax = Select(
"users".to_string(),
ColumnSelection::Columns(vec!["age".to_string(), "does_not_exist".to_string()]),
None,
);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_))));
} }
@ -454,7 +631,11 @@ mod tests {
let users_schema: TableSchema = users_schema(); let users_schema: TableSchema = users_schema();
let db_schema: DbSchema = db_schema(&users_schema); let db_schema: DbSchema = db_schema(&users_schema);
let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, Some(Eq("does_not_exist".to_string(), Indexable(Int(25))))); let syntax: RawQuerySyntax = Select(
"users".to_string(),
ColumnSelection::All,
Some(Eq("does_not_exist".to_string(), Indexable(Int(25)))),
);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_))));
} }
@ -464,7 +645,11 @@ mod tests {
let users_schema: TableSchema = users_schema(); let users_schema: TableSchema = users_schema();
let db_schema: DbSchema = db_schema(&users_schema); let db_schema: DbSchema = db_schema(&users_schema);
let syntax: RawQuerySyntax = Select("users".to_string(), ColumnSelection::All, Some(Eq("age".to_string(), Indexable(String("25".to_string()))))); let syntax: RawQuerySyntax = Select(
"users".to_string(),
ColumnSelection::All,
Some(Eq("age".to_string(), Indexable(String("25".to_string())))),
);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); assert!(matches!(result, Err(ValidationError::TypeMismatch { .. })));
} }
@ -483,18 +668,28 @@ mod tests {
("name".to_string(), Indexable(String("Alice".to_string()))), ("name".to_string(), Indexable(String("Alice".to_string()))),
("id".to_string(), Indexable(Uuid(0))), ("id".to_string(), Indexable(Uuid(0))),
("age".to_string(), Indexable(Int(25))), ("age".to_string(), Indexable(Int(25))),
]); ],
);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Ok(Operation::Insert(_, _)))); assert!(matches!(result, Ok(Operation::Insert(_, _))));
let Ok(Operation::Insert(table_position, values)) = result else { panic!() }; let Ok(Operation::Insert(table_position, values)) = result else {
panic!()
};
assert!(table_position == users_position); assert!(table_position == users_position);
// Recall the order is // Recall the order is
// let id = 0; // let id = 0;
// let name = 1; // let name = 1;
// let age = 2; // let age = 2;
assert!(values == vec![Indexable(Uuid(0)), Indexable(String("Alice".to_string())), Indexable(Int(25))]); assert!(
values
== vec![
Indexable(Uuid(0)),
Indexable(String("Alice".to_string())),
Indexable(Int(25))
]
);
} }
#[test] #[test]
@ -509,7 +704,8 @@ mod tests {
("id".to_string(), Indexable(Uuid(0))), ("id".to_string(), Indexable(Uuid(0))),
("age".to_string(), Indexable(Int(25))), ("age".to_string(), Indexable(Int(25))),
("does_not_exist".to_string(), Indexable(Int(25))), ("does_not_exist".to_string(), Indexable(Int(25))),
]); ],
);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_)))); assert!(matches!(result, Err(ValidationError::ColumnsDoNotExist(_))));
} }
@ -525,7 +721,8 @@ mod tests {
("name".to_string(), Indexable(String("Alice".to_string()))), ("name".to_string(), Indexable(String("Alice".to_string()))),
("id".to_string(), Indexable(Uuid(0))), ("id".to_string(), Indexable(Uuid(0))),
("age".to_string(), Number(25.0)), ("age".to_string(), Number(25.0)),
]); ],
);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); assert!(matches!(result, Err(ValidationError::TypeMismatch { .. })));
} }
@ -542,7 +739,9 @@ mod tests {
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Ok(Operation::Delete(_, None)))); assert!(matches!(result, Ok(Operation::Delete(_, None))));
let Ok(Operation::Delete(table_position, _)) = result else { panic!() }; let Ok(Operation::Delete(table_position, _)) = result else {
panic!()
};
assert!(table_position == users_position); assert!(table_position == users_position);
} }
@ -555,11 +754,21 @@ mod tests {
let users_position = 0; let users_position = 0;
let age = 2; let age = 2;
let syntax: RawQuerySyntax = Delete("users".to_string(), Some(Eq("age".to_string(), Indexable(Int(25))))); let syntax: RawQuerySyntax = Delete(
"users".to_string(),
Some(Eq("age".to_string(), Indexable(Int(25)))),
);
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Ok(Operation::Delete(_, Some(operation::Condition::Eq(_, _)))))); assert!(matches!(
result,
Ok(Operation::Delete(_, Some(operation::Condition::Eq(_, _))))
));
let Ok(Operation::Delete(table_position, Some(operation::Condition::Eq(column, value)))) = result else { panic!() }; let Ok(Operation::Delete(table_position, Some(operation::Condition::Eq(column, value)))) =
result
else {
panic!()
};
assert!(table_position == users_position); assert!(table_position == users_position);
assert!(column == age); assert!(column == age);
@ -579,7 +788,9 @@ mod tests {
let result = validate_operation(syntax, &db_schema); let result = validate_operation(syntax, &db_schema);
assert!(matches!(result, Ok(Operation::CreateIndex(_, _)))); assert!(matches!(result, Ok(Operation::CreateIndex(_, _))));
let Ok(Operation::CreateIndex(table_position, column)) = result else { panic!() }; let Ok(Operation::CreateIndex(table_position, column)) = result else {
panic!()
};
assert!(table_position == users_position); assert!(table_position == users_position);
assert!(column == age); assert!(column == age);

View file

@ -14,7 +14,6 @@ pub async fn do_client_handshake(
reader: &mut impl BackendProtoReader, reader: &mut impl BackendProtoReader,
request: HandshakeRequest, request: HandshakeRequest,
) -> Result<HandshakeResponse, ClientHandshakeError> { ) -> Result<HandshakeResponse, ClientHandshakeError> {
// Send StartupMessage without SSLRequest // Send StartupMessage without SSLRequest
let startup_message: StartupMessageData = request.into(); let startup_message: StartupMessageData = request.into();
writer.write_startup_message(startup_message).await?; writer.write_startup_message(startup_message).await?;

View file

@ -1,10 +1,10 @@
use crate::message::backend::BackendMessage; use crate::message::backend::BackendMessage;
use crate::message::errors::ProtoDeserializeError; use crate::message::errors::ProtoDeserializeError;
use crate::message::special::CancelRequestData;
use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError}; use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError};
use crate::writer::errors::ProtoWriteError; use crate::writer::errors::ProtoWriteError;
use thiserror::Error; use thiserror::Error;
use tokio::io; use tokio::io;
use crate::message::special::CancelRequestData;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ClientHandshakeError { pub enum ClientHandshakeError {

View file

@ -8,7 +8,6 @@ pub struct HandshakeRequest {
} }
impl HandshakeRequest { impl HandshakeRequest {
/// Creates a new `HandshakeRequest` with the specified protocol version. /// Creates a new `HandshakeRequest` with the specified protocol version.
/// Expected `version` is `196608` for the 3.0. /// Expected `version` is `196608` for the 3.0.
pub fn new(version: i32) -> Self { pub fn new(version: i32) -> Self {

View file

@ -15,7 +15,6 @@ pub async fn do_server_handshake(
reader: &mut impl FrontendProtoReader, reader: &mut impl FrontendProtoReader,
response: HandshakeResponse, response: HandshakeResponse,
) -> Result<HandshakeRequest, ServerHandshakeError> { ) -> Result<HandshakeRequest, ServerHandshakeError> {
// Check if client requested SSL // Check if client requested SSL
match &reader.peek_special_message().await? { match &reader.peek_special_message().await? {
Some(msg @ SpecialMessage::SSLRequest) => { Some(msg @ SpecialMessage::SSLRequest) => {

View file

@ -1,6 +1,6 @@
use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError};
use bincode::{Decode, Encode};
use bincode::config::{BigEndian, Configuration, Fixint}; use bincode::config::{BigEndian, Configuration, Fixint};
use bincode::{Decode, Encode};
fn pg_proto_config() -> Configuration<BigEndian, Fixint> { fn pg_proto_config() -> Configuration<BigEndian, Fixint> {
bincode::config::standard() bincode::config::standard()

View file

@ -50,16 +50,17 @@ impl Decode for PgString {
bytes.push(byte); bytes.push(byte);
} }
let string = String::from_utf8(bytes) let string = String::from_utf8(bytes).map_err(|e| DecodeError::Utf8 {
.map_err(|e| DecodeError::Utf8 { inner: e.utf8_error() })?; inner: e.utf8_error(),
})?;
Ok(PgString(string)) Ok(PgString(string))
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::message::primitive::data::MessageData;
use super::*; use super::*;
use crate::message::primitive::data::MessageData;
#[test] #[test]
fn test_encode_decode_utf8() { fn test_encode_decode_utf8() {

View file

@ -1,5 +1,5 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
pub struct ResetCancelToken { pub struct ResetCancelToken {
is_canceled: Arc<AtomicBool>, is_canceled: Arc<AtomicBool>,
@ -31,4 +31,4 @@ impl Clone for ResetCancelToken {
is_canceled: self.is_canceled.clone(), is_canceled: self.is_canceled.clone(),
} }
} }
} }

View file

@ -1,6 +1,6 @@
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::PathBuf; use std::path::PathBuf;
use clap::Parser;
const LOCAL_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); const LOCAL_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
@ -9,7 +9,12 @@ const LOCAL_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
pub struct Configuration { pub struct Configuration {
#[arg(short, long, default_value_t = LOCAL_IPV4, help = "IP address for the server to listen on")] #[arg(short, long, default_value_t = LOCAL_IPV4, help = "IP address for the server to listen on")]
address: IpAddr, address: IpAddr,
#[arg(short, long, default_value = "5432", help = "Port for the server to listen on")] #[arg(
short,
long,
default_value = "5432",
help = "Port for the server to listen on"
)]
port: u16, port: u16,
#[arg(short, long, help = "Path to the data file")] #[arg(short, long, help = "Path to the data file")]
file: PathBuf, file: PathBuf,
@ -25,4 +30,4 @@ impl Configuration {
pub fn get_file_path(&self) -> &PathBuf { pub fn get_file_path(&self) -> &PathBuf {
&self.file &self.file
} }
} }

View file

@ -24,10 +24,10 @@ use crate::config::Configuration;
use crate::persistence::state_to_file; use crate::persistence::state_to_file;
use crate::proto_wrapper::{CompleteStatus, ServerProto}; use crate::proto_wrapper::{CompleteStatus, ServerProto};
mod config;
mod proto_wrapper;
mod cancellation; mod cancellation;
mod config;
mod persistence; mod persistence;
mod proto_wrapper;
type TokenStore = Arc<Mutex<HashMap<(i32, i32), ResetCancelToken>>>; type TokenStore = Arc<Mutex<HashMap<(i32, i32), ResetCancelToken>>>;
type SharedDbState = Arc<RwLock<State>>; type SharedDbState = Arc<RwLock<State>>;
@ -65,16 +65,17 @@ async fn get_state(config: &Configuration) -> anyhow::Result<State> {
println!("WARNING: No DB state file found, creating new one"); println!("WARNING: No DB state file found, creating new one");
Ok(State::new()) Ok(State::new())
} }
Err(e) => { Err(e) => Err(e)?,
Err(e)? Ok(state) => Ok(state),
}
Ok(state) => {
Ok(state)
}
} }
} }
async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore, config: Arc<Configuration>) -> anyhow::Result<()> { async fn handle_stream(
mut stream: TcpStream,
state: SharedDbState,
tokens: TokenStore,
config: Arc<Configuration>,
) -> anyhow::Result<()> {
let (reader, writer) = stream.split(); let (reader, writer) = stream.split();
let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut writer = ProtoWriter::new(BufWriter::new(writer));
let mut reader = ProtoReader::new(BufReader::new(reader), 1024); let mut reader = ProtoReader::new(BufReader::new(reader), 1024);
@ -88,7 +89,9 @@ async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: Toke
let result = match request { let result = match request {
Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token, config).await, Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token, config).await,
Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await, Err(ServerHandshakeError::IsCancelRequest(cancel)) => {
handle_cancellation(cancel.pid, cancel.secret, &tokens).await
}
Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)), Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)),
}; };
@ -134,10 +137,17 @@ async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow:
Ok(()) Ok(())
} }
async fn handle_connection<R, W>(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken, config: Arc<Configuration>) -> anyhow::Result<()> async fn handle_connection<R, W>(
where reader: &mut R,
R: FrontendProtoReader + Send, writer: &mut W,
W: BackendProtoWriter + ProtoFlush + Send, request: HandshakeRequest,
state: SharedDbState,
token: ResetCancelToken,
config: Arc<Configuration>,
) -> anyhow::Result<()>
where
R: FrontendProtoReader + Send,
W: BackendProtoWriter + ProtoFlush + Send,
{ {
println!("Client connected: {:?}", request); println!("Client connected: {:?}", request);
@ -152,9 +162,7 @@ async fn handle_connection<R, W>(reader: &mut R, writer: &mut W, request: Handsh
let result = handle_query(writer, &state, data.query.into(), &token, &config).await; let result = handle_query(writer, &state, data.query.into(), &token, &config).await;
match result { match result {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => writer.write_error_message(&e.to_string()).await?,
writer.write_error_message(&e.to_string()).await?
}
} }
writer.write_ready_for_query().await?; writer.write_ready_for_query().await?;
} }
@ -165,9 +173,15 @@ async fn handle_connection<R, W>(reader: &mut R, writer: &mut W, request: Handsh
Ok(()) Ok(())
} }
async fn handle_query<W>(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken, config: &Arc<Configuration>) -> anyhow::Result<()> async fn handle_query<W>(
where writer: &mut W,
W: BackendProtoWriter + ProtoFlush + Send, state: &SharedDbState,
query: String,
token: &ResetCancelToken,
config: &Arc<Configuration>,
) -> anyhow::Result<()>
where
W: BackendProtoWriter + ProtoFlush + Send,
{ {
// Make sure token is reset before next query // Make sure token is reset before next query
token.reset(); token.reset();
@ -184,11 +198,15 @@ async fn handle_query<W>(writer: &mut W, state: &SharedDbState, query: String, t
match response { match response {
Response::Deleted(i) => { Response::Deleted(i) => {
writer.write_command_complete(CompleteStatus::Delete(i)).await?; writer
.write_command_complete(CompleteStatus::Delete(i))
.await?;
true true
} }
Response::Inserted => { Response::Inserted => {
writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?; writer
.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 })
.await?;
true true
} }
Response::Selected(schema, columns, mut rows) => { Response::Selected(schema, columns, mut rows) => {
@ -207,22 +225,30 @@ async fn handle_query<W>(writer: &mut W, state: &SharedDbState, query: String, t
} }
} }
writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; writer
.write_command_complete(CompleteStatus::Select(sent_rows))
.await?;
} }
_ => { _ => {
writer.write_command_complete(CompleteStatus::Select(0)).await?; writer
.write_command_complete(CompleteStatus::Select(0))
.await?;
} }
} }
false false
}, }
Response::TableCreated => { Response::TableCreated => {
writer.write_command_complete(CompleteStatus::CreateTable).await?; writer
.write_command_complete(CompleteStatus::CreateTable)
.await?;
true true
}, }
Response::IndexCreated => { Response::IndexCreated => {
writer.write_command_complete(CompleteStatus::CreateIndex).await?; writer
.write_command_complete(CompleteStatus::CreateIndex)
.await?;
true true
}, }
} }
}; };

View file

@ -1,6 +1,6 @@
use minisql::interpreter::State;
use std::path::PathBuf; use std::path::PathBuf;
use tokio::{fs, io}; use tokio::{fs, io};
use minisql::interpreter::State;
pub async fn state_from_file(path: &PathBuf) -> io::Result<State> { pub async fn state_from_file(path: &PathBuf) -> io::Result<State> {
let content = fs::read_to_string(path).await?; let content = fs::read_to_string(path).await?;

View file

@ -1,20 +1,20 @@
use async_trait::async_trait; use async_trait::async_trait;
use minisql::operation::ColumnSelection;
use minisql::restricted_row::RestrictedRow;
use minisql::schema::{Column, TableSchema};
use proto::message::backend::{
BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData,
ReadyForQueryData, RowDescriptionData,
};
use proto::message::primitive::pglist::PgList;
use proto::writer::backend::BackendProtoWriter;
use rand::Rng; use rand::Rng;
use rand_pcg::Pcg64; use rand_pcg::Pcg64;
use rand_seeder::Seeder; use rand_seeder::Seeder;
use std::fmt; use std::fmt;
use minisql::operation::ColumnSelection;
use minisql::restricted_row::RestrictedRow;
use minisql::schema::{Column, TableSchema};
use proto::message::backend::{BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, ReadyForQueryData, RowDescriptionData};
use proto::message::primitive::pglist::PgList;
use proto::writer::backend::BackendProtoWriter;
pub enum CompleteStatus { pub enum CompleteStatus {
Insert { Insert { oid: i32, rows: i32 },
oid: i32,
rows: i32,
},
Delete(usize), Delete(usize),
Select(usize), Select(usize),
CreateTable, CreateTable,
@ -38,24 +38,36 @@ pub trait ServerProto {
async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()>; async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()>;
async fn write_ready_for_query(&mut self) -> anyhow::Result<()>; async fn write_ready_for_query(&mut self) -> anyhow::Result<()>;
async fn write_empty_query(&mut self) -> anyhow::Result<()>; async fn write_empty_query(&mut self) -> anyhow::Result<()>;
async fn write_table_header(&mut self, table_schema: &TableSchema, columns: &ColumnSelection) -> anyhow::Result<()>; async fn write_table_header(
&mut self,
table_schema: &TableSchema,
columns: &ColumnSelection,
) -> anyhow::Result<()>;
async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()>; async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()>;
async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()>; async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()>;
} }
#[async_trait] #[async_trait]
impl<W> ServerProto for W where W: BackendProtoWriter + Send { impl<W> ServerProto for W
where
W: BackendProtoWriter + Send,
{
async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()> { async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()> {
self.write_proto(ErrorResponseData { self.write_proto(
code: b'M', ErrorResponseData {
message: format!("{error_message}\0").into(), code: b'M',
}.into()).await?; message: format!("{error_message}\0").into(),
}
.into(),
)
.await?;
Ok(()) Ok(())
} }
async fn write_ready_for_query(&mut self) -> anyhow::Result<()> { async fn write_ready_for_query(&mut self) -> anyhow::Result<()> {
self.write_proto(ReadyForQueryData { status: b'I' }.into()).await?; self.write_proto(ReadyForQueryData { status: b'I' }.into())
.await?;
Ok(()) Ok(())
} }
@ -64,35 +76,52 @@ impl<W> ServerProto for W where W: BackendProtoWriter + Send {
Ok(()) Ok(())
} }
async fn write_table_header(&mut self, table_schema: &TableSchema, columns: &ColumnSelection) -> anyhow::Result<()> { async fn write_table_header(
let columns = columns.iter() &mut self,
table_schema: &TableSchema,
columns: &ColumnSelection,
) -> anyhow::Result<()> {
let columns = columns
.iter()
.map(|column| column_to_description(table_schema, *column)) .map(|column| column_to_description(table_schema, *column))
.collect::<anyhow::Result<Vec<ColumnDescription>>>()?; .collect::<anyhow::Result<Vec<ColumnDescription>>>()?;
self.write_proto(RowDescriptionData { columns: columns.into() }.into()).await?; self.write_proto(
RowDescriptionData {
columns: columns.into(),
}
.into(),
)
.await?;
Ok(()) Ok(())
} }
async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()> { async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()> {
let values = row.iter() let values = row
.iter()
.map(|(_, value)| value.as_text_bytes().into()) .map(|(_, value)| value.as_text_bytes().into())
.collect::<Vec<PgList<u8, i32>>>(); .collect::<Vec<PgList<u8, i32>>>();
self.write_proto(BackendMessage::DataRow(DataRowData { self.write_proto(BackendMessage::DataRow(DataRowData {
columns: values.into(), columns: values.into(),
})).await?; }))
.await?;
Ok(()) Ok(())
} }
async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()> { async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()> {
self.write_proto(BackendMessage::CommandComplete(CommandCompleteData { self.write_proto(BackendMessage::CommandComplete(CommandCompleteData {
tag: status.to_string().into(), tag: status.to_string().into(),
})).await?; }))
.await?;
Ok(()) Ok(())
} }
} }
fn column_to_description(schema: &TableSchema, column: Column) -> anyhow::Result<ColumnDescription> { fn column_to_description(
schema: &TableSchema,
column: Column,
) -> anyhow::Result<ColumnDescription> {
let table_name = schema.table_name(); let table_name = schema.table_name();
let table_oid = table_name_to_oid(table_name); let table_oid = table_name_to_oid(table_name);