ntex_mqtt/
server.rs

1use std::{fmt, io, marker, task::Context};
2
3use ntex_codec::{Decoder, Encoder};
4use ntex_io::{DispatchItem, Filter, Io, IoBoxed};
5use ntex_service::{Middleware, Service, ServiceCtx, ServiceFactory};
6use ntex_util::future::{join, select, Either};
7use ntex_util::time::{Deadline, Millis, Seconds};
8
9use crate::version::{ProtocolVersion, VersionCodec};
10use crate::{error::HandshakeError, error::MqttError, service};
11
12/// Mqtt Server
13pub struct MqttServer<V3, V5, Err, InitErr> {
14    svc_v3: V3,
15    svc_v5: V5,
16    connect_timeout: Millis,
17    _t: marker::PhantomData<(Err, InitErr)>,
18}
19
20impl<Err, InitErr>
21    MqttServer<
22        DefaultProtocolServer<Err, InitErr>,
23        DefaultProtocolServer<Err, InitErr>,
24        Err,
25        InitErr,
26    >
27{
28    /// Create mqtt server
29    pub fn new() -> Self {
30        MqttServer {
31            svc_v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3),
32            svc_v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5),
33            connect_timeout: Millis(5_000),
34            _t: marker::PhantomData,
35        }
36    }
37}
38
39impl<Err, InitErr> Default
40    for MqttServer<
41        DefaultProtocolServer<Err, InitErr>,
42        DefaultProtocolServer<Err, InitErr>,
43        Err,
44        InitErr,
45    >
46{
47    fn default() -> Self {
48        MqttServer::new()
49    }
50}
51
52impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr> {
53    /// Set client timeout reading protocol version.
54    ///
55    /// Defines a timeout for reading protocol version. If a client does not transmit
56    /// version of the protocol within this time, the connection is terminated with
57    /// Mqtt::Handshake(HandshakeError::Timeout) error.
58    ///
59    /// By default, timeuot is 5 seconds.
60    pub fn protocol_version_timeout(mut self, timeout: Seconds) -> Self {
61        self.connect_timeout = timeout.into();
62        self
63    }
64}
65
66impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
67where
68    Err: fmt::Debug,
69    V3: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
70    V5: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
71{
72    /// Service to handle v3 protocol
73    pub fn v3<St, H, P, M, Codec>(
74        self,
75        service: service::MqttServer<St, H, P, M, Codec>,
76    ) -> MqttServer<
77        impl ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
78        V5,
79        Err,
80        InitErr,
81    >
82    where
83        St: 'static,
84        H: ServiceFactory<
85                IoBoxed,
86                Response = (IoBoxed, Codec, St, Seconds),
87                Error = MqttError<Err>,
88                InitError = InitErr,
89            > + 'static,
90        P: ServiceFactory<
91                DispatchItem<Codec>,
92                St,
93                Response = Option<<Codec as Encoder>::Item>,
94                Error = MqttError<Err>,
95                InitError = MqttError<Err>,
96            > + 'static,
97        M: Middleware<P::Service>,
98        M::Service: Service<
99                DispatchItem<Codec>,
100                Response = Option<<Codec as Encoder>::Item>,
101                Error = MqttError<Err>,
102            > + 'static,
103        Codec: Encoder + Decoder + Clone + 'static,
104    {
105        MqttServer {
106            svc_v3: service,
107            svc_v5: self.svc_v5,
108            connect_timeout: self.connect_timeout,
109            _t: marker::PhantomData,
110        }
111    }
112
113    /// Service to handle v5 protocol
114    pub fn v5<St, H, P, M, Codec>(
115        self,
116        service: service::MqttServer<St, H, P, M, Codec>,
117    ) -> MqttServer<
118        V3,
119        impl ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
120        Err,
121        InitErr,
122    >
123    where
124        St: 'static,
125        H: ServiceFactory<
126                IoBoxed,
127                Response = (IoBoxed, Codec, St, Seconds),
128                Error = MqttError<Err>,
129                InitError = InitErr,
130            > + 'static,
131        P: ServiceFactory<
132                DispatchItem<Codec>,
133                St,
134                Response = Option<<Codec as Encoder>::Item>,
135                Error = MqttError<Err>,
136                InitError = MqttError<Err>,
137            > + 'static,
138        M: Middleware<P::Service>,
139        M::Service: Service<
140                DispatchItem<Codec>,
141                Response = Option<<Codec as Encoder>::Item>,
142                Error = MqttError<Err>,
143            > + 'static,
144        Codec: Encoder + Decoder + Clone + 'static,
145    {
146        MqttServer {
147            svc_v3: self.svc_v3,
148            svc_v5: service,
149            connect_timeout: self.connect_timeout,
150            _t: marker::PhantomData,
151        }
152    }
153}
154
155impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
156where
157    V3: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
158    V5: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
159{
160    async fn create_service(
161        &self,
162    ) -> Result<MqttServerImpl<V3::Service, V5::Service, Err>, InitErr> {
163        let (v3, v5) = join(self.svc_v3.create(()), self.svc_v5.create(())).await;
164        let v3 = v3?;
165        let v5 = v5?;
166        Ok(MqttServerImpl {
167            handlers: (v3, v5),
168            connect_timeout: self.connect_timeout,
169            _t: marker::PhantomData,
170        })
171    }
172}
173
174impl<V3, V5, Err, InitErr> ServiceFactory<IoBoxed> for MqttServer<V3, V5, Err, InitErr>
175where
176    V3: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>
177        + 'static,
178    V5: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>
179        + 'static,
180    Err: 'static,
181    InitErr: 'static,
182{
183    type Response = ();
184    type Error = MqttError<Err>;
185    type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
186    type InitError = InitErr;
187
188    async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
189        self.create_service().await
190    }
191}
192
193impl<F, V3, V5, Err, InitErr> ServiceFactory<Io<F>> for MqttServer<V3, V5, Err, InitErr>
194where
195    F: Filter,
196    V3: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>
197        + 'static,
198    V5: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>
199        + 'static,
200    Err: 'static,
201    InitErr: 'static,
202{
203    type Response = ();
204    type Error = MqttError<Err>;
205    type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
206    type InitError = InitErr;
207
208    async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
209        self.create_service().await
210    }
211}
212
213/// Mqtt Server
214pub struct MqttServerImpl<V3, V5, Err> {
215    handlers: (V3, V5),
216    connect_timeout: Millis,
217    _t: marker::PhantomData<Err>,
218}
219
220impl<V3, V5, Err> Service<IoBoxed> for MqttServerImpl<V3, V5, Err>
221where
222    V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
223    V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
224{
225    type Response = ();
226    type Error = MqttError<Err>;
227
228    #[inline]
229    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
230        let (ready1, ready2) =
231            join(ctx.ready(&self.handlers.0), ctx.ready(&self.handlers.1)).await;
232        ready1?;
233        ready2
234    }
235
236    #[inline]
237    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
238        self.handlers.0.poll(cx)?;
239        self.handlers.1.poll(cx)
240    }
241
242    #[inline]
243    async fn shutdown(&self) {
244        self.handlers.0.shutdown().await;
245        self.handlers.1.shutdown().await;
246    }
247
248    #[inline]
249    async fn call(
250        &self,
251        io: IoBoxed,
252        ctx: ServiceCtx<'_, Self>,
253    ) -> Result<Self::Response, Self::Error> {
254        // try to read Version, buffer may already contain info
255        let res = io
256            .decode(&VersionCodec)
257            .map_err(|e| MqttError::Handshake(HandshakeError::Protocol(e.into())))?;
258        if let Some(ver) = res {
259            match ver {
260                ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
261                ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
262            }
263        } else {
264            let fut = async {
265                match io.recv(&VersionCodec).await {
266                    Ok(ver) => Ok(ver),
267                    Err(Either::Left(e)) => {
268                        Err(MqttError::Handshake(HandshakeError::Protocol(e.into())))
269                    }
270                    Err(Either::Right(e)) => {
271                        Err(MqttError::Handshake(HandshakeError::Disconnected(Some(e))))
272                    }
273                }
274            };
275
276            match select(&mut Deadline::new(self.connect_timeout), fut).await {
277                Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)),
278                Either::Right(Ok(Some(ver))) => match ver {
279                    ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
280                    ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
281                },
282                Either::Right(Ok(None)) => {
283                    Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
284                }
285                Either::Right(Err(e)) => Err(e),
286            }
287        }
288    }
289}
290
291impl<F, V3, V5, Err> Service<Io<F>> for MqttServerImpl<V3, V5, Err>
292where
293    F: Filter,
294    V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
295    V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
296{
297    type Response = ();
298    type Error = MqttError<Err>;
299
300    #[inline]
301    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
302        Service::<IoBoxed>::ready(self, ctx).await
303    }
304
305    #[inline]
306    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
307        Service::<IoBoxed>::poll(self, cx)
308    }
309
310    #[inline]
311    async fn shutdown(&self) {
312        Service::<IoBoxed>::shutdown(self).await
313    }
314
315    #[inline]
316    async fn call(
317        &self,
318        io: Io<F>,
319        ctx: ServiceCtx<'_, Self>,
320    ) -> Result<Self::Response, Self::Error> {
321        Service::<IoBoxed>::call(self, IoBoxed::from(io), ctx).await
322    }
323}
324
325pub struct DefaultProtocolServer<Err, InitErr> {
326    ver: ProtocolVersion,
327    _t: marker::PhantomData<(Err, InitErr)>,
328}
329
330impl<Err, InitErr> DefaultProtocolServer<Err, InitErr> {
331    fn new(ver: ProtocolVersion) -> Self {
332        Self { ver, _t: marker::PhantomData }
333    }
334}
335
336impl<Err, InitErr> ServiceFactory<IoBoxed> for DefaultProtocolServer<Err, InitErr> {
337    type Response = ();
338    type Error = MqttError<Err>;
339    type Service = DefaultProtocolServer<Err, InitErr>;
340    type InitError = InitErr;
341
342    async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
343        Ok(DefaultProtocolServer { ver: self.ver, _t: marker::PhantomData })
344    }
345}
346
347impl<Err, InitErr> Service<IoBoxed> for DefaultProtocolServer<Err, InitErr> {
348    type Response = ();
349    type Error = MqttError<Err>;
350
351    async fn call(
352        &self,
353        _: IoBoxed,
354        _: ServiceCtx<'_, Self>,
355    ) -> Result<Self::Response, Self::Error> {
356        Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::new(
357            io::ErrorKind::Other,
358            format!("Protocol is not supported: {:?}", self.ver),
359        )))))
360    }
361}