Merge remote-tracking branch 'origin/main' into clippy-formatting

This commit is contained in:
Yuriy Dupyn 2024-01-28 22:27:03 +01:00
commit b836ba5e04
4 changed files with 122 additions and 32 deletions

View file

@ -1,3 +1,4 @@
use std::io::Write;
use clap::Parser; use clap::Parser;
use proto::handshake::client::do_client_handshake; use proto::handshake::client::do_client_handshake;
use proto::handshake::request::HandshakeRequest; use proto::handshake::request::HandshakeRequest;
@ -5,7 +6,7 @@ use proto::reader::protoreader::ProtoReader;
use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; use proto::writer::protowriter::{ProtoFlush, ProtoWriter};
use tokio::io::{BufReader, BufWriter}; use tokio::io::{BufReader, BufWriter};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use proto::message::backend::{BackendMessage, DataRowData, RowDescriptionData}; use proto::message::backend::{BackendMessage, CommandCompleteData, DataRowData, ErrorResponseData, RowDescriptionData};
use proto::message::frontend::{FrontendMessage, QueryData}; use proto::message::frontend::{FrontendMessage, QueryData};
use proto::reader::oneway::OneWayProtoReader; use proto::reader::oneway::OneWayProtoReader;
use proto::writer::oneway::OneWayProtoWriter; use proto::writer::oneway::OneWayProtoWriter;
@ -19,12 +20,17 @@ struct Cli {
/// Host name or IP address of the server. /// Host name or IP address of the server.
#[arg(long, default_value = "127.0.0.1", help = "Host name or IP address of the server")] #[arg(long, default_value = "127.0.0.1", help = "Host name or IP address of the server")]
host: String, host: String,
/// User name sent to the server.
#[arg(long, default_value = "minisql user", help = "User name")]
username: String,
} }
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let cli = Cli::parse(); let cli = Cli::parse();
let addr = format!("{}:{}", cli.host, cli.port); let addr = format!("{}:{}", cli.host, cli.port);
println!("Connecting to {}", addr);
let mut stream = TcpStream::connect(addr).await?; let mut stream = TcpStream::connect(addr).await?;
let (reader, writer) = stream.split(); let (reader, writer) = stream.split();
@ -33,42 +39,55 @@ async fn main() -> anyhow::Result<()> {
let mut reader = ProtoReader::new(BufReader::new(reader), 1024); let mut reader = ProtoReader::new(BufReader::new(reader), 1024);
let request = HandshakeRequest::new(196608) let request = HandshakeRequest::new(196608)
.parameter("user", "test user") .parameter("user", cli.username.as_str())
.parameter("client_encoding", "UTF8"); .parameter("client_encoding", "UTF8");
let response = do_client_handshake(&mut writer, &mut reader, request).await?; let _ = do_client_handshake(&mut writer, &mut reader, request).await?;
println!("Connected to the server");
println!("Handshake complete:\n{response:?}"); let mut exit = false;
let command = prompt()?;
if let Some(cmd) = command {
writer.write_proto(FrontendMessage::Query(QueryData {
query: cmd.into(),
})).await?;
writer.flush().await?;
} else {
exit = true;
}
writer.write_proto(FrontendMessage::Query(QueryData { while !exit {
query: "SELECT * FROM users;".to_string().into(),
})).await?;
writer.flush().await?;
let mut line = String::new();
loop {
let msg: BackendMessage = reader.read_proto().await?; let msg: BackendMessage = reader.read_proto().await?;
match msg { match msg {
BackendMessage::RowDescription(data) => { BackendMessage::RowDescription(data) => {
print_header(data); print_row_description(data);
}, },
BackendMessage::DataRow(data) => { BackendMessage::DataRow(data) => {
print_row(data); print_row_data(data);
}, },
BackendMessage::CommandComplete(data) => { BackendMessage::CommandComplete(data) => {
println!("Command complete: {:?}", data); print_command_complete(data);
},
BackendMessage::ErrorResponse(data) => {
print_error_response(data);
},
BackendMessage::EmptyQueryResponse => {
println!("Empty query response");
},
BackendMessage::NoData => {
println!("No data");
}, },
BackendMessage::ReadyForQuery(data) => { BackendMessage::ReadyForQuery(data) => {
println!("Ready for query: {:?}", data); println!("Ready for next query ({})", data.status);
line.clear();
let res = std::io::stdin().read_line(&mut line); let command = prompt()?;
if res.is_ok() { if let Some(cmd) = command {
if line.eq("exit") {
break;
}
writer.write_proto(FrontendMessage::Query(QueryData { writer.write_proto(FrontendMessage::Query(QueryData {
query: line.clone().into(), query: cmd.into(),
})).await?; })).await?;
writer.flush().await?; writer.flush().await?;
} else {
exit = true;
} }
}, },
m => { m => {
@ -83,21 +102,70 @@ async fn main() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
fn print_header(header: RowDescriptionData) { fn prompt() -> std::io::Result<Option<String>> {
print!("Header -> "); print!("> ");
for column in Vec::from(header.columns) { std::io::stdout().flush()?;
print!("{} | ", column.name.as_str());
let mut line = String::new();
let _ = std::io::stdin().read_line(&mut line)?;
let line = line.trim();
if line.is_empty() {
return prompt();
}
if line == "exit" || line == "quit" {
return Ok(None);
}
Ok(Some(line.trim().to_string()))
}
fn print_error_response(data: ErrorResponseData) {
println!("Error with code {}: {}", data.code, data.message.as_str());
}
fn print_command_complete(data: CommandCompleteData) {
println!("Result: {}", data.tag.as_str());
}
fn print_row_description(data: RowDescriptionData) {
let mut lengths = vec![];
let columns = Vec::from(data.columns);
for i in 0..columns.len() {
let column_name = columns[i].name.as_str();
lengths.push(column_name.len());
print!("{}", column_name);
if i < columns.len() - 1 {
print!(" | ");
}
}
println!();
for i in 0..lengths.len() {
for _ in 0..lengths[i] {
print!("-");
}
if i < lengths.len() - 1 {
print!("-+-");
}
} }
println!(); println!();
} }
fn print_row(row: DataRowData) { fn print_row_data(data: DataRowData) {
print!("Row -> "); let columns = Vec::from(data.columns);
for column in Vec::from(row.columns) { let length = columns.len();
let bytes = Vec::from(column); for column in columns.into_iter().enumerate() {
let bytes = Vec::from(column.1);
let string = String::from_utf8(bytes).unwrap(); let string = String::from_utf8(bytes).unwrap();
print!("{} | ", string); print!("{}", string);
if column.0 < length - 1 {
print!(" | ");
}
} }
println!(); println!();
} }

1
demo.json Normal file

File diff suppressed because one or more lines are too long

View file

@ -41,15 +41,33 @@ impl Encode for PgString {
impl Decode for PgString { impl Decode for PgString {
fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, DecodeError> { fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, DecodeError> {
let mut string = String::new(); let mut bytes = Vec::new();
loop { loop {
let byte = u8::decode(decoder)?; let byte = u8::decode(decoder)?;
if byte == 0 { if byte == 0 {
break; break;
} }
string.push(byte as char); bytes.push(byte);
} }
let string = String::from_utf8(bytes)
.map_err(|e| DecodeError::Utf8 { inner: e.utf8_error() })?;
Ok(PgString(string)) Ok(PgString(string))
} }
} }
#[cfg(test)]
mod tests {
use crate::message::primitive::data::MessageData;
use super::*;
#[test]
fn test_encode_decode_utf8() {
let pg_string = PgString::from("áhój jěžkó");
let encoded = pg_string.serialize().unwrap();
let decoded: PgString = PgString::deserialize(&encoded).unwrap();
let actual = decoded.as_str();
assert_eq!("áhój jěžkó", actual);
}
}

View file

@ -169,6 +169,9 @@ async fn handle_query<W>(writer: &mut W, state: &SharedDbState, query: String, t
where where
W: BackendProtoWriter + ProtoFlush + Send, W: BackendProtoWriter + ProtoFlush + Send,
{ {
// Make sure token is reset before next query
token.reset();
let operation = { let operation = {
let state = state.read().await; let state = state.read().await;
let db_schema = state.db_schema(); let db_schema = state.db_schema();