diff --git a/Cargo.lock b/Cargo.lock index 514f99c..4fee400 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,9 +77,9 @@ version = "0.1.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ - "proc-macro2 1.0.70", - "quote 1.0.33", - "syn 2.0.41", + "proc-macro2 1.0.78", + "quote 1.0.35", + "syn 2.0.48", ] [[package]] @@ -108,6 +108,9 @@ name = "bimap" version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "230c5f1ca6a325a32553f8640d31ac9b49f2411e901e427570154868b46da4f7" +dependencies = [ + "serde", +] [[package]] name = "bincode" @@ -184,9 +187,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" dependencies = [ "heck", - "proc-macro2 1.0.70", - "quote 1.0.33", - "syn 2.0.41", + "proc-macro2 1.0.78", + "quote 1.0.35", + "syn 2.0.48", ] [[package]] @@ -241,6 +244,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + [[package]] name = "libc" version = "0.2.151" @@ -274,6 +283,7 @@ name = "minisql" version = "0.1.0" dependencies = [ "bimap", + "serde", "thiserror", ] @@ -405,9 +415,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.70" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] @@ -433,11 +443,11 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ - "proc-macro2 1.0.70", + "proc-macro2 1.0.78", ] [[package]] @@ -485,6 +495,12 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "ryu" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" + [[package]] name = "scopeguard" version = "1.2.0" @@ -493,22 +509,33 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.193" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.193" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ - "proc-macro2 1.0.70", - "quote 1.0.33", - "syn 2.0.41", + "proc-macro2 1.0.78", + "quote 1.0.35", + "syn 2.0.48", +] + +[[package]] +name = "serde_json" +version = "1.0.112" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d1bd37ce2324cf3bf85e5a25f96eb4baf0d5aa6eba43e7ae8958870c4ec48ed" +dependencies = [ + "itoa", + "ryu", + "serde", ] [[package]] @@ -522,6 +549,7 @@ dependencies = [ "parser", "proto", "rand", + "serde_json", "tokio", ] @@ -569,12 +597,12 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.41" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ - "proc-macro2 1.0.70", - "quote 1.0.33", + "proc-macro2 1.0.78", + "quote 1.0.35", "unicode-ident", ] @@ -593,9 +621,9 @@ version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ - "proc-macro2 1.0.70", - "quote 1.0.33", - "syn 2.0.41", + "proc-macro2 1.0.78", + "quote 1.0.35", + "syn 2.0.48", ] [[package]] @@ -623,9 +651,9 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ - "proc-macro2 1.0.70", - "quote 1.0.33", - "syn 2.0.41", + "proc-macro2 1.0.78", + "quote 1.0.35", + "syn 2.0.48", ] [[package]] diff --git a/minisql/Cargo.toml b/minisql/Cargo.toml index 6164b6b..d200f70 100644 --- a/minisql/Cargo.toml +++ b/minisql/Cargo.toml @@ -6,5 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -bimap = "0.6.3" +bimap = { version = "0.6.3", features = ["serde"] } +serde = { version = "1.0.196", features = ["derive"] } thiserror = "1.0.50" diff --git a/minisql/src/internals/column_index.rs b/minisql/src/internals/column_index.rs index cdd331d..c5a5ebb 100644 --- a/minisql/src/internals/column_index.rs +++ b/minisql/src/internals/column_index.rs @@ -1,7 +1,8 @@ use crate::type_system::{IndexableValue, Uuid}; use std::collections::{BTreeMap, HashSet}; +use serde::{Deserialize, Serialize}; -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct ColumnIndex { index: BTreeMap>, } diff --git a/minisql/src/internals/row.rs b/minisql/src/internals/row.rs index 6fa10e1..23590d2 100644 --- a/minisql/src/internals/row.rs +++ b/minisql/src/internals/row.rs @@ -2,11 +2,12 @@ use crate::type_system::Value; use crate::operation::InsertionValues; use std::ops::{Index, IndexMut}; use std::slice::SliceIndex; +use serde::{Deserialize, Serialize}; use crate::restricted_row::RestrictedRow; pub type ColumnPosition = usize; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Row(Vec); impl Index for Row diff --git a/minisql/src/internals/table.rs b/minisql/src/internals/table.rs index 05ee66e..18c98a9 100644 --- a/minisql/src/internals/table.rs +++ b/minisql/src/internals/table.rs @@ -1,4 +1,5 @@ use std::collections::{BTreeMap, HashMap, HashSet}; +use serde::{Deserialize, Serialize}; use crate::error::Error; use crate::internals::column_index::ColumnIndex; @@ -8,7 +9,7 @@ use crate::schema::{ColumnName, TableSchema, TableName}; use crate::result::DbResult; use crate::type_system::{IndexableValue, Uuid, Value}; -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct Table { schema: TableSchema, rows: Rows, // TODO: Consider wrapping this in a lock. Also consider if we need to have the diff --git a/minisql/src/interpreter.rs b/minisql/src/interpreter.rs index 2caf1f0..85266b9 100644 --- a/minisql/src/interpreter.rs +++ b/minisql/src/interpreter.rs @@ -4,6 +4,7 @@ use crate::internals::table::Table; use crate::operation::{Operation, Condition}; use crate::result::DbResult; use bimap::BiMap; +use serde::{Deserialize, Serialize}; use crate::restricted_row::RestrictedRow; // Use `TablePosition` as index @@ -11,7 +12,7 @@ pub type Tables = Vec; pub type TablePosition = usize; // ==============Interpreter================ -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct State { table_name_position_mapping: BiMap, tables: Tables, diff --git a/minisql/src/schema.rs b/minisql/src/schema.rs index 97cd87b..c61b7c2 100644 --- a/minisql/src/schema.rs +++ b/minisql/src/schema.rs @@ -4,10 +4,11 @@ use crate::operation::{InsertionValues, ColumnSelection}; use crate::result::DbResult; use crate::type_system::{DbType, IndexableValue, Uuid, Value}; use bimap::BiMap; +use serde::{Deserialize, Serialize}; // Note that it is nice to split metadata from the data because // then you can give the metadata to the parser without giving it the data. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct TableSchema { table_name: TableName, // used for descriptive errors primary_key: ColumnPosition, diff --git a/minisql/src/type_system.rs b/minisql/src/type_system.rs index 8b95d7b..3d4c837 100644 --- a/minisql/src/type_system.rs +++ b/minisql/src/type_system.rs @@ -1,7 +1,8 @@ +use serde::{Deserialize, Serialize}; use crate::error::TypeConversionError; // ==============Types================ -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum DbType { String, Int, @@ -15,14 +16,14 @@ pub type Uuid = u64; // TODO: What about nulls? I would rather not have that in SQL, it sucks. // I would rather have non-nullable values by default, // and something like an explicit Option type for nulls. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Value { Number(f64), // TODO: Can't put floats as keys in maps, since they don't implement Eq. What to // do? Indexable(IndexableValue), } -#[derive(Debug, Ord, Eq, Clone, PartialOrd, PartialEq)] +#[derive(Debug, Ord, Eq, Clone, PartialOrd, PartialEq, Serialize, Deserialize)] pub enum IndexableValue { String(String), Int(u64), diff --git a/server/Cargo.toml b/server/Cargo.toml index 6a511f6..67c87d7 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,6 +11,7 @@ anyhow = "1.0.76" clap = { version = "4.4.18", features = ["derive"] } async-trait = "0.1.74" rand = "0.8.5" +serde_json = "1.0.112" minisql = { path = "../minisql" } proto = { path = "../proto" } parser = { path = "../parser" } \ No newline at end of file diff --git a/server/src/main.rs b/server/src/main.rs index 2793e26..8ce6785 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::io::ErrorKind; use std::sync::Arc; use clap::Parser; @@ -20,11 +21,13 @@ use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; use crate::cancellation::ResetCancelToken; use crate::config::Configuration; +use crate::persistence::state_to_file; use crate::proto_wrapper::{CompleteStatus, ServerProto}; mod config; mod proto_wrapper; mod cancellation; +mod persistence; type TokenStore = Arc>>; type SharedDbState = Arc>; @@ -32,8 +35,9 @@ type SharedDbState = Arc>; #[tokio::main] async fn main() -> anyhow::Result<()> { let config = Configuration::parse(); + let config = Arc::new(config); - let state = Arc::new(RwLock::new(State::new())); + let state = Arc::new(RwLock::new(get_state(&config).await?)); let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new())); let addr = config.get_socket_address(); @@ -41,19 +45,36 @@ async fn main() -> anyhow::Result<()> { println!("Server started at {addr}"); loop { + let config = config.clone(); let state = state.clone(); let tokens = tokens.clone(); let (socket, _) = listener.accept().await?; println!("New client connected: {}", socket.peer_addr()?); tokio::spawn(async move { - let reason = handle_stream(socket, state, tokens).await; + let reason = handle_stream(socket, state, tokens, config).await; println!("Client disconnected: {reason:?}"); }); } } -async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore) -> anyhow::Result<()> { +async fn get_state(config: &Configuration) -> anyhow::Result { + let result = persistence::state_from_file(config.get_file_path()).await; + match result { + Err(ref e) if e.kind() == ErrorKind::NotFound => { + println!("WARNING: No DB state file found, creating new one"); + Ok(State::new()) + } + Err(e) => { + Err(e)? + } + Ok(state) => { + Ok(state) + } + } +} + +async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore, config: Arc) -> anyhow::Result<()> { let (reader, writer) = stream.split(); let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut reader = ProtoReader::new(BufReader::new(reader), 1024); @@ -66,7 +87,7 @@ async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: Toke let request = do_server_handshake(&mut writer, &mut reader, response).await; let result = match request { - Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token).await, + Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token, config).await, Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await, Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)), }; @@ -111,10 +132,10 @@ async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow: Ok(()) } -async fn handle_connection(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken) -> anyhow::Result<()> -where - R: FrontendProtoReader + Send, - W: BackendProtoWriter + ProtoFlush + Send, +async fn handle_connection(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken, config: Arc) -> anyhow::Result<()> + where + R: FrontendProtoReader + Send, + W: BackendProtoWriter + ProtoFlush + Send, { println!("Client connected: {:?}", request); @@ -126,7 +147,7 @@ where break; } FrontendMessage::Query(data) => { - let result = handle_query(writer, &state, data.query.into(), &token).await; + let result = handle_query(writer, &state, data.query.into(), &token, &config).await; match result { Ok(_) => {} Err(e) => { @@ -142,9 +163,9 @@ where Ok(()) } -async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken) -> anyhow::Result<()> -where - W: BackendProtoWriter + ProtoFlush + Send, +async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken, config: &Arc) -> anyhow::Result<()> + where + W: BackendProtoWriter + ProtoFlush + Send, { let operation = { let state = state.read().await; @@ -152,36 +173,51 @@ where parse_and_validate(query, &db_schema)? }; - let mut state = state.write().await; - let response = state.interpret(operation)?; + let need_write = { + let mut state = state.write().await; + let response = state.interpret(operation)?; - match response { - Response::Deleted(i) => writer.write_command_complete(CompleteStatus::Delete(i)).await?, - Response::Inserted => writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?, - Response::Selected(schema, mut rows) => { - match rows.next() { - Some(row) => { - writer.write_table_header(&schema, &row).await?; - writer.write_table_row(&row).await?; - - let mut sent_rows = 1; - for row in rows { - sent_rows += 1; - writer.write_table_row(&row).await?; - if token.is_canceled() { - token.reset(); - break; - } - } - - writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; - } - _ => { - writer.write_command_complete(CompleteStatus::Select(0)).await?; - } + match response { + Response::Deleted(i) => { + writer.write_command_complete(CompleteStatus::Delete(i)).await?; + true } + Response::Inserted => { + writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?; + true + } + Response::Selected(schema, mut rows) => { + match rows.next() { + Some(row) => { + writer.write_table_header(&schema, &row).await?; + writer.write_table_row(&row).await?; + + let mut sent_rows = 1; + for row in rows { + sent_rows += 1; + writer.write_table_row(&row).await?; + if token.is_canceled() { + token.reset(); + break; + } + } + + writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; + } + _ => { + writer.write_command_complete(CompleteStatus::Select(0)).await?; + } + } + false + }, + Response::TableCreated => true, + Response::IndexCreated => true, } - _ => {} + }; + + if need_write { + let state = state.read().await; + state_to_file(&state, &config.get_file_path()).await?; } Ok(()) diff --git a/server/src/persistence.rs b/server/src/persistence.rs new file mode 100644 index 0000000..980945a --- /dev/null +++ b/server/src/persistence.rs @@ -0,0 +1,15 @@ +use std::path::PathBuf; +use tokio::{fs, io}; +use minisql::interpreter::State; + +pub async fn state_from_file(path: &PathBuf) -> io::Result { + let content = fs::read_to_string(path).await?; + let state = serde_json::from_str(&content)?; + Ok(state) +} + +pub async fn state_to_file(state: &State, path: &PathBuf) -> io::Result<()> { + let content = serde_json::to_string(state)?; + fs::write(path, content).await?; + Ok(()) +}