370 lines
12 KiB
Rust
370 lines
12 KiB
Rust
use crate::message::errors::{ProtoDeserializeError, ProtoSerializeError};
|
|
use crate::message::primitive::data::MessageData;
|
|
use crate::message::primitive::pglist::PgList;
|
|
use crate::message::primitive::pgstring::PgString;
|
|
use crate::message::proto_message::ProtoMessage;
|
|
use bincode::{Decode, Encode};
|
|
|
|
#[derive(Debug)]
|
|
pub enum BackendMessage {
|
|
AuthenticationOk(AuthenticationOkData),
|
|
BackendKeyData(BackendKeyDataData),
|
|
CommandComplete(CommandCompleteData),
|
|
DataRow(DataRowData),
|
|
EmptyQueryResponse,
|
|
ErrorResponse(ErrorResponseData),
|
|
NoData,
|
|
ParameterStatus(ParameterStatusData),
|
|
ReadyForQuery(ReadyForQueryData),
|
|
RowDescription(RowDescriptionData),
|
|
}
|
|
|
|
impl ProtoMessage for BackendMessage {
|
|
fn variant(&self) -> u8 {
|
|
match self {
|
|
BackendMessage::AuthenticationOk(_) => b'R',
|
|
BackendMessage::BackendKeyData(_) => b'K',
|
|
BackendMessage::CommandComplete(_) => b'C',
|
|
BackendMessage::DataRow(_) => b'D',
|
|
BackendMessage::EmptyQueryResponse => b'I',
|
|
BackendMessage::ErrorResponse(_) => b'E',
|
|
BackendMessage::NoData => b'n',
|
|
BackendMessage::ParameterStatus(_) => b'S',
|
|
BackendMessage::ReadyForQuery(_) => b'Z',
|
|
BackendMessage::RowDescription(_) => b'T',
|
|
}
|
|
}
|
|
|
|
fn serialize(&self) -> Result<Vec<u8>, ProtoSerializeError> {
|
|
match self {
|
|
BackendMessage::AuthenticationOk(data) => data.serialize(),
|
|
BackendMessage::BackendKeyData(data) => data.serialize(),
|
|
BackendMessage::CommandComplete(data) => data.serialize(),
|
|
BackendMessage::DataRow(data) => data.serialize(),
|
|
BackendMessage::EmptyQueryResponse => Ok(Vec::with_capacity(0)),
|
|
BackendMessage::ErrorResponse(data) => data.serialize(),
|
|
BackendMessage::NoData => Ok(Vec::with_capacity(0)),
|
|
BackendMessage::ParameterStatus(data) => data.serialize(),
|
|
BackendMessage::ReadyForQuery(data) => data.serialize(),
|
|
BackendMessage::RowDescription(data) => data.serialize(),
|
|
}
|
|
}
|
|
|
|
fn deserialize(variant: u8, data: &[u8]) -> Result<BackendMessage, ProtoDeserializeError> {
|
|
match variant {
|
|
b'R' => Ok(BackendMessage::AuthenticationOk(
|
|
AuthenticationOkData::deserialize(data)?,
|
|
)),
|
|
b'K' => {
|
|
let data = BackendKeyDataData::deserialize(data)?;
|
|
Ok(BackendMessage::BackendKeyData(data))
|
|
}
|
|
b'C' => {
|
|
let data = CommandCompleteData::deserialize(data)?;
|
|
Ok(BackendMessage::CommandComplete(data))
|
|
}
|
|
b'D' => {
|
|
let data = DataRowData::deserialize(data)?;
|
|
Ok(BackendMessage::DataRow(data))
|
|
}
|
|
b'I' => Ok(BackendMessage::EmptyQueryResponse),
|
|
b'E' => {
|
|
let data = ErrorResponseData::deserialize(data)?;
|
|
Ok(BackendMessage::ErrorResponse(data))
|
|
}
|
|
b'n' => Ok(BackendMessage::NoData),
|
|
b'S' => {
|
|
let data = ParameterStatusData::deserialize(data)?;
|
|
Ok(BackendMessage::ParameterStatus(data))
|
|
}
|
|
b'Z' => {
|
|
let data = ReadyForQueryData::deserialize(data)?;
|
|
Ok(BackendMessage::ReadyForQuery(data))
|
|
}
|
|
b'T' => {
|
|
let data = RowDescriptionData::deserialize(data)?;
|
|
Ok(BackendMessage::RowDescription(data))
|
|
}
|
|
v => Err(ProtoDeserializeError::InvalidVariant(v)),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Encode, Decode)]
|
|
pub struct AuthenticationOkData {
|
|
pub status: i32,
|
|
}
|
|
|
|
impl From<AuthenticationOkData> for BackendMessage {
|
|
fn from(data: AuthenticationOkData) -> Self {
|
|
BackendMessage::AuthenticationOk(data)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Encode, Decode)]
|
|
pub struct BackendKeyDataData {
|
|
pub process: i32,
|
|
pub secret: i32,
|
|
}
|
|
|
|
impl From<BackendKeyDataData> for BackendMessage {
|
|
fn from(data: BackendKeyDataData) -> Self {
|
|
BackendMessage::BackendKeyData(data)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Encode, Decode)]
|
|
pub struct CommandCompleteData {
|
|
pub tag: PgString,
|
|
}
|
|
|
|
impl From<CommandCompleteData> for BackendMessage {
|
|
fn from(data: CommandCompleteData) -> Self {
|
|
BackendMessage::CommandComplete(data)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Encode, Decode)]
|
|
pub struct DataRowData {
|
|
pub columns: PgList<PgList<u8, i32>, i16>,
|
|
}
|
|
|
|
impl From<DataRowData> for BackendMessage {
|
|
fn from(data: DataRowData) -> Self {
|
|
BackendMessage::DataRow(data)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Encode, Decode)]
|
|
pub struct ErrorResponseData {
|
|
pub code: u8,
|
|
pub message: PgString,
|
|
}
|
|
|
|
impl From<ErrorResponseData> for BackendMessage {
|
|
fn from(data: ErrorResponseData) -> Self {
|
|
BackendMessage::ErrorResponse(data)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Encode, Decode)]
|
|
pub struct ParameterStatusData {
|
|
pub name: PgString,
|
|
pub value: PgString,
|
|
}
|
|
|
|
impl From<ParameterStatusData> for BackendMessage {
|
|
fn from(data: ParameterStatusData) -> Self {
|
|
BackendMessage::ParameterStatus(data)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Encode, Decode)]
|
|
pub struct ReadyForQueryData {
|
|
pub status: u8,
|
|
}
|
|
|
|
impl From<ReadyForQueryData> for BackendMessage {
|
|
fn from(data: ReadyForQueryData) -> Self {
|
|
BackendMessage::ReadyForQuery(data)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Encode, Decode)]
|
|
pub struct RowDescriptionData {
|
|
pub columns: PgList<ColumnDescription, i16>,
|
|
}
|
|
|
|
impl From<RowDescriptionData> for BackendMessage {
|
|
fn from(data: RowDescriptionData) -> Self {
|
|
BackendMessage::RowDescription(data)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Encode, Decode)]
|
|
pub struct ColumnDescription {
|
|
pub name: PgString,
|
|
pub table_oid: i32,
|
|
pub column_index: i16,
|
|
pub type_oid: i32,
|
|
pub type_size: i16,
|
|
pub type_modifier: i32,
|
|
pub format_code: i16,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_symmetric_authentication_ok() {
|
|
let backend = BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 });
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(matches!(
|
|
message,
|
|
BackendMessage::AuthenticationOk(AuthenticationOkData { status: 123 })
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_symmetric_backend_key_data() {
|
|
let backend = BackendMessage::BackendKeyData(BackendKeyDataData {
|
|
process: 123,
|
|
secret: 456,
|
|
});
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(matches!(
|
|
message,
|
|
BackendMessage::BackendKeyData(BackendKeyDataData {
|
|
process: 123,
|
|
secret: 456
|
|
})
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_symmetric_command_complete() {
|
|
let backend = BackendMessage::CommandComplete(CommandCompleteData {
|
|
tag: PgString::from("SELECT 1"),
|
|
});
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(matches!(
|
|
message,
|
|
BackendMessage::CommandComplete(CommandCompleteData { tag }) if tag.as_str() == "SELECT 1"
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_symmetric_data_row() {
|
|
let backend = BackendMessage::DataRow(DataRowData {
|
|
columns: PgList::from(vec![PgList::from(vec![1, 2, 3])]),
|
|
});
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(matches!(
|
|
message,
|
|
BackendMessage::DataRow(DataRowData { columns }) if columns == PgList::from(vec![PgList::from(vec![1, 2, 3])])
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_symmetric_empty_query_response() {
|
|
let backend = BackendMessage::EmptyQueryResponse;
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(matches!(message, BackendMessage::EmptyQueryResponse));
|
|
}
|
|
|
|
#[test]
|
|
fn test_symmetric_error_response() {
|
|
let backend = BackendMessage::ErrorResponse(ErrorResponseData {
|
|
code: b'X',
|
|
message: PgString::from("Some error"),
|
|
});
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(matches!(
|
|
message,
|
|
BackendMessage::ErrorResponse(ErrorResponseData { code, message }) if code == b'X' && message.as_str() == "Some error"
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_symmetric_no_data() {
|
|
let backend = BackendMessage::NoData;
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(matches!(message, BackendMessage::NoData));
|
|
}
|
|
|
|
#[test]
|
|
fn test_symmetric_parameter_status() {
|
|
let backend = BackendMessage::ParameterStatus(ParameterStatusData {
|
|
name: PgString::from("Some name"),
|
|
value: PgString::from("Some value"),
|
|
});
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(matches!(
|
|
message,
|
|
BackendMessage::ParameterStatus(ParameterStatusData { name, value }) if name.as_str() == "Some name" && value.as_str() == "Some value"
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_symmetric_ready_for_query() {
|
|
let backend = BackendMessage::ReadyForQuery(ReadyForQueryData { status: b'I' });
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(matches!(
|
|
message,
|
|
BackendMessage::ReadyForQuery(ReadyForQueryData { status }) if status == b'I'
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_symmetric_row_description() {
|
|
let backend = BackendMessage::RowDescription(RowDescriptionData {
|
|
columns: PgList::from(vec![ColumnDescription {
|
|
name: PgString::from("Some name"),
|
|
table_oid: 123,
|
|
column_index: 456,
|
|
type_oid: 789,
|
|
type_size: 101,
|
|
type_modifier: 112,
|
|
format_code: 113,
|
|
}]),
|
|
});
|
|
let raw = backend.serialize().unwrap();
|
|
let variant = backend.variant();
|
|
|
|
let message = BackendMessage::deserialize(variant, &raw).unwrap();
|
|
assert!(match message {
|
|
BackendMessage::RowDescription(RowDescriptionData { columns }) => {
|
|
let columns: Vec<ColumnDescription> = columns.into();
|
|
let column = &columns[0];
|
|
column.name.as_str() == "Some name"
|
|
&& column.table_oid == 123
|
|
&& column.column_index == 456
|
|
&& column.type_oid == 789
|
|
&& column.type_size == 101
|
|
&& column.type_modifier == 112
|
|
&& column.format_code == 113
|
|
}
|
|
_ => false,
|
|
},)
|
|
}
|
|
|
|
#[test]
|
|
fn test_unknown_variant() {
|
|
let variant = 0;
|
|
let data = vec![1, 2, 3];
|
|
|
|
let message = BackendMessage::deserialize(variant, &data);
|
|
assert!(matches!(
|
|
message,
|
|
Err(ProtoDeserializeError::InvalidVariant(0))
|
|
));
|
|
}
|
|
}
|