minisql/proto/src/message/backend.rs
2023-12-23 01:28:30 +01:00

370 lines
12 KiB
Rust

use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError};
use crate::message::primitive::data::MessageData;
use crate::message::primitive::pglist::PgList;
use crate::message::primitive::pgstring::PgString;
use crate::message::proto_message::ProtoMessage;
use bincode::{Decode, Encode};
#[derive(Debug)]
pub enum BackendMessage {
AuthenticationOk(AuthenticationOkData),
BackendKeyData(BackendKeyDataData),
CommandComplete(CommandCompleteData),
DataRow(DataRowData),
EmptyQueryResponse,
ErrorResponse(ErrorResponseData),
NoData,
ParameterStatus(ParameterStatusData),
ReadyForQuery(ReadyForQueryData),
RowDescription(RowDescriptionData),
}
impl ProtoMessage for BackendMessage {
fn variant(&self) -> u8 {
match self {
BackendMessage::AuthenticationOk(_) => b'R',
BackendMessage::BackendKeyData(_) => b'K',
BackendMessage::CommandComplete(_) => b'C',
BackendMessage::DataRow(_) => b'D',
BackendMessage::EmptyQueryResponse => b'I',
BackendMessage::ErrorResponse(_) => b'E',
BackendMessage::NoData => b'n',
BackendMessage::ParameterStatus(_) => b'S',
BackendMessage::ReadyForQuery(_) => b'Z',
BackendMessage::RowDescription(_) => b'T',
}
}
fn serialize(&self) -> Result<Vec<u8>, ProtoSerializeError> {
match self {
BackendMessage::AuthenticationOk(data) => data.serialize(),
BackendMessage::BackendKeyData(data) => data.serialize(),
BackendMessage::CommandComplete(data) => data.serialize(),
BackendMessage::DataRow(data) => data.serialize(),
BackendMessage::EmptyQueryResponse => Ok(Vec::with_capacity(0)),
BackendMessage::ErrorResponse(data) => data.serialize(),
BackendMessage::NoData => Ok(Vec::with_capacity(0)),
BackendMessage::ParameterStatus(data) => data.serialize(),
BackendMessage::ReadyForQuery(data) => data.serialize(),
BackendMessage::RowDescription(data) => data.serialize(),
}
}
fn deserialize(variant: u8, data: &[u8]) -> Result<BackendMessage, ProtoDeserializeError> {
match variant {
b'R' => Ok(BackendMessage::AuthenticationOk(
AuthenticationOkData::deserialize(data)?,
)),
b'K' => {
let data = BackendKeyDataData::deserialize(data)?;
Ok(BackendMessage::BackendKeyData(data))
}
b'C' => {
let data = CommandCompleteData::deserialize(data)?;
Ok(BackendMessage::CommandComplete(data))
}
b'D' => {
let data = DataRowData::deserialize(data)?;
Ok(BackendMessage::DataRow(data))
}
b'I' => Ok(BackendMessage::EmptyQueryResponse),
b'E' => {
let data = ErrorResponseData::deserialize(data)?;
Ok(BackendMessage::ErrorResponse(data))
}
b'n' => Ok(BackendMessage::NoData),
b'S' => {
let data = ParameterStatusData::deserialize(data)?;
Ok(BackendMessage::ParameterStatus(data))
}
b'Z' => {
let data = ReadyForQueryData::deserialize(data)?;
Ok(BackendMessage::ReadyForQuery(data))
}
b'T' => {
let data = RowDescriptionData::deserialize(data)?;
Ok(BackendMessage::RowDescription(data))
}
v => Err(ProtoDeserializeError::InvalidVariant(v)),
}
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct AuthenticationOkData {
pub status: i32,
}
impl From<AuthenticationOkData> for BackendMessage {
fn from(data: AuthenticationOkData) -> Self {
BackendMessage::AuthenticationOk(data)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct BackendKeyDataData {
pub process: i32,
pub secret: i32,
}
impl From<BackendKeyDataData> for BackendMessage {
fn from(data: BackendKeyDataData) -> Self {
BackendMessage::BackendKeyData(data)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct CommandCompleteData {
pub tag: PgString,
}
impl From<CommandCompleteData> for BackendMessage {
fn from(data: CommandCompleteData) -> Self {
BackendMessage::CommandComplete(data)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct DataRowData {
pub columns: PgList<PgList<u8, i32>, i16>,
}
impl From<DataRowData> for BackendMessage {
fn from(data: DataRowData) -> Self {
BackendMessage::DataRow(data)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct ErrorResponseData {
pub code: u8,
pub message: PgString,
}
impl From<ErrorResponseData> for BackendMessage {
fn from(data: ErrorResponseData) -> Self {
BackendMessage::ErrorResponse(data)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct ParameterStatusData {
pub name: PgString,
pub value: PgString,
}
impl From<ParameterStatusData> for BackendMessage {
fn from(data: ParameterStatusData) -> Self {
BackendMessage::ParameterStatus(data)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct ReadyForQueryData {
pub status: u8,
}
impl From<ReadyForQueryData> for BackendMessage {
fn from(data: ReadyForQueryData) -> Self {
BackendMessage::ReadyForQuery(data)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct RowDescriptionData {
pub columns: PgList<ColumnDescription, i16>,
}
impl From<RowDescriptionData> for BackendMessage {
fn from(data: RowDescriptionData) -> Self {
BackendMessage::RowDescription(data)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct ColumnDescription {
pub name: PgString,
pub table_oid: i32,
pub column_index: i16,
pub type_oid: i32,
pub type_size: i16,
pub type_modifier: i32,
pub format_code: i16,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_symmetric_authentication_ok() {
let backend = BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 });
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(
message,
BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 })
));
}
#[test]
fn test_symmetric_backend_key_data() {
let backend = BackendMessage::BackendKeyData(BackendKeyDataData {
process: 123,
secret: 456,
});
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(
message,
BackendMessage::BackendKeyData(BackendKeyDataData {
process: 123,
secret: 456
})
));
}
#[test]
fn test_symmetric_command_complete() {
let backend = BackendMessage::CommandComplete(CommandCompleteData {
tag: PgString::from("SELECT 1"),
});
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(
message,
BackendMessage::CommandComplete(CommandCompleteData { tag }) if tag.as_str() == "SELECT 1"
));
}
#[test]
fn test_symmetric_data_row() {
let backend = BackendMessage::DataRow(DataRowData {
columns: PgList::from(vec![PgList::from(vec![1, 2, 3])]),
});
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(
message,
BackendMessage::DataRow(DataRowData { columns }) if columns == PgList::from(vec![PgList::from(vec![1, 2, 3])])
));
}
#[test]
fn test_symmetric_empty_query_response() {
let backend = BackendMessage::EmptyQueryResponse;
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(message, BackendMessage::EmptyQueryResponse));
}
#[test]
fn test_symmetric_error_response() {
let backend = BackendMessage::ErrorResponse(ErrorResponseData {
code: b'X',
message: PgString::from("Some error"),
});
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(
message,
BackendMessage::ErrorResponse(ErrorResponseData { code, message }) if code == b'X' && message.as_str() == "Some error"
));
}
#[test]
fn test_symmetric_no_data() {
let backend = BackendMessage::NoData;
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(message, BackendMessage::NoData));
}
#[test]
fn test_symmetric_parameter_status() {
let backend = BackendMessage::ParameterStatus(ParameterStatusData {
name: PgString::from("Some name"),
value: PgString::from("Some value"),
});
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(
message,
BackendMessage::ParameterStatus(ParameterStatusData { name, value }) if name.as_str() == "Some name" && value.as_str() == "Some value"
));
}
#[test]
fn test_symmetric_ready_for_query() {
let backend = BackendMessage::ReadyForQuery(ReadyForQueryData { status: b'I' });
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(matches!(
message,
BackendMessage::ReadyForQuery(ReadyForQueryData { status }) if status == b'I'
));
}
#[test]
fn test_symmetric_row_description() {
let backend = BackendMessage::RowDescription(RowDescriptionData {
columns: PgList::from(vec![ColumnDescription {
name: PgString::from("Some name"),
table_oid: 123,
column_index: 456,
type_oid: 789,
type_size: 101,
type_modifier: 112,
format_code: 113,
}]),
});
let raw = backend.serialize().unwrap();
let variant = backend.variant();
let message = BackendMessage::deserialize(variant, &raw).unwrap();
assert!(match message {
BackendMessage::RowDescription(RowDescriptionData { columns }) => {
let columns: Vec<ColumnDescription> = columns.into();
let column = &columns[0];
column.name.as_str() == "Some name"
&& column.table_oid == 123
&& column.column_index == 456
&& column.type_oid == 789
&& column.type_size == 101
&& column.type_modifier == 112
&& column.format_code == 113
}
_ => false,
},)
}
#[test]
fn test_unknown_variant() {
let variant = 0;
let data = vec![1, 2, 3];
let message = BackendMessage::deserialize(variant, &data);
assert!(matches!(
message,
Err(ProtoDeserializeError::InvalidVariant(0))
));
}
}