feat: add db state persistence
This commit is contained in:
parent
6bf4e34006
commit
e75ea5d5db
4 changed files with 115 additions and 39 deletions
|
|
@ -11,6 +11,7 @@ anyhow = "1.0.76"
|
|||
clap = { version = "4.4.18", features = ["derive"] }
|
||||
async-trait = "0.1.74"
|
||||
rand = "0.8.5"
|
||||
serde_json = "1.0.112"
|
||||
minisql = { path = "../minisql" }
|
||||
proto = { path = "../proto" }
|
||||
parser = { path = "../parser" }
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
use std::collections::HashMap;
|
||||
use std::io::ErrorKind;
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::Parser;
|
||||
|
|
@ -20,11 +21,13 @@ use proto::writer::protowriter::{ProtoFlush, ProtoWriter};
|
|||
|
||||
use crate::cancellation::ResetCancelToken;
|
||||
use crate::config::Configuration;
|
||||
use crate::persistence::state_to_file;
|
||||
use crate::proto_wrapper::{CompleteStatus, ServerProto};
|
||||
|
||||
mod config;
|
||||
mod proto_wrapper;
|
||||
mod cancellation;
|
||||
mod persistence;
|
||||
|
||||
type TokenStore = Arc<Mutex<HashMap<(i32, i32), ResetCancelToken>>>;
|
||||
type SharedDbState = Arc<RwLock<State>>;
|
||||
|
|
@ -32,8 +35,9 @@ type SharedDbState = Arc<RwLock<State>>;
|
|||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let config = Configuration::parse();
|
||||
let config = Arc::new(config);
|
||||
|
||||
let state = Arc::new(RwLock::new(State::new()));
|
||||
let state = Arc::new(RwLock::new(get_state(&config).await?));
|
||||
let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new()));
|
||||
|
||||
let addr = config.get_socket_address();
|
||||
|
|
@ -41,19 +45,36 @@ async fn main() -> anyhow::Result<()> {
|
|||
println!("Server started at {addr}");
|
||||
|
||||
loop {
|
||||
let config = config.clone();
|
||||
let state = state.clone();
|
||||
let tokens = tokens.clone();
|
||||
|
||||
let (socket, _) = listener.accept().await?;
|
||||
println!("New client connected: {}", socket.peer_addr()?);
|
||||
tokio::spawn(async move {
|
||||
let reason = handle_stream(socket, state, tokens).await;
|
||||
let reason = handle_stream(socket, state, tokens, config).await;
|
||||
println!("Client disconnected: {reason:?}");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore) -> anyhow::Result<()> {
|
||||
async fn get_state(config: &Configuration) -> anyhow::Result<State> {
|
||||
let result = persistence::state_from_file(config.get_file_path()).await;
|
||||
match result {
|
||||
Err(ref e) if e.kind() == ErrorKind::NotFound => {
|
||||
println!("WARNING: No DB state file found, creating new one");
|
||||
Ok(State::new())
|
||||
}
|
||||
Err(e) => {
|
||||
Err(e)?
|
||||
}
|
||||
Ok(state) => {
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore, config: Arc<Configuration>) -> anyhow::Result<()> {
|
||||
let (reader, writer) = stream.split();
|
||||
let mut writer = ProtoWriter::new(BufWriter::new(writer));
|
||||
let mut reader = ProtoReader::new(BufReader::new(reader), 1024);
|
||||
|
|
@ -66,7 +87,7 @@ async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: Toke
|
|||
let request = do_server_handshake(&mut writer, &mut reader, response).await;
|
||||
|
||||
let result = match request {
|
||||
Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token).await,
|
||||
Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token, config).await,
|
||||
Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await,
|
||||
Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)),
|
||||
};
|
||||
|
|
@ -111,10 +132,10 @@ async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow:
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_connection<R, W>(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken) -> anyhow::Result<()>
|
||||
where
|
||||
R: FrontendProtoReader + Send,
|
||||
W: BackendProtoWriter + ProtoFlush + Send,
|
||||
async fn handle_connection<R, W>(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken, config: Arc<Configuration>) -> anyhow::Result<()>
|
||||
where
|
||||
R: FrontendProtoReader + Send,
|
||||
W: BackendProtoWriter + ProtoFlush + Send,
|
||||
{
|
||||
println!("Client connected: {:?}", request);
|
||||
|
||||
|
|
@ -126,7 +147,7 @@ where
|
|||
break;
|
||||
}
|
||||
FrontendMessage::Query(data) => {
|
||||
let result = handle_query(writer, &state, data.query.into(), &token).await;
|
||||
let result = handle_query(writer, &state, data.query.into(), &token, &config).await;
|
||||
match result {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
|
|
@ -142,9 +163,9 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_query<W>(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken) -> anyhow::Result<()>
|
||||
where
|
||||
W: BackendProtoWriter + ProtoFlush + Send,
|
||||
async fn handle_query<W>(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken, config: &Arc<Configuration>) -> anyhow::Result<()>
|
||||
where
|
||||
W: BackendProtoWriter + ProtoFlush + Send,
|
||||
{
|
||||
let operation = {
|
||||
let state = state.read().await;
|
||||
|
|
@ -152,36 +173,51 @@ where
|
|||
parse_and_validate(query, &db_schema)?
|
||||
};
|
||||
|
||||
let mut state = state.write().await;
|
||||
let response = state.interpret(operation)?;
|
||||
let need_write = {
|
||||
let mut state = state.write().await;
|
||||
let response = state.interpret(operation)?;
|
||||
|
||||
match response {
|
||||
Response::Deleted(i) => writer.write_command_complete(CompleteStatus::Delete(i)).await?,
|
||||
Response::Inserted => writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?,
|
||||
Response::Selected(schema, mut rows) => {
|
||||
match rows.next() {
|
||||
Some(row) => {
|
||||
writer.write_table_header(&schema, &row).await?;
|
||||
writer.write_table_row(&row).await?;
|
||||
|
||||
let mut sent_rows = 1;
|
||||
for row in rows {
|
||||
sent_rows += 1;
|
||||
writer.write_table_row(&row).await?;
|
||||
if token.is_canceled() {
|
||||
token.reset();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?;
|
||||
}
|
||||
_ => {
|
||||
writer.write_command_complete(CompleteStatus::Select(0)).await?;
|
||||
}
|
||||
match response {
|
||||
Response::Deleted(i) => {
|
||||
writer.write_command_complete(CompleteStatus::Delete(i)).await?;
|
||||
true
|
||||
}
|
||||
Response::Inserted => {
|
||||
writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?;
|
||||
true
|
||||
}
|
||||
Response::Selected(schema, mut rows) => {
|
||||
match rows.next() {
|
||||
Some(row) => {
|
||||
writer.write_table_header(&schema, &row).await?;
|
||||
writer.write_table_row(&row).await?;
|
||||
|
||||
let mut sent_rows = 1;
|
||||
for row in rows {
|
||||
sent_rows += 1;
|
||||
writer.write_table_row(&row).await?;
|
||||
if token.is_canceled() {
|
||||
token.reset();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?;
|
||||
}
|
||||
_ => {
|
||||
writer.write_command_complete(CompleteStatus::Select(0)).await?;
|
||||
}
|
||||
}
|
||||
false
|
||||
},
|
||||
Response::TableCreated => true,
|
||||
Response::IndexCreated => true,
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
|
||||
if need_write {
|
||||
let state = state.read().await;
|
||||
state_to_file(&state, &config.get_file_path()).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
|||
15
server/src/persistence.rs
Normal file
15
server/src/persistence.rs
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
use std::path::PathBuf;
|
||||
use tokio::{fs, io};
|
||||
use minisql::interpreter::State;
|
||||
|
||||
pub async fn state_from_file(path: &PathBuf) -> io::Result<State> {
|
||||
let content = fs::read_to_string(path).await?;
|
||||
let state = serde_json::from_str(&content)?;
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
pub async fn state_to_file(state: &State, path: &PathBuf) -> io::Result<()> {
|
||||
let content = serde_json::to_string(state)?;
|
||||
fs::write(path, content).await?;
|
||||
Ok(())
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue