use std::io::Write; use clap::Parser; use proto::handshake::client::do_client_handshake; use proto::handshake::request::HandshakeRequest; use proto::reader::protoreader::ProtoReader; use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; use tokio::io::{BufReader, BufWriter}; use tokio::net::TcpStream; use proto::message::backend::{BackendMessage, CommandCompleteData, DataRowData, ErrorResponseData, RowDescriptionData}; use proto::message::frontend::{FrontendMessage, QueryData}; use proto::reader::oneway::OneWayProtoReader; use proto::writer::oneway::OneWayProtoWriter; #[derive(Parser)] struct Cli { /// Port number of the server. #[arg(short, long, default_value_t = 5432, help = "Port number of the server")] port: u16, /// Host name or IP address of the server. #[arg(long, default_value = "127.0.0.1", help = "Host name or IP address of the server")] host: String, /// User name sent to the server. #[arg(long, default_value = "minisql user", help = "User name")] username: String, } #[tokio::main] async fn main() -> anyhow::Result<()> { let cli = Cli::parse(); let addr = format!("{}:{}", cli.host, cli.port); println!("Connecting to {}", addr); let mut stream = TcpStream::connect(addr).await?; let (reader, writer) = stream.split(); let mut writer = ProtoWriter::new(BufWriter::new(writer)); let mut reader = ProtoReader::new(BufReader::new(reader), 1024); let request = HandshakeRequest::new(196608) .parameter("user", cli.username.as_str()) .parameter("client_encoding", "UTF8"); let _ = do_client_handshake(&mut writer, &mut reader, request).await?; println!("Connected to the server"); let mut exit = false; let command = prompt()?; if let Some(cmd) = command { writer.write_proto(FrontendMessage::Query(QueryData { query: cmd.into(), })).await?; writer.flush().await?; } else { exit = true; } while !exit { let msg: BackendMessage = reader.read_proto().await?; match msg { BackendMessage::RowDescription(data) => { print_row_description(data); }, BackendMessage::DataRow(data) => { print_row_data(data); }, BackendMessage::CommandComplete(data) => { print_command_complete(data); }, BackendMessage::ErrorResponse(data) => { print_error_response(data); }, BackendMessage::EmptyQueryResponse => { println!("Empty query response"); }, BackendMessage::NoData => { println!("No data"); }, BackendMessage::ReadyForQuery(data) => { println!("Ready for next query ({})", data.status); let command = prompt()?; if let Some(cmd) = command { writer.write_proto(FrontendMessage::Query(QueryData { query: cmd.into(), })).await?; writer.flush().await?; } else { exit = true; } }, m => { println!("Unexpected message: {:?}", m); } } } writer.write_proto(FrontendMessage::Terminate).await?; writer.flush().await?; Ok(()) } fn prompt() -> std::io::Result> { print!("> "); std::io::stdout().flush()?; let mut line = String::new(); let _ = std::io::stdin().read_line(&mut line)?; let line = line.trim(); if line.is_empty() { return prompt(); } if line == "exit" || line == "quit" { return Ok(None); } Ok(Some(line.trim().to_string())) } fn print_error_response(data: ErrorResponseData) { println!("Error with code {}: {}", data.code, data.message.as_str()); } fn print_command_complete(data: CommandCompleteData) { println!("Result: {}", data.tag.as_str()); } fn print_row_description(data: RowDescriptionData) { let mut lengths = vec![]; let columns = Vec::from(data.columns); for i in 0..columns.len() { let column_name = columns[i].name.as_str(); lengths.push(column_name.len()); print!("{}", column_name); if i < columns.len() - 1 { print!(" | "); } } println!(); for i in 0..lengths.len() { for _ in 0..lengths[i] { print!("-"); } if i < lengths.len() - 1 { print!("-+-"); } } println!(); } fn print_row_data(data: DataRowData) { let columns = Vec::from(data.columns); let length = columns.len(); for column in columns.into_iter().enumerate() { let bytes = Vec::from(column.1); let string = String::from_utf8(bytes).unwrap(); print!("{}", string); if column.0 < length - 1 { print!(" | "); } } println!(); }