refactor(proto): reuse code in handshakes

This commit is contained in:
Jindřich Moravec 2023-12-23 00:52:53 +01:00
parent 7b2dce4dfb
commit c1744711d3
9 changed files with 111 additions and 88 deletions

View file

@ -1,6 +1,6 @@
use crate::handshake::client::error::ClientHandshakeError;
use crate::handshake::client::request::ClientHandshakeRequest;
use crate::handshake::client::response::ClientHandshakeResponse;
use crate::handshake::errors::ClientHandshakeError;
use crate::handshake::request::HandshakeRequest;
use crate::handshake::response::HandshakeResponse;
use crate::message::backend::{AuthenticationOkData, BackendMessage};
use crate::message::special::StartupMessageData;
use crate::reader::backend::BackendProtoReader;
@ -10,8 +10,8 @@ use crate::writer::protowriter::ProtoFlush;
pub async fn do_client_handshake(
writer: &mut (impl FrontendProtoWriter + ProtoFlush),
reader: &mut impl BackendProtoReader,
request: ClientHandshakeRequest,
) -> Result<ClientHandshakeResponse, ClientHandshakeError> {
request: HandshakeRequest,
) -> Result<HandshakeResponse, ClientHandshakeError> {
let startup_message: StartupMessageData = request.into();
writer.write_startup_message(startup_message).await?;
@ -30,5 +30,5 @@ pub async fn do_client_handshake(
messages.push(msg);
}
ClientHandshakeResponse::from_backend_messages(&messages)
HandshakeResponse::try_from(messages.as_slice())
}

View file

