feat: finish work on server

This commit is contained in:
Jindřich Moravec 2024-01-25 23:07:27 +01:00
parent 7b79dd69b4
commit 51ed3bbc5c
9 changed files with 356 additions and 145 deletions

50
Cargo.lock generated
View file

@ -210,6 +210,17 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "getrandom"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]] [[package]]
name = "gimli" name = "gimli"
version = "0.28.1" version = "0.28.1"
@ -256,6 +267,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"bimap", "bimap",
"thiserror", "thiserror",
"tokio",
] ]
[[package]] [[package]]
@ -326,6 +338,12 @@ version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.70" version = "1.0.70"
@ -354,6 +372,36 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.4.1" version = "0.4.1"
@ -400,9 +448,11 @@ name = "server"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"clap", "clap",
"minisql", "minisql",
"proto", "proto",
"rand",
"tokio", "tokio",
] ]

View file

@ -8,3 +8,4 @@ edition = "2021"
[dependencies] [dependencies]
bimap = "0.6.3" bimap = "0.6.3"
thiserror = "1.0.50" thiserror = "1.0.50"
tokio = { version = "1.35.1", features = ["sync"] }

View file

@ -1,3 +1,4 @@
use std::sync::Arc;
use crate::error::Error; use crate::error::Error;
use crate::internals::row::ColumnPosition; use crate::internals::row::ColumnPosition;
use crate::schema::{TableName, TableSchema}; use crate::schema::{TableName, TableSchema};
@ -6,6 +7,7 @@ use crate::operation::{ColumnSelection, Condition, Operation};
use crate::result::DbResult; use crate::result::DbResult;
use crate::type_system::{DbType, IndexableValue, Value}; use crate::type_system::{DbType, IndexableValue, Value};
use bimap::BiMap; use bimap::BiMap;
use tokio::sync::Mutex;
use crate::restricted_row::RestrictedRow; use crate::restricted_row::RestrictedRow;
// Use `TablePosition` as index // Use `TablePosition` as index
@ -21,7 +23,7 @@ pub struct State {
// #[derive(Debug)] // #[derive(Debug)]
pub enum Response<'a> { pub enum Response<'a> {
Selected(&'a TableSchema, Box<dyn Iterator<Item=RestrictedRow> + 'a>), Selected(&'a TableSchema, Arc<Mutex<dyn Iterator<Item=RestrictedRow> + 'a + Send>>),
Inserted, Inserted,
Deleted(usize), // how many were deleted Deleted(usize), // how many were deleted
TableCreated, TableCreated,
@ -49,7 +51,7 @@ impl std::fmt::Debug for Response<'_> {
} }
impl State { impl State {
fn new() -> Self { pub fn new() -> Self {
Self { Self {
table_name_position_mapping: BiMap::new(), table_name_position_mapping: BiMap::new(),
tables: vec![], tables: vec![],
@ -100,7 +102,7 @@ impl State {
let selected_rows = match maybe_condition { let selected_rows = match maybe_condition {
None => { None => {
let x = table.select_all_rows(selected_column_positions); let x = table.select_all_rows(selected_column_positions);
Box::new(x) as Box<dyn Iterator<Item=RestrictedRow> + 'a> Arc::new(Mutex::new(x)) as Arc<Mutex<dyn Iterator<Item=RestrictedRow> + 'a + Send>>
}, },
Some(Condition::Eq(eq_column_name, value)) => { Some(Condition::Eq(eq_column_name, value)) => {
@ -113,7 +115,7 @@ impl State {
eq_column_position, eq_column_position,
value, value,
)?; )?;
Box::new(x) as Box<dyn Iterator<Item=RestrictedRow> + 'a> Arc::new(Mutex::new(x)) as Arc<Mutex<dyn Iterator<Item=RestrictedRow> + 'a + Send>>
} }
}; };

View file

@ -6,6 +6,7 @@ use crate::type_system::Value;
// Perhaps consider factoring the table name out // Perhaps consider factoring the table name out
// and think of the operations as operating on a unique table. // and think of the operations as operating on a unique table.
// TODO: `TableName` should be replaced by `TablePosition` // TODO: `TableName` should be replaced by `TablePosition`
#[derive(Debug)]
pub enum Operation { pub enum Operation {
Select(TableName, ColumnSelection, Option<Condition>), Select(TableName, ColumnSelection, Option<Condition>),
Insert(TableName, InsertionValues), Insert(TableName, InsertionValues),
@ -18,11 +19,13 @@ pub enum Operation {
pub type InsertionValues = Vec<(ColumnName, Value)>; pub type InsertionValues = Vec<(ColumnName, Value)>;
#[derive(Debug)]
pub enum ColumnSelection { pub enum ColumnSelection {
All, All,
Columns(Vec<ColumnName>), Columns(Vec<ColumnName>),
} }
#[derive(Debug)]
pub enum Condition { pub enum Condition {
// And(Box<Condition>, Box<Condition>), // And(Box<Condition>, Box<Condition>),
// Or(Box<Condition>, Box<Condition>), // Or(Box<Condition>, Box<Condition>),

View file

@ -20,7 +20,7 @@ pub type TableName = String;
pub type ColumnName = String; pub type ColumnName = String;
impl TableSchema { impl TableSchema {
pub(crate) fn new(table_name: TableName, primary_key: ColumnPosition, column_name_position_map: Vec<(ColumnName, ColumnPosition)>, types: Vec<DbType>) -> Self { pub fn new(table_name: TableName, primary_key: ColumnPosition, column_name_position_map: Vec<(ColumnName, ColumnPosition)>, types: Vec<DbType>) -> Self {
let mut column_name_position_mapping: BiMap<ColumnName, ColumnPosition> = BiMap::new(); let mut column_name_position_mapping: BiMap<ColumnName, ColumnPosition> = BiMap::new();
for (column_name, column_position) in column_name_position_map { for (column_name, column_position) in column_name_position_map {
column_name_position_mapping.insert(column_name, column_position); column_name_position_mapping.insert(column_name, column_position);

View file

@ -11,5 +11,5 @@ clap = { version = "4.4.18", features = ["derive"] }
tokio = { version = "1.35.1", features = ["full"] } tokio = { version = "1.35.1", features = ["full"] }
minisql = { path = "../minisql" } minisql = { path = "../minisql" }
proto = { path = "../proto" } proto = { path = "../proto" }
async-trait = "0.1.74"
rand = "0.8.5"

View file

@ -1,49 +1,121 @@
mod config; use std::collections::HashMap;
use std::sync::Arc;
use std::net::SocketAddr;
use clap::Parser; use clap::Parser;
use tokio::io::{BufReader, BufWriter};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Mutex, RwLock};
use minisql::interpreter::{Response, State};
use proto::handshake::errors::ServerHandshakeError;
use proto::handshake::request::HandshakeRequest;
use proto::handshake::response::HandshakeResponse; use proto::handshake::response::HandshakeResponse;
use proto::handshake::server::do_server_handshake; use proto::handshake::server::do_server_handshake;
use proto::message::backend::{
BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData,
ReadyForQueryData, RowDescriptionData,
};
use proto::message::frontend::FrontendMessage; use proto::message::frontend::FrontendMessage;
use proto::reader::oneway::OneWayProtoReader; use proto::reader::frontend::FrontendProtoReader;
use proto::reader::protoreader::ProtoReader; use proto::reader::protoreader::ProtoReader;
use proto::writer::backend::BackendProtoWriter; use proto::writer::backend::BackendProtoWriter;
use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; use proto::writer::protowriter::{ProtoFlush, ProtoWriter};
use tokio::io::{BufReader, BufWriter};
use tokio::net::{TcpListener, TcpStream}; use crate::cancellation::ResetCancelToken;
use crate::config::Configuration; use crate::config::Configuration;
use crate::parser_stub::parse_query;
use crate::proto_wrapper::{CompleteStatus, ServerProto};
mod config;
mod proto_wrapper;
mod cancellation;
mod parser_stub;
type TokenStore = Arc<Mutex<HashMap<(i32, i32), ResetCancelToken>>>;
type DbState = Arc<RwLock<State>>;
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let config = Configuration::parse(); let config = Configuration::parse();
let state = Arc::new(RwLock::new(State::new()));
let tokens = Arc::new(Mutex::new(HashMap::<(i32, i32), ResetCancelToken>::new()));
let addr = config.get_socket_address(); let addr = config.get_socket_address();
let listener = TcpListener::bind(&addr).await?; let listener = TcpListener::bind(&addr).await?;
println!("Server started at {addr}"); println!("Server started at {addr}");
loop { loop {
let state = state.clone();
let tokens = tokens.clone();
let (pid, key) = random_pid_key();
let (socket, _) = listener.accept().await?; let (socket, _) = listener.accept().await?;
println!("New client connected: {}", socket.peer_addr()?); println!("New client connected: {}", socket.peer_addr()?);
tokio::spawn(async move { tokio::spawn(async move {
let reason = handle_stream(socket).await; let reason = handle_stream(socket, state, tokens).await;
println!("Client disconnected: {reason:?}"); println!("Client disconnected: {reason:?}");
}); });
} }
} }
async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { async fn handle_stream(mut stream: TcpStream, state: DbState, tokens: TokenStore) -> anyhow::Result<()> {
let (reader, writer) = stream.split(); let (reader, writer) = stream.split();
let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut writer = ProtoWriter::new(BufWriter::new(writer));
let mut reader = ProtoReader::new(BufReader::new(reader), 1024); let mut reader = ProtoReader::new(BufReader::new(reader), 1024);
let response = HandshakeResponse::new("minisql", 123, 123); // Create a token with random PID and key
let (pid, key, token) = create_token(&tokens).await?;
let request = do_server_handshake(&mut writer, &mut reader, response).await?; // Handle handshake
let response = HandshakeResponse::new("minisql", pid, key);
let request = do_server_handshake(&mut writer, &mut reader, response).await;
let result = match request {
Ok(req) => handle_connection(&mut reader, &mut writer, req, state, token).await,
Err(ServerHandshakeError::IsCancelRequest(cancel)) => handle_cancellation(cancel.pid, cancel.secret, &tokens).await,
Err(e) => Err(anyhow::anyhow!("Error during handshake: {:?}", e)),
};
// Release cancellation token
let mut tokens = tokens.lock().await;
tokens.remove(&(pid, key));
result
}
fn random_pid_key() -> (i32, i32) {
let pid = rand::random::<i32>();
let key = rand::random::<i32>();
(pid, key)
}
async fn create_token(tokens: &TokenStore) -> anyhow::Result<(i32, i32, ResetCancelToken)> {
let token = ResetCancelToken::new();
let mut tokens = tokens.lock().await;
loop {
let pid_key = random_pid_key();
if !tokens.contains_key(&pid_key) {
tokens.insert(pid_key, token.clone());
let (pid, key) = pid_key;
return Ok((pid, key, token));
}
}
}
async fn handle_cancellation(pid: i32, key: i32, tokens: &TokenStore) -> anyhow::Result<()> {
let tokens = tokens.lock().await;
let token = tokens.get(&(pid, key));
match token {
Some(t) => t.cancel(),
None => return Err(anyhow::anyhow!("Invalid PID and Key cancel combination")),
}
Ok(())
}
async fn handle_connection<R, W>(reader: &mut R, writer: &mut W, request: HandshakeRequest, state: DbState, token: ResetCancelToken) -> anyhow::Result<()>
where
R: FrontendProtoReader + Send,
W: BackendProtoWriter + ProtoFlush + Send,
{
println!("Handshake complete:\n{request:?}"); println!("Handshake complete:\n{request:?}");
loop { loop {
@ -57,17 +129,48 @@ async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> {
} }
FrontendMessage::Query(data) => { FrontendMessage::Query(data) => {
println!("Received Query: {:?}", data); println!("Received Query: {:?}", data);
if data.query.as_str().contains("car") { let operation = parse_query(data.query.as_str());
println!("Sending error message"); println!("Parsed query: {:?}", operation);
send_error_response(&mut writer, "Car not found").await?;
} else if data.query.as_str().to_lowercase().contains("select") { let mut state = state.write().await;
println!("Sending table"); let result = state.interpret(operation);
send_query_response(&mut writer).await?; println!("Result: {:?}", result);
} else {
println!("Sending empty query"); match result {
send_empty_query(&mut writer).await?; Err(e) => {
writer.write_error_message(&format!("Error: {:?}", e)).await?;
} }
send_ready_for_query(&mut writer).await?; Ok(res) => {
match res {
Response::Deleted(i) => writer.write_command_complete(CompleteStatus::Delete(i)).await?,
Response::Inserted => writer.write_command_complete(CompleteStatus::Insert { oid: 0, rows: 1 }).await?,
Response::Selected(schema, rows) => {
let mut rows = rows.lock().await;
let first_row = rows.next();
match first_row {
Some(row) => {
writer.write_table_header(&schema, &row).await?;
writer.write_table_row(&row).await?;
let mut sent_rows = 1;
while let Some(row) = rows.next() {
writer.write_table_row(&row).await?;
sent_rows += 1;
}
writer.write_command_complete(CompleteStatus::Select(sent_rows)).await?;
}
None => {
writer.write_command_complete(CompleteStatus::Select(0)).await?;
}
}
}
_ => {}
}
}
}
writer.write_ready_for_query().await?;
} }
} }
writer.flush().await?; writer.flush().await?;
@ -75,118 +178,3 @@ async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
async fn send_error_response(
writer: &mut impl BackendProtoWriter,
error_message: &str,
) -> anyhow::Result<()> {
writer
.write_proto(
ErrorResponseData {
code: b'M',
message: error_message.to_string().into(),
}
.into(),
)
.await?;
Ok(())
}
async fn send_ready_for_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> {
writer
.write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' }))
.await?;
Ok(())
}
async fn send_empty_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> {
writer
.write_proto(BackendMessage::EmptyQueryResponse)
.await?;
Ok(())
}
async fn send_row_description(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> {
let columns = vec![
ColumnDescription {
name: "id".to_string().into(),
table_oid: 123,
column_index: 1,
type_oid: 23,
type_size: 4,
type_modifier: -1,
format_code: 0,
},
ColumnDescription {
name: "argument".to_string().into(),
table_oid: 123,
column_index: 2,
type_oid: 23,
type_size: 4,
type_modifier: -1,
format_code: 0,
},
ColumnDescription {
name: "description".to_string().into(),
table_oid: 123,
column_index: 3,
type_oid: 1043,
type_size: 32,
type_modifier: -1,
format_code: 0,
},
];
writer
.write_proto(
RowDescriptionData {
columns: columns.into(),
}
.into(),
)
.await?;
Ok(())
}
async fn send_query_response(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> {
send_row_description(writer).await?;
write_row(writer, b"0", b"1337", b"auto").await?;
write_row(writer, b"1", b"69", b"bus").await?;
write_row(writer, b"2", b"420", b"kolo").await?;
writer
.write_proto(
CommandCompleteData {
tag: "SELECT 3".to_string().into(),
}
.into(),
)
.await?;
Ok(())
}
async fn write_row(
writer: &mut impl BackendProtoWriter,
first: &[u8],
second: &[u8],
third: &[u8],
) -> anyhow::Result<()> {
let row_data = vec![
first.to_vec().into(),
second.to_vec().into(),
third.to_vec().into(),
]
.into();
writer
.write_proto(DataRowData { columns: row_data }.into())
.await?;
Ok(())
}

63
server/src/parser_stub.rs Normal file
View file

@ -0,0 +1,63 @@
use minisql::operation::{ColumnSelection, Operation};
use minisql::schema::TableSchema;
use minisql::type_system::{DbType, IndexableValue, Value};
const TABLE_NAME: &'static str = "tablus";
static mut ID_COUNTER: u64 = 0;
pub fn parse_query(query: &str) -> Operation {
if query.contains("select") {
if query.contains("*") {
Operation::Select(TABLE_NAME.to_string(), ColumnSelection::All, None)
} else {
Operation::Select(TABLE_NAME.to_string(), ColumnSelection::Columns(vec![
"name".to_string(),
"price".to_string(),
]), None)
}
} else if query.contains("insert") {
let id = unsafe {
ID_COUNTER += 1;
ID_COUNTER
};
let rand_rak = rand::random::<u8>();
let rand_price = rand::random::<f64>();
Operation::Insert(TABLE_NAME.to_string(), vec![
("id".to_string(), Value::Indexable(IndexableValue::Uuid(id))),
("name".to_string(), Value::Indexable(IndexableValue::String(format!("Car {}", rand_rak)))),
("price".to_string(), Value::Number(rand_price)),
("mileage".to_string(), Value::Indexable(IndexableValue::Int(1234))),
])
} else if query.contains("delete") {
Operation::Delete(TABLE_NAME.to_string(), None)
} else if query.contains("create table") {
Operation::CreateTable(TABLE_NAME.to_string(), get_cars_schema())
} else if query.contains("create index") {
Operation::CreateIndex(TABLE_NAME.to_string(), "price".to_string())
} else {
panic!("Unknown query: {}", query);
}
}
fn get_cars_schema() -> TableSchema {
TableSchema::new(
"cars".to_string(),
0,
vec![
("id".to_string(), 0),
("name".to_string(), 1),
("price".to_string(), 2),
("mileage".to_string(), 3),
],
vec![
DbType::Uuid,
DbType::String,
DbType::Number,
DbType::Int,
]
)
}

104
server/src/proto_wrapper.rs Normal file
View file

@ -0,0 +1,104 @@
use async_trait::async_trait;
use minisql::restricted_row::RestrictedRow;
use minisql::schema::TableSchema;
use minisql::type_system::{Value};
use proto::message::backend::{BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, ReadyForQueryData, RowDescriptionData};
use proto::message::primitive::pglist::PgList;
use proto::writer::backend::BackendProtoWriter;
pub enum CompleteStatus {
Insert {
oid: i32,
rows: i32,
},
Delete(usize),
Select(usize),
}
impl CompleteStatus {
fn to_string(&self) -> String {
match self {
CompleteStatus::Insert { oid, rows } => format!("INSERT {} {}", oid, rows),
CompleteStatus::Delete(rows) => format!("DELETE {}", rows),
CompleteStatus::Select(rows) => format!("SELECT {}", rows),
}
}
}
#[async_trait]
pub trait ServerProto {
async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()>;
async fn write_ready_for_query(&mut self) -> anyhow::Result<()>;
async fn write_empty_query(&mut self) -> anyhow::Result<()>;
async fn write_table_header(&mut self, table_schema: &TableSchema, row: &RestrictedRow) -> anyhow::Result<()>;
async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()>;
async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()>;
}
#[async_trait]
impl<W> ServerProto for W where W: BackendProtoWriter + Send {
async fn write_error_message(&mut self, error_message: &str) -> anyhow::Result<()> {
self.write_proto(ErrorResponseData {
code: b'M',
message: format!("{error_message}\0").into(),
}.into()).await?;
Ok(())
}
async fn write_ready_for_query(&mut self) -> anyhow::Result<()> {
self.write_proto(ReadyForQueryData { status: b'I' }.into()).await?;
Ok(())
}
async fn write_empty_query(&mut self) -> anyhow::Result<()> {
self.write_proto(BackendMessage::EmptyQueryResponse).await?;
Ok(())
}
async fn write_table_header(&mut self, table_schema: &TableSchema, row: &RestrictedRow) -> anyhow::Result<()> {
let columns = row.iter()
.map(|(index, value)| value_to_column_description(table_schema, value, index))
.collect::<anyhow::Result<Vec<ColumnDescription>>>()?;
self.write_proto(RowDescriptionData { columns: columns.into() }.into()).await?;
Ok(())
}
async fn write_table_row(&mut self, row: &RestrictedRow) -> anyhow::Result<()> {
let values = row.iter()
.map(|(_, value)| value.as_text_bytes().into())
.collect::<Vec<PgList<u8, i32>>>();
self.write_proto(BackendMessage::DataRow(DataRowData {
columns: values.into(),
})).await?;
Ok(())
}
async fn write_command_complete(&mut self, status: CompleteStatus) -> anyhow::Result<()> {
self.write_proto(BackendMessage::CommandComplete(CommandCompleteData {
tag: status.to_string().into(),
})).await?;
Ok(())
}
}
fn value_to_column_description(schema: &TableSchema, value: &Value, index: &usize) -> anyhow::Result<ColumnDescription> {
let name = schema.column_name_from_column_position(*index)?;
let table_oid = schema.table_name().as_bytes().as_ptr() as i32;
let column_index = (*index).try_into()?;
let type_oid = value.type_oid();
let type_size = value.type_size();
Ok(ColumnDescription {
name: name.to_string().into(),
table_oid,
column_index,
type_oid,
type_size,
type_modifier: -1,
format_code: 0, // text format
})
}