minisql/server/src/main.rs
2024-01-28 22:40:41 +01:00

261 lines
7.9 KiB
Rust

use std::collections::HashMap;
use std::io::ErrorKind;
use std::sync::Arc;
use clap::Parser;
use tokio::io::{BufReader, BufWriter};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Mutex, RwLock};
use minisql::interpreter::{Response, State};
use parser::parse_and_validate;
use proto::handshake::errors::ServerHandshakeError;
use proto::handshake::request::HandshakeRequest;
use proto::handshake::response::HandshakeResponse;
use proto::handshake::server::do_server_handshake;
use proto::message::frontend::FrontendMessage;
use proto::reader::frontend::FrontendProtoReader;
use proto::reader::protoreader::ProtoReader;
use proto::writer::backend::BackendProtoWriter;
use proto::writer::protowriter::{ProtoFlush, ProtoWriter};
use crate::cancellation::ResetCancelToken;
use crate::config::Configuration;
use crate::persistence::state_to_file;
use crate::proto_wrapper::{CompleteStatus, ServerProto};
mod cancellation;
mod config;
mod persistence;
mod proto_wrapper;
type TokenStore = Arc<Mutex<HashMap<(i32, i32), ResetCancelToken>>>;
type SharedDbState = Arc<RwLock<State>>;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let config = Configuration::parse();
let config = Arc::new(config);
let state = Arc::new(RwLock::new(get_state(&config).await?));
let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new()));
let addr = config.get_socket_address();
let listener = TcpListener::bind(&addr).await?;
println!("Server started at {addr}");
loop {
let config = config.clone();
let state = state.clone();
let tokens = tokens.clone();
let (socket, _) = listener.accept().await?;
println!("New client connected: {}", socket.peer_addr()?);
tokio::spawn(async move {
let reason = handle_stream(socket, state, tokens, config).await;
println!("Client disconnected: {reason:?}");
});
}
}
async fn get_state(config: &Configuration) -> anyhow::Result<State> {
let result = persistence::state_from_file(config.get_file_path()).await;
match result {
Err(ref e) if e.kind() == ErrorKind::NotFound => {
println!("WARNING: No DB state file found, creating new one");
Ok(State::new())
}
Err(e) => Err(e)?,
Ok(state) => Ok(state),
}
}
async fn handle_stream(
mut stream: TcpStream,
state: SharedDbState,
tokens: TokenStore,
config: Arc<Configuration>,
) -> anyhow::Result<()> {
let (reader, writer) = stream.split();
let mut writer = ProtoWriter::new(BufWriter::new(writer));
let mut reader = ProtoReader::new(BufReader::new(reader), 1024);
// Create a token with random PID and key
let (pid, key, token) = create_token(&tokens).await?;
// Handle handshake
let response = HandshakeResponse::new("minisql", pid, key);
let request = do_server_handshake(&mut writer, &mut reader, response).await;
let result = match request {
Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token, config).await,
Err(ServerHandshakeError::IsCancelRequest(cancel)) => {
handle_cancellation(cancel.pid, cancel.secret, &tokens).await
}
Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)),
};
// Release cancellation token
let mut tokens = tokens.lock().await;
tokens.remove(&(pid, key));
result
}
fn random_pid_key() -> (i32, i32) {
let pid = rand::random::<i32>();
let key = rand::random::<i32>();
(pid, key)
}
async fn create_token(tokens: &TokenStore) -> anyhow::Result<(i32, i32, ResetCancelToken)> {
let token = ResetCancelToken::new();
let mut tokens = tokens.lock().await;
loop {
let pid_key = random_pid_key();
use std::collections::hash_map;
if let hash_map::Entry::Vacant(token_entry) = tokens.entry(pid_key) {
token_entry.insert(token.clone());
let (pid, key) = pid_key;
return Ok((pid, key, token));
}
}
}
async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow::Result<()> {
println!("Cancel request, PID: {}, Key: {}", pid, key);
let tokens = tokens.lock().await;
let token = tokens.get(&(pid, key));
match token {
Some(t) => t.cancel(),
None => return Err(anyhow::anyhow!("Invalid PID and Key cancel combination")),
}
Ok(())
}
async fn handle_connection<R, W>(
reader: &mut R,
writer: &mut W,
request: HandshakeRequest,
state: SharedDbState,
token: ResetCancelToken,
config: Arc<Configuration>,
) -> anyhow::Result<()>
where
R: FrontendProtoReader + Send,
W: BackendProtoWriter + ProtoFlush + Send,
{
println!("Client connected: {:?}", request);
loop {
let next: FrontendMessage = reader.read_proto().await?;
match next {
FrontendMessage::Terminate => {
break;
}
FrontendMessage::Query(data) => {
let result = handle_query(writer, &state, data.query.into(), &token, &config).await;
match result {
Ok(_) => {}
Err(e) => writer.write_error_message(&e.to_string()).await?,
}
writer.write_ready_for_query().await?;
}
}
writer.flush().await?;
}
Ok(())
}
async fn handle_query<W>(
writer: &mut W,
state: &SharedDbState,
query: String,
token: &ResetCancelToken,
config: &Arc<Configuration>,
) -> anyhow::Result<()>
where
W: BackendProtoWriter + ProtoFlush + Send,
{
// Make sure token is reset before next query
token.reset();
let operation = {
let state = state.read().await;
let db_schema = state.db_schema();
parse_and_validate(query, &db_schema)?
};
let need_write = {
let mut state = state.write().await;
let response = state.interpret(operation)?;
match response {
Response::Deleted(i) => {
writer
.write_command_complete(CompleteStatus::Delete(i))
.await?;
true
}
Response::Inserted => {
writer
.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 })
.await?;
true
}
Response::Selected(schema, columns, mut rows) => {
writer.write_table_header(schema, &columns).await?;
match rows.next() {
Some(row) => {
writer.write_table_row(&row).await?;
let mut sent_rows = 1;
for row in rows {
sent_rows += 1;
writer.write_table_row(&row).await?;
if token.is_canceled() {
token.reset();
break;
}
}
writer
.write_command_complete(CompleteStatus::Select(sent_rows))
.await?;
}
_ => {
writer
.write_command_complete(CompleteStatus::Select(0))
.await?;
}
}
false
}
Response::TableCreated => {
writer
.write_command_complete(CompleteStatus::CreateTable)
.await?;
true
}
Response::IndexCreated => {
writer
.write_command_complete(CompleteStatus::CreateIndex)
.await?;
true
}
}
};
if need_write {
let state = state.read().await;
state_to_file(&state, config.get_file_path()).await?;
}
Ok(())
}