@ -1,16 +0,0 @@
use thiserror::Error;
use crate::message::backend::BackendMessage;
use crate::reader::errors::ProtoReadError;
use crate::writer::errors::ProtoWriteError;
#[derive(Debug, Error)]
pub enum ClientHandshakeError {
#[error("unexpected response from server")]
UnexpectedResponse,
#[error("unexpected auth response")]
UnexpectedAuthResponse(BackendMessage),
#[error("writing message to socket failed")]
Write(#[from] ProtoWriteError),
#[error("reading message from socket failed")]
Read(#[from] ProtoReadError),
}

View file

@ -1,4 +0,0 @@
pub mod request;
pub mod response;
pub mod shaker;
pub mod error;

View file

@ -1,37 +0,0 @@
use crate::handshake::client::error::ClientHandshakeError;
use crate::message::backend::BackendMessage;
pub struct ClientHandshakeResponse {
pub version: String,
pub process_id: i32,
pub secret_key: i32,
}
impl ClientHandshakeResponse {
pub fn from_backend_messages(message: &[BackendMessage]) -> Result<Self, ClientHandshakeError> {
let mut version = None;
let mut process_id = None;
let mut secret_key = None;
for message in message {
match message {
BackendMessage::ParameterStatus(data) => {
if data.name.as_str() == "server_version" {
version = Some(String::from(data.value.as_str()));
}
}
BackendMessage::BackendKeyData(data) => {
process_id = Some(data.process);
secret_key = Some(data.secret);
}
_ => {}
}
}
match (version, process_id, secret_key) {
(Some(version), Some(process_id), Some(secret_key)) => {
Ok(Self { version, process_id, secret_key })
}
_ => Err(ClientHandshakeError::UnexpectedResponse),
}
}
}

View file

@ -1,14 +1,27 @@
use thiserror::Error;
use tokio::io;
use crate::message::backend::BackendMessage;
use crate::message::errors::ProtoDeserializeError;
use crate::reader::errors::{ProtoConsumeError, ProtoPeekError};
use crate::reader::errors::{ProtoConsumeError, ProtoPeekError, ProtoReadError};
use crate::writer::errors::ProtoWriteError;
#[derive(Debug, Error)]
pub enum ClientHandshakeError {
#[error("unexpected response from server")]
UnexpectedResponse,
#[error("unexpected auth response")]
UnexpectedAuthResponse(BackendMessage),
#[error("writing message to socket failed")]
Write(#[from] ProtoWriteError),
#[error("reading message from socket failed")]
Read(#[from] ProtoReadError),
}
#[derive(Debug, Error)]
pub enum ServerHandshakeError {
#[error("startup message not found")]
MissingStartupMessage,
#[error("reading from socket failed")]
#[error("socket communication failed")]
Io(#[from] io::Error),
#[error("deserialization of inner data failed")]
Deserialize(#[from] ProtoDeserializeError),
@ -18,4 +31,4 @@ pub enum ServerHandshakeError {
Consume(#[from] ProtoConsumeError),
#[error("writing message to socket failed")]
Write(#[from] ProtoWriteError),
}
}

View file

@ -1,3 +1,5 @@
pub mod server;
pub mod errors;
pub mod response;
pub mod request;
pub mod client;
pub mod server;
pub mod errors;

View file

@ -1,12 +1,12 @@
use crate::message::primitive::pgstring::PgString;
use crate::message::special::StartupMessageData;
pub struct ClientHandshakeRequest {
pub struct HandshakeRequest {
version: i32,
parameters: Vec<(PgString, PgString)>,
}
impl ClientHandshakeRequest {
impl HandshakeRequest {
pub fn new(version: i32) -> Self {
Self { version, parameters: Vec::new() }
}
@ -17,8 +17,14 @@ impl ClientHandshakeRequest {
}
}
impl From<ClientHandshakeRequest> for StartupMessageData {
fn from(request: ClientHandshakeRequest) -> Self {
impl From<HandshakeRequest> for StartupMessageData {
fn from(request: HandshakeRequest) -> Self {
Self { version: request.version, params: request.parameters }
}
}
impl From<StartupMessageData> for HandshakeRequest {
fn from(data: StartupMessageData) -> Self {
Self { version: data.version, parameters: data.params }
}
}

View file

@ -0,0 +1,65 @@
use crate::handshake::errors::ClientHandshakeError;
use crate::message::backend::{BackendKeyDataData, BackendMessage, ParameterStatusData};
pub struct HandshakeResponse {
pub version: String,
pub process_id: i32,
pub secret_key: i32,
}
impl HandshakeResponse {
pub fn new(name: &str, pid: i32, key: i32) -> Self {
Self {
version: format!("16.0 ({name})", name = name),
process_id: pid,
secret_key: key,
}
}
}
impl TryFrom<&[BackendMessage]> for HandshakeResponse {
type Error = ClientHandshakeError;
fn try_from(messages: &[BackendMessage]) -> Result<Self, Self::Error> {
let mut version = None;
let mut process_id = None;
let mut secret_key = None;
for message in messages {
match message {
BackendMessage::ParameterStatus(data) => {
if data.name.as_str() == "server_version" {
version = Some(String::from(data.value.as_str()));
}
}
BackendMessage::BackendKeyData(data) => {
process_id = Some(data.process);
secret_key = Some(data.secret);
}
_ => {}
}
}
match (version, process_id, secret_key) {
(Some(version), Some(process_id), Some(secret_key)) => {
Ok(Self { version, process_id, secret_key })
}
_ => Err(ClientHandshakeError::UnexpectedResponse),
}
}
}
impl From<&HandshakeResponse> for Vec<BackendMessage> {
fn from(response: &HandshakeResponse) -> Self {
vec![
BackendMessage::ParameterStatus(ParameterStatusData {
name: "server_version".into(),
value: response.version.clone().into(),
}),
BackendMessage::BackendKeyData(BackendKeyDataData {
process: response.process_id,
secret: response.secret_key,
}),
]
}
}

View file

@ -1,6 +1,8 @@
use crate::handshake::errors::ServerHandshakeError;
use crate::handshake::request::HandshakeRequest;
use crate::handshake::response::HandshakeResponse;
use crate::message::backend::{
AuthenticationOkData, BackendKeyDataData, BackendMessage, ParameterStatusData,
AuthenticationOkData, BackendMessage,
ReadyForQueryData,
};
use crate::message::special::{SpecialMessage, StartupMessageData};
@ -11,10 +13,8 @@ use crate::writer::protowriter::ProtoFlush;
pub async fn do_server_handshake(
writer: &mut (impl BackendProtoWriter + ProtoFlush),
reader: &mut impl FrontendProtoReader,
name: &str,
process: i32,
secret: i32,
) -> Result<StartupMessageData, ServerHandshakeError> {
response: &HandshakeResponse,
) -> Result<HandshakeRequest, ServerHandshakeError> {
match &reader.peek_special_message().await? {
Some(msg @ SpecialMessage::SSLRequest) => {
reader.consume_special_message(msg).await?;
@ -40,21 +40,15 @@ pub async fn do_server_handshake(
.write_proto(BackendMessage::from(AuthenticationOkData { status: 0 }))
.await?;
writer
.write_proto(BackendMessage::from(ParameterStatusData {
name: "server_version".to_string().into(),
value: format!("16.0 ({name})").into(),
}))
.await?;
writer
.write_proto(BackendMessage::from(BackendKeyDataData { process, secret }))
.await?;
let messages: Vec<BackendMessage> = response.into();
for message in messages {
writer.write_proto(message).await?;
}
writer
.write_proto(BackendMessage::from(ReadyForQueryData { status: b'I' }))
.await?;
writer.flush().await?;
Ok(startup_message)
Ok(startup_message.into())
}