1use std::{cell::Cell, fmt, future::ready, future::Future, num::NonZeroU16, rc::Rc};
2
3use ntex_bytes::{ByteString, Bytes};
4use ntex_util::{channel::pool, future::Either, future::Ready};
5
6use crate::v3::shared::{Ack, AckType, MqttShared};
7use crate::v3::{codec, error::SendPacketError};
8use crate::{error::EncodeError, types::QoS};
9
10pub struct MqttSink(Rc<MqttShared>);
11
12impl Clone for MqttSink {
13 fn clone(&self) -> Self {
14 MqttSink(self.0.clone())
15 }
16}
17
18impl MqttSink {
19 pub(crate) fn new(state: Rc<MqttShared>) -> Self {
20 MqttSink(state)
21 }
22
23 pub(super) fn shared(&self) -> Rc<MqttShared> {
24 self.0.clone()
25 }
26
27 #[inline]
28 pub fn is_open(&self) -> bool {
30 !self.0.is_closed()
31 }
32
33 #[inline]
34 pub fn is_ready(&self) -> bool {
36 if self.0.is_closed() {
37 false
38 } else {
39 self.0.is_ready()
40 }
41 }
42
43 #[inline]
44 pub fn credit(&self) -> usize {
46 self.0.credit()
47 }
48
49 pub fn ready(&self) -> impl Future<Output = bool> {
53 if !self.0.is_closed() {
54 self.0
55 .wait_readiness()
56 .map(|rx| Either::Right(async move { rx.await.is_ok() }))
57 .unwrap_or_else(|| Either::Left(ready(true)))
58 } else {
59 Either::Left(ready(false))
60 }
61 }
62
63 #[inline]
64 pub fn close(&self) {
66 self.0.close();
67 }
68
69 #[inline]
70 pub fn force_close(&self) {
73 self.0.force_close();
74 }
75
76 #[inline]
77 pub(super) fn ping(&self) -> bool {
79 self.0.encode_packet(codec::Packet::PingRequest).is_ok()
80 }
81
82 #[inline]
83 pub fn publish<U>(&self, topic: U, payload: Bytes) -> PublishBuilder
85 where
86 ByteString: From<U>,
87 {
88 self.publish_pkt(
89 codec::Publish {
90 dup: false,
91 retain: false,
92 topic: topic.into(),
93 qos: codec::QoS::AtMostOnce,
94 packet_id: None,
95 payload_size: payload.len() as u32,
96 },
97 payload,
98 )
99 }
100
101 #[inline]
102 pub fn publish_pkt(&self, packet: codec::Publish, payload: Bytes) -> PublishBuilder {
104 PublishBuilder { packet, payload, shared: self.0.clone() }
105 }
106
107 pub fn publish_ack_cb<F>(&self, f: F)
112 where
113 F: Fn(NonZeroU16, bool) + 'static,
114 {
115 self.0.set_publish_ack(Box::new(f));
116 }
117
118 #[inline]
119 pub fn subscribe(&self) -> SubscribeBuilder {
123 SubscribeBuilder { id: None, topic_filters: Vec::new(), shared: self.0.clone() }
124 }
125
126 #[inline]
127 pub fn unsubscribe(&self) -> UnsubscribeBuilder {
129 UnsubscribeBuilder { id: None, topic_filters: Vec::new(), shared: self.0.clone() }
130 }
131}
132
133impl fmt::Debug for MqttSink {
134 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
135 fmt.debug_struct("MqttSink").finish()
136 }
137}
138
139pub struct PublishBuilder {
140 packet: codec::Publish,
141 shared: Rc<MqttShared>,
142 payload: Bytes,
143}
144
145impl PublishBuilder {
146 #[inline]
147 pub fn packet_id(mut self, id: u16) -> Self {
155 let id = NonZeroU16::new(id).expect("id 0 is not allowed");
156 self.packet.packet_id = Some(id);
157 self
158 }
159
160 #[inline]
161 pub fn dup(mut self, val: bool) -> Self {
163 self.packet.dup = val;
164 self
165 }
166
167 #[inline]
168 pub fn retain(mut self) -> Self {
170 self.packet.retain = true;
171 self
172 }
173
174 #[inline]
175 pub fn size(&self) -> u32 {
177 codec::encode::get_encoded_publish_size(&self.packet) as u32
178 }
179
180 pub fn streaming(mut self, size: u32) -> (StreamingPublishBuilder, StreamingPayload) {
182 self.packet.payload_size = size;
183 let payload = if self.payload.is_empty() { None } else { Some(self.payload) };
184
185 let (tx, rx) = self.shared.pool.waiters.channel();
186 (
187 StreamingPublishBuilder {
188 size,
189 payload,
190 tx: Some(tx),
191 shared: self.shared.clone(),
192 packet: self.packet,
193 },
194 StreamingPayload {
195 rx: Cell::new(Some(rx)),
196 shared: self.shared.clone(),
197 inprocess: Cell::new(false),
198 },
199 )
200 }
201
202 #[inline]
203 pub fn send_at_most_once(mut self) -> Result<(), SendPacketError> {
205 if !self.shared.is_closed() {
206 log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
207 self.packet.qos = codec::QoS::AtMostOnce;
208 self.shared
209 .encode_publish(self.packet, Some(self.payload))
210 .map_err(SendPacketError::Encode)
211 .map(|_| ())
212 } else {
213 log::error!("Mqtt sink is disconnected");
214 Err(SendPacketError::Disconnected)
215 }
216 }
217
218 pub fn send_at_least_once(mut self) -> impl Future<Output = Result<(), SendPacketError>> {
220 if !self.shared.is_closed() {
221 self.packet.qos = codec::QoS::AtLeastOnce;
222
223 if let Some(rx) = self.shared.wait_readiness() {
225 Either::Left(Either::Left(async move {
226 if rx.await.is_err() {
227 return Err(SendPacketError::Disconnected);
228 }
229 self.send_at_least_once_inner().await
230 }))
231 } else {
232 Either::Left(Either::Right(self.send_at_least_once_inner()))
233 }
234 } else {
235 Either::Right(Ready::Err(SendPacketError::Disconnected))
236 }
237 }
238
239 pub fn send_at_least_once_no_block(mut self) -> Result<(), SendPacketError> {
243 if !self.shared.is_closed() {
244 if !self.shared.is_ready() {
246 panic!("Mqtt sink is not ready");
247 }
248 self.packet.qos = codec::QoS::AtLeastOnce;
249 let idx = self.shared.set_publish_id(&mut self.packet);
250
251 log::trace!("Publish (QoS1) to {:#?}", self.packet);
252
253 self.shared.wait_publish_response_no_block(
254 idx,
255 AckType::Publish,
256 self.packet,
257 Some(self.payload),
258 )
259 } else {
260 Err(SendPacketError::Disconnected)
261 }
262 }
263
264 fn send_at_least_once_inner(mut self) -> impl Future<Output = Result<(), SendPacketError>> {
265 let idx = self.shared.set_publish_id(&mut self.packet);
266 log::trace!("Publish (QoS1) to {:#?}", self.packet);
267
268 let rx = self.shared.wait_publish_response(
269 idx,
270 AckType::Publish,
271 self.packet,
272 Some(self.payload),
273 );
274 async move { rx?.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected) }
275 }
276
277 pub fn send_exactly_once(
279 mut self,
280 ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
281 if !self.shared.is_closed() {
282 self.packet.qos = codec::QoS::ExactlyOnce;
283
284 if let Some(rx) = self.shared.wait_readiness() {
286 Either::Left(Either::Left(async move {
287 if rx.await.is_err() {
288 return Err(SendPacketError::Disconnected);
289 }
290 self.send_exactly_once_inner().await
291 }))
292 } else {
293 Either::Left(Either::Right(self.send_exactly_once_inner()))
294 }
295 } else {
296 Either::Right(Ready::Err(SendPacketError::Disconnected))
297 }
298 }
299
300 fn send_exactly_once_inner(
301 mut self,
302 ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
303 let shared = self.shared.clone();
304 let idx = shared.set_publish_id(&mut self.packet);
305 log::trace!("Publish (QoS2) to {:#?}", self.packet);
306
307 let rx = shared.wait_publish_response(
308 idx,
309 AckType::Receive,
310 self.packet,
311 Some(self.payload),
312 );
313 async move {
314 rx?.await
315 .map(move |_| PublishReceived { packet_id: Some(idx), shared })
316 .map_err(|_| SendPacketError::Disconnected)
317 }
318 }
319}
320
321pub struct PublishReceived {
323 packet_id: Option<NonZeroU16>,
324 shared: Rc<MqttShared>,
325}
326
327impl PublishReceived {
328 pub async fn release(mut self) -> Result<(), SendPacketError> {
330 let rx = self.shared.release_publish(self.packet_id.take().unwrap())?;
331
332 rx.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
333 }
334}
335
336impl Drop for PublishReceived {
337 fn drop(&mut self) {
338 if let Some(id) = self.packet_id.take() {
339 self.shared.release_publish(id);
340 }
341 }
342}
343
344pub struct SubscribeBuilder {
346 id: Option<NonZeroU16>,
347 shared: Rc<MqttShared>,
348 topic_filters: Vec<(ByteString, codec::QoS)>,
349}
350
351impl SubscribeBuilder {
352 #[inline]
353 pub fn packet_id(mut self, id: u16) -> Self {
357 if let Some(id) = NonZeroU16::new(id) {
358 self.id = Some(id);
359 self
360 } else {
361 panic!("id 0 is not allowed");
362 }
363 }
364
365 #[inline]
366 pub fn topic_filter(mut self, filter: ByteString, qos: codec::QoS) -> Self {
368 self.topic_filters.push((filter, qos));
369 self
370 }
371
372 #[inline]
373 pub fn size(&self) -> u32 {
375 codec::encode::get_encoded_subscribe_size(&self.topic_filters) as u32
376 }
377
378 pub async fn send(self) -> Result<Vec<codec::SubscribeReturnCode>, SendPacketError> {
380 if !self.shared.is_closed() {
381 if let Some(rx) = self.shared.wait_readiness() {
383 if rx.await.is_err() {
384 return Err(SendPacketError::Disconnected);
385 }
386 }
387 let idx = self.id.unwrap_or_else(|| self.shared.next_id());
388 let rx = self.shared.wait_response(idx, AckType::Subscribe)?;
389
390 log::trace!(
392 "Sending subscribe packet id: {} filters:{:?}",
393 idx,
394 self.topic_filters
395 );
396
397 match self.shared.encode_packet(codec::Packet::Subscribe {
398 packet_id: idx,
399 topic_filters: self.topic_filters,
400 }) {
401 Ok(_) => {
402 rx.await
404 .map_err(|_| SendPacketError::Disconnected)
405 .map(|pkt| pkt.subscribe())
406 }
407 Err(err) => Err(SendPacketError::Encode(err)),
408 }
409 } else {
410 Err(SendPacketError::Disconnected)
411 }
412 }
413}
414
415pub struct UnsubscribeBuilder {
417 id: Option<NonZeroU16>,
418 shared: Rc<MqttShared>,
419 topic_filters: Vec<ByteString>,
420}
421
422impl UnsubscribeBuilder {
423 #[inline]
424 pub fn packet_id(mut self, id: u16) -> Self {
428 if let Some(id) = NonZeroU16::new(id) {
429 self.id = Some(id);
430 self
431 } else {
432 panic!("id 0 is not allowed");
433 }
434 }
435
436 #[inline]
437 pub fn topic_filter(mut self, filter: ByteString) -> Self {
439 self.topic_filters.push(filter);
440 self
441 }
442
443 #[inline]
444 pub fn size(&self) -> u32 {
446 codec::encode::get_encoded_unsubscribe_size(&self.topic_filters) as u32
447 }
448
449 pub async fn send(self) -> Result<(), SendPacketError> {
451 let shared = self.shared;
452 let filters = self.topic_filters;
453
454 if !shared.is_closed() {
455 if let Some(rx) = shared.wait_readiness() {
457 if rx.await.is_err() {
458 return Err(SendPacketError::Disconnected);
459 }
460 }
461 let idx = self.id.unwrap_or_else(|| shared.next_id());
463 let rx = shared.wait_response(idx, AckType::Unsubscribe)?;
464
465 log::trace!("Sending unsubscribe packet id: {} filters:{:?}", idx, filters);
467
468 match shared.encode_packet(codec::Packet::Unsubscribe {
469 packet_id: idx,
470 topic_filters: filters,
471 }) {
472 Ok(_) => {
473 rx.await.map_err(|_| SendPacketError::Disconnected).map(|_| ())
475 }
476 Err(err) => Err(SendPacketError::Encode(err)),
477 }
478 } else {
479 Err(SendPacketError::Disconnected)
480 }
481 }
482}
483
484pub struct StreamingPublishBuilder {
485 shared: Rc<MqttShared>,
486 packet: codec::Publish,
487 payload: Option<Bytes>,
488 size: u32,
489 tx: Option<pool::Sender<()>>,
490}
491
492impl StreamingPublishBuilder {
493 fn notify_payload_streamer(&mut self) -> Result<(), SendPacketError> {
494 if let Some(tx) = self.tx.take() {
495 tx.send(()).map_err(|_| SendPacketError::StreamingCancelled)
496 } else {
497 Ok(())
498 }
499 }
500
501 pub fn send_at_most_once(mut self) -> Result<(), SendPacketError> {
503 if !self.shared.is_closed() {
504 log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
505 self.notify_payload_streamer()?;
506
507 self.packet.qos = QoS::AtMostOnce;
508 self.shared
509 .encode_publish(self.packet, self.payload)
510 .map_err(SendPacketError::Encode)
511 .map(|_| ())
512 } else {
513 log::error!("Mqtt sink is disconnected");
514 Err(SendPacketError::Disconnected)
515 }
516 }
517
518 pub fn send_at_least_once(mut self) -> impl Future<Output = Result<(), SendPacketError>> {
520 if !self.shared.is_closed() {
521 self.packet.qos = QoS::AtLeastOnce;
522
523 if let Some(rx) = self.shared.wait_readiness() {
525 Either::Left(Either::Left(async move {
526 if rx.await.is_err() {
527 return Err(SendPacketError::Disconnected);
528 }
529 self.send_at_least_once_inner().await
530 }))
531 } else {
532 Either::Left(Either::Right(self.send_at_least_once_inner()))
533 }
534 } else {
535 Either::Right(Ready::Err(SendPacketError::Disconnected))
536 }
537 }
538
539 pub fn send_at_least_once_no_block(mut self) -> Result<(), SendPacketError> {
543 if !self.shared.is_closed() {
544 if !self.shared.is_ready() {
546 panic!("Mqtt sink is not ready");
547 }
548 self.packet.qos = codec::QoS::AtLeastOnce;
549 let tx = self.tx.take().unwrap();
550 let idx = self.shared.set_publish_id(&mut self.packet);
551
552 if tx.is_canceled() {
553 Err(SendPacketError::StreamingCancelled)
554 } else {
555 log::trace!("Publish (QoS1) to {:#?}", self.packet);
556 let _ = tx.send(());
557 self.shared.wait_publish_response_no_block(
558 idx,
559 AckType::Publish,
560 self.packet,
561 self.payload,
562 )
563 }
564 } else {
565 Err(SendPacketError::Disconnected)
566 }
567 }
568
569 async fn send_at_least_once_inner(mut self) -> Result<(), SendPacketError> {
570 let idx = self.shared.set_publish_id(&mut self.packet);
572
573 log::trace!("Publish (QoS1) to {:#?}", self.packet);
575
576 let tx = self.tx.take().unwrap();
577 if tx.is_canceled() {
578 Err(SendPacketError::StreamingCancelled)
579 } else {
580 let rx = self.shared.wait_publish_response(
581 idx,
582 AckType::Publish,
583 self.packet,
584 self.payload,
585 );
586 let _ = tx.send(());
587
588 rx?.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
589 }
590 }
591}
592
593pub struct StreamingPayload {
594 shared: Rc<MqttShared>,
595 rx: Cell<Option<pool::Receiver<()>>>,
596 inprocess: Cell<bool>,
597}
598
599impl StreamingPayload {
600 fn drop(&mut self) {
601 if self.inprocess.get() {
602 if self.shared.is_streaming() {
603 self.shared.streaming_dropped();
604 }
605 }
606 }
607}
608
609impl StreamingPayload {
610 pub async fn send(&self, chunk: Bytes) -> Result<(), SendPacketError> {
612 if let Some(rx) = self.rx.take() {
613 if rx.await.is_err() {
614 return Err(SendPacketError::StreamingCancelled);
615 }
616 log::trace!("Publish is encoded, ready to process payload");
617 self.inprocess.set(true);
618 }
619
620 if !self.inprocess.get() {
621 Err(EncodeError::UnexpectedPayload.into())
622 } else {
623 log::trace!("Sending payload chunk: {:?}", chunk.len());
624 self.shared.want_payload_stream().await?;
625
626 if !self.shared.encode_publish_payload(chunk)? {
627 self.inprocess.set(false);
628 }
629 Ok(())
630 }
631 }
632}