refactor(proto): reuse code in handshakes

This commit is contained in:
Jindřich Moravec 2023-12-23 00:52:53 +01:00
parent 7b2dce4dfb
commit c1744711d3
9 changed files with 111 additions and 88 deletions

View file

@ -1,6 +1,6 @@
use crate::handshake::client::error::ClientHandshakeError; use crate::handshake::errors::ClientHandshakeError;
use crate::handshake::client::request::ClientHandshakeRequest; use crate::handshake::request::HandshakeRequest;
use crate::handshake::client::response::ClientHandshakeResponse; use crate::handshake::response::HandshakeResponse;
use crate::message::backend::{AuthenticationOkData, BackendMessage}; use crate::message::backend::{AuthenticationOkData, BackendMessage};
use crate::message::special::StartupMessageData; use crate::message::special::StartupMessageData;
use crate::reader::backend::BackendProtoReader; use crate::reader::backend::BackendProtoReader;
@ -10,8 +10,8 @@ use crate::writer::protowriter::ProtoFlush;
pub async fn do_client_handshake( pub async fn do_client_handshake(
writer: &mut (impl FrontendProtoWriter + ProtoFlush), writer: &mut (impl FrontendProtoWriter + ProtoFlush),
reader: &mut impl BackendProtoReader, reader: &mut impl BackendProtoReader,
request: ClientHandshakeRequest, request: HandshakeRequest,
) -> Result<ClientHandshakeResponse, ClientHandshakeError> { ) -> Result<HandshakeResponse, ClientHandshakeError> {
let startup_message: StartupMessageData = request.into(); let startup_message: StartupMessageData = request.into();
writer.write_startup_message(startup_message).await?; writer.write_startup_message(startup_message).await?;
@ -30,5 +30,5 @@ pub async fn do_client_handshake(
messages.push(msg); messages.push(msg);
} }
ClientHandshakeResponse::from_backend_messages(&messages) HandshakeResponse::try_from(messages.as_slice())
} }

View file

