use std::time::Duration;
use azure_iot_operations_protocol::common::payload_serialize::{
DeserializationError, FormatIndicator, PayloadSerialize, SerializedPayload,
};
#[derive(Clone, Debug)]
pub(crate) enum Request {
Set {
key: Vec<u8>,
value: Vec<u8>,
options: SetOptions,
},
Get {
key: Vec<u8>,
},
Del {
key: Vec<u8>,
},
VDel {
key: Vec<u8>,
value: Vec<u8>,
},
KeyNotify {
key: Vec<u8>,
options: KeyNotifyOptions,
},
}
#[derive(Clone, Debug, Default)]
pub struct SetOptions {
pub set_condition: SetCondition,
pub expires: Option<Duration>,
}
#[derive(Clone, Debug, Default)]
pub enum SetCondition {
OnlyIfDoesNotExist,
OnlyIfEqualOrDoesNotExist,
#[default]
Unconditional,
}
#[derive(Clone, Debug, Default)]
pub(crate) struct KeyNotifyOptions {
pub stop: bool,
}
impl PayloadSerialize for Request {
type Error = String;
fn serialize(self) -> Result<SerializedPayload, String> {
Ok(SerializedPayload {
payload: match self {
Request::Set {
key,
value,
options,
} => serialize_set(&key, &value, &options),
Request::Get { key } => serialize_get(&key),
Request::KeyNotify { key, options } => serialize_key_notify(&key, &options),
Request::Del { key } => serialize_del(&key),
Request::VDel { key, value } => serialize_v_del(&key, &value),
},
content_type: "application/octet-stream".to_string(),
format_indicator: FormatIndicator::UnspecifiedBytes,
})
}
fn deserialize(
_payload: &[u8],
_content_type: Option<&String>,
_format_indicator: &FormatIndicator,
) -> Result<Self, DeserializationError<String>> {
Err(DeserializationError::InvalidPayload(
"Not implemented".into(),
))
}
}
struct RequestBufferBuilder {
buffer: Vec<u8>,
}
impl RequestBufferBuilder {
fn new() -> Self {
RequestBufferBuilder { buffer: Vec::new() }
}
fn get_buffer(self) -> Vec<u8> {
self.buffer
}
fn append_array_number(&mut self, num_elements: u32) {
self.buffer
.extend(format!("*{num_elements}\r\n").as_bytes());
}
fn append_argument(&mut self, arg: &[u8]) {
self.buffer.extend(format!("${}\r\n", arg.len()).as_bytes());
self.buffer.extend(arg);
self.buffer.extend(b"\r\n");
}
}
fn get_number_additional_arguments(options: &SetOptions) -> u32 {
let mut additional_arguments: u32 = 0;
match options.set_condition {
SetCondition::OnlyIfEqualOrDoesNotExist | SetCondition::OnlyIfDoesNotExist => {
additional_arguments += 1;
}
SetCondition::Unconditional => (),
}
if options.expires.is_some() {
additional_arguments += 2;
}
additional_arguments
}
fn serialize_set(key: &[u8], value: &[u8], options: &SetOptions) -> Vec<u8> {
let mut builder = RequestBufferBuilder::new();
let mut num_arguments = 3;
num_arguments += get_number_additional_arguments(options);
builder.append_array_number(num_arguments);
builder.append_argument(b"SET");
builder.append_argument(key);
builder.append_argument(value);
match options.set_condition {
SetCondition::OnlyIfDoesNotExist => builder.append_argument(b"NX"),
SetCondition::OnlyIfEqualOrDoesNotExist => builder.append_argument(b"NEX"),
SetCondition::Unconditional => (),
}
if let Some(expires) = options.expires {
builder.append_argument(b"PX");
builder.append_argument(expires.as_millis().to_string().as_bytes());
}
builder.get_buffer()
}
fn serialize_get(key: &[u8]) -> Vec<u8> {
let mut builder = RequestBufferBuilder::new();
builder.append_array_number(2);
builder.append_argument(b"GET");
builder.append_argument(key);
builder.get_buffer()
}
fn serialize_del(key: &[u8]) -> Vec<u8> {
let mut builder = RequestBufferBuilder::new();
builder.append_array_number(2);
builder.append_argument(b"DEL");
builder.append_argument(key);
builder.get_buffer()
}
fn serialize_v_del(key: &[u8], value: &[u8]) -> Vec<u8> {
let mut builder = RequestBufferBuilder::new();
builder.append_array_number(3);
builder.append_argument(b"VDEL");
builder.append_argument(key);
builder.append_argument(value);
builder.get_buffer()
}
fn serialize_key_notify(key: &[u8], options: &KeyNotifyOptions) -> Vec<u8> {
let mut num_arguments = 2;
let mut builder = RequestBufferBuilder::new();
if options.stop {
num_arguments += 1;
}
builder.append_array_number(num_arguments);
builder.append_argument(b"KEYNOTIFY");
builder.append_argument(key);
if options.stop {
builder.append_argument(b"STOP");
}
builder.get_buffer()
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum Response {
Ok,
Value(Vec<u8>),
ValuesDeleted(i64),
NotApplied,
NotFound,
Error(Vec<u8>),
}
impl Response {
const RESPONSE_OK: &'static [u8] = b"+OK\r\n";
const RESPONSE_ERROR_PREFIX: &'static [u8] = b"-ERR ";
const RESPONSE_SUFFIX: &'static [u8] = b"\r\n";
const GET_RESPONSE_NOT_FOUND: &'static [u8] = b"$-1\r\n";
const RESPONSE_NOT_APPLIED: &'static [u8] = b":-1\r\n";
const RESPONSE_KEY_NOT_FOUND: &'static [u8] = b":0\r\n";
const RESPONSE_LENGTH_PREFIX: &'static [u8] = b"$";
const DELETE_RESPONSE_PREFIX: &'static [u8] = b":";
fn parse_error(payload: &[u8]) -> Result<Vec<u8>, String> {
if let Some(err) = payload.strip_prefix(Self::RESPONSE_ERROR_PREFIX) {
if let Some(err_msg) = err.strip_suffix(Self::RESPONSE_SUFFIX) {
return Ok(err_msg.to_vec());
}
}
Err(format!("Invalid error response: {payload:?}"))
}
}
impl PayloadSerialize for Response {
type Error = String;
fn serialize(self) -> Result<SerializedPayload, String> {
Err("Not implemented".into())
}
fn deserialize(
payload: &[u8],
content_type: Option<&String>,
_format_indicator: &FormatIndicator,
) -> Result<Self, DeserializationError<String>> {
if let Some(content_type) = content_type {
if content_type != "application/octet-stream" {
return Err(DeserializationError::UnsupportedContentType(format!(
"Invalid content type: '{content_type:?}'. Must be 'application/octet-stream'"
)));
}
}
match payload {
Self::RESPONSE_OK => Ok(Response::Ok),
Self::GET_RESPONSE_NOT_FOUND | Self::RESPONSE_KEY_NOT_FOUND => Ok(Response::NotFound),
Self::RESPONSE_NOT_APPLIED => Ok(Response::NotApplied),
_ if payload.starts_with(Self::RESPONSE_ERROR_PREFIX) => {
Ok(Response::Error(Self::parse_error(payload)?))
}
_ if payload.starts_with(Self::RESPONSE_LENGTH_PREFIX) => Ok(Response::Value(
parse_value(payload, Self::RESPONSE_LENGTH_PREFIX)?,
)),
_ if payload.starts_with(Self::DELETE_RESPONSE_PREFIX) => {
match parse_numeric(payload, Self::DELETE_RESPONSE_PREFIX)?.try_into() {
Ok(n) => Ok(Response::ValuesDeleted(n)),
Err(e) => Err(DeserializationError::InvalidPayload(format!(
"Error parsing number of keys deleted: {e}. Payload: {payload:?}"
))),
}
}
_ => Err(DeserializationError::InvalidPayload(format!(
"Unknown response: {payload:?}"
))),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum Operation {
Set(Vec<u8>),
Del,
}
impl Operation {
const OPERATION_DELETE: &'static [u8] = b"*2\r\n$6\r\nNOTIFY\r\n$6\r\nDELETE\r\n";
const SET_WITH_VALUE_PREFIX: &'static [u8] =
b"*4\r\n$6\r\nNOTIFY\r\n$3\r\nSET\r\n$5\r\nVALUE\r\n$";
}
impl PayloadSerialize for Operation {
type Error = String;
fn serialize(self) -> Result<SerializedPayload, String> {
Err("Not implemented".into())
}
fn deserialize(
payload: &[u8],
content_type: Option<&String>,
_format_indicator: &FormatIndicator,
) -> Result<Self, DeserializationError<String>> {
if let Some(content_type) = content_type {
if content_type != "application/octet-stream" {
return Err(DeserializationError::UnsupportedContentType(format!(
"Invalid content type: '{content_type:?}'. Must be 'application/octet-stream'"
)));
}
}
match payload {
Operation::OPERATION_DELETE => Ok(Operation::Del),
_ if payload.starts_with(Operation::SET_WITH_VALUE_PREFIX) => Ok(Operation::Set(
parse_value(payload, Operation::SET_WITH_VALUE_PREFIX)?,
)),
_ => Err(DeserializationError::InvalidPayload(format!(
"Unknown operation: {payload:?}"
))),
}
}
}
const RESPONSE_SUFFIX: &[u8] = b"\r\n";
fn parse_numeric(payload: &[u8], prefix: &[u8]) -> Result<usize, String> {
if let Some(val) = payload.strip_prefix(prefix) {
let (num_deleted, current_index) = get_numeric(val)?;
if current_index + 2 == val.len() && val[current_index + 1] == b'\n' {
return Ok(num_deleted);
}
}
Err(format!("Invalid numeric response: {payload:?}"))
}
fn get_numeric(payload: &[u8]) -> Result<(usize, usize), String> {
let mut value_len: usize = 0;
let mut current_index: usize = 0;
for byte in &payload[0..] {
match byte {
b'\r' => {
break;
}
b'0'..=b'9' => {
let value = usize::from(byte - b'0');
match value_len.checked_mul(10) {
Some(v) => value_len = v,
None => {
return Err(format!(
"Multiplication overflow while parsing value length: {payload:?}"
));
}
}
match value_len.checked_add(value) {
Some(v) => value_len = v,
None => {
return Err(format!(
"Addition overflow while parsing value length: {payload:?}"
));
}
}
}
_ => {
return Err(format!("Invalid value length format: {payload:?}"));
}
}
current_index += 1;
}
Ok((value_len, current_index))
}
fn parse_value(payload: &[u8], prefix: &[u8]) -> Result<Vec<u8>, String> {
if let Some(stripped_payload) = payload.strip_prefix(prefix) {
let (value_len, mut current_index) = get_numeric(stripped_payload)?;
current_index += 1; if current_index == stripped_payload.len() || stripped_payload[current_index] != b'\n' {
return Err(format!("Invalid format: {payload:?}"));
}
current_index += 1;
if current_index + value_len + 2 != stripped_payload.len() {
return Err(format!(
"Value length does not match actual value length: {payload:?}"
));
}
let closing_bytes =
&stripped_payload[(stripped_payload.len() - 2)..(stripped_payload.len())];
if closing_bytes != RESPONSE_SUFFIX {
return Err(format!("Invalid format: {payload:?}"));
}
Ok(stripped_payload[current_index..current_index + value_len].to_vec())
} else {
Err(format!(
"Invalid payload, must start with {prefix:?}: {payload:?}"
))
}
}
#[cfg(test)]
mod tests {
use test_case::test_case;
use super::*;
#[test_case(b"+OK\r\n", &Response::Ok; "test_set_response")]
#[test_case(b":-1\r\n", &Response::NotApplied; "test_did_not_set_response")]
#[test_case(b"$4\r\n1234\r\n", &Response::Value(b"1234".to_vec()); "test_get_response_success")]
#[test_case(b"$0\r\n\r\n", &Response::Value(b"".to_vec()); "test_get_response_empty_success")]
#[test_case(b"$-1\r\n", &Response::NotFound; "test_get_response_no_key")]
#[test_case(b":1\r\n", &Response::ValuesDeleted(1); "test_del_response")] #[test_case(b":-1\r\n", &Response::NotApplied; "test_vdel_no_match_response")]
#[test_case(b":6\r\n", &Response::ValuesDeleted(6); "test_del_multiple_response")] #[test_case(b":0\r\n", &Response::NotFound; "test_del_no_key")] #[test_case(b"-ERR syntax error\r\n", &Response::Error(b"syntax error".to_vec()); "test_error_response")]
#[test_case(b"-ERR \r\n", &Response::Error(b"".to_vec()); "test_empty_error_response_success")]
fn test_response_deserialization_success(payload: &[u8], expected: &Response) {
assert_eq!(
Response::deserialize(
payload,
Some(&"application/octet-stream".to_string()),
&FormatIndicator::UnspecifiedBytes
)
.unwrap(),
expected.clone()
);
}
#[test]
fn test_response_deserialization_no_content_type_success() {
assert_eq!(
Response::deserialize(b"+OK\r\n", None, &FormatIndicator::UnspecifiedBytes).unwrap(),
Response::Ok
);
}
#[test_case(b"1"; "too short")]
#[test_case(b"11\r\nhello world\r\n"; "no $ on get response")]
#[test_case(b"$11hello world\r\n"; "missing first newline")]
#[test_case(b"$11\r\nhello world"; "missing second newline")]
#[test_case(b"$not an integer\r\nhello world"; "length not an integer")]
#[test_case(b"$11\r\nthis string is longer than 11 characters\r\n"; "length not accurate")]
#[test_case(b"-ERR\r\n"; "Malformed error")]
#[test_case(b"ERR description\r\n"; "Error missing minus")]
#[test_case(b"-ERR description"; "Error missing newline")]
#[test_case(b":"; "Delete response too short")]
#[test_case(b"1234\r\n"; "Delete response doesn't start with colon")]
#[test_case(b":1234"; "Delete response doesn't end with newline")]
#[test_case(b":not an integer\r\n"; "Delete response value not integer")]
#[test_case(b"+hello world\r\n"; "Incorrect OK value")]
#[test_case(b"+"; "OK response too short")]
#[test_case(b"OK\r\n"; "OK response doesn't start with plus sign")]
#[test_case(b"+OK"; "OK response doesn't end with newline")]
fn test_response_deserialization_failures(payload: &[u8]) {
assert!(
Response::deserialize(
payload,
Some(&"application/octet-stream".to_string()),
&FormatIndicator::UnspecifiedBytes
)
.is_err()
);
}
#[test]
fn test_response_deserialization_content_type_failure() {
assert!(
Response::deserialize(
b"+OK\r\n",
Some(&"application/json".to_string()),
&FormatIndicator::UnspecifiedBytes
)
.is_err()
);
}
#[test]
fn test_parse_number() {
assert_eq!(
parse_numeric(b":1234\r\n", Response::DELETE_RESPONSE_PREFIX).unwrap(),
1234
);
}
#[test_case(SetOptions::default(),
b"*3\r\n$3\r\nSET\r\n$7\r\ntestkey\r\n$9\r\ntestvalue\r\n";
"default")]
#[test_case(SetOptions {set_condition: SetCondition::OnlyIfDoesNotExist, ..Default::default()},
b"*4\r\n$3\r\nSET\r\n$7\r\ntestkey\r\n$9\r\ntestvalue\r\n$2\r\nNX\r\n";
"OnlyIfDoesNotExist")]
#[test_case(SetOptions {set_condition: SetCondition::OnlyIfEqualOrDoesNotExist, ..Default::default()},
b"*4\r\n$3\r\nSET\r\n$7\r\ntestkey\r\n$9\r\ntestvalue\r\n$3\r\nNEX\r\n";
"OnlyIfEqualOrDoesNotExist")]
#[test_case(SetOptions {expires: Some(Duration::from_millis(10)), ..Default::default()},
b"*5\r\n$3\r\nSET\r\n$7\r\ntestkey\r\n$9\r\ntestvalue\r\n$2\r\nPX\r\n$2\r\n10\r\n";
"expires set")]
fn test_serialize_set_options(set_options: SetOptions, expected: &[u8]) {
assert_eq!(
Request::serialize(Request::Set {
key: b"testkey".to_vec(),
value: b"testvalue".to_vec(),
options: set_options
})
.unwrap(),
SerializedPayload {
payload: expected.to_vec(),
content_type: "application/octet-stream".to_string(),
format_indicator: FormatIndicator::UnspecifiedBytes,
}
);
}
#[test]
fn test_serialize_empty_set() {
assert_eq!(
Request::serialize(Request::Set {
key: b"".to_vec(),
value: b"".to_vec(),
options: SetOptions::default()
})
.unwrap(),
SerializedPayload {
payload: b"*3\r\n$3\r\nSET\r\n$0\r\n\r\n$0\r\n\r\n".to_vec(),
content_type: "application/octet-stream".to_string(),
format_indicator: FormatIndicator::UnspecifiedBytes,
}
);
}
#[test]
fn test_serialize_get() {
assert_eq!(
Request::serialize(Request::Get {
key: b"testkey".to_vec()
})
.unwrap(),
SerializedPayload {
payload: b"*2\r\n$3\r\nGET\r\n$7\r\ntestkey\r\n".to_vec(),
content_type: "application/octet-stream".to_string(),
format_indicator: FormatIndicator::UnspecifiedBytes,
}
);
}
#[test]
fn test_serialize_del() {
assert_eq!(
Request::serialize(Request::Del {
key: b"testkey".to_vec()
})
.unwrap(),
SerializedPayload {
payload: b"*2\r\n$3\r\nDEL\r\n$7\r\ntestkey\r\n".to_vec(),
content_type: "application/octet-stream".to_string(),
format_indicator: FormatIndicator::UnspecifiedBytes,
}
);
}
#[test]
fn test_serialize_vdel() {
assert_eq!(
Request::serialize(Request::VDel {
key: b"testkey".to_vec(),
value: b"testvalue".to_vec()
})
.unwrap(),
SerializedPayload {
payload: b"*3\r\n$4\r\nVDEL\r\n$7\r\ntestkey\r\n$9\r\ntestvalue\r\n".to_vec(),
content_type: "application/octet-stream".to_string(),
format_indicator: FormatIndicator::UnspecifiedBytes,
}
);
}
}