diff --git a/.gitignore b/.gitignore index 8cf2bff..e28629e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ +.idea /target tmp_repl.txt diff --git a/Cargo.lock b/Cargo.lock index 7254440..8b89a08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,15 +2,473 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "anyhow" +version = "1.0.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59d2a3357dde987206219e78ecfbbb6e8dad06cbb65292758d3270e6254f7355" + +[[package]] +name = "async-trait" +version = "0.1.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "bimap" version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "230c5f1ca6a325a32553f8640d31ac9b49f2411e901e427570154868b46da4f7" +[[package]] +name = "bincode" +version = "2.0.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f11ea1a0346b94ef188834a65c068a03aec181c94896d481d7a0a40d85b0ce95" +dependencies = [ + "bincode_derive", + "serde", +] + +[[package]] +name = "bincode_derive" +version = "2.0.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e30759b3b99a1b802a7a3aa21c85c3ded5c28e1c83170d82d70f08bbf7f3e4c" +dependencies = [ + "virtue", +] + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "client" +version = "0.1.0" +dependencies = [ + "anyhow", + "proto", + "tokio", +] + +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + +[[package]] +name = "hermit-abi" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" + +[[package]] +name = "libc" +version = "0.2.151" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" + +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "memchr" +version = "2.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" + [[package]] name = "minisql" version = "0.1.0" dependencies = [ "bimap", ] + +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + +[[package]] +name = "mio" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "object" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "proc-macro2" +version = "1.0.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "proto" +version = "0.1.0" +dependencies = [ + "async-trait", + "bincode", + "thiserror", + "tokio", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "server" +version = "0.1.0" +dependencies = [ + "anyhow", + "proto", + "tokio", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + +[[package]] +name = "smallvec" +version = "1.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" + +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "syn" +version = "2.0.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio" +version = "1.35.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "virtue" +version = "0.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dcc60c0624df774c82a0ef104151231d37da4962957d691c011c852b2473314" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/Cargo.toml b/Cargo.toml index 3e0b7c2..714e4bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,4 +2,7 @@ resolver = "2" members = [ "minisql", + "proto", + "server", + "client" ] diff --git a/client/Cargo.toml b/client/Cargo.toml new file mode 100644 index 0000000..9cf09e6 --- /dev/null +++ b/client/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "client" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio = { version = "1.35.1", features = ["full"] } +anyhow = "1.0.76" +proto = { path = "../proto" } \ No newline at end of file diff --git a/client/src/main.rs b/client/src/main.rs new file mode 100644 index 0000000..e77ea5c --- /dev/null +++ b/client/src/main.rs @@ -0,0 +1,80 @@ +use proto::handshake::client::do_client_handshake; +use proto::handshake::request::HandshakeRequest; +use proto::reader::protoreader::ProtoReader; +use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::TcpStream; +use proto::message::backend::{BackendMessage, DataRowData, RowDescriptionData}; +use proto::message::frontend::{FrontendMessage, QueryData}; +use proto::reader::oneway::OneWayProtoReader; +use proto::writer::oneway::OneWayProtoWriter; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let addr = "127.0.0.1:5432"; + + let mut stream = TcpStream::connect(addr).await?; + let (reader, writer) = stream.split(); + + let mut writer = ProtoWriter::new(BufWriter::new(writer)); + let mut reader = ProtoReader::new(BufReader::new(reader), 1024); + + let request = HandshakeRequest::new(196608) + .parameter("user", "test user") + .parameter("client_encoding", "UTF8"); + + let response = do_client_handshake(&mut writer, &mut reader, request).await?; + + println!("Handshake complete:\n{response:?}"); + + writer.write_proto(FrontendMessage::Query(QueryData { + query: "SELECT * FROM users;".to_string().into(), + })).await?; + writer.flush().await?; + + loop { + let msg: BackendMessage = reader.read_proto().await?; + match msg { + BackendMessage::RowDescription(data) => { + print_header(data); + }, + BackendMessage::DataRow(data) => { + print_row(data); + }, + BackendMessage::CommandComplete(data) => { + println!("Command complete: {:?}", data); + }, + BackendMessage::ReadyForQuery(data) => { + println!("Ready for query: {:?}", data); + break; + }, + m => { + println!("Unexpected message: {:?}", m); + } + } + } + + writer.write_proto(FrontendMessage::Terminate).await?; + writer.flush().await?; + + Ok(()) +} + +fn print_header(header: RowDescriptionData) { + print!("Header -> "); + for column in Vec::from(header.columns) { + print!("{} | ", column.name.as_str()); + } + println!(); +} + +fn print_row(row: DataRowData) { + print!("Row -> "); + for column in Vec::from(row.columns) { + let bytes = Vec::from(column); + let string = String::from_utf8(bytes).unwrap(); + + print!("{} | ", string); + } + println!(); +} diff --git a/proto/Cargo.toml b/proto/Cargo.toml new file mode 100644 index 0000000..c8c9134 --- /dev/null +++ b/proto/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "proto" +version = "0.1.0" +edition = "2021" + +[dependencies] +bincode = "2.0.0-rc.3" +tokio = { version = "1.34.0", features = ["io-util", "macros", "test-util"] } +async-trait = "0.1.74" +thiserror = "1.0.50" diff --git a/proto/src/handshake/client.rs b/proto/src/handshake/client.rs new file mode 100644 index 0000000..ff3aaed --- /dev/null +++ b/proto/src/handshake/client.rs @@ -0,0 +1,44 @@ +use crate::handshake::errors::ClientHandshakeError; +use crate::handshake::request::HandshakeRequest; +use crate::handshake::response::HandshakeResponse; +use crate::message::backend::{AuthenticationOkData, BackendMessage}; +use crate::message::special::StartupMessageData; +use crate::reader::backend::BackendProtoReader; +use crate::writer::frontend::FrontendProtoWriter; +use crate::writer::protowriter::ProtoFlush; + +/// Performs client-side handshake with the server until the `ReadyForQuery` message is received. +/// For more info visit the [`55.2.1. Start-up`](https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-START-UP) +pub async fn do_client_handshake( + writer: &mut (impl FrontendProtoWriter + ProtoFlush), + reader: &mut impl BackendProtoReader, + request: HandshakeRequest, +) -> Result { + + // Send StartupMessage without SSLRequest + let startup_message: StartupMessageData = request.into(); + writer.write_startup_message(startup_message).await?; + writer.flush().await?; + + // Wait for AuthenticationOk + let auth = reader.read_proto().await?; + if !matches!( + auth, + BackendMessage::AuthenticationOk(AuthenticationOkData { status: 0 }) + ) { + return Err(ClientHandshakeError::UnexpectedAuthResponse(auth)); + } + + // Read server parameter messages until ReadyForQuery is received + let mut messages = Vec::new(); + loop { + let msg = reader.read_proto().await?; + if matches!(msg, BackendMessage::ReadyForQuery(_)) { + break; + } + + messages.push(msg); + } + + HandshakeResponse::try_from(messages.as_slice()) +} diff --git a/proto/src/handshake/errors.rs b/proto/src/handshake/errors.rs new file mode 100644 index 0000000..0811790 --- /dev/null +++ b/proto/src/handshake/errors.rs @@ -0,0 +1,36 @@ +use crate::message::backend::BackendMessage; +use crate::message::errors::ProtoDeserializeError; +use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError}; +use crate::writer::errors::ProtoWriteError; +use thiserror::Error; +use tokio::io; + +#[derive(Debug, Error)] +pub enum ClientHandshakeError { + #[error("unexpected response from server")] + UnexpectedResponse, + #[error("unexpected auth response")] + UnexpectedAuthResponse(BackendMessage), + #[error("socket communication failed")] + Io(#[from] io::Error), + #[error("writing message to socket failed")] + Write(#[from] ProtoWriteError), + #[error("reading message from socket failed")] + Read(#[from] ProtoReadError), +} + +#[derive(Debug, Error)] +pub enum ServerHandshakeError { + #[error("startup message not found")] + MissingStartupMessage, + #[error("socket communication failed")] + Io(#[from] io::Error), + #[error("deserialization of inner data failed")] + Deserialize(#[from] ProtoDeserializeError), + #[error("peeking special message failed")] + Peek(#[from] ProtoPeekError), + #[error("consuming special message failed")] + Consume(#[from] ProtoConsumeError), + #[error("writing message to socket failed")] + Write(#[from] ProtoWriteError), +} diff --git a/proto/src/handshake/mod.rs b/proto/src/handshake/mod.rs new file mode 100644 index 0000000..61e9c24 --- /dev/null +++ b/proto/src/handshake/mod.rs @@ -0,0 +1,5 @@ +pub mod client; +pub mod errors; +pub mod request; +pub mod response; +pub mod server; diff --git a/proto/src/handshake/request.rs b/proto/src/handshake/request.rs new file mode 100644 index 0000000..51b6ad5 --- /dev/null +++ b/proto/src/handshake/request.rs @@ -0,0 +1,46 @@ +use crate::message::primitive::pgstring::PgString; +use crate::message::special::StartupMessageData; + +#[derive(Debug)] +pub struct HandshakeRequest { + pub version: i32, + pub parameters: Vec<(PgString, PgString)>, +} + +impl HandshakeRequest { + + /// Creates a new `HandshakeRequest` with the specified protocol version. + /// Expected `version` is `196608` for the 3.0. + pub fn new(version: i32) -> Self { + Self { + version, + parameters: Vec::new(), + } + } + + /// Adds a parameter to the startup message. + /// Generally recognized names are `user`, `database`, `option` and `replication` but others can be used. + /// For more info visit [`StartupMessage`](https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE) + pub fn parameter(mut self, key: &str, value: &str) -> Self { + self.parameters.push((key.into(), value.into())); + self + } +} + +impl From for StartupMessageData { + fn from(request: HandshakeRequest) -> Self { + Self { + version: request.version, + params: request.parameters, + } + } +} + +impl From for HandshakeRequest { + fn from(data: StartupMessageData) -> Self { + Self { + version: data.version, + parameters: data.params, + } + } +} diff --git a/proto/src/handshake/response.rs b/proto/src/handshake/response.rs new file mode 100644 index 0000000..60d0b6f --- /dev/null +++ b/proto/src/handshake/response.rs @@ -0,0 +1,69 @@ +use crate::handshake::errors::ClientHandshakeError; +use crate::message::backend::{BackendKeyDataData, BackendMessage, ParameterStatusData}; + +#[derive(Debug)] +pub struct HandshakeResponse { + pub version: String, + pub process_id: i32, + pub secret_key: i32, +} + +impl HandshakeResponse { + pub fn new(name: &str, pid: i32, key: i32) -> Self { + Self { + version: format!("16.0 ({name})"), + process_id: pid, + secret_key: key, + } + } +} + +impl TryFrom<&[BackendMessage]> for HandshakeResponse { + type Error = ClientHandshakeError; + + fn try_from(messages: &[BackendMessage]) -> Result { + let mut version = None; + let mut process_id = None; + let mut secret_key = None; + + for message in messages { + match message { + BackendMessage::ParameterStatus(data) => { + if data.name.as_str() == "server_version" { + version = Some(String::from(data.value.as_str())); + } + } + BackendMessage::BackendKeyData(data) => { + process_id = Some(data.process); + secret_key = Some(data.secret); + } + // Different messages are ignored during the handshake + _ => {} + } + } + + match (version, process_id, secret_key) { + (Some(version), Some(process_id), Some(secret_key)) => Ok(Self { + version, + process_id, + secret_key, + }), + _ => Err(ClientHandshakeError::UnexpectedResponse), + } + } +} + +impl From for Vec { + fn from(response: HandshakeResponse) -> Self { + vec![ + BackendMessage::ParameterStatus(ParameterStatusData { + name: "server_version".into(), + value: response.version.into(), + }), + BackendMessage::BackendKeyData(BackendKeyDataData { + process: response.process_id, + secret: response.secret_key, + }), + ] + } +} diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs new file mode 100644 index 0000000..6c8deb2 --- /dev/null +++ b/proto/src/handshake/server.rs @@ -0,0 +1,59 @@ +use crate::handshake::errors::ServerHandshakeError; +use crate::handshake::request::HandshakeRequest; +use crate::handshake::response::HandshakeResponse; +use crate::message::backend::{AuthenticationOkData, BackendMessage, ReadyForQueryData}; +use crate::message::special::SpecialMessage; +use crate::reader::frontend::FrontendProtoReader; +use crate::writer::backend::BackendProtoWriter; +use crate::writer::protowriter::ProtoFlush; + +/// Performs server-side handshake with the client until ending it with `ReadyForQuery` message. +/// For more info visit the [`55.2.1. Start-up`](https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-START-UP) +pub async fn do_server_handshake( + writer: &mut (impl BackendProtoWriter + ProtoFlush), + reader: &mut impl FrontendProtoReader, + response: HandshakeResponse, +) -> Result { + + // Check if client requested SSL + match &reader.peek_special_message().await? { + Some(msg @ SpecialMessage::SSLRequest) => { + reader.consume_special_message(msg).await?; + writer.write_ssl_reject().await?; + writer.flush().await?; + } + _ => { + // No SSL request + } + } + + // Wait for mandatory StartupMessage + let startup_message = match &reader.peek_special_message().await? { + Some(msg @ SpecialMessage::StartupMessage(data)) => { + reader.consume_special_message(msg).await?; + data.clone() + } + _ => { + return Err(ServerHandshakeError::MissingStartupMessage); + } + }; + + // Authenticate client + writer + .write_proto(BackendMessage::from(AuthenticationOkData { status: 0 })) + .await?; + + // Send server parameters + let messages: Vec = response.into(); + for message in messages { + writer.write_proto(message).await?; + } + + // Finish the handshake + writer + .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) + .await?; + + writer.flush().await?; + Ok(startup_message.into()) +} diff --git a/proto/src/lib.rs b/proto/src/lib.rs new file mode 100644 index 0000000..e9d155d --- /dev/null +++ b/proto/src/lib.rs @@ -0,0 +1,8 @@ +//! # PostgreSQL Protocol +//! Low-level PostgreSQL protocol implementation for the server version 16, protocol version 3.0. +//! Includes server and client side handshake with no password authentication. + +pub mod handshake; +pub mod message; +pub mod reader; +pub mod writer; diff --git a/proto/src/message/backend.rs b/proto/src/message/backend.rs new file mode 100644 index 0000000..869fe14 --- /dev/null +++ b/proto/src/message/backend.rs @@ -0,0 +1,372 @@ +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; +use crate::message::primitive::data::MessageData; +use crate::message::primitive::pglist::PgList; +use crate::message::primitive::pgstring::PgString; +use crate::message::proto_message::ProtoMessage; +use bincode::{Decode, Encode}; + +/// Backend messages sent from the server to the client. +/// For more info visit the [`55.2.3. Message Formats`](https://www.postgresql.org/docs/current/protocol-message-formats.html) +#[derive(Debug)] +pub enum BackendMessage { + AuthenticationOk(AuthenticationOkData), + BackendKeyData(BackendKeyDataData), + CommandComplete(CommandCompleteData), + DataRow(DataRowData), + EmptyQueryResponse, + ErrorResponse(ErrorResponseData), + NoData, + ParameterStatus(ParameterStatusData), + ReadyForQuery(ReadyForQueryData), + RowDescription(RowDescriptionData), +} + +impl ProtoMessage for BackendMessage { + fn variant(&self) -> u8 { + match self { + BackendMessage::AuthenticationOk(_) => b'R', + BackendMessage::BackendKeyData(_) => b'K', + BackendMessage::CommandComplete(_) => b'C', + BackendMessage::DataRow(_) => b'D', + BackendMessage::EmptyQueryResponse => b'I', + BackendMessage::ErrorResponse(_) => b'E', + BackendMessage::NoData => b'n', + BackendMessage::ParameterStatus(_) => b'S', + BackendMessage::ReadyForQuery(_) => b'Z', + BackendMessage::RowDescription(_) => b'T', + } + } + + fn serialize(&self) -> Result, ProtoSerializeError> { + match self { + BackendMessage::AuthenticationOk(data) => data.serialize(), + BackendMessage::BackendKeyData(data) => data.serialize(), + BackendMessage::CommandComplete(data) => data.serialize(), + BackendMessage::DataRow(data) => data.serialize(), + BackendMessage::EmptyQueryResponse => Ok(Vec::with_capacity(0)), + BackendMessage::ErrorResponse(data) => data.serialize(), + BackendMessage::NoData => Ok(Vec::with_capacity(0)), + BackendMessage::ParameterStatus(data) => data.serialize(), + BackendMessage::ReadyForQuery(data) => data.serialize(), + BackendMessage::RowDescription(data) => data.serialize(), + } + } + + fn deserialize(variant: u8, data: &[u8]) -> Result { + match variant { + b'R' => Ok(BackendMessage::AuthenticationOk( + AuthenticationOkData::deserialize(data)?, + )), + b'K' => { + let data = BackendKeyDataData::deserialize(data)?; + Ok(BackendMessage::BackendKeyData(data)) + } + b'C' => { + let data = CommandCompleteData::deserialize(data)?; + Ok(BackendMessage::CommandComplete(data)) + } + b'D' => { + let data = DataRowData::deserialize(data)?; + Ok(BackendMessage::DataRow(data)) + } + b'I' => Ok(BackendMessage::EmptyQueryResponse), + b'E' => { + let data = ErrorResponseData::deserialize(data)?; + Ok(BackendMessage::ErrorResponse(data)) + } + b'n' => Ok(BackendMessage::NoData), + b'S' => { + let data = ParameterStatusData::deserialize(data)?; + Ok(BackendMessage::ParameterStatus(data)) + } + b'Z' => { + let data = ReadyForQueryData::deserialize(data)?; + Ok(BackendMessage::ReadyForQuery(data)) + } + b'T' => { + let data = RowDescriptionData::deserialize(data)?; + Ok(BackendMessage::RowDescription(data)) + } + v => Err(ProtoDeserializeError::InvalidVariant(v)), + } + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct AuthenticationOkData { + pub status: i32, +} + +impl From for BackendMessage { + fn from(data: AuthenticationOkData) -> Self { + BackendMessage::AuthenticationOk(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct BackendKeyDataData { + pub process: i32, + pub secret: i32, +} + +impl From for BackendMessage { + fn from(data: BackendKeyDataData) -> Self { + BackendMessage::BackendKeyData(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct CommandCompleteData { + pub tag: PgString, +} + +impl From for BackendMessage { + fn from(data: CommandCompleteData) -> Self { + BackendMessage::CommandComplete(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct DataRowData { + pub columns: PgList, i16>, +} + +impl From for BackendMessage { + fn from(data: DataRowData) -> Self { + BackendMessage::DataRow(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ErrorResponseData { + pub code: u8, + pub message: PgString, +} + +impl From for BackendMessage { + fn from(data: ErrorResponseData) -> Self { + BackendMessage::ErrorResponse(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ParameterStatusData { + pub name: PgString, + pub value: PgString, +} + +impl From for BackendMessage { + fn from(data: ParameterStatusData) -> Self { + BackendMessage::ParameterStatus(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ReadyForQueryData { + pub status: u8, +} + +impl From for BackendMessage { + fn from(data: ReadyForQueryData) -> Self { + BackendMessage::ReadyForQuery(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct RowDescriptionData { + pub columns: PgList, +} + +impl From for BackendMessage { + fn from(data: RowDescriptionData) -> Self { + BackendMessage::RowDescription(data) + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ColumnDescription { + pub name: PgString, + pub table_oid: i32, + pub column_index: i16, + pub type_oid: i32, + pub type_size: i16, + pub type_modifier: i32, + pub format_code: i16, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symmetric_authentication_ok() { + let backend = BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 }) + )); + } + + #[test] + fn test_symmetric_backend_key_data() { + let backend = BackendMessage::BackendKeyData(BackendKeyDataData { + process: 123, + secret: 456, + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::BackendKeyData(BackendKeyDataData { + process: 123, + secret: 456 + }) + )); + } + + #[test] + fn test_symmetric_command_complete() { + let backend = BackendMessage::CommandComplete(CommandCompleteData { + tag: PgString::from("SELECT 1"), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::CommandComplete(CommandCompleteData { tag }) if tag.as_str() == "SELECT 1" + )); + } + + #[test] + fn test_symmetric_data_row() { + let backend = BackendMessage::DataRow(DataRowData { + columns: PgList::from(vec![PgList::from(vec![1, 2, 3])]), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::DataRow(DataRowData { columns }) if columns == PgList::from(vec![PgList::from(vec![1, 2, 3])]) + )); + } + + #[test] + fn test_symmetric_empty_query_response() { + let backend = BackendMessage::EmptyQueryResponse; + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!(message, BackendMessage::EmptyQueryResponse)); + } + + #[test] + fn test_symmetric_error_response() { + let backend = BackendMessage::ErrorResponse(ErrorResponseData { + code: b'X', + message: PgString::from("Some error"), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::ErrorResponse(ErrorResponseData { code, message }) if code == b'X' && message.as_str() == "Some error" + )); + } + + #[test] + fn test_symmetric_no_data() { + let backend = BackendMessage::NoData; + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!(message, BackendMessage::NoData)); + } + + #[test] + fn test_symmetric_parameter_status() { + let backend = BackendMessage::ParameterStatus(ParameterStatusData { + name: PgString::from("Some name"), + value: PgString::from("Some value"), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::ParameterStatus(ParameterStatusData { name, value }) if name.as_str() == "Some name" && value.as_str() == "Some value" + )); + } + + #[test] + fn test_symmetric_ready_for_query() { + let backend = BackendMessage::ReadyForQuery(ReadyForQueryData { status: b'I' }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!( + message, + BackendMessage::ReadyForQuery(ReadyForQueryData { status }) if status == b'I' + )); + } + + #[test] + fn test_symmetric_row_description() { + let backend = BackendMessage::RowDescription(RowDescriptionData { + columns: PgList::from(vec![ColumnDescription { + name: PgString::from("Some name"), + table_oid: 123, + column_index: 456, + type_oid: 789, + type_size: 101, + type_modifier: 112, + format_code: 113, + }]), + }); + let raw = backend.serialize().unwrap(); + let variant = backend.variant(); + + let message = BackendMessage::deserialize(variant, &raw).unwrap(); + assert!(match message { + BackendMessage::RowDescription(RowDescriptionData { columns }) => { + let columns: Vec = columns.into(); + let column = &columns[0]; + column.name.as_str() == "Some name" + && column.table_oid == 123 + && column.column_index == 456 + && column.type_oid == 789 + && column.type_size == 101 + && column.type_modifier == 112 + && column.format_code == 113 + } + _ => false, + },) + } + + #[test] + fn test_unknown_variant() { + let variant = 0; + let data = vec![1, 2, 3]; + + let message = BackendMessage::deserialize(variant, &data); + assert!(matches!( + message, + Err(ProtoDeserializeError::InvalidVariant(0)) + )); + } +} diff --git a/proto/src/message/errors.rs b/proto/src/message/errors.rs new file mode 100644 index 0000000..7a3e04e --- /dev/null +++ b/proto/src/message/errors.rs @@ -0,0 +1,16 @@ +use bincode::error::{DecodeError, EncodeError}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ProtoDeserializeError { + #[error("invalid message variant: {0}")] + InvalidVariant(u8), + #[error("decoding of inner data failed")] + DecodeData(#[from] DecodeError), +} + +#[derive(Debug, Error)] +pub enum ProtoSerializeError { + #[error("encoding of inner data failed")] + EncodeData(#[from] EncodeError), +} diff --git a/proto/src/message/frontend.rs b/proto/src/message/frontend.rs new file mode 100644 index 0000000..648938e --- /dev/null +++ b/proto/src/message/frontend.rs @@ -0,0 +1,84 @@ +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; +use crate::message::primitive::data::MessageData; +use crate::message::primitive::pgstring::PgString; +use crate::message::proto_message::ProtoMessage; +use bincode::{Decode, Encode}; + +/// Frontend messages sent from the client to the server. +/// For more info visit the [`55.2.3. Message Formats`](https://www.postgresql.org/docs/current/protocol-message-formats.html) +#[derive(Debug)] +pub enum FrontendMessage { + Query(QueryData), + Terminate, +} + +impl ProtoMessage for FrontendMessage { + fn variant(&self) -> u8 { + match self { + FrontendMessage::Query(_) => b'Q', + FrontendMessage::Terminate => b'X', + } + } + + fn serialize(&self) -> Result, ProtoSerializeError> { + match self { + FrontendMessage::Query(data) => data.serialize(), + FrontendMessage::Terminate => Ok(Vec::with_capacity(0)), + } + } + + fn deserialize(variant: u8, data: &[u8]) -> Result { + match variant { + b'Q' => Ok(FrontendMessage::Query(QueryData::deserialize(data)?)), + b'X' => Ok(FrontendMessage::Terminate), + v => Err(ProtoDeserializeError::InvalidVariant(v)), + } + } +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct QueryData { + pub query: PgString, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::backend::BackendMessage; + + #[test] + fn test_symmetric_query() { + let frontend = FrontendMessage::Query(QueryData { + query: PgString::from("SELECT * FROM foo WHERE bar = $1"), + }); + let raw = frontend.serialize().unwrap(); + let variant = frontend.variant(); + + let message = FrontendMessage::deserialize(variant, &raw).unwrap(); + assert!( + matches!(message, FrontendMessage::Query(QueryData { query }) if query.as_str() == "SELECT * FROM foo WHERE bar = $1") + ); + } + + #[test] + fn test_symmetric_terminate() { + let frontend = FrontendMessage::Terminate; + let raw = frontend.serialize().unwrap(); + let variant = frontend.variant(); + + let message = FrontendMessage::deserialize(variant, &raw).unwrap(); + assert!(matches!(message, FrontendMessage::Terminate)); + } + + #[test] + fn test_unknown_variant() { + let variant = 0; + let data = vec![1, 2, 3]; + + let message = BackendMessage::deserialize(variant, &data); + assert!(matches!( + message, + Err(ProtoDeserializeError::InvalidVariant(0)) + )); + } +} diff --git a/proto/src/message/mod.rs b/proto/src/message/mod.rs new file mode 100644 index 0000000..0d8130c --- /dev/null +++ b/proto/src/message/mod.rs @@ -0,0 +1,6 @@ +pub mod backend; +pub mod errors; +pub mod frontend; +pub mod primitive; +pub mod proto_message; +pub mod special; diff --git a/proto/src/message/primitive/data.rs b/proto/src/message/primitive/data.rs new file mode 100644 index 0000000..db19ad4 --- /dev/null +++ b/proto/src/message/primitive/data.rs @@ -0,0 +1,29 @@ +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; +use bincode::{Decode, Encode}; +use bincode::config::{BigEndian, Configuration, Fixint}; + +fn pg_proto_config() -> Configuration { + bincode::config::standard() + .with_big_endian() + .with_fixed_int_encoding() +} + +pub trait MessageData: Sized { + fn serialize(&self) -> Result, ProtoSerializeError>; + fn deserialize(data: &[u8]) -> Result; +} + +impl MessageData for T +where + T: Encode + Decode, +{ + #[inline] + fn serialize(&self) -> Result, ProtoSerializeError> { + Ok(bincode::encode_to_vec(self, pg_proto_config())?) + } + + #[inline] + fn deserialize(data: &[u8]) -> Result { + Ok(bincode::decode_from_slice(data, pg_proto_config())?.0) + } +} diff --git a/proto/src/message/primitive/mod.rs b/proto/src/message/primitive/mod.rs new file mode 100644 index 0000000..4e84a1b --- /dev/null +++ b/proto/src/message/primitive/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod data; +pub mod pglist; +pub mod pgstring; diff --git a/proto/src/message/primitive/pglist.rs b/proto/src/message/primitive/pglist.rs new file mode 100644 index 0000000..1e76db3 --- /dev/null +++ b/proto/src/message/primitive/pglist.rs @@ -0,0 +1,87 @@ +use bincode::de::Decoder; +use bincode::enc::Encoder; +use bincode::error::{DecodeError, EncodeError}; +use bincode::{BorrowDecode, Decode, Encode}; +use std::marker::PhantomData; + +/// Item list common in PostgreSQL messages. +/// - Generic type `T` is the type of the items in the list. +/// - Generic type `U` is the type of the list length (`i16` or `i32`). +#[derive(Debug, Clone, PartialEq, BorrowDecode)] +pub struct PgList(Vec, PhantomData); + +impl PgList { + pub fn as_slice(&self) -> &[T] { + &self.0 + } +} + +impl From> for Vec { + fn from(pg_list: PgList) -> Self { + pg_list.0 + } +} + +impl From> for PgList { + fn from(list: Vec) -> Self { + PgList(list, PhantomData) + } +} + +impl Encode for PgList +where + T: Encode, +{ + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let length = self.0.len() as i16; + length.encode(encoder)?; + for item in &self.0 { + item.encode(encoder)?; + } + Ok(()) + } +} + +impl Decode for PgList +where + T: Decode, +{ + fn decode(decoder: &mut D) -> Result { + let length = i16::decode(decoder)?; + let mut list = Vec::new(); + for _ in 0..length { + list.push(T::decode(decoder)?); + } + + Ok(PgList(list, PhantomData)) + } +} + +impl Encode for PgList +where + T: Encode, +{ + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let length = self.0.len() as i32; + length.encode(encoder)?; + for item in &self.0 { + item.encode(encoder)?; + } + Ok(()) + } +} + +impl Decode for PgList +where + T: Decode, +{ + fn decode(decoder: &mut D) -> Result { + let length = i32::decode(decoder)?; + let mut list = Vec::new(); + for _ in 0..length { + list.push(T::decode(decoder)?); + } + + Ok(PgList(list, PhantomData)) + } +} diff --git a/proto/src/message/primitive/pgstring.rs b/proto/src/message/primitive/pgstring.rs new file mode 100644 index 0000000..58fad78 --- /dev/null +++ b/proto/src/message/primitive/pgstring.rs @@ -0,0 +1,55 @@ +use bincode::de::Decoder; +use bincode::enc::write::Writer; +use bincode::enc::Encoder; +use bincode::error::{DecodeError, EncodeError}; +use bincode::{BorrowDecode, Decode, Encode}; + +/// PostgreSQL format of string encoded as a null-terminated string. +#[derive(Debug, Clone, BorrowDecode)] +pub struct PgString(String); + +impl PgString { + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl From<&str> for PgString { + fn from(string: &str) -> Self { + PgString(string.to_string()) + } +} + +impl From for String { + fn from(pg_string: PgString) -> Self { + pg_string.0 + } +} + +impl From for PgString { + fn from(string: String) -> Self { + PgString(string) + } +} + +impl Encode for PgString { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + encoder.writer().write(self.0.as_bytes())?; + encoder.writer().write(b"\0") + } +} + +impl Decode for PgString { + fn decode(decoder: &mut D) -> Result { + let mut string = String::new(); + loop { + let byte = u8::decode(decoder)?; + if byte == 0 { + break; + } + string.push(byte as char); + } + + Ok(PgString(string)) + } +} diff --git a/proto/src/message/proto_message.rs b/proto/src/message/proto_message.rs new file mode 100644 index 0000000..13986e1 --- /dev/null +++ b/proto/src/message/proto_message.rs @@ -0,0 +1,7 @@ +use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError}; + +pub trait ProtoMessage: Sized { + fn variant(&self) -> u8; + fn serialize(&self) -> Result, ProtoSerializeError>; + fn deserialize(variant: u8, data: &[u8]) -> Result; +} diff --git a/proto/src/message/special.rs b/proto/src/message/special.rs new file mode 100644 index 0000000..6c45ab9 --- /dev/null +++ b/proto/src/message/special.rs @@ -0,0 +1,65 @@ +use crate::message::primitive::pgstring::PgString; +use bincode::de::Decoder; +use bincode::enc::Encoder; +use bincode::error::{DecodeError, EncodeError}; +use bincode::{Decode, Encode}; + +/// Special messages sent during handshake or to cancel request. +/// Sent in different format to preserve compatibility with older protocol versions. +#[derive(Debug)] +pub enum SpecialMessage { + /// Sent by client to cancel request. + CancelRequest(CancelRequestData), + /// Sent by client to request upgrade to SSL connection. + SSLRequest, + /// Sent by client to initiate the handshake. + StartupMessage(StartupMessageData), +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct CancelRequestData { + pub pid: i32, + pub secret: i32, +} + +#[derive(Debug, Clone)] +pub struct StartupMessageData { + pub version: i32, + pub params: Vec<(PgString, PgString)>, +} + +impl Encode for StartupMessageData { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + self.version.encode(encoder)?; + for (key, value) in &self.params { + key.encode(encoder)?; + value.encode(encoder)?; + } + Ok(()) + } +} + +impl Decode for StartupMessageData { + fn decode(decoder: &mut D) -> Result { + let version = i32::decode(decoder)?; + let mut params = Vec::new(); + loop { + let maybe_key = PgString::decode(decoder); + match maybe_key { + Ok(_) => {} + Err(DecodeError::UnexpectedEnd { .. }) => break, + Err(e) => return Err(e), + } + + let maybe_value = PgString::decode(decoder); + match maybe_value { + Ok(_) => {} + Err(DecodeError::UnexpectedEnd { .. }) => break, + Err(e) => return Err(e), + } + + params.push((maybe_key.unwrap(), maybe_value.unwrap())); + } + Ok(StartupMessageData { version, params }) + } +} diff --git a/proto/src/reader/backend.rs b/proto/src/reader/backend.rs new file mode 100644 index 0000000..33db099 --- /dev/null +++ b/proto/src/reader/backend.rs @@ -0,0 +1,58 @@ +use crate::message::backend::BackendMessage; +use crate::reader::oneway::OneWayProtoReader; +use async_trait::async_trait; + +#[async_trait] +pub trait BackendProtoReader: OneWayProtoReader {} + +#[async_trait] +impl BackendProtoReader for R where R: OneWayProtoReader {} + +#[cfg(test)] +mod tests { + use crate::message::backend::{ + AuthenticationOkData, BackendKeyDataData, BackendMessage, CommandCompleteData, + }; + use crate::reader::oneway::OneWayProtoReader; + use crate::reader::protoreader::ProtoReader; + use std::io::Cursor; + use tokio::io::{AsyncBufReadExt, BufReader}; + + #[tokio::test] + async fn test_message_sequence() { + let data = [ + b'R', 0, 0, 0, 8, 0, 0, 0, 123, b'K', 0, 0, 0, 12, 0, 0, 0, 111, 0, 0, 0, 222, b'C', 0, + 0, 0, 8, b'A', b'B', b'C', 0, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let msg = reader.read_proto().await; + assert!(matches!( + msg, + Ok(BackendMessage::AuthenticationOk(AuthenticationOkData { + status: 123 + })) + )); + + let msg = reader.read_proto().await; + assert!(matches!( + msg, + Ok(BackendMessage::BackendKeyData(BackendKeyDataData { + process: 111, + secret: 222 + })) + )); + + let msg = reader.read_proto().await; + assert!(match msg { + Ok(BackendMessage::CommandComplete(CommandCompleteData { tag })) => + tag.as_str() == "ABC", + _ => false, + }); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } +} diff --git a/proto/src/reader/errors.rs b/proto/src/reader/errors.rs new file mode 100644 index 0000000..78b138d --- /dev/null +++ b/proto/src/reader/errors.rs @@ -0,0 +1,33 @@ +use crate::message::errors::ProtoDeserializeError; +use thiserror::Error; +use tokio::io; + +#[derive(Debug, Error)] +pub enum ProtoReadError { + #[error("message has invalid length, got {0}")] + InvalidLength(i32), + #[error("message has too much data, got {actual}, limit is {limit}")] + LengthOverflow { limit: usize, actual: usize }, + #[error("reading from socket failed")] + Io(#[from] io::Error), + #[error("deserialization of inner data failed")] + Deserialize(#[from] ProtoDeserializeError), +} + +#[derive(Debug, Error)] +pub enum ProtoPeekError { + #[error("message has too much data, got {actual}, limit is {limit}")] + LengthOverflow { limit: usize, actual: usize }, + #[error("reading from socket failed")] + Io(#[from] io::Error), + #[error("deserialization of inner data failed")] + Deserialize(#[from] ProtoDeserializeError), +} + +#[derive(Debug, Error)] +pub enum ProtoConsumeError { + #[error("unexpected data length, expected {expected}, got {actual}")] + UnexpectedDataLength { expected: usize, actual: usize }, + #[error("reading from socket failed")] + Io(#[from] io::Error), +} diff --git a/proto/src/reader/frontend.rs b/proto/src/reader/frontend.rs new file mode 100644 index 0000000..bb2ffc8 --- /dev/null +++ b/proto/src/reader/frontend.rs @@ -0,0 +1,295 @@ +use crate::message::frontend::FrontendMessage; +use crate::message::primitive::data::MessageData; +use crate::message::special::{CancelRequestData, SpecialMessage, StartupMessageData}; +use crate::reader::errors::{ProtoConsumeError, ProtoPeekError}; +use crate::reader::oneway::OneWayProtoReader; +use crate::reader::protoreader::ProtoReader; +use crate::reader::utils::AsyncPeek; +use async_trait::async_trait; +use tokio::io; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +#[async_trait] +pub trait FrontendProtoReader: OneWayProtoReader { + async fn peek_special_message(&mut self) -> Result, ProtoPeekError>; + async fn consume_special_message( + &mut self, + msg: &SpecialMessage, + ) -> Result<(), ProtoConsumeError>; +} + +#[async_trait] +impl FrontendProtoReader for ProtoReader +where + R: AsyncBufRead + Unpin + Send, +{ + async fn peek_special_message(&mut self) -> Result, ProtoPeekError> { + if let Some(cancel) = try_get_cancel_request(&mut self).await? { + return Ok(Some(cancel)); + } + + if let Some(ssl) = try_get_ssl_request(&mut self).await? { + return Ok(Some(ssl)); + } + + if let Some(startup) = try_get_startup_message(&mut self).await? { + return Ok(Some(startup)); + } + + Ok(None) + } + + async fn consume_special_message( + &mut self, + msg: &SpecialMessage, + ) -> Result<(), ProtoConsumeError> { + Ok(match msg { + SpecialMessage::CancelRequest(_) => consume_cancel_request(self), + SpecialMessage::SSLRequest => consume_ssl_request(self), + SpecialMessage::StartupMessage(_) => consume_startup_message(self).await?, + }) + } +} + +async fn try_get_cancel_request( + reader: &mut ProtoReader, +) -> Result, io::Error> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 16]; + if reader.inner.peek(&mut header).await? != 16 { + return Ok(None); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]); + if length != 16 { + return Ok(None); + } + + let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); + if code != 80877102 { + return Ok(None); + } + + let pid = i32::from_be_bytes([header[8], header[9], header[10], header[11]]); + let secret = i32::from_be_bytes([header[12], header[13], header[14], header[15]]); + + Ok(Some(SpecialMessage::CancelRequest(CancelRequestData { + pid, + secret, + }))) +} + +fn consume_cancel_request(reader: &mut ProtoReader) +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + reader.inner.consume(16); +} + +async fn try_get_ssl_request( + reader: &mut ProtoReader, +) -> Result, io::Error> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 8]; + if reader.inner.peek(&mut header).await? != 8 { + return Ok(None); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]); + if length != 8 { + return Ok(None); + } + + let code = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); + if code != 80877103 { + return Ok(None); + } + + Ok(Some(SpecialMessage::SSLRequest)) +} + +fn consume_ssl_request(reader: &mut ProtoReader) +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + reader.inner.consume(8); +} + +async fn try_get_startup_message( + reader: &mut ProtoReader, +) -> Result, ProtoPeekError> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 8]; + if reader.inner.peek(&mut header).await? != 8 { + return Ok(None); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]); + if length < 8 { + return Ok(None); + } + if length > reader.msg_len_limit { + return Err(ProtoPeekError::LengthOverflow { + limit: reader.msg_len_limit as usize, + actual: length as usize, + }); + } + + let version = i32::from_be_bytes([header[4], header[5], header[6], header[7]]); + if version != 196608 { + return Ok(None); + } + + let length = length as usize; + let mut data = vec![0u8; length]; + if reader.inner.peek(&mut data).await? != length { + return Ok(None); + } + + let data = StartupMessageData::deserialize(&data[4..])?; + Ok(Some(SpecialMessage::StartupMessage(data))) +} + +async fn consume_startup_message(reader: &mut ProtoReader) -> Result<(), ProtoConsumeError> +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + let mut header = [0u8; 4]; + let size = reader.inner.peek(&mut header).await?; + if size != 4 { + return Err(ProtoConsumeError::UnexpectedDataLength { + expected: 4, + actual: size, + }); + } + + let length = i32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize; + if length < 8 { + return Err(ProtoConsumeError::UnexpectedDataLength { + expected: 8, + actual: length, + }); + } + + reader.inner.consume(length); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::frontend::QueryData; + use crate::message::special::StartupMessageData; + use std::io::Cursor; + use tokio::io::{AsyncBufReadExt, BufReader}; + + #[tokio::test] + async fn test_message_sequence() { + let data = [ + b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let msg = reader.read_proto().await; + assert!( + match &msg { + Ok(FrontendMessage::Query(QueryData { query })) => query.as_str() == "SLIME", + _ => false, + }, + "{msg:?}" + ); + + let msg = reader.read_proto().await; + assert!(matches!(msg, Ok(FrontendMessage::Terminate)), "{msg:?}"); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } + + #[tokio::test] + async fn test_cancel_request() { + let data = [ + 0, 0, 0, 16, 0x04, 0xD2, 0x16, 0x2E, 0, 0, 0, 111, 0, 0, 0, 222, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let peeked = reader.peek_special_message().await.unwrap(); + assert!(matches!( + peeked, + Some(SpecialMessage::CancelRequest(CancelRequestData { + pid: 111, + secret: 222 + })) + )); + + reader + .consume_special_message(&peeked.unwrap()) + .await + .unwrap(); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } + + #[tokio::test] + async fn test_ssl_request() { + let data = [0, 0, 0, 8, 0x04, 0xD2, 0x16, 0x2F]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let peeked = reader.peek_special_message().await.unwrap(); + assert!(matches!(peeked, Some(SpecialMessage::SSLRequest))); + + reader + .consume_special_message(&peeked.unwrap()) + .await + .unwrap(); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } + + #[tokio::test] + async fn test_startup_message() { + let data = [ + 0, 0, 0, 26, 0, 3, 0, 0, b'd', b'a', b't', b'a', b'b', b'a', b's', b'e', 0, b'b', b'r', + b'a', b'n', b'i', b'k', 0, 0, 0, + ]; + + let reader = BufReader::new(Cursor::new(&data)); + let mut reader = ProtoReader::new(reader, 1024); + + let peeked = reader.peek_special_message().await.unwrap(); + assert!(match &peeked { + Some(SpecialMessage::StartupMessage(StartupMessageData { + version: 196608, + params, + })) => + params.len() == 2 + && params[0].0.as_str() == "database" + && params[0].1.as_str() == "branik" + && params[1].0.as_str() == "" + && params[1].1.as_str() == "", + _ => false, + }); + + reader + .consume_special_message(&peeked.unwrap()) + .await + .unwrap(); + + let rest = reader.inner.fill_buf().await.unwrap(); + assert!(rest.is_empty()); + } +} diff --git a/proto/src/reader/mod.rs b/proto/src/reader/mod.rs new file mode 100644 index 0000000..41297de --- /dev/null +++ b/proto/src/reader/mod.rs @@ -0,0 +1,6 @@ +pub mod backend; +pub mod errors; +pub mod frontend; +pub mod oneway; +pub mod protoreader; +mod utils; diff --git a/proto/src/reader/oneway.rs b/proto/src/reader/oneway.rs new file mode 100644 index 0000000..d1db637 --- /dev/null +++ b/proto/src/reader/oneway.rs @@ -0,0 +1,41 @@ +use crate::message::proto_message::ProtoMessage; +use crate::reader::errors::ProtoReadError; +use crate::reader::protoreader::ProtoReader; +use crate::reader::utils::AsyncPeek; +use async_trait::async_trait; +use tokio::io::{AsyncBufRead, AsyncReadExt}; + +#[async_trait] +pub trait OneWayProtoReader +where + T: ProtoMessage, +{ + async fn read_proto(&mut self) -> Result; +} + +#[async_trait] +impl OneWayProtoReader for ProtoReader +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, + T: ProtoMessage, +{ + async fn read_proto(&mut self) -> Result { + let variant = self.inner.read_u8().await?; + let length = self.inner.read_i32().await?; + + if length < 4 { + return Err(ProtoReadError::InvalidLength(length)); + } + if length > self.msg_len_limit { + return Err(ProtoReadError::LengthOverflow { + limit: self.msg_len_limit as usize, + actual: length as usize, + }); + } + + let mut data = vec![0u8; (length - 4) as usize]; + self.inner.read_exact(&mut data).await?; + + Ok(T::deserialize(variant, &data)?) + } +} diff --git a/proto/src/reader/protoreader.rs b/proto/src/reader/protoreader.rs new file mode 100644 index 0000000..5e3f572 --- /dev/null +++ b/proto/src/reader/protoreader.rs @@ -0,0 +1,22 @@ +use crate::reader::utils::AsyncPeek; +use tokio::io::AsyncBufRead; + +pub struct ProtoReader +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + pub(super) inner: R, + pub(super) msg_len_limit: i32, +} + +impl ProtoReader +where + R: AsyncBufRead + AsyncPeek + Unpin + Send, +{ + pub fn new(reader: R, msg_len_limit: i32) -> ProtoReader { + ProtoReader { + inner: reader, + msg_len_limit, + } + } +} diff --git a/proto/src/reader/utils.rs b/proto/src/reader/utils.rs new file mode 100644 index 0000000..0ca8f85 --- /dev/null +++ b/proto/src/reader/utils.rs @@ -0,0 +1,24 @@ +use async_trait::async_trait; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +#[async_trait] +pub trait AsyncPeek { + async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result; +} + +#[async_trait] +impl AsyncPeek for T +where + T: AsyncBufRead + Unpin + Send, +{ + async fn peek(&mut self, buf: &mut [u8]) -> tokio::io::Result { + let filled = self.fill_buf().await?; + if filled.len() >= buf.len() { + buf.copy_from_slice(&filled[..buf.len()]); + Ok(buf.len()) + } else { + buf[..filled.len()].copy_from_slice(filled); + Ok(filled.len()) + } + } +} diff --git a/proto/src/writer/backend.rs b/proto/src/writer/backend.rs new file mode 100644 index 0000000..cc22e5c --- /dev/null +++ b/proto/src/writer/backend.rs @@ -0,0 +1,60 @@ +use crate::message::backend::BackendMessage; +use crate::writer::errors::ProtoWriteError; +use crate::writer::oneway::OneWayProtoWriter; +use crate::writer::protowriter::ProtoWriter; +use async_trait::async_trait; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +#[async_trait] +pub trait BackendProtoWriter: OneWayProtoWriter { + async fn write_ssl_reject(&mut self) -> Result<(), ProtoWriteError>; +} + +#[async_trait] +impl BackendProtoWriter for ProtoWriter +where + W: AsyncWrite + Unpin + Send, +{ + async fn write_ssl_reject(&mut self) -> Result<(), ProtoWriteError> { + self.inner.write_u8(b'N').await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::backend::AuthenticationOkData; + use crate::writer::protowriter::ProtoWriter; + use tokio::io::BufWriter; + + #[tokio::test] + async fn test_message_sequence() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer + .write_proto(BackendMessage::AuthenticationOk(AuthenticationOkData { + status: 123, + })) + .await + .unwrap(); + + writer.write_proto(BackendMessage::NoData).await.unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![b'R', 0, 0, 0, 8, 0, 0, 0, 123, b'n', 0, 0, 0, 4] + ); + } + + #[tokio::test] + async fn test_ssl_reject() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer.write_ssl_reject().await.unwrap(); + + assert_eq!(writer.inner.buffer(), vec![b'N']); + } +} diff --git a/proto/src/writer/errors.rs b/proto/src/writer/errors.rs new file mode 100644 index 0000000..5cc0a7b --- /dev/null +++ b/proto/src/writer/errors.rs @@ -0,0 +1,11 @@ +use crate::message::errors::ProtoSerializeError; +use thiserror::Error; +use tokio::io; + +#[derive(Debug, Error)] +pub enum ProtoWriteError { + #[error("writing to socket failed")] + Io(#[from] io::Error), + #[error("serialization of inner data failed")] + Serialize(#[from] ProtoSerializeError), +} diff --git a/proto/src/writer/frontend.rs b/proto/src/writer/frontend.rs new file mode 100644 index 0000000..4ca6c0b --- /dev/null +++ b/proto/src/writer/frontend.rs @@ -0,0 +1,128 @@ +use crate::message::frontend::FrontendMessage; +use crate::message::primitive::data::MessageData; +use crate::message::special::{CancelRequestData, StartupMessageData}; +use crate::writer::errors::ProtoWriteError; +use crate::writer::oneway::OneWayProtoWriter; +use crate::writer::protowriter::ProtoWriter; +use async_trait::async_trait; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +#[async_trait] +pub trait FrontendProtoWriter: OneWayProtoWriter { + async fn write_startup_message( + &mut self, + startup_message: StartupMessageData, + ) -> Result<(), ProtoWriteError>; + async fn write_cancel_request( + &mut self, + cancel_request: CancelRequestData, + ) -> Result<(), ProtoWriteError>; +} + +#[async_trait] +impl FrontendProtoWriter for ProtoWriter +where + W: AsyncWrite + Unpin + Send, +{ + async fn write_startup_message( + &mut self, + startup_message: StartupMessageData, + ) -> Result<(), ProtoWriteError> { + let data = startup_message.serialize()?; + let length = data.len() + 4; + + self.inner.write_i32(length as i32).await?; + self.inner.write_all(&data).await?; + + Ok(()) + } + + async fn write_cancel_request( + &mut self, + cancel_request: CancelRequestData, + ) -> Result<(), ProtoWriteError> { + let data = cancel_request.serialize()?; + let length = data.len() + 4; + + self.inner.write_i32(length as i32).await?; + self.inner.write_all(&data).await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::frontend::QueryData; + use crate::writer::protowriter::ProtoWriter; + use tokio::io::BufWriter; + + #[tokio::test] + async fn test_message_sequence() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer + .write_proto(FrontendMessage::Query(QueryData { + query: "SLIME".into(), + })) + .await + .unwrap(); + + writer + .write_proto(FrontendMessage::Terminate) + .await + .unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![b'Q', 0, 0, 0, 10, b'S', b'L', b'I', b'M', b'E', 0, b'X', 0, 0, 0, 4] + ); + } + + #[tokio::test] + async fn test_startup_message() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer + .write_startup_message(StartupMessageData { + version: 196608, + params: vec![ + ("user".into(), "postgres".into()), + ("database".into(), "postgres".into()), + ], + }) + .await + .unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![ + 0, 0, 0, 40, 0, 3, 0, 0, b'u', b's', b'e', b'r', 0, b'p', b'o', b's', b't', b'g', + b'r', b'e', b's', 0, b'd', b'a', b't', b'a', b'b', b'a', b's', b'e', 0, b'p', b'o', + b's', b't', b'g', b'r', b'e', b's', 0 + ] + ); + } + + #[tokio::test] + async fn test_cancel_request() { + let writer = BufWriter::new(Vec::new()); + let mut writer = ProtoWriter::new(writer); + + writer + .write_cancel_request(CancelRequestData { + pid: 123, + secret: 234, + }) + .await + .unwrap(); + + assert_eq!( + writer.inner.buffer(), + vec![0, 0, 0, 12, 0, 0, 0, 123, 0, 0, 0, 234] + ); + } +} diff --git a/proto/src/writer/mod.rs b/proto/src/writer/mod.rs new file mode 100644 index 0000000..651a31e --- /dev/null +++ b/proto/src/writer/mod.rs @@ -0,0 +1,5 @@ +pub mod backend; +pub mod errors; +pub mod frontend; +pub mod oneway; +pub mod protowriter; diff --git a/proto/src/writer/oneway.rs b/proto/src/writer/oneway.rs new file mode 100644 index 0000000..30d2665 --- /dev/null +++ b/proto/src/writer/oneway.rs @@ -0,0 +1,32 @@ +use crate::message::proto_message::ProtoMessage; +use crate::writer::errors::ProtoWriteError; +use crate::writer::protowriter::ProtoWriter; +use async_trait::async_trait; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +#[async_trait] +pub trait OneWayProtoWriter +where + T: ProtoMessage, +{ + async fn write_proto(&mut self, message: T) -> Result<(), ProtoWriteError>; +} + +#[async_trait] +impl OneWayProtoWriter for ProtoWriter +where + W: AsyncWrite + Unpin + Send, + T: ProtoMessage + Send + 'static, +{ + async fn write_proto(&mut self, message: T) -> Result<(), ProtoWriteError> { + let variant = message.variant(); + let mut data = message.serialize()?; + let length = data.len() as i32 + 4; + + self.inner.write_u8(variant).await?; + self.inner.write_i32(length).await?; + self.inner.write_all(&mut data).await?; + + Ok(()) + } +} diff --git a/proto/src/writer/protowriter.rs b/proto/src/writer/protowriter.rs new file mode 100644 index 0000000..27aa9e4 --- /dev/null +++ b/proto/src/writer/protowriter.rs @@ -0,0 +1,35 @@ +use async_trait::async_trait; +use tokio::io; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +pub struct ProtoWriter +where + W: AsyncWrite + Unpin + Send, +{ + pub(super) inner: W, +} + +impl ProtoWriter +where + W: AsyncWrite + Unpin + Send, +{ + pub fn new(writer: W) -> ProtoWriter { + ProtoWriter { inner: writer } + } +} + +#[async_trait] +pub trait ProtoFlush { + async fn flush(&mut self) -> Result<(), io::Error>; +} + +#[async_trait] +impl ProtoFlush for ProtoWriter +where + W: AsyncWrite + Unpin + Send, +{ + async fn flush(&mut self) -> Result<(), io::Error> { + self.inner.flush().await?; + Ok(()) + } +} diff --git a/server/Cargo.toml b/server/Cargo.toml new file mode 100644 index 0000000..bca61ec --- /dev/null +++ b/server/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "server" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio = { version = "1.35.1", features = ["full"] } +anyhow = "1.0.76" +proto = { path = "../proto" } \ No newline at end of file diff --git a/server/src/main.rs b/server/src/main.rs new file mode 100644 index 0000000..bda6dfd --- /dev/null +++ b/server/src/main.rs @@ -0,0 +1,185 @@ +use proto::handshake::response::HandshakeResponse; +use proto::handshake::server::do_server_handshake; +use proto::message::backend::{ + BackendMessage, ColumnDescription, CommandCompleteData, DataRowData, ErrorResponseData, + ReadyForQueryData, RowDescriptionData, +}; +use proto::message::frontend::FrontendMessage; +use proto::reader::oneway::OneWayProtoReader; +use proto::reader::protoreader::ProtoReader; +use proto::writer::backend::BackendProtoWriter; +use proto::writer::protowriter::{ProtoFlush, ProtoWriter}; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::{TcpListener, TcpStream}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let addr = "0.0.0.0:5432"; + let listener = TcpListener::bind(&addr).await?; + println!("Server started at {addr}"); + + loop { + let (socket, _) = listener.accept().await?; + println!("New client connected: {}", socket.peer_addr()?); + tokio::spawn(async move { + let reason = handle_stream(socket).await; + println!("Client disconnected: {reason:?}"); + }); + } +} + +async fn handle_stream(mut stream: TcpStream) -> anyhow::Result<()> { + let (reader, writer) = stream.split(); + let mut writer = ProtoWriter::new(BufWriter::new(writer)); + let mut reader = ProtoReader::new(BufReader::new(reader), 1024); + + let response = HandshakeResponse::new("minisql", 123, 123); + + let request = do_server_handshake(&mut writer, &mut reader, response).await?; + + println!("Handshake complete:\n{request:?}"); + + loop { + println!("Waiting for next message"); + let next: FrontendMessage = reader.read_proto().await?; + + match next { + FrontendMessage::Terminate => { + println!("Received Terminate"); + break; + } + FrontendMessage::Query(data) => { + println!("Received Query: {:?}", data); + if data.query.as_str().contains("car") { + println!("Sending error message"); + send_error_response(&mut writer, "Car not found").await?; + } else if data.query.as_str().to_lowercase().contains("select") { + println!("Sending table"); + send_query_response(&mut writer).await?; + } else { + println!("Sending empty query"); + send_empty_query(&mut writer).await?; + } + send_ready_for_query(&mut writer).await?; + } + } + writer.flush().await?; + } + + Ok(()) +} + +async fn send_error_response( + writer: &mut impl BackendProtoWriter, + error_message: &str, +) -> anyhow::Result<()> { + writer + .write_proto( + ErrorResponseData { + code: b'M', + message: error_message.to_string().into(), + } + .into(), + ) + .await?; + + Ok(()) +} + +async fn send_ready_for_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { + writer + .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) + .await?; + + Ok(()) +} + +async fn send_empty_query(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { + writer + .write_proto(BackendMessage::EmptyQueryResponse) + .await?; + + Ok(()) +} + +async fn send_row_description(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { + let columns = vec![ + ColumnDescription { + name: "id".to_string().into(), + table_oid: 123, + column_index: 1, + type_oid: 23, + type_size: 4, + type_modifier: -1, + format_code: 0, + }, + ColumnDescription { + name: "argument".to_string().into(), + table_oid: 123, + column_index: 2, + type_oid: 23, + type_size: 4, + type_modifier: -1, + format_code: 0, + }, + ColumnDescription { + name: "description".to_string().into(), + table_oid: 123, + column_index: 3, + type_oid: 1043, + type_size: 32, + type_modifier: -1, + format_code: 0, + }, + ]; + + writer + .write_proto( + RowDescriptionData { + columns: columns.into(), + } + .into(), + ) + .await?; + + Ok(()) +} + +async fn send_query_response(writer: &mut impl BackendProtoWriter) -> anyhow::Result<()> { + send_row_description(writer).await?; + + write_row(writer, b"0", b"1337", b"auto").await?; + write_row(writer, b"1", b"69", b"bus").await?; + write_row(writer, b"2", b"420", b"kolo").await?; + + writer + .write_proto( + CommandCompleteData { + tag: "SELECT 3".to_string().into(), + } + .into(), + ) + .await?; + + Ok(()) +} + +async fn write_row( + writer: &mut impl BackendProtoWriter, + first: &[u8], + second: &[u8], + third: &[u8], +) -> anyhow::Result<()> { + let row_data = vec![ + first.to_vec().into(), + second.to_vec().into(), + third.to_vec().into(), + ] + .into(); + + writer + .write_proto(DataRowData { columns: row_data }.into()) + .await?; + + Ok(()) +}