1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use rumqttc::v5::mqttbytes::v5::Packet;
6use rumqttc::v5::mqttbytes::QoS;
7use rumqttc::v5::{AsyncClient, Event, EventLoop, MqttOptions};
8use rumqttc::Outgoing;
9use rumqttc::Transport;
10use serde::{Deserialize, Serialize};
11use tokio::sync::{watch, Mutex};
12use tokio::time::sleep;
13use tracing::{error, info};
14use utoipa::ToSchema;
15
16use crate::cutil::generator::rand_string;
17use crate::cutil::meta::R;
18
19#[async_trait]
20pub trait MessageBroker: Send + Sync {
21 async fn subscribe(&self, names: Vec<String>, qos: Qos) -> R<()>;
22
23 async fn unsubscribe(&self, names: Vec<String>) -> R<()>;
24
25 async fn listen(&self, handler: Arc<dyn Fn(Message) -> R<()> + Send + Sync>) -> R<()>;
26
27 async fn publish(&self, message: Message) -> R<()>;
28
29 async fn shutdown(&self) -> R<()>;
30}
31
32#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
33pub struct Message {
34 pub name: String,
35 pub body: String,
36 pub qos: Qos,
37 pub retain: bool,
38}
39
40impl Default for Message {
41 fn default() -> Self {
42 Self {
43 name: "".to_string(),
44 body: "".to_string(),
45 qos: Qos::AtMostOnce,
46 retain: false,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Serialize, Deserialize, ToSchema)]
52pub enum Qos {
53 AtMostOnce = 0,
54 AtLeastOnce = 1,
55 ExactlyOnce = 2,
56}
57
58#[derive(Clone, Debug)]
59pub struct MessageBrokerOptions {
60 pub host: String,
61 pub port: u16,
62 pub username: String,
63 pub password: String,
64 pub client_id: String,
65 pub keep_alive: u64,
66 pub max_reconnect_delay: u64,
67}
68
69impl Default for MessageBrokerOptions {
70 fn default() -> Self {
71 Self {
72 host: String::default(),
73 port: 1883,
74 username: String::default(),
75 password: String::default(),
76 client_id: rand_string(16),
77 keep_alive: 60,
78 max_reconnect_delay: 300,
79 }
80 }
81}
82
83pub struct MessageBrokerImpl {
84 client: Arc<Mutex<AsyncClient>>,
85 eventloop: Arc<Mutex<EventLoop>>,
86 shutdown: watch::Receiver<bool>,
87 shutdown_tx: watch::Sender<bool>,
88 options: MessageBrokerOptions,
89}
90
91impl MessageBrokerImpl {
92 pub fn new(options: MessageBrokerOptions) -> R<MessageBrokerImpl> {
93 let (client, eventloop) = Self::create_mqtt_client(&options)?;
94 let (shutdown_tx, shutdown) = watch::channel(false);
95
96 Ok(MessageBrokerImpl {
97 client: Arc::new(Mutex::new(client)),
98 eventloop: Arc::new(Mutex::new(eventloop)),
99 shutdown,
100 shutdown_tx,
101 options,
102 })
103 }
104
105 fn create_mqtt_client(options: &MessageBrokerOptions) -> R<(AsyncClient, EventLoop)> {
106 let mut mqttoptions = MqttOptions::new(&options.client_id, &options.host, options.port);
107 mqttoptions.set_credentials(&options.username, &options.password);
108 mqttoptions.set_keep_alive(Duration::from_secs(options.keep_alive));
109 mqttoptions.set_clean_start(true);
110 mqttoptions.set_transport(Transport::Tcp);
111
112 Ok(AsyncClient::new(mqttoptions, 10))
113 }
114
115 async fn handle_reconnect(&self, reconnect_delay: u64) -> R<()> {
116 error!("Attempting to reconnect in {} seconds...", reconnect_delay);
117 sleep(Duration::from_secs(reconnect_delay)).await;
118
119 let (new_client, new_eventloop) = Self::create_mqtt_client(&self.options).unwrap();
120 *self.client.lock().await = new_client;
121 *self.eventloop.lock().await = new_eventloop;
122 Ok(())
123 }
124}
125
126#[async_trait]
127impl MessageBroker for MessageBrokerImpl {
128 async fn subscribe(&self, names: Vec<String>, qos: Qos) -> R<()> {
129 for name in &names {
130 self
131 .client
132 .lock()
133 .await
134 .subscribe(
135 name,
136 match qos {
137 Qos::AtMostOnce => QoS::AtMostOnce,
138 Qos::AtLeastOnce => QoS::AtLeastOnce,
139 Qos::ExactlyOnce => QoS::ExactlyOnce,
140 },
141 )
142 .await?;
143 }
144 Ok(())
145 }
146
147 async fn unsubscribe(&self, names: Vec<String>) -> R<()> {
148 for name in &names {
149 self.client.lock().await.unsubscribe(name).await?;
150 }
151 Ok(())
152 }
153
154 async fn listen(&self, handler: Arc<dyn Fn(Message) -> R<()> + Send + Sync>) -> R<()> {
155 let mut reconnect_delay = 1;
156 let mut shutdown_rx = self.shutdown.clone();
157
158 loop {
159 let mut eventloop = self.eventloop.lock().await;
160
161 tokio::select! {
162 Ok(event) = eventloop.poll() => match event {
163 Event::Incoming(Packet::Publish(msg)) => {
164 let message = Message {
165 name: String::from_utf8(msg.topic.to_vec()).unwrap_or("".to_string()),
166 qos: match msg.qos {
167 QoS::AtMostOnce => Qos::AtMostOnce,
168 QoS::AtLeastOnce => Qos::AtLeastOnce,
169 QoS::ExactlyOnce => Qos::ExactlyOnce,
170 },
171 retain: msg.retain,
172 body: String::from_utf8(msg.payload.to_vec()).unwrap_or("".to_string()),
173 };
174 if let Err(e) = handler(message) {
175 error!("Handler error: {}", e);
176 }
177 reconnect_delay = 1;
178 }
179 Event::Outgoing(Outgoing::Disconnect) => {
180 drop(eventloop);
181 self.handle_reconnect(reconnect_delay).await?;
182 reconnect_delay = (reconnect_delay * 2).min(self.options.max_reconnect_delay);
183 }
184 _ => {}
185 },
186 _ = shutdown_rx.changed() => break,
187 }
188 }
189
190 info!("MQTT listener shutdown complete");
191 Ok(())
192 }
193
194 async fn shutdown(&self) -> R<()> {
195 self.shutdown_tx.send(true)?;
196 tokio::time::sleep(Duration::from_millis(100)).await;
197 self.client.lock().await.disconnect().await?;
198 Ok(())
199 }
200
201 async fn publish(&self, message: Message) -> R<()> {
202 self
203 .client
204 .lock()
205 .await
206 .publish(
207 message.name,
208 match message.qos {
209 Qos::AtMostOnce => QoS::AtMostOnce,
210 Qos::AtLeastOnce => QoS::AtLeastOnce,
211 Qos::ExactlyOnce => QoS::ExactlyOnce,
212 },
213 message.retain,
214 message.body,
215 )
216 .await?;
217 Ok(())
218 }
219}