From e75ea5d5dbf3fdd45c47ae45e7d5643f9bd2f716 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jind=C5=99ich=20Moravec?= Date: Sun, 28 Jan 2024 16:19:21 +0100 Subject: [PATCH] feat: add db state persistence --- Cargo.lock | 24 ++++++++ server/Cargo.toml | 1 + server/src/main.rs | 114 +++++++++++++++++++++++++------------- server/src/persistence.rs | 15 +++++ 4 files changed, 115 insertions(+), 39 deletions(-) create mode 100644 server/src/persistence.rs diff --git a/Cargo.lock b/Cargo.lock index dcd6dea..4fee400 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -244,6 +244,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + [[package]] name = "libc" version = "0.2.151" @@ -489,6 +495,12 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "ryu" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" + [[package]] name = "scopeguard" version = "1.2.0" @@ -515,6 +527,17 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "serde_json" +version = "1.0.112" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d1bd37ce2324cf3bf85e5a25f96eb4baf0d5aa6eba43e7ae8958870c4ec48ed" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "server" version = "0.1.0" @@ -526,6 +549,7 @@ dependencies = [ "parser", "proto", "rand", + "serde_json", "tokio", ] diff --git a/server/Cargo.toml b/server/Cargo.toml index 6a511f6..67c87d7 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,6 +11,7 @@ anyhow = "1.0.76" clap = { version = "4.4.18", features = ["derive"] } async-trait = "0.1.74" rand = "0.8.5" +serde_json = "1.0.112" minisql = { path = "../minisql" } proto = { path = "../proto" } parser = { path = "../parser" } \ No newline at end of file diff --git a/server/src/main.rs b/server/src/main.rs index 2793e26..8ce6785 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::io::ErrorKind; use std::sync::Arc; use clap::Parser; @@ -20,11 +21,13 @@ use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; use crate::cancellation::ResetCancelToken; use crate::config::Configuration; +use crate::persistence::state_to_file; use crate::proto_wrapper::{CompleteStatus, ServerProto}; mod config; mod proto_wrapper; mod cancellation; +mod persistence; type TokenStore = Arc>>; type SharedDbState = Arc>; @@ -32,8 +35,9 @@ type SharedDbState = Arc>; #[tokio::main] async fn main() -> anyhow::Result<()> { let config = Configuration::parse(); + let config = Arc::new(config); - let state = Arc::new(RwLock::new(State::new())); + let state = Arc::new(RwLock::new(get_state(&config).await?)); let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new())); let addr = config.get_socket_address(); @@ -41,19 +45,36 @@ async fn main() -> anyhow::Result<()> { println!("Server started at {addr}"); loop { + let config = config.clone(); let state = state.clone(); let tokens = tokens.clone(); let (socket, _) = listener.accept().await?; println!("New client connected: {}", socket.peer_addr()?); tokio::spawn(async move { - let reason = handle_stream(socket, state, tokens).await; + let reason = handle_stream(socket, state, tokens, config).await; println!("Client disconnected: {reason:?}"); }); } } -async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore) -> anyhow::Result<()> { +async fn get_state(config: &Configuration) -> anyhow::Result { + let result = persistence::state_from_file(config.get_file_path()).await; + match result { + Err(ref e) if e.kind() == ErrorKind::NotFound => { + println!("WARNING: No DB state file found, creating new one"); + Ok(State::new()) + } + Err(e) => { + Err(e)? + } + Ok(state) => { + Ok(state) + } + } +} + +async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore, config: Arc) -> anyhow::Result<()> { let (reader, writer) = stream.split(); let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut reader = ProtoReader::new(BufReader::new(reader), 1024); @@ -66,7 +87,7 @@ async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: Toke let request = do_server_handshake(&mut writer, &mut reader, response).await; let result = match request { - Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token).await, + Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token, config).await, Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await, Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)), }; @@ -111,10 +132,10 @@ async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow: Ok(()) } -async fn handle_connection(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken) -> anyhow::Result<()> -where - R: FrontendProtoReader + Send, - W: BackendProtoWriter + ProtoFlush + Send, +async fn handle_connection(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken, config: Arc) -> anyhow::Result<()> + where + R: FrontendProtoReader + Send, + W: BackendProtoWriter + ProtoFlush + Send, { println!("Client connected: {:?}", request); @@ -126,7 +147,7 @@ where break; } FrontendMessage::Query(data) => { - let result = handle_query(writer, &state, data.query.into(), &token).await; + let result = handle_query(writer, &state, data.query.into(), &token, &config).await; match result { Ok(_) => {} Err(e) => { @@ -142,9 +163,9 @@ where Ok(()) } -async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken) -> anyhow::Result<()> -where - W: BackendProtoWriter + ProtoFlush + Send, +async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken, config: &Arc) -> anyhow::Result<()> + where + W: BackendProtoWriter + ProtoFlush + Send, { let operation = { let state = state.read().await; @@ -152,36 +173,51 @@ where parse_and_validate(query, &db_schema)? }; - let mut state = state.write().await; - let response = state.interpret(operation)?; + let need_write = { + let mut state = state.write().await; + let response = state.interpret(operation)?; - match response { - Response::Deleted(i) => writer.write_command_complete(CompleteStatus::Delete(i)).await?, - Response::Inserted => writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?, - Response::Selected(schema, mut rows) => { - match rows.next() { - Some(row) => { - writer.write_table_header(&schema, &row).await?; - writer.write_table_row(&row).await?; - - let mut sent_rows = 1; - for row in rows { - sent_rows += 1; - writer.write_table_row(&row).await?; - if token.is_canceled() { - token.reset(); - break; - } - } - - writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; - } - _ => { - writer.write_command_complete(CompleteStatus::Select(0)).await?; - } + match response { + Response::Deleted(i) => { + writer.write_command_complete(CompleteStatus::Delete(i)).await?; + true } + Response::Inserted => { + writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?; + true + } + Response::Selected(schema, mut rows) => { + match rows.next() { + Some(row) => { + writer.write_table_header(&schema, &row).await?; + writer.write_table_row(&row).await?; + + let mut sent_rows = 1; + for row in rows { + sent_rows += 1; + writer.write_table_row(&row).await?; + if token.is_canceled() { + token.reset(); + break; + } + } + + writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; + } + _ => { + writer.write_command_complete(CompleteStatus::Select(0)).await?; + } + } + false + }, + Response::TableCreated => true, + Response::IndexCreated => true, } - _ => {} + }; + + if need_write { + let state = state.read().await; + state_to_file(&state, &config.get_file_path()).await?; } Ok(()) diff --git a/server/src/persistence.rs b/server/src/persistence.rs new file mode 100644 index 0000000..980945a --- /dev/null +++ b/server/src/persistence.rs @@ -0,0 +1,15 @@ +use std::path::PathBuf; +use tokio::{fs, io}; +use minisql::interpreter::State; + +pub async fn state_from_file(path: &PathBuf) -> io::Result { + let content = fs::read_to_string(path).await?; + let state = serde_json::from_str(&content)?; + Ok(state) +} + +pub async fn state_to_file(state: &State, path: &PathBuf) -> io::Result<()> { + let content = serde_json::to_string(state)?; + fs::write(path, content).await?; + Ok(()) +}