use std::collections::HashMap; use std::sync::Arc; use clap::Parser; use tokio::io::{BufReader, BufWriter}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{Mutex, RwLock}; use minisql::interpreter::{Response, State}; use parser::parse_and_validate; use proto::handshake::errors::ServerHandshakeError; use proto::handshake::request::HandshakeRequest; use proto::handshake::response::HandshakeResponse; use proto::handshake::server::do_server_handshake; use proto::message::frontend::FrontendMessage; use proto::reader::frontend::FrontendProtoReader; use proto::reader::protoreader::ProtoReader; use proto::writer::backend::BackendProtoWriter; use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; use crate::cancellation::ResetCancelToken; use crate::config::Configuration; use crate::proto_wrapper::{CompleteStatus, ServerProto}; mod config; mod proto_wrapper; mod cancellation; type TokenStore = Arc>>; type SharedDbState = Arc>; #[tokio::main] async fn main() -> anyhow::Result<()> { let config = Configuration::parse(); let state = Arc::new(RwLock::new(State::new())); let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new())); let addr = config.get_socket_address(); let listener = TcpListener::bind(&addr).await?; println!("Server started at {addr}"); loop { let state = state.clone(); let tokens = tokens.clone(); let (socket, _) = listener.accept().await?; println!("New client connected: {}", socket.peer_addr()?); tokio::spawn(async move { let reason = handle_stream(socket, state, tokens).await; println!("Client disconnected: {reason:?}"); }); } } async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore) -> anyhow::Result<()> { let (reader, writer) = stream.split(); let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut reader = ProtoReader::new(BufReader::new(reader), 1024); // Create a token with random PID and key let (pid, key, token) = create_token(&tokens).await?; // Handle handshake let response = HandshakeResponse::new("minisql", pid, key); let request = do_server_handshake(&mut writer, &mut reader, response).await; let result = match request { Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token).await, Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await, Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)), }; // Release cancellation token let mut tokens = tokens.lock().await; tokens.remove(&(pid, key)); result } fn random_pid_key() -> (i32, i32) { let pid = rand::random::(); let key = rand::random::(); (pid, key) } async fn create_token(tokens: &TokenStore) -> anyhow::Result<(i32, i32, ResetCancelToken)> { let token = ResetCancelToken::new(); let mut tokens = tokens.lock().await; loop { let pid_key = random_pid_key(); if !tokens.contains_key(&pid_key) { tokens.insert(pid_key, token.clone()); let (pid, key) = pid_key; return Ok((pid, key, token)); } } } async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow::Result<()> { println!("Cancel request, PID: {}, Key: {}", pid, key); let tokens = tokens.lock().await; let token = tokens.get(&(pid, key)); match token { Some(t) => t.cancel(), None => return Err(anyhow::anyhow!("Invalid PID and Key cancel combination")), } Ok(()) } async fn handle_connection(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken) -> anyhow::Result<()> where R: FrontendProtoReader + Send, W: BackendProtoWriter + ProtoFlush + Send, { println!("Client connected: {:?}", request); loop { let next: FrontendMessage = reader.read_proto().await?; match next { FrontendMessage::Terminate => { break; } FrontendMessage::Query(data) => { let result = handle_query(writer, &state, data.query.into(), &token).await; match result { Ok(_) => {} Err(e) => { writer.write_error_message(&e.to_string()).await? } } writer.write_ready_for_query().await?; } } writer.flush().await?; } Ok(()) } async fn handle_query(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken) -> anyhow::Result<()> where W: BackendProtoWriter + ProtoFlush + Send, { let operation = { let state = state.read().await; let db_schema = state.db_schema(); parse_and_validate(query, &db_schema)? }; let mut state = state.write().await; let response = state.interpret(operation)?; match response { Response::Deleted(i) => writer.write_command_complete(CompleteStatus::Delete(i)).await?, Response::Inserted => writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?, Response::Selected(schema, mut rows) => { match rows.next() { Some(row) => { writer.write_table_header(&schema, &row).await?; writer.write_table_row(&row).await?; let mut sent_rows = 1; for row in rows { sent_rows += 1; writer.write_table_row(&row).await?; if token.is_canceled() { token.reset(); break; } } writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?; } _ => { writer.write_command_complete(CompleteStatus::Select(0)).await?; } } } _ => {} } Ok(()) }