use std::{fmt, fs, time::Duration};
use async_trait::async_trait;
use bytes::Bytes;
use openssl::{pkey::PKey, x509::X509};
use rumqttc::{self, TlsConfiguration, Transport, tokio_native_tls::native_tls};
use thiserror::Error;
use crate::connection_settings::MqttConnectionSettings;
use crate::control_packet::{
AuthProperties, Publish, PublishProperties, QoS, SubscribeProperties, UnsubscribeProperties,
};
use crate::error::{
AckError, AckErrorKind, ConnectionError, DisconnectError, DisconnectErrorKind, PublishError,
PublishErrorKind, ReauthError, ReauthErrorKind, SubscribeError, SubscribeErrorKind,
UnsubscribeError, UnsubscribeErrorKind,
};
use crate::interface::{
CompletionToken, Event, MqttAck, MqttClient, MqttDisconnect, MqttEventLoop, MqttPubSub,
};
use crate::topic::{TopicFilter, TopicName};
pub type ClientAlias = rumqttc::v5::AsyncClient;
pub type EventLoopAlias = rumqttc::v5::EventLoop;
impl From<rumqttc::v5::ClientError> for PublishError {
fn from(err: rumqttc::v5::ClientError) -> Self {
match err {
rumqttc::v5::ClientError::Request(_) | rumqttc::v5::ClientError::TryRequest(_) => {
PublishError::new(PublishErrorKind::DetachedClient)
}
}
}
}
impl From<rumqttc::v5::ClientError> for SubscribeError {
fn from(err: rumqttc::v5::ClientError) -> Self {
match err {
rumqttc::v5::ClientError::Request(_) | rumqttc::v5::ClientError::TryRequest(_) => {
SubscribeError::new(SubscribeErrorKind::DetachedClient)
}
}
}
}
impl From<rumqttc::v5::ClientError> for UnsubscribeError {
fn from(err: rumqttc::v5::ClientError) -> Self {
match err {
rumqttc::v5::ClientError::Request(_) | rumqttc::v5::ClientError::TryRequest(_) => {
UnsubscribeError::new(UnsubscribeErrorKind::DetachedClient)
}
}
}
}
impl From<rumqttc::v5::ClientError> for AckError {
fn from(err: rumqttc::v5::ClientError) -> Self {
match err {
rumqttc::v5::ClientError::Request(_) | rumqttc::v5::ClientError::TryRequest(_) => {
AckError::new(AckErrorKind::DetachedClient)
}
}
}
}
impl From<rumqttc::v5::ClientError> for DisconnectError {
fn from(err: rumqttc::v5::ClientError) -> Self {
match err {
rumqttc::v5::ClientError::Request(_) | rumqttc::v5::ClientError::TryRequest(_) => {
DisconnectError::new(DisconnectErrorKind::DetachedClient)
}
}
}
}
impl From<rumqttc::v5::ClientError> for ReauthError {
fn from(err: rumqttc::v5::ClientError) -> Self {
match err {
rumqttc::v5::ClientError::Request(_) | rumqttc::v5::ClientError::TryRequest(_) => {
ReauthError::new(ReauthErrorKind::DetachedClient)
}
}
}
}
#[async_trait]
impl MqttPubSub for rumqttc::v5::AsyncClient {
async fn publish(
&self,
topic: impl Into<String> + Send,
qos: QoS,
retain: bool,
payload: impl Into<Bytes> + Send,
) -> Result<CompletionToken, PublishError> {
let topic = topic.into();
if !TopicName::is_valid_topic_name(&topic) {
return Err(PublishError::new(PublishErrorKind::InvalidTopicName));
}
let nf = self.publish(topic, qos, retain, payload).await?;
Ok(CompletionToken(Box::new(nf.wait_async())))
}
async fn publish_with_properties(
&self,
topic: impl Into<String> + Send,
qos: QoS,
retain: bool,
payload: impl Into<Bytes> + Send,
properties: PublishProperties,
) -> Result<CompletionToken, PublishError> {
let topic = topic.into();
if !TopicName::is_valid_topic_name(&topic) {
return Err(PublishError::new(PublishErrorKind::InvalidTopicName));
}
let nf = self
.publish_with_properties(topic, qos, retain, payload, properties)
.await?;
Ok(CompletionToken(Box::new(nf.wait_async())))
}
async fn subscribe(
&self,
topic: impl Into<String> + Send,
qos: QoS,
) -> Result<CompletionToken, SubscribeError> {
let topic = topic.into();
if !TopicFilter::is_valid_topic_filter(&topic) {
return Err(SubscribeError::new(SubscribeErrorKind::InvalidTopicFilter));
}
let nf = self.subscribe(topic, qos).await?;
Ok(CompletionToken(Box::new(nf.wait_async())))
}
async fn subscribe_with_properties(
&self,
topic: impl Into<String> + Send,
qos: QoS,
properties: SubscribeProperties,
) -> Result<CompletionToken, SubscribeError> {
let topic = topic.into();
if !TopicFilter::is_valid_topic_filter(&topic) {
return Err(SubscribeError::new(SubscribeErrorKind::InvalidTopicFilter));
}
let nf = self
.subscribe_with_properties(topic, qos, properties)
.await?;
Ok(CompletionToken(Box::new(nf.wait_async())))
}
async fn unsubscribe(
&self,
topic: impl Into<String> + Send,
) -> Result<CompletionToken, UnsubscribeError> {
let topic = topic.into();
if !TopicFilter::is_valid_topic_filter(&topic) {
return Err(UnsubscribeError::new(
UnsubscribeErrorKind::InvalidTopicFilter,
));
}
let nf = self.unsubscribe(topic).await?;
Ok(CompletionToken(Box::new(nf.wait_async())))
}
async fn unsubscribe_with_properties(
&self,
topic: impl Into<String> + Send,
properties: UnsubscribeProperties,
) -> Result<CompletionToken, UnsubscribeError> {
let topic = topic.into();
if !TopicFilter::is_valid_topic_filter(&topic) {
return Err(UnsubscribeError::new(
UnsubscribeErrorKind::InvalidTopicFilter,
));
}
let nf = self.unsubscribe_with_properties(topic, properties).await?;
Ok(CompletionToken(Box::new(nf.wait_async())))
}
}
#[async_trait]
impl MqttAck for rumqttc::v5::AsyncClient {
async fn ack(&self, publish: &Publish) -> Result<CompletionToken, AckError> {
let mut manual_ack = self.get_manual_ack(publish);
manual_ack.set_reason(rumqttc::v5::ManualAckReason::Success);
self.manual_ack(manual_ack).await?;
Ok(CompletionToken(Box::new(async { Ok(()) })))
}
}
#[async_trait]
impl MqttClient for rumqttc::v5::AsyncClient {
async fn reauth(&self, auth_props: AuthProperties) -> Result<(), ReauthError> {
Ok(self.reauth(Some(auth_props)).await?)
}
}
#[async_trait]
impl MqttDisconnect for rumqttc::v5::AsyncClient {
async fn disconnect(&self) -> Result<(), DisconnectError> {
Ok(self.disconnect().await?)
}
}
#[async_trait]
impl MqttEventLoop for rumqttc::v5::EventLoop {
async fn poll(&mut self) -> Result<Event, ConnectionError> {
self.poll().await
}
fn set_clean_start(&mut self, clean_start: bool) {
self.options.set_clean_start(clean_start);
}
fn set_authentication_method(&mut self, authentication_method: Option<String>) {
self.options
.set_authentication_method(authentication_method);
}
fn set_authentication_data(&mut self, authentication_data: Option<Bytes>) {
self.options.set_authentication_data(authentication_data);
}
}
pub fn client(
connection_settings: MqttConnectionSettings,
channel_capacity: usize,
manual_ack: bool,
connection_user_properties: Vec<(String, String)>,
) -> Result<(rumqttc::v5::AsyncClient, rumqttc::v5::EventLoop), MqttAdapterError> {
if channel_capacity == usize::MAX {
return Err(MqttAdapterError::Other(
"rumqttc does not support channel capacity of usize::MAX".to_string(),
));
}
let mut mqtt_options: rumqttc::v5::MqttOptions = connection_settings.try_into()?;
mqtt_options.set_manual_acks(manual_ack);
let mut existing_props = mqtt_options.user_properties();
existing_props.extend(connection_user_properties);
mqtt_options.set_user_properties(existing_props);
Ok(rumqttc::v5::AsyncClient::new(
mqtt_options,
channel_capacity,
))
}
#[derive(Error, Debug)]
pub enum MqttAdapterError {
#[error(transparent)]
ConnectionSettings(#[from] ConnectionSettingsAdapterError),
#[error("Other adapter error: {0}")]
Other(String),
}
#[derive(Error, Debug)]
#[error("{msg}: {field}")]
pub struct ConnectionSettingsAdapterError {
msg: String,
field: ConnectionSettingsField,
#[source]
source: Option<Box<dyn std::error::Error>>,
}
#[derive(Debug)]
pub enum ConnectionSettingsField {
SessionExpiry(Duration),
PasswordFile(String),
UseTls(bool),
SatAuthFile(String),
}
impl fmt::Display for ConnectionSettingsField {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConnectionSettingsField::SessionExpiry(v) => write!(f, "Session Expiry: {v:?}"),
ConnectionSettingsField::PasswordFile(v) => write!(f, "Password File: {v:?}"),
ConnectionSettingsField::UseTls(v) => write!(f, "Use TLS: {v:?}"),
ConnectionSettingsField::SatAuthFile(v) => write!(f, "SAT Auth File: {v:?}"),
}
}
}
#[derive(Error, Debug)]
#[error("{msg}")]
pub struct TlsError {
msg: String,
source: Option<anyhow::Error>,
}
impl TlsError {
pub fn new(msg: &str) -> Self {
TlsError {
msg: msg.to_string(),
source: None,
}
}
}
impl TryFrom<MqttConnectionSettings> for rumqttc::v5::MqttOptions {
type Error = ConnectionSettingsAdapterError;
fn try_from(value: MqttConnectionSettings) -> Result<Self, Self::Error> {
let mut mqtt_options =
rumqttc::v5::MqttOptions::new(value.client_id.clone(), value.hostname, value.tcp_port);
mqtt_options.set_keep_alive(value.keep_alive);
mqtt_options.set_receive_maximum(Some(value.receive_max));
mqtt_options.set_max_packet_size(value.receive_packet_size_max.or(Some(u32::MAX)));
match value.session_expiry.as_secs().try_into() {
Ok(se) => {
if se < 5 {
return Err(ConnectionSettingsAdapterError {
msg: "require > 5 seconds".to_string(),
field: ConnectionSettingsField::SessionExpiry(value.session_expiry),
source: None,
});
}
mqtt_options.set_session_expiry_interval(Some(se));
}
Err(e) => {
return Err(ConnectionSettingsAdapterError {
msg: "cannot convert to u32".to_string(),
field: ConnectionSettingsField::SessionExpiry(value.session_expiry),
source: Some(Box::new(e)),
});
}
};
mqtt_options.set_connection_timeout(value.connection_timeout.as_secs());
mqtt_options.set_clean_start(value.clean_start);
if let Some(username) = value.username {
let password = {
if let Some(password_file) = value.password_file {
match fs::read_to_string(&password_file) {
Ok(password) => password,
Err(e) => {
return Err(ConnectionSettingsAdapterError {
msg: "cannot read password file".to_string(),
field: ConnectionSettingsField::PasswordFile(password_file),
source: Some(Box::new(e)),
});
}
}
} else {
value.password.unwrap_or_default()
}
};
mqtt_options.set_credentials(username, password);
}
if value.use_tls {
let transport = tls_config(
value.ca_file,
value.cert_file,
value.key_file,
value.key_password_file,
)
.map_err(|e| ConnectionSettingsAdapterError {
msg: "tls config error".to_string(),
field: ConnectionSettingsField::UseTls(true),
source: Some(Box::new(TlsError {
msg: e.to_string(),
source: Some(e),
})),
})?;
mqtt_options.set_transport(transport);
}
if let Some(sat_file) = value.sat_file {
mqtt_options.set_authentication_method(Some("K8S-SAT".to_string()));
let sat_auth =
fs::read(sat_file.clone()).map_err(|e| ConnectionSettingsAdapterError {
msg: "cannot read sat auth file".to_string(),
field: ConnectionSettingsField::SatAuthFile(sat_file),
source: Some(Box::new(e)),
})?;
mqtt_options.set_authentication_data(Some(sat_auth.into()));
}
Ok(mqtt_options)
}
}
fn read_root_ca_certs(ca_file: String) -> Result<Vec<native_tls::Certificate>, anyhow::Error> {
let mut ca_certs = Vec::new();
let ca_pem = fs::read(ca_file)?;
let certs = &mut X509::stack_from_pem(&ca_pem)?;
ca_certs.append(certs);
if ca_certs.is_empty() {
Err(TlsError::new("No CA certs available in CA File"))?;
}
ca_certs.sort();
ca_certs.dedup();
Ok(ca_certs
.iter()
.map(|cert| {
native_tls::Certificate::from_pem(&cert.to_pem().expect("cert should serialize to PEM"))
.expect("Failed to deserialize cert")
})
.collect())
}
fn tls_config(
ca_file: Option<String>,
cert_file: Option<String>,
key_file: Option<String>,
key_password_file: Option<String>,
) -> Result<Transport, anyhow::Error> {
let mut tls_connector_builder = native_tls::TlsConnector::builder();
tls_connector_builder.min_protocol_version(Some(native_tls::Protocol::Tlsv12));
if let Some(ca_file) = ca_file {
let ca_certs = read_root_ca_certs(ca_file)?;
for ca_cert in ca_certs {
tls_connector_builder.add_root_certificate(ca_cert);
}
}
if let (Some(cert_file), Some(key_file)) = (cert_file, key_file) {
let cert_file_contents = fs::read(cert_file)?;
let client_cert_chain = X509::stack_from_pem(&cert_file_contents)?;
let mut client_cert_chain_pem = Vec::new();
for cert in client_cert_chain {
let mut cert_pem = cert.to_pem()?;
client_cert_chain_pem.append(&mut cert_pem);
}
let private_key_pem = {
let key_file_contents = fs::read(key_file)?;
if let Some(key_password_file) = key_password_file {
let key_password_file_contents = fs::read(key_password_file)?;
let private_key = PKey::private_key_from_pem_passphrase(
&key_file_contents,
&key_password_file_contents,
)?;
private_key.private_key_to_pem_pkcs8()?
} else {
let private_key = PKey::private_key_from_pem(&key_file_contents)?;
private_key.private_key_to_pem_pkcs8()?
}
};
let identity = native_tls::Identity::from_pkcs8(&client_cert_chain_pem, &private_key_pem)
.map_err(|err| {
TlsError::new(&format!("Failed to build TLS client identity: {err}"))
})?;
tls_connector_builder.identity(identity);
}
let tls_connector = tls_connector_builder
.build()
.map_err(|err| TlsError::new(&format!("Failed to build TLS connector: {err}")))?;
Ok(Transport::Tls(TlsConfiguration::NativeConnector(
tls_connector,
)))
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::*;
use crate::MqttConnectionSettingsBuilder;
#[test]
fn test_mqtt_connection_settings_no_tls() {
let connection_settings = MqttConnectionSettingsBuilder::default()
.client_id("test_client_id".to_string())
.hostname("test_host".to_string())
.use_tls(false)
.build()
.unwrap();
let mqtt_options_result: Result<rumqttc::v5::MqttOptions, ConnectionSettingsAdapterError> =
connection_settings.try_into();
assert!(mqtt_options_result.is_ok());
}
#[test]
fn test_mqtt_connection_settings_username() {
let connection_settings = MqttConnectionSettingsBuilder::default()
.client_id("test_client_id".to_string())
.hostname("test_host".to_string())
.use_tls(false)
.username("test_username".to_string())
.password("test_password".to_string())
.build()
.unwrap();
let mqtt_options_result: Result<rumqttc::v5::MqttOptions, ConnectionSettingsAdapterError> =
connection_settings.try_into();
assert!(mqtt_options_result.is_ok());
let connection_settings = MqttConnectionSettingsBuilder::default()
.client_id("test_client_id".to_string())
.hostname("test_host".to_string())
.use_tls(false)
.username("test_username".to_string())
.build()
.unwrap();
let mqtt_options_result: Result<rumqttc::v5::MqttOptions, ConnectionSettingsAdapterError> =
connection_settings.try_into();
assert!(mqtt_options_result.is_ok());
let mut password_file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
password_file_path.push("../../eng/test/dummy_credentials/TestMqttPasswordFile.txt");
let connection_settings = MqttConnectionSettingsBuilder::default()
.client_id("test_client_id".to_string())
.hostname("test_host".to_string())
.use_tls(false)
.username("test_username".to_string())
.password_file(password_file_path.into_os_string().into_string().unwrap())
.build()
.unwrap();
let mqtt_options_result: Result<rumqttc::v5::MqttOptions, ConnectionSettingsAdapterError> =
connection_settings.try_into();
assert!(mqtt_options_result.is_ok());
}
#[test]
fn test_mqtt_connection_settings_ca_file() {
let mut ca_file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
ca_file_path.push("../../eng/test/dummy_credentials/TestCa.txt");
let connection_settings = MqttConnectionSettingsBuilder::default()
.client_id("test_client_id".to_string())
.hostname("test_host".to_string())
.ca_file(ca_file_path.into_os_string().into_string().unwrap())
.build()
.unwrap();
let mqtt_options_result: Result<rumqttc::v5::MqttOptions, ConnectionSettingsAdapterError> =
connection_settings.try_into();
assert!(mqtt_options_result.is_ok());
}
#[test]
fn test_mqtt_connection_settings_ca_file_plus_cert() {
let mut dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
dir.push("../../eng/test/dummy_credentials/");
let ca_file = dir.join("TestCa.txt");
let cert_file = dir.join("TestCert1Pem.txt");
let key_file = dir.join("TestCert1Key.txt");
let connection_settings = MqttConnectionSettingsBuilder::default()
.client_id("test_client_id".to_string())
.hostname("test_host".to_string())
.ca_file(ca_file.into_os_string().into_string().unwrap())
.cert_file(cert_file.into_os_string().into_string().unwrap())
.key_file(key_file.into_os_string().into_string().unwrap())
.build()
.unwrap();
let mqtt_options_result: Result<rumqttc::v5::MqttOptions, ConnectionSettingsAdapterError> =
connection_settings.try_into();
assert!(mqtt_options_result.is_ok());
}
#[test]
fn test_mqtt_connection_settings_cert() {
let mut dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
dir.push("../../eng/test/dummy_credentials/");
let cert_file = dir.join("TestCert1Pem.txt");
let key_file = dir.join("TestCert1Key.txt");
let connection_settings = MqttConnectionSettingsBuilder::default()
.client_id("test_client_id".to_string())
.hostname("test_host".to_string())
.cert_file(cert_file.into_os_string().into_string().unwrap())
.key_file(key_file.into_os_string().into_string().unwrap())
.build()
.unwrap();
let mqtt_options_result: Result<rumqttc::v5::MqttOptions, ConnectionSettingsAdapterError> =
connection_settings.try_into();
assert!(mqtt_options_result.is_ok());
}
#[test]
fn test_mqtt_connection_settings_cert_key_file_password() {
let mut dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
dir.push("../../eng/test/dummy_credentials/");
let cert_file = dir.join("TestCert2Pem.txt");
let key_file = dir.join("TestCert2KeyEncrypted.txt");
let key_password_file = dir.join("TestCert2KeyPasswordFile.txt");
let connection_settings = MqttConnectionSettingsBuilder::default()
.client_id("test_client_id".to_string())
.hostname("test_host".to_string())
.cert_file(cert_file.into_os_string().into_string().unwrap())
.key_file(key_file.into_os_string().into_string().unwrap())
.key_password_file(key_password_file.into_os_string().into_string().unwrap())
.build()
.unwrap();
let mqtt_options_result: Result<rumqttc::v5::MqttOptions, ConnectionSettingsAdapterError> =
connection_settings.try_into();
assert!(mqtt_options_result.is_ok());
}
#[test]
fn test_receive_packet_size_max_override_none() {
let connection_settings = MqttConnectionSettingsBuilder::default()
.client_id("test_client_id".to_string())
.hostname("test_host".to_string())
.receive_packet_size_max(None)
.build()
.unwrap();
let mqtt_options_result: Result<rumqttc::v5::MqttOptions, ConnectionSettingsAdapterError> =
connection_settings.try_into();
assert!(mqtt_options_result.is_ok());
assert_eq!(
mqtt_options_result.unwrap().max_packet_size(),
Some(u32::MAX)
);
}
}