diff --git a/proto/src/handshake/client/shaker.rs b/proto/src/handshake/client.rs similarity index 73% rename from proto/src/handshake/client/shaker.rs rename to proto/src/handshake/client.rs index 893f0b7..9226b2a 100644 --- a/proto/src/handshake/client/shaker.rs +++ b/proto/src/handshake/client.rs @@ -1,6 +1,6 @@ -use crate::handshake::client::error::ClientHandshakeError; -use crate::handshake::client::request::ClientHandshakeRequest; -use crate::handshake::client::response::ClientHandshakeResponse; +use crate::handshake::errors::ClientHandshakeError; +use crate::handshake::request::HandshakeRequest; +use crate::handshake::response::HandshakeResponse; use crate::message::backend::{AuthenticationOkData, BackendMessage}; use crate::message::special::StartupMessageData; use crate::reader::backend::BackendProtoReader; @@ -10,8 +10,8 @@ use crate::writer::protowriter::ProtoFlush; pub async fn do_client_handshake( writer: &mut (impl FrontendProtoWriter + ProtoFlush), reader: &mut impl BackendProtoReader, - request: ClientHandshakeRequest, -) -> Result { + request: HandshakeRequest, +) -> Result { let startup_message: StartupMessageData = request.into(); writer.write_startup_message(startup_message).await?; @@ -30,5 +30,5 @@ pub async fn do_client_handshake( messages.push(msg); } - ClientHandshakeResponse::from_backend_messages(&messages) + HandshakeResponse::try_from(messages.as_slice()) } \ No newline at end of file diff --git a/proto/src/handshake/client/error.rs b/proto/src/handshake/client/error.rs deleted file mode 100644 index ab6f23c..0000000 --- a/proto/src/handshake/client/error.rs +++ /dev/null @@ -1,16 +0,0 @@ -use thiserror::Error; -use crate::message::backend::BackendMessage; -use crate::reader::errors::ProtoReadError; -use crate::writer::errors::ProtoWriteError; - -#[derive(Debug, Error)] -pub enum ClientHandshakeError { - #[error("unexpected response from server")] - UnexpectedResponse, - #[error("unexpected auth response")] - UnexpectedAuthResponse(BackendMessage), - #[error("writing message to socket failed")] - Write(#[from] ProtoWriteError), - #[error("reading message from socket failed")] - Read(#[from] ProtoReadError), -} \ No newline at end of file diff --git a/proto/src/handshake/client/mod.rs b/proto/src/handshake/client/mod.rs deleted file mode 100644 index 060973a..0000000 --- a/proto/src/handshake/client/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod request; -pub mod response; -pub mod shaker; -pub mod error; \ No newline at end of file diff --git a/proto/src/handshake/client/response.rs b/proto/src/handshake/client/response.rs deleted file mode 100644 index 777cb81..0000000 --- a/proto/src/handshake/client/response.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::handshake::client::error::ClientHandshakeError; -use crate::message::backend::BackendMessage; - -pub struct ClientHandshakeResponse { - pub version: String, - pub process_id: i32, - pub secret_key: i32, -} - -impl ClientHandshakeResponse { - pub fn from_backend_messages(message: &[BackendMessage]) -> Result { - let mut version = None; - let mut process_id = None; - let mut secret_key = None; - for message in message { - match message { - BackendMessage::ParameterStatus(data) => { - if data.name.as_str() == "server_version" { - version = Some(String::from(data.value.as_str())); - } - } - BackendMessage::BackendKeyData(data) => { - process_id = Some(data.process); - secret_key = Some(data.secret); - } - _ => {} - } - } - - match (version, process_id, secret_key) { - (Some(version), Some(process_id), Some(secret_key)) => { - Ok(Self { version, process_id, secret_key }) - } - _ => Err(ClientHandshakeError::UnexpectedResponse), - } - } -} \ No newline at end of file diff --git a/proto/src/handshake/errors.rs b/proto/src/handshake/errors.rs index 90565ac..d2bc1bd 100644 --- a/proto/src/handshake/errors.rs +++ b/proto/src/handshake/errors.rs @@ -1,14 +1,27 @@ use thiserror::Error; use tokio::io; +use crate::message::backend::BackendMessage; use crate::message::errors::ProtoDeserializeError; -use crate::reader::errors::{ProtoConsumeError, ProtoPeekError}; +use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError}; use crate::writer::errors::ProtoWriteError; +#[derive(Debug, Error)] +pub enum ClientHandshakeError { + #[error("unexpected response from server")] + UnexpectedResponse, + #[error("unexpected auth response")] + UnexpectedAuthResponse(BackendMessage), + #[error("writing message to socket failed")] + Write(#[from] ProtoWriteError), + #[error("reading message from socket failed")] + Read(#[from] ProtoReadError), +} + #[derive(Debug, Error)] pub enum ServerHandshakeError { #[error("startup message not found")] MissingStartupMessage, - #[error("reading from socket failed")] + #[error("socket communication failed")] Io(#[from] io::Error), #[error("deserialization of inner data failed")] Deserialize(#[from] ProtoDeserializeError), @@ -18,4 +31,4 @@ pub enum ServerHandshakeError { Consume(#[from] ProtoConsumeError), #[error("writing message to socket failed")] Write(#[from] ProtoWriteError), -} +} \ No newline at end of file diff --git a/proto/src/handshake/mod.rs b/proto/src/handshake/mod.rs index 0b12ab2..86f92ab 100644 --- a/proto/src/handshake/mod.rs +++ b/proto/src/handshake/mod.rs @@ -1,3 +1,5 @@ -pub mod server; -pub mod errors; +pub mod response; +pub mod request; pub mod client; +pub mod server; +pub mod errors; \ No newline at end of file diff --git a/proto/src/handshake/client/request.rs b/proto/src/handshake/request.rs similarity index 59% rename from proto/src/handshake/client/request.rs rename to proto/src/handshake/request.rs index 54998b5..ec3e123 100644 --- a/proto/src/handshake/client/request.rs +++ b/proto/src/handshake/request.rs @@ -1,12 +1,12 @@ use crate::message::primitive::pgstring::PgString; use crate::message::special::StartupMessageData; -pub struct ClientHandshakeRequest { +pub struct HandshakeRequest { version: i32, parameters: Vec<(PgString, PgString)>, } -impl ClientHandshakeRequest { +impl HandshakeRequest { pub fn new(version: i32) -> Self { Self { version, parameters: Vec::new() } } @@ -17,8 +17,14 @@ impl ClientHandshakeRequest { } } -impl From for StartupMessageData { - fn from(request: ClientHandshakeRequest) -> Self { +impl From for StartupMessageData { + fn from(request: HandshakeRequest) -> Self { Self { version: request.version, params: request.parameters } } } + +impl From for HandshakeRequest { + fn from(data: StartupMessageData) -> Self { + Self { version: data.version, parameters: data.params } + } +} \ No newline at end of file diff --git a/proto/src/handshake/response.rs b/proto/src/handshake/response.rs new file mode 100644 index 0000000..84a908f --- /dev/null +++ b/proto/src/handshake/response.rs @@ -0,0 +1,65 @@ +use crate::handshake::errors::ClientHandshakeError; +use crate::message::backend::{BackendKeyDataData, BackendMessage, ParameterStatusData}; + +pub struct HandshakeResponse { + pub version: String, + pub process_id: i32, + pub secret_key: i32, +} + +impl HandshakeResponse { + pub fn new(name: &str, pid: i32, key: i32) -> Self { + Self { + version: format!("16.0 ({name})", name = name), + process_id: pid, + secret_key: key, + } + } +} + +impl TryFrom<&[BackendMessage]> for HandshakeResponse { + type Error = ClientHandshakeError; + + fn try_from(messages: &[BackendMessage]) -> Result { + let mut version = None; + let mut process_id = None; + let mut secret_key = None; + + for message in messages { + match message { + BackendMessage::ParameterStatus(data) => { + if data.name.as_str() == "server_version" { + version = Some(String::from(data.value.as_str())); + } + } + BackendMessage::BackendKeyData(data) => { + process_id = Some(data.process); + secret_key = Some(data.secret); + } + _ => {} + } + } + + match (version, process_id, secret_key) { + (Some(version), Some(process_id), Some(secret_key)) => { + Ok(Self { version, process_id, secret_key }) + } + _ => Err(ClientHandshakeError::UnexpectedResponse), + } + } +} + +impl From<&HandshakeResponse> for Vec { + fn from(response: &HandshakeResponse) -> Self { + vec![ + BackendMessage::ParameterStatus(ParameterStatusData { + name: "server_version".into(), + value: response.version.clone().into(), + }), + BackendMessage::BackendKeyData(BackendKeyDataData { + process: response.process_id, + secret: response.secret_key, + }), + ] + } +} \ No newline at end of file diff --git a/proto/src/handshake/server.rs b/proto/src/handshake/server.rs index b998c66..127407c 100644 --- a/proto/src/handshake/server.rs +++ b/proto/src/handshake/server.rs @@ -1,6 +1,8 @@ use crate::handshake::errors::ServerHandshakeError; +use crate::handshake::request::HandshakeRequest; +use crate::handshake::response::HandshakeResponse; use crate::message::backend::{ - AuthenticationOkData, BackendKeyDataData, BackendMessage, ParameterStatusData, + AuthenticationOkData, BackendMessage, ReadyForQueryData, }; use crate::message::special::{SpecialMessage, StartupMessageData}; @@ -11,10 +13,8 @@ use crate::writer::protowriter::ProtoFlush; pub async fn do_server_handshake( writer: &mut (impl BackendProtoWriter + ProtoFlush), reader: &mut impl FrontendProtoReader, - name: &str, - process: i32, - secret: i32, -) -> Result { + response: &HandshakeResponse, +) -> Result { match &reader.peek_special_message().await? { Some(msg @ SpecialMessage::SSLRequest) => { reader.consume_special_message(msg).await?; @@ -40,21 +40,15 @@ pub async fn do_server_handshake( .write_proto(BackendMessage::from(AuthenticationOkData { status: 0 })) .await?; - writer - .write_proto(BackendMessage::from(ParameterStatusData { - name: "server_version".to_string().into(), - value: format!("16.0 ({name})").into(), - })) - .await?; - - writer - .write_proto(BackendMessage::from(BackendKeyDataData { process, secret })) - .await?; + let messages: Vec = response.into(); + for message in messages { + writer.write_proto(message).await?; + } writer .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) .await?; writer.flush().await?; - Ok(startup_message) + Ok(startup_message.into()) }