188 lines
6.1 KiB
Rust
188 lines
6.1 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
use clap::Parser;
|
|
use tokio::io::{BufReader, BufWriter};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::sync::{Mutex, RwLock};
|
|
|
|
use minisql::interpreter::{Response, State};
|
|
use parser::parse_and_validate;
|
|
use proto::handshake::errors::ServerHandshakeError;
|
|
use proto::handshake::request::HandshakeRequest;
|
|
use proto::handshake::response::HandshakeResponse;
|
|
use proto::handshake::server::do_server_handshake;
|
|
use proto::message::frontend::FrontendMessage;
|
|
use proto::reader::frontend::FrontendProtoReader;
|
|
use proto::reader::protoreader::ProtoReader;
|
|
use proto::writer::backend::BackendProtoWriter;
|
|
use proto::writer::protowriter::{ProtoFlush, ProtoWriter};
|
|
|
|
use crate::cancellation::ResetCancelToken;
|
|
use crate::config::Configuration;
|
|
use crate::proto_wrapper::{CompleteStatus, ServerProto};
|
|
|
|
mod config;
|
|
mod proto_wrapper;
|
|
mod cancellation;
|
|
|
|
type TokenStore = Arc<Mutex<HashMap<(i32, i32), ResetCancelToken>>>;
|
|
type SharedDbState = Arc<RwLock<State>>;
|
|
|
|
#[tokio::main]
|
|
async fn main() -> anyhow::Result<()> {
|
|
let config = Configuration::parse();
|
|
|
|
let state = Arc::new(RwLock::new(State::new()));
|
|
let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new()));
|
|
|
|
let addr = config.get_socket_address();
|
|
let listener = TcpListener::bind(&addr).await?;
|
|
println!("Server started at {addr}");
|
|
|
|
loop {
|
|
let state = state.clone();
|
|
let tokens = tokens.clone();
|
|
|
|
let (socket, _) = listener.accept().await?;
|
|
println!("New client connected: {}", socket.peer_addr()?);
|
|
tokio::spawn(async move {
|
|
let reason = handle_stream(socket, state, tokens).await;
|
|
println!("Client disconnected: {reason:?}");
|
|
});
|
|
}
|
|
}
|
|
|
|
async fn handle_stream(mut stream: TcpStream, state: SharedDbState, tokens: TokenStore) -> anyhow::Result<()> {
|
|
let (reader, writer) = stream.split();
|
|
let mut writer = ProtoWriter::new(BufWriter::new(writer));
|
|
let mut reader = ProtoReader::new(BufReader::new(reader), 1024);
|
|
|
|
// Create a token with random PID and key
|
|
let (pid, key, token) = create_token(&tokens).await?;
|
|
|
|
// Handle handshake
|
|
let response = HandshakeResponse::new("minisql", pid, key);
|
|
let request = do_server_handshake(&mut writer, &mut reader, response).await;
|
|
|
|
let result = match request {
|
|
Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token).await,
|
|
Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await,
|
|
Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)),
|
|
};
|
|
|
|
// Release cancellation token
|
|
let mut tokens = tokens.lock().await;
|
|
tokens.remove(&(pid, key));
|
|
|
|
result
|
|
}
|
|
|
|
fn random_pid_key() -> (i32, i32) {
|
|
let pid = rand::random::<i32>();
|
|
let key = rand::random::<i32>();
|
|
(pid, key)
|
|
}
|
|
|
|
async fn create_token(tokens: &TokenStore) -> anyhow::Result<(i32, i32, ResetCancelToken)> {
|
|
let token = ResetCancelToken::new();
|
|
let mut tokens = tokens.lock().await;
|
|
loop {
|
|
let pid_key = random_pid_key();
|
|
if !tokens.contains_key(&pid_key) {
|
|
tokens.insert(pid_key, token.clone());
|
|
|
|
let (pid, key) = pid_key;
|
|
return Ok((pid, key, token));
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow::Result<()> {
|
|
println!("Cancel request, PID: {}, Key: {}", pid, key);
|
|
|
|
let tokens = tokens.lock().await;
|
|
let token = tokens.get(&(pid, key));
|
|
match token {
|
|
Some(t) => t.cancel(),
|
|
None => return Err(anyhow::anyhow!("Invalid PID and Key cancel combination")),
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_connection<R, W>(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: SharedDbState, token: ResetCancelToken) -> anyhow::Result<()>
|
|
where
|
|
R: FrontendProtoReader + Send,
|
|
W: BackendProtoWriter + ProtoFlush + Send,
|
|
{
|
|
println!("Client connected: {:?}", request);
|
|
|
|
loop {
|
|
let next: FrontendMessage = reader.read_proto().await?;
|
|
|
|
match next {
|
|
FrontendMessage::Terminate => {
|
|
break;
|
|
}
|
|
FrontendMessage::Query(data) => {
|
|
let result = handle_query(writer, &state, data.query.into(), &token).await;
|
|
match result {
|
|
Ok(_) => {}
|
|
Err(e) => {
|
|
writer.write_error_message(&e.to_string()).await?
|
|
}
|
|
}
|
|
writer.write_ready_for_query().await?;
|
|
}
|
|
}
|
|
writer.flush().await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_query<W>(writer: &mut W, state: &SharedDbState, query: String, token: &ResetCancelToken) -> anyhow::Result<()>
|
|
where
|
|
W: BackendProtoWriter + ProtoFlush + Send,
|
|
{
|
|
let operation = {
|
|
let state = state.read().await;
|
|
let db_schema = state.db_schema();
|
|
parse_and_validate(query, &db_schema)?
|
|
};
|
|
|
|
let mut state = state.write().await;
|
|
let response = state.interpret(operation)?;
|
|
|
|
match response {
|
|
Response::Deleted(i) => writer.write_command_complete(CompleteStatus::Delete(i)).await?,
|
|
Response::Inserted => writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?,
|
|
Response::Selected(schema, mut rows) => {
|
|
match rows.next() {
|
|
Some(row) => {
|
|
writer.write_table_header(&schema, &row).await?;
|
|
writer.write_table_row(&row).await?;
|
|
|
|
let mut sent_rows = 1;
|
|
for row in rows {
|
|
sent_rows += 1;
|
|
writer.write_table_row(&row).await?;
|
|
if token.is_canceled() {
|
|
token.reset();
|
|
break;
|
|
}
|
|
}
|
|
|
|
writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?;
|
|
}
|
|
_ => {
|
|
writer.write_command_complete(CompleteStatus::Select(0)).await?;
|
|
}
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
Ok(())
|
|
}
|