1use std::{
2 collections::HashMap,
3 error::Error as StdError,
4 str::FromStr,
5 sync::{Arc, Mutex},
6 time::Duration,
7};
8
9use async_trait::async_trait;
10use regex::Regex;
11use rumqttc::{
12 AsyncClient as RumqttConnection, ClientError, Event as RumqttEvent,
13 MqttOptions as RumqttOption, NetworkOptions, Packet, Publish, TlsConfiguration, Transport,
14};
15use tokio::{
16 task::{self, JoinHandle},
17 time,
18};
19
20use super::uri::{MQTTScheme, MQTTUri};
21use crate::{
22 connection::{EventHandler, GmqConnection, Status},
23 randomstring, ID_SIZE,
24};
25
26#[derive(Clone)]
28pub struct MqttConnection {
29 opts: InnerOptions,
31 status: Arc<Mutex<Status>>,
33 conn: Arc<Mutex<Option<RumqttConnection>>>,
35 handlers: Arc<Mutex<HashMap<String, Arc<dyn EventHandler>>>>,
37 packet_handlers: Arc<Mutex<HashMap<String, Arc<dyn PacketHandler>>>>,
42 ev_loop: Arc<Mutex<Option<JoinHandle<()>>>>,
44}
45
46pub struct MqttConnectionOptions {
48 pub uri: String,
52 pub connect_timeout_millis: u64,
56 pub reconnect_millis: u64,
60 pub client_id: Option<String>,
62 pub clean_session: bool,
66}
67
68pub(super) trait PacketHandler: Send + Sync {
70 fn on_publish(&self, packet: Publish);
72}
73
74#[derive(Clone)]
76struct InnerOptions {
77 uri: MQTTUri,
79 connect_timeout_millis: u64,
81 reconnect_millis: u64,
83 client_id: String,
85 clean_session: bool,
87}
88
89const DEF_CONN_TIMEOUT_MS: u64 = 3000;
91const DEF_RECONN_TIME_MS: u64 = 1000;
93const CLIENT_ID_PATTERN: &'static str = "^[0-9A-Za-z-]{1,23}$";
95
96impl MqttConnection {
97 pub fn new(opts: MqttConnectionOptions) -> Result<MqttConnection, String> {
99 let uri = MQTTUri::from_str(opts.uri.as_str())?;
100
101 Ok(MqttConnection {
102 opts: InnerOptions {
103 uri,
104 connect_timeout_millis: match opts.connect_timeout_millis {
105 0 => DEF_CONN_TIMEOUT_MS,
106 _ => opts.connect_timeout_millis,
107 },
108 reconnect_millis: match opts.reconnect_millis {
109 0 => DEF_RECONN_TIME_MS,
110 _ => opts.reconnect_millis,
111 },
112 client_id: match opts.client_id {
113 None => format!("general-mq-{}", randomstring(12)),
114 Some(client_id) => {
115 let re = Regex::new(CLIENT_ID_PATTERN).unwrap();
116 if !re.is_match(client_id.as_str()) {
117 return Err(format!("client_id is not match {}", CLIENT_ID_PATTERN));
118 }
119 client_id
120 }
121 },
122 clean_session: opts.clean_session,
123 },
124 status: Arc::new(Mutex::new(Status::Closed)),
125 conn: Arc::new(Mutex::new(None)),
126 handlers: Arc::new(Mutex::new(HashMap::<String, Arc<dyn EventHandler>>::new())),
127 packet_handlers: Arc::new(Mutex::new(HashMap::<String, Arc<dyn PacketHandler>>::new())),
128 ev_loop: Arc::new(Mutex::new(None)),
129 })
130 }
131
132 pub(super) fn add_packet_handler(&mut self, name: &str, handler: Arc<dyn PacketHandler>) {
134 self.packet_handlers
135 .lock()
136 .unwrap()
137 .insert(name.to_string(), handler);
138 }
139
140 pub(super) fn remove_packet_handler(&mut self, name: &str) {
142 self.packet_handlers.lock().unwrap().remove(name);
143 }
144
145 pub(super) fn get_raw_connection(&self) -> Option<RumqttConnection> {
147 match self.conn.lock().unwrap().as_ref() {
148 None => None,
149 Some(conn) => Some(conn.clone()),
150 }
151 }
152}
153
154#[async_trait]
155impl GmqConnection for MqttConnection {
156 fn status(&self) -> Status {
157 *self.status.lock().unwrap()
158 }
159
160 fn add_handler(&mut self, handler: Arc<dyn EventHandler>) -> String {
161 let id = randomstring(ID_SIZE);
162 self.handlers.lock().unwrap().insert(id.clone(), handler);
163 id
164 }
165
166 fn remove_handler(&mut self, id: &str) {
167 self.handlers.lock().unwrap().remove(id);
168 }
169
170 fn connect(&mut self) -> Result<(), Box<dyn StdError>> {
171 {
172 let mut task_handle_mutex = self.ev_loop.lock().unwrap();
173 if (*task_handle_mutex).is_some() {
174 return Ok(());
175 }
176 *self.status.lock().unwrap() = Status::Connecting;
177 *task_handle_mutex = Some(create_event_loop(self));
178 }
179 Ok(())
180 }
181
182 async fn close(&mut self) -> Result<(), Box<dyn StdError + Send + Sync>> {
183 match { self.ev_loop.lock().unwrap().take() } {
184 None => return Ok(()),
185 Some(handle) => handle.abort(),
186 }
187 {
188 *self.status.lock().unwrap() = Status::Closing;
189 }
190
191 let conn = { self.conn.lock().unwrap().take() };
192 let mut result: Result<(), ClientError> = Ok(());
193 if let Some(conn) = conn {
194 result = conn.disconnect().await;
195 }
196
197 {
198 *self.status.lock().unwrap() = Status::Closed;
199 }
200 let handlers = { (*self.handlers.lock().unwrap()).clone() };
201 for (id, handler) in handlers {
202 let conn = Arc::new(self.clone());
203 task::spawn(async move {
204 handler.on_status(id.clone(), conn, Status::Closed).await;
205 });
206 }
207
208 result?;
209 Ok(())
210 }
211}
212
213impl Default for MqttConnectionOptions {
214 fn default() -> Self {
215 MqttConnectionOptions {
216 uri: "mqtt://localhost".to_string(),
217 connect_timeout_millis: DEF_CONN_TIMEOUT_MS,
218 reconnect_millis: DEF_RECONN_TIME_MS,
219 client_id: None,
220 clean_session: true,
221 }
222 }
223}
224
225fn create_event_loop(conn: &MqttConnection) -> JoinHandle<()> {
227 let this = Arc::new(conn.clone());
228 task::spawn(async move {
229 loop {
230 match this.status() {
231 Status::Closing | Status::Closed => task::yield_now().await,
232 Status::Connecting | Status::Connected => {
233 let mut opts = RumqttOption::new(
234 this.opts.client_id.as_str(),
235 this.opts.uri.host.as_str(),
236 this.opts.uri.port,
237 );
238 opts.set_clean_session(this.opts.clean_session)
239 .set_credentials(
240 this.opts.uri.username.as_str(),
241 this.opts.uri.password.as_str(),
242 );
243 if this.opts.uri.scheme == MQTTScheme::MQTTS {
244 opts.set_transport(Transport::Tls(TlsConfiguration::default()));
245 }
246
247 let mut to_disconnected = false;
248 let (client, mut event_loop) = RumqttConnection::new(opts, 10);
249 let mut net_opts = NetworkOptions::new();
250 net_opts.set_connection_timeout(this.opts.connect_timeout_millis);
251 event_loop.set_network_options(net_opts);
252 loop {
253 match event_loop.poll().await {
254 Err(_) => {
255 if this.status() == Status::Connected {
256 to_disconnected = true;
257 }
258 break;
259 }
260 Ok(event) => {
261 let packet = match event {
262 RumqttEvent::Incoming(packet) => packet,
263 _ => continue,
264 };
265 match packet {
266 Packet::Publish(packet) => {
267 if this.status() != Status::Connected {
268 continue;
269 }
270 let handler = {
271 let topic = packet.topic.as_str();
272 match this.packet_handlers.lock().unwrap().get(topic) {
273 None => continue,
274 Some(handler) => handler.clone(),
275 }
276 };
277 handler.on_publish(packet);
278 }
279 Packet::ConnAck(_) => {
280 let mut to_connected = false;
281 {
282 let mut status_mutex = this.status.lock().unwrap();
283 let status = *status_mutex;
284 if status == Status::Closing || status == Status::Closed
285 {
286 break;
287 } else if status != Status::Connected {
288 *this.conn.lock().unwrap() = Some(client.clone());
289 *status_mutex = Status::Connected;
290 to_connected = true;
291 }
292 }
293
294 if to_connected {
295 let handlers =
296 { (*this.handlers.lock().unwrap()).clone() };
297 for (id, handler) in handlers {
298 let conn = this.clone();
299 task::spawn(async move {
300 handler
301 .on_status(
302 id.clone(),
303 conn,
304 Status::Connected,
305 )
306 .await;
307 });
308 }
309 }
310 }
311 _ => {}
312 }
313 }
314 }
315 }
316
317 {
318 let mut status_mutex = this.status.lock().unwrap();
319 if *status_mutex == Status::Closing || *status_mutex == Status::Closed {
320 continue;
321 }
322 let _ = this.conn.lock().unwrap().take();
323 *status_mutex = Status::Disconnected;
324 }
325
326 if to_disconnected {
327 let handlers = { (*this.handlers.lock().unwrap()).clone() };
328 for (id, handler) in handlers {
329 let conn = this.clone();
330 task::spawn(async move {
331 handler
332 .on_status(id.clone(), conn, Status::Disconnected)
333 .await;
334 });
335 }
336 }
337 time::sleep(Duration::from_millis(this.opts.reconnect_millis)).await;
338 {
339 let mut status_mutex = this.status.lock().unwrap();
340 if *status_mutex == Status::Closing || *status_mutex == Status::Closed {
341 continue;
342 }
343 *status_mutex = Status::Connecting;
344 }
345 if to_disconnected {
346 let handlers = { (*this.handlers.lock().unwrap()).clone() };
347 for (id, handler) in handlers {
348 let conn = this.clone();
349 task::spawn(async move {
350 handler
351 .on_status(id.clone(), conn, Status::Connecting)
352 .await;
353 });
354 }
355 }
356 }
357 Status::Disconnected => {
358 *this.status.lock().unwrap() = Status::Connecting;
359 }
360 }
361 }
362 })
363}