ntex_mqtt/v3/
sink.rs

1use std::{cell::Cell, fmt, future::ready, future::Future, num::NonZeroU16, rc::Rc};
2
3use ntex_bytes::{ByteString, Bytes};
4use ntex_util::{channel::pool, future::Either, future::Ready};
5
6use crate::v3::shared::{Ack, AckType, MqttShared};
7use crate::v3::{codec, error::SendPacketError};
8use crate::{error::EncodeError, types::QoS};
9
10pub struct MqttSink(Rc<MqttShared>);
11
12impl Clone for MqttSink {
13    fn clone(&self) -> Self {
14        MqttSink(self.0.clone())
15    }
16}
17
18impl MqttSink {
19    pub(crate) fn new(state: Rc<MqttShared>) -> Self {
20        MqttSink(state)
21    }
22
23    pub(super) fn shared(&self) -> Rc<MqttShared> {
24        self.0.clone()
25    }
26
27    #[inline]
28    /// Check if io stream is open
29    pub fn is_open(&self) -> bool {
30        !self.0.is_closed()
31    }
32
33    #[inline]
34    /// Check if sink is ready
35    pub fn is_ready(&self) -> bool {
36        if self.0.is_closed() {
37            false
38        } else {
39            self.0.is_ready()
40        }
41    }
42
43    #[inline]
44    /// Get client receive credit
45    pub fn credit(&self) -> usize {
46        self.0.credit()
47    }
48
49    /// Get notification when packet could be send to the peer.
50    ///
51    /// Result indicates if connection is alive
52    pub fn ready(&self) -> impl Future<Output = bool> {
53        if !self.0.is_closed() {
54            self.0
55                .wait_readiness()
56                .map(|rx| Either::Right(async move { rx.await.is_ok() }))
57                .unwrap_or_else(|| Either::Left(ready(true)))
58        } else {
59            Either::Left(ready(false))
60        }
61    }
62
63    #[inline]
64    /// Close mqtt connection
65    pub fn close(&self) {
66        self.0.close();
67    }
68
69    #[inline]
70    /// Force close mqtt connection. mqtt dispatcher does not wait for uncompleted
71    /// responses, but it flushes buffers.
72    pub fn force_close(&self) {
73        self.0.force_close();
74    }
75
76    #[inline]
77    /// Send ping
78    pub(super) fn ping(&self) -> bool {
79        self.0.encode_packet(codec::Packet::PingRequest).is_ok()
80    }
81
82    #[inline]
83    /// Create publish message builder
84    pub fn publish<U>(&self, topic: U, payload: Bytes) -> PublishBuilder
85    where
86        ByteString: From<U>,
87    {
88        self.publish_pkt(
89            codec::Publish {
90                dup: false,
91                retain: false,
92                topic: topic.into(),
93                qos: codec::QoS::AtMostOnce,
94                packet_id: None,
95                payload_size: payload.len() as u32,
96            },
97            payload,
98        )
99    }
100
101    #[inline]
102    /// Create publish builder with publish packet
103    pub fn publish_pkt(&self, packet: codec::Publish, payload: Bytes) -> PublishBuilder {
104        PublishBuilder { packet, payload, shared: self.0.clone() }
105    }
106
107    /// Set publish ack callback
108    ///
109    /// Use non-blocking send, PublishBuilder::send_at_least_once_no_block()
110    /// First argument is packet id, second argument is "disconnected" state
111    pub fn publish_ack_cb<F>(&self, f: F)
112    where
113        F: Fn(NonZeroU16, bool) + 'static,
114    {
115        self.0.set_publish_ack(Box::new(f));
116    }
117
118    #[inline]
119    /// Create subscribe packet builder
120    ///
121    /// panics if id is 0
122    pub fn subscribe(&self) -> SubscribeBuilder {
123        SubscribeBuilder { id: None, topic_filters: Vec::new(), shared: self.0.clone() }
124    }
125
126    #[inline]
127    /// Create unsubscribe packet builder
128    pub fn unsubscribe(&self) -> UnsubscribeBuilder {
129        UnsubscribeBuilder { id: None, topic_filters: Vec::new(), shared: self.0.clone() }
130    }
131}
132
133impl fmt::Debug for MqttSink {
134    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
135        fmt.debug_struct("MqttSink").finish()
136    }
137}
138
139pub struct PublishBuilder {
140    packet: codec::Publish,
141    shared: Rc<MqttShared>,
142    payload: Bytes,
143}
144
145impl PublishBuilder {
146    #[inline]
147    /// Set packet id.
148    ///
149    /// Note: if packet id is not set, it gets generated automatically.
150    /// Packet id management should not be mixed, it should be auto-generated
151    /// or set by user. Otherwise collisions could occure.
152    ///
153    /// panics if id is 0
154    pub fn packet_id(mut self, id: u16) -> Self {
155        let id = NonZeroU16::new(id).expect("id 0 is not allowed");
156        self.packet.packet_id = Some(id);
157        self
158    }
159
160    #[inline]
161    /// This might be re-delivery of an earlier attempt to send the Packet.
162    pub fn dup(mut self, val: bool) -> Self {
163        self.packet.dup = val;
164        self
165    }
166
167    #[inline]
168    /// Set retain flag
169    pub fn retain(mut self) -> Self {
170        self.packet.retain = true;
171        self
172    }
173
174    #[inline]
175    /// Get size of the publish packet
176    pub fn size(&self) -> u32 {
177        codec::encode::get_encoded_publish_size(&self.packet) as u32
178    }
179
180    /// Create streamimng publish builder
181    pub fn streaming(mut self, size: u32) -> (StreamingPublishBuilder, StreamingPayload) {
182        self.packet.payload_size = size;
183        let payload = if self.payload.is_empty() { None } else { Some(self.payload) };
184
185        let (tx, rx) = self.shared.pool.waiters.channel();
186        (
187            StreamingPublishBuilder {
188                size,
189                payload,
190                tx: Some(tx),
191                shared: self.shared.clone(),
192                packet: self.packet,
193            },
194            StreamingPayload {
195                rx: Cell::new(Some(rx)),
196                shared: self.shared.clone(),
197                inprocess: Cell::new(false),
198            },
199        )
200    }
201
202    #[inline]
203    /// Send publish packet with QoS 0
204    pub fn send_at_most_once(mut self) -> Result<(), SendPacketError> {
205        if !self.shared.is_closed() {
206            log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
207            self.packet.qos = codec::QoS::AtMostOnce;
208            self.shared
209                .encode_publish(self.packet, Some(self.payload))
210                .map_err(SendPacketError::Encode)
211                .map(|_| ())
212        } else {
213            log::error!("Mqtt sink is disconnected");
214            Err(SendPacketError::Disconnected)
215        }
216    }
217
218    /// Send publish packet with QoS 1
219    pub fn send_at_least_once(mut self) -> impl Future<Output = Result<(), SendPacketError>> {
220        if !self.shared.is_closed() {
221            self.packet.qos = codec::QoS::AtLeastOnce;
222
223            // handle client receive maximum
224            if let Some(rx) = self.shared.wait_readiness() {
225                Either::Left(Either::Left(async move {
226                    if rx.await.is_err() {
227                        return Err(SendPacketError::Disconnected);
228                    }
229                    self.send_at_least_once_inner().await
230                }))
231            } else {
232                Either::Left(Either::Right(self.send_at_least_once_inner()))
233            }
234        } else {
235            Either::Right(Ready::Err(SendPacketError::Disconnected))
236        }
237    }
238
239    /// Non-blocking send publish packet with QoS 1
240    ///
241    /// Panics if sink is not ready or publish ack callback is not set
242    pub fn send_at_least_once_no_block(mut self) -> Result<(), SendPacketError> {
243        if !self.shared.is_closed() {
244            // check readiness
245            if !self.shared.is_ready() {
246                panic!("Mqtt sink is not ready");
247            }
248            self.packet.qos = codec::QoS::AtLeastOnce;
249            let idx = self.shared.set_publish_id(&mut self.packet);
250
251            log::trace!("Publish (QoS1) to {:#?}", self.packet);
252
253            self.shared.wait_publish_response_no_block(
254                idx,
255                AckType::Publish,
256                self.packet,
257                Some(self.payload),
258            )
259        } else {
260            Err(SendPacketError::Disconnected)
261        }
262    }
263
264    fn send_at_least_once_inner(mut self) -> impl Future<Output = Result<(), SendPacketError>> {
265        let idx = self.shared.set_publish_id(&mut self.packet);
266        log::trace!("Publish (QoS1) to {:#?}", self.packet);
267
268        let rx = self.shared.wait_publish_response(
269            idx,
270            AckType::Publish,
271            self.packet,
272            Some(self.payload),
273        );
274        async move { rx?.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected) }
275    }
276
277    /// Send publish packet with QoS 2
278    pub fn send_exactly_once(
279        mut self,
280    ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
281        if !self.shared.is_closed() {
282            self.packet.qos = codec::QoS::ExactlyOnce;
283
284            // handle client receive maximum
285            if let Some(rx) = self.shared.wait_readiness() {
286                Either::Left(Either::Left(async move {
287                    if rx.await.is_err() {
288                        return Err(SendPacketError::Disconnected);
289                    }
290                    self.send_exactly_once_inner().await
291                }))
292            } else {
293                Either::Left(Either::Right(self.send_exactly_once_inner()))
294            }
295        } else {
296            Either::Right(Ready::Err(SendPacketError::Disconnected))
297        }
298    }
299
300    fn send_exactly_once_inner(
301        mut self,
302    ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
303        let shared = self.shared.clone();
304        let idx = shared.set_publish_id(&mut self.packet);
305        log::trace!("Publish (QoS2) to {:#?}", self.packet);
306
307        let rx = shared.wait_publish_response(
308            idx,
309            AckType::Receive,
310            self.packet,
311            Some(self.payload),
312        );
313        async move {
314            rx?.await
315                .map(move |_| PublishReceived { packet_id: Some(idx), shared })
316                .map_err(|_| SendPacketError::Disconnected)
317        }
318    }
319}
320
321/// Publish released for QoS2
322pub struct PublishReceived {
323    packet_id: Option<NonZeroU16>,
324    shared: Rc<MqttShared>,
325}
326
327impl PublishReceived {
328    /// Release publish
329    pub async fn release(mut self) -> Result<(), SendPacketError> {
330        let rx = self.shared.release_publish(self.packet_id.take().unwrap())?;
331
332        rx.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
333    }
334}
335
336impl Drop for PublishReceived {
337    fn drop(&mut self) {
338        if let Some(id) = self.packet_id.take() {
339            self.shared.release_publish(id);
340        }
341    }
342}
343
344/// Subscribe packet builder
345pub struct SubscribeBuilder {
346    id: Option<NonZeroU16>,
347    shared: Rc<MqttShared>,
348    topic_filters: Vec<(ByteString, codec::QoS)>,
349}
350
351impl SubscribeBuilder {
352    #[inline]
353    /// Set packet id.
354    ///
355    /// panics if id is 0
356    pub fn packet_id(mut self, id: u16) -> Self {
357        if let Some(id) = NonZeroU16::new(id) {
358            self.id = Some(id);
359            self
360        } else {
361            panic!("id 0 is not allowed");
362        }
363    }
364
365    #[inline]
366    /// Add topic filter
367    pub fn topic_filter(mut self, filter: ByteString, qos: codec::QoS) -> Self {
368        self.topic_filters.push((filter, qos));
369        self
370    }
371
372    #[inline]
373    /// Get size of the subscribe packet
374    pub fn size(&self) -> u32 {
375        codec::encode::get_encoded_subscribe_size(&self.topic_filters) as u32
376    }
377
378    /// Send subscribe packet
379    pub async fn send(self) -> Result<Vec<codec::SubscribeReturnCode>, SendPacketError> {
380        if !self.shared.is_closed() {
381            // handle client receive maximum
382            if let Some(rx) = self.shared.wait_readiness() {
383                if rx.await.is_err() {
384                    return Err(SendPacketError::Disconnected);
385                }
386            }
387            let idx = self.id.unwrap_or_else(|| self.shared.next_id());
388            let rx = self.shared.wait_response(idx, AckType::Subscribe)?;
389
390            // send subscribe to client
391            log::trace!(
392                "Sending subscribe packet id: {} filters:{:?}",
393                idx,
394                self.topic_filters
395            );
396
397            match self.shared.encode_packet(codec::Packet::Subscribe {
398                packet_id: idx,
399                topic_filters: self.topic_filters,
400            }) {
401                Ok(_) => {
402                    // wait ack from peer
403                    rx.await
404                        .map_err(|_| SendPacketError::Disconnected)
405                        .map(|pkt| pkt.subscribe())
406                }
407                Err(err) => Err(SendPacketError::Encode(err)),
408            }
409        } else {
410            Err(SendPacketError::Disconnected)
411        }
412    }
413}
414
415/// Unsubscribe packet builder
416pub struct UnsubscribeBuilder {
417    id: Option<NonZeroU16>,
418    shared: Rc<MqttShared>,
419    topic_filters: Vec<ByteString>,
420}
421
422impl UnsubscribeBuilder {
423    #[inline]
424    /// Set packet id.
425    ///
426    /// panics if id is 0
427    pub fn packet_id(mut self, id: u16) -> Self {
428        if let Some(id) = NonZeroU16::new(id) {
429            self.id = Some(id);
430            self
431        } else {
432            panic!("id 0 is not allowed");
433        }
434    }
435
436    #[inline]
437    /// Add topic filter
438    pub fn topic_filter(mut self, filter: ByteString) -> Self {
439        self.topic_filters.push(filter);
440        self
441    }
442
443    #[inline]
444    /// Get size of the unsubscribe packet
445    pub fn size(&self) -> u32 {
446        codec::encode::get_encoded_unsubscribe_size(&self.topic_filters) as u32
447    }
448
449    /// Send unsubscribe packet
450    pub async fn send(self) -> Result<(), SendPacketError> {
451        let shared = self.shared;
452        let filters = self.topic_filters;
453
454        if !shared.is_closed() {
455            // handle client receive maximum
456            if let Some(rx) = shared.wait_readiness() {
457                if rx.await.is_err() {
458                    return Err(SendPacketError::Disconnected);
459                }
460            }
461            // allocate packet id
462            let idx = self.id.unwrap_or_else(|| shared.next_id());
463            let rx = shared.wait_response(idx, AckType::Unsubscribe)?;
464
465            // send subscribe to client
466            log::trace!("Sending unsubscribe packet id: {} filters:{:?}", idx, filters);
467
468            match shared.encode_packet(codec::Packet::Unsubscribe {
469                packet_id: idx,
470                topic_filters: filters,
471            }) {
472                Ok(_) => {
473                    // wait ack from peer
474                    rx.await.map_err(|_| SendPacketError::Disconnected).map(|_| ())
475                }
476                Err(err) => Err(SendPacketError::Encode(err)),
477            }
478        } else {
479            Err(SendPacketError::Disconnected)
480        }
481    }
482}
483
484pub struct StreamingPublishBuilder {
485    shared: Rc<MqttShared>,
486    packet: codec::Publish,
487    payload: Option<Bytes>,
488    size: u32,
489    tx: Option<pool::Sender<()>>,
490}
491
492impl StreamingPublishBuilder {
493    fn notify_payload_streamer(&mut self) -> Result<(), SendPacketError> {
494        if let Some(tx) = self.tx.take() {
495            tx.send(()).map_err(|_| SendPacketError::StreamingCancelled)
496        } else {
497            Ok(())
498        }
499    }
500
501    /// Send publish packet with QoS 0
502    pub fn send_at_most_once(mut self) -> Result<(), SendPacketError> {
503        if !self.shared.is_closed() {
504            log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
505            self.notify_payload_streamer()?;
506
507            self.packet.qos = QoS::AtMostOnce;
508            self.shared
509                .encode_publish(self.packet, self.payload)
510                .map_err(SendPacketError::Encode)
511                .map(|_| ())
512        } else {
513            log::error!("Mqtt sink is disconnected");
514            Err(SendPacketError::Disconnected)
515        }
516    }
517
518    /// Send publish packet with QoS 1
519    pub fn send_at_least_once(mut self) -> impl Future<Output = Result<(), SendPacketError>> {
520        if !self.shared.is_closed() {
521            self.packet.qos = QoS::AtLeastOnce;
522
523            // handle client receive maximum
524            if let Some(rx) = self.shared.wait_readiness() {
525                Either::Left(Either::Left(async move {
526                    if rx.await.is_err() {
527                        return Err(SendPacketError::Disconnected);
528                    }
529                    self.send_at_least_once_inner().await
530                }))
531            } else {
532                Either::Left(Either::Right(self.send_at_least_once_inner()))
533            }
534        } else {
535            Either::Right(Ready::Err(SendPacketError::Disconnected))
536        }
537    }
538
539    /// Non-blocking send publish packet with QoS 1
540    ///
541    /// Panics if sink is not ready or publish ack callback is not set
542    pub fn send_at_least_once_no_block(mut self) -> Result<(), SendPacketError> {
543        if !self.shared.is_closed() {
544            // check readiness
545            if !self.shared.is_ready() {
546                panic!("Mqtt sink is not ready");
547            }
548            self.packet.qos = codec::QoS::AtLeastOnce;
549            let tx = self.tx.take().unwrap();
550            let idx = self.shared.set_publish_id(&mut self.packet);
551
552            if tx.is_canceled() {
553                Err(SendPacketError::StreamingCancelled)
554            } else {
555                log::trace!("Publish (QoS1) to {:#?}", self.packet);
556                let _ = tx.send(());
557                self.shared.wait_publish_response_no_block(
558                    idx,
559                    AckType::Publish,
560                    self.packet,
561                    self.payload,
562                )
563            }
564        } else {
565            Err(SendPacketError::Disconnected)
566        }
567    }
568
569    async fn send_at_least_once_inner(mut self) -> Result<(), SendPacketError> {
570        // packet id
571        let idx = self.shared.set_publish_id(&mut self.packet);
572
573        // send publish to client
574        log::trace!("Publish (QoS1) to {:#?}", self.packet);
575
576        let tx = self.tx.take().unwrap();
577        if tx.is_canceled() {
578            Err(SendPacketError::StreamingCancelled)
579        } else {
580            let rx = self.shared.wait_publish_response(
581                idx,
582                AckType::Publish,
583                self.packet,
584                self.payload,
585            );
586            let _ = tx.send(());
587
588            rx?.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
589        }
590    }
591}
592
593pub struct StreamingPayload {
594    shared: Rc<MqttShared>,
595    rx: Cell<Option<pool::Receiver<()>>>,
596    inprocess: Cell<bool>,
597}
598
599impl StreamingPayload {
600    fn drop(&mut self) {
601        if self.inprocess.get() {
602            if self.shared.is_streaming() {
603                self.shared.streaming_dropped();
604            }
605        }
606    }
607}
608
609impl StreamingPayload {
610    /// Send payload chunk
611    pub async fn send(&self, chunk: Bytes) -> Result<(), SendPacketError> {
612        if let Some(rx) = self.rx.take() {
613            if rx.await.is_err() {
614                return Err(SendPacketError::StreamingCancelled);
615            }
616            log::trace!("Publish is encoded, ready to process payload");
617            self.inprocess.set(true);
618        }
619
620        if !self.inprocess.get() {
621            Err(EncodeError::UnexpectedPayload.into())
622        } else {
623            log::trace!("Sending payload chunk: {:?}", chunk.len());
624            self.shared.want_payload_stream().await?;
625
626            if !self.shared.encode_publish_payload(chunk)? {
627                self.inprocess.set(false);
628            }
629            Ok(())
630        }
631    }
632}