ntex_mqtt/v3/
server.rs

1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use ntex_io::{DispatchItem, DispatcherConfig, IoBoxed};
4use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack};
5use ntex_util::time::{timeout_checked, Millis, Seconds};
6
7use crate::error::{HandshakeError, MqttError, ProtocolError};
8use crate::{service, types::QoS, InFlightService};
9
10use super::control::{Control, ControlAck};
11use super::default::{DefaultControlService, DefaultPublishService};
12use super::handshake::{Handshake, HandshakeAck};
13use super::shared::{MqttShared, MqttSinkPool};
14use super::{codec as mqtt, dispatcher::factory, MqttSink, Publish, Session};
15
16/// Mqtt v3.1.1 server
17///
18/// `St` - connection state
19/// `H` - handshake service
20/// `C` - service for handling control messages
21/// `P` - service for handling publish
22///
23/// Every mqtt connection is handled in several steps. First step is handshake. Server calls
24/// handshake service with `Handshake` message, during this step service can authenticate connect
25/// packet, it must return instance of connection state `St`.
26///
27/// Handshake service could be expressed as simple function:
28///
29/// ```rust,ignore
30/// use ntex_mqtt::v3::{Handshake, HandshakeAck};
31///
32/// async fn handshake(hnd: Handshake) -> Result<HandshakeAkc<MyState>, MyError> {
33///     Ok(hnd.ack(MyState::new(), false))
34/// }
35/// ```
36///
37/// During next stage, control and publish services get constructed,
38/// both factories receive `Session<St>` state object as an argument. Publish service
39/// handles `Publish` packet. On success, server server sends `PublishAck` packet to
40/// the client, in case of error connection get closed. Control service receives all
41/// other packets, like `Subscribe`, `Unsubscribe` etc. Also control service receives
42/// errors from publish service and connection disconnect.
43pub struct MqttServer<St, H, C, P, M = Identity> {
44    handshake: H,
45    control: C,
46    publish: P,
47    middleware: M,
48    max_qos: QoS,
49    max_size: u32,
50    max_send: u16,
51    max_send_size: (u32, u32),
52    min_chunk_size: u32,
53    handle_qos_after_disconnect: Option<QoS>,
54    connect_timeout: Seconds,
55    config: DispatcherConfig,
56    pub(super) pool: Rc<MqttSinkPool>,
57    _t: PhantomData<St>,
58}
59
60impl<St, H>
61    MqttServer<
62        St,
63        H,
64        DefaultControlService<St, H::Error>,
65        DefaultPublishService<St, H::Error>,
66        InFlightService,
67    >
68where
69    St: 'static,
70    H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
71    H::Error: fmt::Debug,
72{
73    /// Create server factory and provide handshake service
74    pub fn new<F>(handshake: F) -> Self
75    where
76        F: IntoServiceFactory<H, Handshake>,
77    {
78        let config = DispatcherConfig::default();
79        config.set_disconnect_timeout(Seconds(3));
80
81        MqttServer {
82            config,
83            handshake: handshake.into_factory(),
84            control: DefaultControlService::default(),
85            publish: DefaultPublishService::default(),
86            middleware: InFlightService::new(16, 65535),
87            max_qos: QoS::AtLeastOnce,
88            max_size: 0,
89            max_send: 16,
90            max_send_size: (65535, 512),
91            min_chunk_size: 32 * 1024,
92            handle_qos_after_disconnect: None,
93            connect_timeout: Seconds::ZERO,
94            pool: Default::default(),
95            _t: PhantomData,
96        }
97    }
98}
99
100impl<St, H, C, P> MqttServer<St, H, C, P, InFlightService> {
101    /// Number of inbound in-flight concurrent messages.
102    ///
103    /// By default inbound is set to 16 messages
104    pub fn max_receive(mut self, val: u16) -> Self {
105        self.middleware = self.middleware.max_receive(val);
106        self
107    }
108
109    /// Total size of inbound in-flight messages.
110    ///
111    /// By default total inbound in-flight size is set to 64Kb
112    pub fn max_receive_size(mut self, val: usize) -> Self {
113        self.middleware = self.middleware.max_receive_size(val);
114        self
115    }
116}
117
118impl<St, H, C, P, M> MqttServer<St, H, C, P, M>
119where
120    St: 'static,
121    H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
122    C: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
123    P: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
124    H::Error:
125        From<C::Error> + From<C::InitError> + From<P::Error> + From<P::InitError> + fmt::Debug,
126{
127    /// Set client timeout for first `Connect` frame.
128    ///
129    /// Defines a timeout for reading `Connect` frame. If a client does not transmit
130    /// the entire frame within this time, the connection is terminated with
131    /// Mqtt::Handshake(HandshakeError::Timeout) error.
132    ///
133    /// By default, connect timeout is disabled.
134    pub fn connect_timeout(mut self, timeout: Seconds) -> Self {
135        self.connect_timeout = timeout;
136        self
137    }
138
139    /// Set server connection disconnect timeout.
140    ///
141    /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
142    /// within this time, the connection get dropped.
143    ///
144    /// To disable timeout set value to 0.
145    ///
146    /// By default disconnect timeout is set to 3 seconds.
147    pub fn disconnect_timeout(self, val: Seconds) -> Self {
148        self.config.set_disconnect_timeout(val);
149        self
150    }
151
152    /// Set read rate parameters for single frame.
153    ///
154    /// Set read timeout, max timeout and rate for reading payload. If the client
155    /// sends `rate` amount of data within `timeout` period of time, extend timeout by `timeout` seconds.
156    /// But no more than `max_timeout` timeout.
157    ///
158    /// By default frame read rate is disabled.
159    pub fn frame_read_rate(self, timeout: Seconds, max_timeout: Seconds, rate: u16) -> Self {
160        self.config.set_frame_read_rate(timeout, max_timeout, rate);
161        self
162    }
163
164    /// Set max allowed QoS.
165    ///
166    /// If peer sends publish with higher qos then ProtocolError::MaxQoSViolated(..)
167    /// By default max qos is set to `ExactlyOnce`.
168    pub fn max_qos(mut self, qos: QoS) -> Self {
169        self.max_qos = qos;
170        self
171    }
172
173    /// Set max inbound frame size.
174    ///
175    /// If max size is set to `0`, size is unlimited.
176    /// By default max size is set to `0`
177    pub fn max_size(mut self, size: u32) -> Self {
178        self.max_size = size;
179        self
180    }
181
182    /// Number of outgoing concurrent messages.
183    ///
184    /// By default outgoing is set to 16 messages
185    pub fn max_send(mut self, val: u16) -> Self {
186        self.max_send = val;
187        self
188    }
189
190    /// Total size of outgoing messages.
191    ///
192    /// By default total outgoing size is set to 64Kb
193    pub fn max_send_size(mut self, val: u32) -> Self {
194        self.max_send_size = (val, val / 10);
195        self
196    }
197
198    /// Set min payload chunk size.
199    ///
200    /// If the minimum size is set to `0`, incoming payload chunks
201    /// will be processed immediately. Otherwise, the codec will
202    /// accumulate chunks until the total size reaches the specified minimum.
203    /// By default min size is set to `0`
204    pub fn min_chunk_size(mut self, size: u32) -> Self {
205        self.min_chunk_size = size;
206        self
207    }
208
209    /// Handle max received QoS messages after client disconnect.
210    ///
211    /// By default, messages received before dispatched to the publish service will be dropped if
212    /// the client disconnect is detected on the server.
213    ///
214    /// If this option is set to `Some(QoS::AtMostOnce)`, only the received QoS 0 messages will
215    /// always be handled by the server's publish service no matter if the client is disconnected
216    /// or not.
217    ///
218    /// If this option is set to `Some(QoS::AtLeastOnce)`, the received QoS 0 and QoS 1 messages
219    /// will always be handled by the server's publish service no matter if the client
220    /// is disconnected or not. The QoS 2 messages will be dropped if the client disconnecting is
221    /// detected before the server dispatches them to the publish service.
222    ///
223    /// If this option is set to `Some(QoS::ExactlyOnce)`, all the messages received will always
224    /// be handled by the server's publish service no matter if the client is disconnected or not.
225    ///
226    /// The received messages which QoS larger than the `max_handle_qos` will not be guaranteed to
227    /// be handled or not after the client disconnect. It depends on the network condition.
228    ///
229    /// By default handle-qos-after-disconnect is set to `None`
230    pub fn handle_qos_after_disconnect(mut self, max_handle_qos: Option<QoS>) -> Self {
231        self.handle_qos_after_disconnect = max_handle_qos;
232        self
233    }
234
235    /// Service to handle control packets
236    ///
237    /// All control packets are processed sequentially, max number of buffered
238    /// control packets is 16.
239    pub fn control<F, Srv>(self, service: F) -> MqttServer<St, H, Srv, P, M>
240    where
241        F: IntoServiceFactory<Srv, Control<H::Error>, Session<St>>,
242        Srv: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
243        H::Error: From<Srv::Error> + From<Srv::InitError>,
244    {
245        MqttServer {
246            handshake: self.handshake,
247            publish: self.publish,
248            control: service.into_factory(),
249            config: self.config,
250            middleware: self.middleware,
251            max_qos: self.max_qos,
252            max_size: self.max_size,
253            max_send: self.max_send,
254            max_send_size: self.max_send_size,
255            min_chunk_size: self.min_chunk_size,
256            handle_qos_after_disconnect: self.handle_qos_after_disconnect,
257            connect_timeout: self.connect_timeout,
258            pool: self.pool,
259            _t: PhantomData,
260        }
261    }
262
263    /// Set service to handle publish packets and create mqtt server factory
264    pub fn publish<F, Srv>(self, publish: F) -> MqttServer<St, H, C, Srv, M>
265    where
266        F: IntoServiceFactory<Srv, Publish, Session<St>>,
267        Srv: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
268        H::Error: From<Srv::Error> + From<Srv::InitError> + fmt::Debug,
269    {
270        MqttServer {
271            handshake: self.handshake,
272            publish: publish.into_factory(),
273            control: self.control,
274            config: self.config,
275            middleware: self.middleware,
276            max_qos: self.max_qos,
277            max_size: self.max_size,
278            max_send: self.max_send,
279            max_send_size: self.max_send_size,
280            min_chunk_size: self.min_chunk_size,
281            handle_qos_after_disconnect: self.handle_qos_after_disconnect,
282            connect_timeout: self.connect_timeout,
283            pool: self.pool,
284            _t: PhantomData,
285        }
286    }
287
288    /// Remove all middlewares
289    pub fn reset_middlewares(self) -> MqttServer<St, H, C, P, Identity> {
290        MqttServer {
291            middleware: Identity,
292            handshake: self.handshake,
293            publish: self.publish,
294            control: self.control,
295            config: self.config,
296            max_qos: self.max_qos,
297            max_size: self.max_size,
298            max_send: self.max_send,
299            max_send_size: self.max_send_size,
300            min_chunk_size: self.min_chunk_size,
301            handle_qos_after_disconnect: self.handle_qos_after_disconnect,
302            connect_timeout: self.connect_timeout,
303            pool: self.pool,
304            _t: PhantomData,
305        }
306    }
307
308    /// Registers middleware, in the form of a middleware component (type),
309    /// that runs during inbound and/or outbound processing in the request
310    /// lifecycle (request -> response), modifying request/response as
311    /// necessary, across all requests managed by the *Server*.
312    ///
313    /// Use middleware when you need to read or modify *every* request or
314    /// response in some way.
315    pub fn middleware<U>(self, mw: U) -> MqttServer<St, H, C, P, Stack<M, U>> {
316        MqttServer {
317            middleware: Stack::new(self.middleware, mw),
318            handshake: self.handshake,
319            publish: self.publish,
320            control: self.control,
321            config: self.config,
322            max_qos: self.max_qos,
323            max_size: self.max_size,
324            max_send: self.max_send,
325            max_send_size: self.max_send_size,
326            min_chunk_size: self.min_chunk_size,
327            handle_qos_after_disconnect: self.handle_qos_after_disconnect,
328            connect_timeout: self.connect_timeout,
329            pool: self.pool,
330            _t: PhantomData,
331        }
332    }
333}
334
335impl<St, H, C, P, M> MqttServer<St, H, C, P, M>
336where
337    St: 'static,
338    H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
339    C: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
340    P: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
341    H::Error:
342        From<C::Error> + From<C::InitError> + From<P::Error> + From<P::InitError> + fmt::Debug,
343{
344    /// Finish server configuration and create mqtt server factory
345    pub fn finish(
346        self,
347    ) -> service::MqttServer<
348        Session<St>,
349        impl ServiceFactory<
350            IoBoxed,
351            Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds),
352            Error = MqttError<H::Error>,
353            InitError = H::InitError,
354        >,
355        impl ServiceFactory<
356            DispatchItem<Rc<MqttShared>>,
357            Session<St>,
358            Response = Option<mqtt::Encoded>,
359            Error = MqttError<H::Error>,
360            InitError = MqttError<H::Error>,
361        >,
362        M,
363        Rc<MqttShared>,
364    > {
365        service::MqttServer::new(
366            HandshakeFactory {
367                factory: self.handshake,
368                max_size: self.max_size,
369                max_send: self.max_send,
370                max_send_size: self.max_send_size,
371                min_chunk_size: self.min_chunk_size,
372                connect_timeout: self.connect_timeout,
373                pool: self.pool.clone(),
374                _t: PhantomData,
375            },
376            factory(self.publish, self.control, self.max_qos, self.handle_qos_after_disconnect),
377            self.middleware,
378            self.config,
379        )
380    }
381}
382
383struct HandshakeFactory<St, H> {
384    factory: H,
385    max_size: u32,
386    max_send: u16,
387    max_send_size: (u32, u32),
388    min_chunk_size: u32,
389    connect_timeout: Seconds,
390    pool: Rc<MqttSinkPool>,
391    _t: PhantomData<St>,
392}
393
394impl<St, H> ServiceFactory<IoBoxed> for HandshakeFactory<St, H>
395where
396    H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
397    H::Error: fmt::Debug,
398{
399    type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
400    type Error = MqttError<H::Error>;
401
402    type Service = HandshakeService<St, H::Service>;
403    type InitError = H::InitError;
404
405    async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
406        Ok(HandshakeService {
407            max_size: self.max_size,
408            max_send: self.max_send,
409            max_send_size: self.max_send_size,
410            min_chunk_size: self.min_chunk_size,
411            pool: self.pool.clone(),
412            service: self.factory.create(()).await?,
413            connect_timeout: self.connect_timeout.into(),
414            _t: PhantomData,
415        })
416    }
417}
418
419struct HandshakeService<St, H> {
420    service: H,
421    max_size: u32,
422    max_send: u16,
423    max_send_size: (u32, u32),
424    min_chunk_size: u32,
425    pool: Rc<MqttSinkPool>,
426    connect_timeout: Millis,
427    _t: PhantomData<St>,
428}
429
430impl<St, H> Service<IoBoxed> for HandshakeService<St, H>
431where
432    H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
433    H::Error: fmt::Debug,
434{
435    type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
436    type Error = MqttError<H::Error>;
437
438    ntex_service::forward_ready!(service, MqttError::Service);
439    ntex_service::forward_shutdown!(service);
440
441    async fn call(
442        &self,
443        io: IoBoxed,
444        ctx: ServiceCtx<'_, Self>,
445    ) -> Result<Self::Response, Self::Error> {
446        log::trace!("Starting mqtt v3 handshake");
447
448        let (h, l) = self.max_send_size;
449        io.memory_pool().set_write_params(h, l);
450
451        let codec = mqtt::Codec::default();
452        codec.set_max_size(self.max_size);
453        codec.set_min_chunk_size(self.min_chunk_size);
454        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, self.pool.clone()));
455
456        // read first packet
457        let packet = timeout_checked(self.connect_timeout, io.recv(&shared.codec))
458            .await
459            .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))?
460            .map_err(|err| {
461                log::trace!("Error is received during mqtt handshake: {:?}", err);
462                MqttError::Handshake(HandshakeError::from(err))
463            })?
464            .ok_or_else(|| {
465                log::trace!("Server mqtt is disconnected during handshake");
466                MqttError::Handshake(HandshakeError::Disconnected(None))
467            })?;
468
469        match packet {
470            mqtt::Decoded::Packet(mqtt::Packet::Connect(connect), size) => {
471                // authenticate mqtt connection
472                let ack = ctx
473                    .call(&self.service, Handshake::new(connect, size, io, shared))
474                    .await
475                    .map_err(MqttError::Service)?;
476
477                match ack.session {
478                    Some(session) => {
479                        let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck {
480                            session_present: ack.session_present,
481                            return_code: mqtt::ConnectAckReason::ConnectionAccepted,
482                        });
483
484                        log::trace!("Sending success handshake ack: {:#?}", pkt);
485
486                        ack.shared.set_cap(ack.max_send.unwrap_or(self.max_send) as usize);
487                        ack.io.encode(mqtt::Encoded::Packet(pkt.into()), &ack.shared.codec)?;
488                        Ok((
489                            ack.io,
490                            ack.shared.clone(),
491                            Session::new(session, MqttSink::new(ack.shared)),
492                            ack.keepalive,
493                        ))
494                    }
495                    None => {
496                        let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck {
497                            session_present: false,
498                            return_code: ack.return_code,
499                        });
500
501                        log::trace!("Sending failed handshake ack: {:#?}", pkt);
502                        ack.io.encode(mqtt::Encoded::Packet(pkt.into()), &ack.shared.codec)?;
503                        let _ = ack.io.shutdown().await;
504
505                        Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
506                    }
507                }
508            }
509            mqtt::Decoded::Packet(packet, _) => {
510                log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet);
511                Err(MqttError::Handshake(HandshakeError::Protocol(
512                    ProtocolError::unexpected_packet(
513                        packet.packet_type(),
514                        "MQTT-3.1.0-1: Expected CONNECT packet",
515                    ),
516                )))
517            }
518            mqtt::Decoded::Publish(..) => {
519                log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received PUBLISH");
520                Err(MqttError::Handshake(HandshakeError::Protocol(
521                    ProtocolError::unexpected_packet(
522                        crate::types::packet_type::PUBLISH_START,
523                        "Expected CONNECT packet [MQTT-3.1.0-1]",
524                    ),
525                )))
526            }
527            mqtt::Decoded::PayloadChunk(..) => unreachable!(),
528        }
529    }
530}