@ -1,16 +0,0 @@
use thiserror::Error;
use crate::message::backend::BackendMessage;
use crate::reader::errors::ProtoReadError;
use crate::writer::errors::ProtoWriteError;
#[derive(Debug, Error)]
pub enum ClientHandshakeError {
#[error("unexpected response from server")]
UnexpectedResponse,
#[error("unexpected auth response")]
UnexpectedAuthResponse(BackendMessage),
#[error("writing message to socket failed")]
Write(#[from] ProtoWriteError),
#[error("reading message from socket failed")]
Read(#[from] ProtoReadError),
}

View file

@ -1,4 +0,0 @@
pub mod request;
pub mod response;
pub mod shaker;
pub mod error;

View file

@ -1,37 +0,0 @@
use crate::handshake::client::error::ClientHandshakeError;
use crate::message::backend::BackendMessage;
pub struct ClientHandshakeResponse {
pub version: String,
pub process_id: i32,
pub secret_key: i32,
}
impl ClientHandshakeResponse {
pub fn from_backend_messages(message: &[BackendMessage]) -> Result<Self, ClientHandshakeError> {
let mut version = None;
let mut process_id = None;
let mut secret_key = None;
for message in message {
match message {
BackendMessage::ParameterStatus(data) => {
if data.name.as_str() == "server_version" {
version = Some(String::from(data.value.as_str()));
}
}
BackendMessage::BackendKeyData(data) => {
process_id = Some(data.process);
secret_key = Some(data.secret);
}
_ => {}
}
}
match (version, process_id, secret_key) {
(Some(version), Some(process_id), Some(secret_key)) => {
Ok(Self { version, process_id, secret_key })
}
_ => Err(ClientHandshakeError::UnexpectedResponse),
}
}
}

View file

@ -1,14 +1,27 @@
use thiserror::Error; use thiserror::Error;
use tokio::io; use tokio::io;
use crate::message::backend::BackendMessage;
use crate::message::errors::ProtoDeserializeError; use crate::message::errors::ProtoDeserializeError;
use crate::reader::errors::{ProtoConsumeError, ProtoPeekError}; use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError};
use crate::writer::errors::ProtoWriteError; use crate::writer::errors::ProtoWriteError;
#[derive(Debug, Error)]
pub enum ClientHandshakeError {
#[error("unexpected response from server")]
UnexpectedResponse,
#[error("unexpected auth response")]
UnexpectedAuthResponse(BackendMessage),
#[error("writing message to socket failed")]
Write(#[from] ProtoWriteError),
#[error("reading message from socket failed")]
Read(#[from] ProtoReadError),
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ServerHandshakeError { pub enum ServerHandshakeError {
#[error("startup message not found")] #[error("startup message not found")]
MissingStartupMessage, MissingStartupMessage,
#[error("reading from socket failed")] #[error("socket communication failed")]
Io(#[from] io::Error), Io(#[from] io::Error),
#[error("deserialization of inner data failed")] #[error("deserialization of inner data failed")]
Deserialize(#[from] ProtoDeserializeError), Deserialize(#[from] ProtoDeserializeError),

View file

@ -1,3 +1,5 @@
pub mod response;
pub mod request;
pub mod client;
pub mod server; pub mod server;
pub mod errors; pub mod errors;
pub mod client;

View file

@ -1,12 +1,12 @@
use crate::message::primitive::pgstring::PgString; use crate::message::primitive::pgstring::PgString;
use crate::message::special::StartupMessageData; use crate::message::special::StartupMessageData;
pub struct ClientHandshakeRequest { pub struct HandshakeRequest {
version: i32, version: i32,
parameters: Vec<(PgString, PgString)>, parameters: Vec<(PgString, PgString)>,
} }
impl ClientHandshakeRequest { impl HandshakeRequest {
pub fn new(version: i32) -> Self { pub fn new(version: i32) -> Self {
Self { version, parameters: Vec::new() } Self { version, parameters: Vec::new() }
} }
@ -17,8 +17,14 @@ impl ClientHandshakeRequest {
} }
} }
impl From<ClientHandshakeRequest> for StartupMessageData { impl From<HandshakeRequest> for StartupMessageData {
fn from(request: ClientHandshakeRequest) -> Self { fn from(request: HandshakeRequest) -> Self {
Self { version: request.version, params: request.parameters } Self { version: request.version, params: request.parameters }
} }
} }
impl From<StartupMessageData> for HandshakeRequest {
fn from(data: StartupMessageData) -> Self {
Self { version: data.version, parameters: data.params }
}
}

View file

@ -0,0 +1,65 @@
use crate::handshake::errors::ClientHandshakeError;
use crate::message::backend::{BackendKeyDataData, BackendMessage, ParameterStatusData};
pub struct HandshakeResponse {
pub version: String,
pub process_id: i32,
pub secret_key: i32,
}
impl HandshakeResponse {
pub fn new(name: &str, pid: i32, key: i32) -> Self {
Self {
version: format!("16.0 ({name})", name = name),
process_id: pid,
secret_key: key,
}
}
}
impl TryFrom<&[BackendMessage]> for HandshakeResponse {
type Error = ClientHandshakeError;
fn try_from(messages: &[BackendMessage]) -> Result<Self, Self::Error> {
let mut version = None;
let mut process_id = None;
let mut secret_key = None;
for message in messages {
match message {
BackendMessage::ParameterStatus(data) => {
if data.name.as_str() == "server_version" {
version = Some(String::from(data.value.as_str()));
}
}
BackendMessage::BackendKeyData(data) => {
process_id = Some(data.process);
secret_key = Some(data.secret);
}
_ => {}
}
}
match (version, process_id, secret_key) {
(Some(version), Some(process_id), Some(secret_key)) => {
Ok(Self { version, process_id, secret_key })
}
_ => Err(ClientHandshakeError::UnexpectedResponse),
}
}
}
impl From<&HandshakeResponse> for Vec<BackendMessage> {
fn from(response: &HandshakeResponse) -> Self {
vec![
BackendMessage::ParameterStatus(ParameterStatusData {
name: "server_version".into(),
value: response.version.clone().into(),
}),
BackendMessage::BackendKeyData(BackendKeyDataData {
process: response.process_id,
secret: response.secret_key,
}),
]
}
}

View file

@ -1,6 +1,8 @@
use crate::handshake::errors::ServerHandshakeError; use crate::handshake::errors::ServerHandshakeError;
use crate::handshake::request::HandshakeRequest;
use crate::handshake::response::HandshakeResponse;
use crate::message::backend::{ use crate::message::backend::{
AuthenticationOkData, BackendKeyDataData, BackendMessage, ParameterStatusData, AuthenticationOkData, BackendMessage,
ReadyForQueryData, ReadyForQueryData,
}; };
use crate::message::special::{SpecialMessage, StartupMessageData}; use crate::message::special::{SpecialMessage, StartupMessageData};
@ -11,10 +13,8 @@ use crate::writer::protowriter::ProtoFlush;
pub async fn do_server_handshake( pub async fn do_server_handshake(
writer: &mut (impl BackendProtoWriter + ProtoFlush), writer: &mut (impl BackendProtoWriter + ProtoFlush),
reader: &mut impl FrontendProtoReader, reader: &mut impl FrontendProtoReader,
name: &str, response: &HandshakeResponse,
process: i32, ) -> Result<HandshakeRequest, ServerHandshakeError> {
secret: i32,
) -> Result<StartupMessageData, ServerHandshakeError> {
match &reader.peek_special_message().await? { match &reader.peek_special_message().await? {
Some(msg @ SpecialMessage::SSLRequest) => { Some(msg @ SpecialMessage::SSLRequest) => {
reader.consume_special_message(msg).await?; reader.consume_special_message(msg).await?;
@ -40,21 +40,15 @@ pub async fn do_server_handshake(
.write_proto(BackendMessage::from(AuthenticationOkData { status: 0 })) .write_proto(BackendMessage::from(AuthenticationOkData { status: 0 }))
.await?; .await?;
writer let messages: Vec<BackendMessage> = response.into();
.write_proto(BackendMessage::from(ParameterStatusData { for message in messages {
name: "server_version".to_string().into(), writer.write_proto(message).await?;
value: format!("16.0 ({name})").into(), }
}))
.await?;
writer
.write_proto(BackendMessage::from(BackendKeyDataData { process, secret }))
.await?;
writer writer
.write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' })) .write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' }))
.await?; .await?;
writer.flush().await?; writer.flush().await?;
Ok(startup_message) Ok(startup_message.into())
} }