use std::collections::HashMap; use std::io::ErrorKind; use std::sync::Arc; use clap::Parser; use tokio::io::{BufReader, BufWriter}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{Mutex, RwLock}; use minisql::interpreter::{Response, State}; use parser::parse_and_validate; use proto::handshake::errors::ServerHandshakeError; use proto::handshake::request::HandshakeRequest; use proto::handshake::response::HandshakeResponse; use proto::handshake::server::do_server_handshake; use proto::message::frontend::FrontendMessage; use proto::reader::frontend::FrontendProtoReader; use proto::reader::protoreader::ProtoReader; use proto::writer::backend::BackendProtoWriter; use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; use crate::cancellation::ResetCancelToken; use crate::config::Configuration; use crate::persistence::state_to_file; use crate::proto_wrapper::{CompleteStatus, ServerProto}; mod config; mod proto_wrapper; mod cancellation; mod persistence; type TokenStore = Arc>>; type SharedDbState = Arc>; #[tokio::main] async fn main() -> anyhow::Result<()> { let config = Configuration::parse(); let config = Arc::new(config); let state = Arc::new(RwLock::new(get_state(&config).await?)); let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new())); let addr = config.get_socket_address(); let listener = TcpListener::bind(&addr).await?; println!("Server started at {addr}"); loop { let config = config.clone(); let state = state.clone(); let tokens = tokens.clone(); let (socket, _) = listener.accept().await?; println!("New client connected: {}", socket.peer_addr()?); tokio::spawn(async move { let reason = handle_stream(socket, state, tokens, config).await; println!("Client disconnected: {reason:?}"); }); } } async fn get_state(config: &Configuration) -> anyhow::Result { let result = persistence::state_from_file(config.get_file_path()).await; match result { Err(ref e) if e.kind() == ErrorKind::NotFound => { println!("WARNING: No DB state file found, creating new one"); Ok(State::new()) } Err(e) => { Err(e)? } Ok(state) => { Ok(state) } } } async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore, config: Arc) -> anyhow::Result<()> { let (reader, writer) = stream.split(); let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut reader = ProtoReader::new(BufReader::new(reader), 1024); // Create a token with random PID and key let (pid, key, token) = create_token(&tokens).await?; // Handle handshake let response = HandshakeResponse::new("minisql", pid, key); let request = do_server_handshake(&mut writer, &mut reader, response).await; let result = match request { Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token, config).await, Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await, Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)), }; // Release cancellation token let mut tokens = tokens.lock().await; tokens.remove(&(pid, key)); result } fn random_pid_key() -> (i32, i32) { let pid = rand::random::(); let key = rand::random::(); (pid, key) } async fn create_token(tokens: &TokenStore) -> anyhow::Result<(i32, i32, ResetCancelToken)> { let token = ResetCancelToken::new(); let mut tokens = tokens.lock().await; loop { let pid_key = random_pid_key(); if !tokens.contains_key(&pid_key) { tokens.insert(pid_key, token.clone()); let (pid, key) = pid_key; return Ok((pid, key, token)); } } } async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow::Result<()> { println!("Cancel request, PID: {}, Key: {}", pid, key); let tokens = tokens.lock().await; let token = tokens.get(&(pid, key)); match token { Some(t) => t.cancel(), None => return Err(anyhow::anyhow!("Invalid PID and Key cancel combination")), } Ok(()) } async fn handle_connection(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken, config: Arc) -> anyhow::Result<()> where R: FrontendProtoReader + Send, W: BackendProtoWriter + ProtoFlush + Send, { println!("Client connected: {:?}", request); loop { let next: FrontendMessage = reader.read_proto().await?; match next { FrontendMessage::Terminate => { break; } FrontendMessage::Query(data) => { let result = handle_query(writer, &state, data.query.into(), &token, &config).await; match result { Ok(_) => {} Err(e) => { writer.write_error_message(&e.to_string()).await? } } writer.write_ready_for_query().await?; } } writer.flush().await?; } Ok(()) } async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken, config: &Arc) -> anyhow::Result<()> where W: BackendProtoWriter + ProtoFlush + Send, { // Make sure token is reset before next query token.reset(); let operation = { let state = state.read().await; let db_schema = state.db_schema(); parse_and_validate(query, &db_schema)? }; let need_write = { let mut state = state.write().await; let response = state.interpret(operation)?; match response { Response::Deleted(i) => { writer.write_command_complete(CompleteStatus::Delete(i)).await?; true } Response::Inserted => { writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?; true } Response::Selected(schema, columns, mut rows) => { writer.write_table_header(&schema, &columns).await?; match rows.next() { Some(row) => { writer.write_table_row(&row).await?; let mut sent_rows = 1; for row in rows { sent_rows += 1; writer.write_table_row(&row).await?; if token.is_canceled() { token.reset(); break; } } writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; } _ => { writer.write_command_complete(CompleteStatus::Select(0)).await?; } } false }, Response::TableCreated => { writer.write_command_complete(CompleteStatus::CreateTable).await?; true }, Response::IndexCreated => { writer.write_command_complete(CompleteStatus::CreateIndex).await?; true }, } }; if need_write { let state = state.read().await; state_to_file(&state, &config.get_file_path()).await?; } Ok(()) }