general_mq/mqtt/
connection.rs

1use std::{
2    collections::HashMap,
3    error::Error as StdError,
4    str::FromStr,
5    sync::{Arc, Mutex},
6    time::Duration,
7};
8
9use async_trait::async_trait;
10use regex::Regex;
11use rumqttc::{
12    AsyncClient as RumqttConnection, ClientError, Event as RumqttEvent,
13    MqttOptions as RumqttOption, NetworkOptions, Packet, Publish, TlsConfiguration, Transport,
14};
15use tokio::{
16    task::{self, JoinHandle},
17    time,
18};
19
20use super::uri::{MQTTScheme, MQTTUri};
21use crate::{
22    connection::{EventHandler, GmqConnection, Status},
23    randomstring, ID_SIZE,
24};
25
26/// Manages a MQTT connection.
27#[derive(Clone)]
28pub struct MqttConnection {
29    /// Options of the connection.
30    opts: InnerOptions,
31    /// Connection status.
32    status: Arc<Mutex<Status>>,
33    /// Hold the connection instance.
34    conn: Arc<Mutex<Option<RumqttConnection>>>,
35    /// Event handlers.
36    handlers: Arc<Mutex<HashMap<String, Arc<dyn EventHandler>>>>,
37    /// Publish packet handlers. The key is **the queue name**.
38    ///
39    /// Because MQTT is connection-driven, the receiver [`crate::MqttQueue`] queues must register a
40    /// handler to receive Publish packets.
41    packet_handlers: Arc<Mutex<HashMap<String, Arc<dyn PacketHandler>>>>,
42    /// The event loop to manage and monitor the connection instance.
43    ev_loop: Arc<Mutex<Option<JoinHandle<()>>>>,
44}
45
46/// The connection options.
47pub struct MqttConnectionOptions {
48    /// Connection URI. Use `mqtt|mqtts://username:password@host:port` format.
49    ///
50    /// Default is `mqtt://localhost`.
51    pub uri: String,
52    /// Connection timeout in milliseconds.
53    ///
54    /// Default or zero value is `3000`.
55    pub connect_timeout_millis: u64,
56    /// Time in milliseconds from disconnection to reconnection.
57    ///
58    /// Default or zero value is `1000`.
59    pub reconnect_millis: u64,
60    /// Client identifier. Use `None` to generate a random client identifier.
61    pub client_id: Option<String>,
62    /// Clean session flag.
63    ///
64    /// **Note**: this is not stable.
65    pub clean_session: bool,
66}
67
68/// Packet handler definitions.
69pub(super) trait PacketHandler: Send + Sync {
70    /// For **Publish** packets.
71    fn on_publish(&self, packet: Publish);
72}
73
74/// The validated options for management.
75#[derive(Clone)]
76struct InnerOptions {
77    /// The formatted URI resource.
78    uri: MQTTUri,
79    /// Connection timeout in milliseconds.
80    connect_timeout_millis: u64,
81    /// Time in milliseconds from disconnection to reconnection.
82    reconnect_millis: u64,
83    /// Client identifier.
84    client_id: String,
85    /// Clean session flag.
86    clean_session: bool,
87}
88
89/// Default connect timeout in milliseconds.
90const DEF_CONN_TIMEOUT_MS: u64 = 3000;
91/// Default reconnect time in milliseconds.
92const DEF_RECONN_TIME_MS: u64 = 1000;
93/// The accepted pattern of the client identifier.
94const CLIENT_ID_PATTERN: &'static str = "^[0-9A-Za-z-]{1,23}$";
95
96impl MqttConnection {
97    /// Create a connection instance.
98    pub fn new(opts: MqttConnectionOptions) -> Result<MqttConnection, String> {
99        let uri = MQTTUri::from_str(opts.uri.as_str())?;
100
101        Ok(MqttConnection {
102            opts: InnerOptions {
103                uri,
104                connect_timeout_millis: match opts.connect_timeout_millis {
105                    0 => DEF_CONN_TIMEOUT_MS,
106                    _ => opts.connect_timeout_millis,
107                },
108                reconnect_millis: match opts.reconnect_millis {
109                    0 => DEF_RECONN_TIME_MS,
110                    _ => opts.reconnect_millis,
111                },
112                client_id: match opts.client_id {
113                    None => format!("general-mq-{}", randomstring(12)),
114                    Some(client_id) => {
115                        let re = Regex::new(CLIENT_ID_PATTERN).unwrap();
116                        if !re.is_match(client_id.as_str()) {
117                            return Err(format!("client_id is not match {}", CLIENT_ID_PATTERN));
118                        }
119                        client_id
120                    }
121                },
122                clean_session: opts.clean_session,
123            },
124            status: Arc::new(Mutex::new(Status::Closed)),
125            conn: Arc::new(Mutex::new(None)),
126            handlers: Arc::new(Mutex::new(HashMap::<String, Arc<dyn EventHandler>>::new())),
127            packet_handlers: Arc::new(Mutex::new(HashMap::<String, Arc<dyn PacketHandler>>::new())),
128            ev_loop: Arc::new(Mutex::new(None)),
129        })
130    }
131
132    /// To add a packet handler for [`crate::MqttQueue`]. The `name` is **the queue name**.
133    pub(super) fn add_packet_handler(&mut self, name: &str, handler: Arc<dyn PacketHandler>) {
134        self.packet_handlers
135            .lock()
136            .unwrap()
137            .insert(name.to_string(), handler);
138    }
139
140    /// To remove a packet handler. The `name` is **the queue name**.
141    pub(super) fn remove_packet_handler(&mut self, name: &str) {
142        self.packet_handlers.lock().unwrap().remove(name);
143    }
144
145    /// To get the raw MQTT connection instance for topic operations such as subscribe or publish.
146    pub(super) fn get_raw_connection(&self) -> Option<RumqttConnection> {
147        match self.conn.lock().unwrap().as_ref() {
148            None => None,
149            Some(conn) => Some(conn.clone()),
150        }
151    }
152}
153
154#[async_trait]
155impl GmqConnection for MqttConnection {
156    fn status(&self) -> Status {
157        *self.status.lock().unwrap()
158    }
159
160    fn add_handler(&mut self, handler: Arc<dyn EventHandler>) -> String {
161        let id = randomstring(ID_SIZE);
162        self.handlers.lock().unwrap().insert(id.clone(), handler);
163        id
164    }
165
166    fn remove_handler(&mut self, id: &str) {
167        self.handlers.lock().unwrap().remove(id);
168    }
169
170    fn connect(&mut self) -> Result<(), Box<dyn StdError>> {
171        {
172            let mut task_handle_mutex = self.ev_loop.lock().unwrap();
173            if (*task_handle_mutex).is_some() {
174                return Ok(());
175            }
176            *self.status.lock().unwrap() = Status::Connecting;
177            *task_handle_mutex = Some(create_event_loop(self));
178        }
179        Ok(())
180    }
181
182    async fn close(&mut self) -> Result<(), Box<dyn StdError + Send + Sync>> {
183        match { self.ev_loop.lock().unwrap().take() } {
184            None => return Ok(()),
185            Some(handle) => handle.abort(),
186        }
187        {
188            *self.status.lock().unwrap() = Status::Closing;
189        }
190
191        let conn = { self.conn.lock().unwrap().take() };
192        let mut result: Result<(), ClientError> = Ok(());
193        if let Some(conn) = conn {
194            result = conn.disconnect().await;
195        }
196
197        {
198            *self.status.lock().unwrap() = Status::Closed;
199        }
200        let handlers = { (*self.handlers.lock().unwrap()).clone() };
201        for (id, handler) in handlers {
202            let conn = Arc::new(self.clone());
203            task::spawn(async move {
204                handler.on_status(id.clone(), conn, Status::Closed).await;
205            });
206        }
207
208        result?;
209        Ok(())
210    }
211}
212
213impl Default for MqttConnectionOptions {
214    fn default() -> Self {
215        MqttConnectionOptions {
216            uri: "mqtt://localhost".to_string(),
217            connect_timeout_millis: DEF_CONN_TIMEOUT_MS,
218            reconnect_millis: DEF_RECONN_TIME_MS,
219            client_id: None,
220            clean_session: true,
221        }
222    }
223}
224
225/// To create an event loop runtime task.
226fn create_event_loop(conn: &MqttConnection) -> JoinHandle<()> {
227    let this = Arc::new(conn.clone());
228    task::spawn(async move {
229        loop {
230            match this.status() {
231                Status::Closing | Status::Closed => task::yield_now().await,
232                Status::Connecting | Status::Connected => {
233                    let mut opts = RumqttOption::new(
234                        this.opts.client_id.as_str(),
235                        this.opts.uri.host.as_str(),
236                        this.opts.uri.port,
237                    );
238                    opts.set_clean_session(this.opts.clean_session)
239                        .set_credentials(
240                            this.opts.uri.username.as_str(),
241                            this.opts.uri.password.as_str(),
242                        );
243                    if this.opts.uri.scheme == MQTTScheme::MQTTS {
244                        opts.set_transport(Transport::Tls(TlsConfiguration::default()));
245                    }
246
247                    let mut to_disconnected = false;
248                    let (client, mut event_loop) = RumqttConnection::new(opts, 10);
249                    let mut net_opts = NetworkOptions::new();
250                    net_opts.set_connection_timeout(this.opts.connect_timeout_millis);
251                    event_loop.set_network_options(net_opts);
252                    loop {
253                        match event_loop.poll().await {
254                            Err(_) => {
255                                if this.status() == Status::Connected {
256                                    to_disconnected = true;
257                                }
258                                break;
259                            }
260                            Ok(event) => {
261                                let packet = match event {
262                                    RumqttEvent::Incoming(packet) => packet,
263                                    _ => continue,
264                                };
265                                match packet {
266                                    Packet::Publish(packet) => {
267                                        if this.status() != Status::Connected {
268                                            continue;
269                                        }
270                                        let handler = {
271                                            let topic = packet.topic.as_str();
272                                            match this.packet_handlers.lock().unwrap().get(topic) {
273                                                None => continue,
274                                                Some(handler) => handler.clone(),
275                                            }
276                                        };
277                                        handler.on_publish(packet);
278                                    }
279                                    Packet::ConnAck(_) => {
280                                        let mut to_connected = false;
281                                        {
282                                            let mut status_mutex = this.status.lock().unwrap();
283                                            let status = *status_mutex;
284                                            if status == Status::Closing || status == Status::Closed
285                                            {
286                                                break;
287                                            } else if status != Status::Connected {
288                                                *this.conn.lock().unwrap() = Some(client.clone());
289                                                *status_mutex = Status::Connected;
290                                                to_connected = true;
291                                            }
292                                        }
293
294                                        if to_connected {
295                                            let handlers =
296                                                { (*this.handlers.lock().unwrap()).clone() };
297                                            for (id, handler) in handlers {
298                                                let conn = this.clone();
299                                                task::spawn(async move {
300                                                    handler
301                                                        .on_status(
302                                                            id.clone(),
303                                                            conn,
304                                                            Status::Connected,
305                                                        )
306                                                        .await;
307                                                });
308                                            }
309                                        }
310                                    }
311                                    _ => {}
312                                }
313                            }
314                        }
315                    }
316
317                    {
318                        let mut status_mutex = this.status.lock().unwrap();
319                        if *status_mutex == Status::Closing || *status_mutex == Status::Closed {
320                            continue;
321                        }
322                        let _ = this.conn.lock().unwrap().take();
323                        *status_mutex = Status::Disconnected;
324                    }
325
326                    if to_disconnected {
327                        let handlers = { (*this.handlers.lock().unwrap()).clone() };
328                        for (id, handler) in handlers {
329                            let conn = this.clone();
330                            task::spawn(async move {
331                                handler
332                                    .on_status(id.clone(), conn, Status::Disconnected)
333                                    .await;
334                            });
335                        }
336                    }
337                    time::sleep(Duration::from_millis(this.opts.reconnect_millis)).await;
338                    {
339                        let mut status_mutex = this.status.lock().unwrap();
340                        if *status_mutex == Status::Closing || *status_mutex == Status::Closed {
341                            continue;
342                        }
343                        *status_mutex = Status::Connecting;
344                    }
345                    if to_disconnected {
346                        let handlers = { (*this.handlers.lock().unwrap()).clone() };
347                        for (id, handler) in handlers {
348                            let conn = this.clone();
349                            task::spawn(async move {
350                                handler
351                                    .on_status(id.clone(), conn, Status::Connecting)
352                                    .await;
353                            });
354                        }
355                    }
356                }
357                Status::Disconnected => {
358                    *this.status.lock().unwrap() = Status::Connecting;
359                }
360            }
361        }
362    })
363}