1use std::{fmt, io, marker, task::Context};
2
3use ntex_codec::{Decoder, Encoder};
4use ntex_io::{DispatchItem, Filter, Io, IoBoxed};
5use ntex_service::{Middleware, Service, ServiceCtx, ServiceFactory};
6use ntex_util::future::{join, select, Either};
7use ntex_util::time::{Deadline, Millis, Seconds};
8
9use crate::version::{ProtocolVersion, VersionCodec};
10use crate::{error::HandshakeError, error::MqttError, service};
11
12pub struct MqttServer<V3, V5, Err, InitErr> {
14 svc_v3: V3,
15 svc_v5: V5,
16 connect_timeout: Millis,
17 _t: marker::PhantomData<(Err, InitErr)>,
18}
19
20impl<Err, InitErr>
21 MqttServer<
22 DefaultProtocolServer<Err, InitErr>,
23 DefaultProtocolServer<Err, InitErr>,
24 Err,
25 InitErr,
26 >
27{
28 pub fn new() -> Self {
30 MqttServer {
31 svc_v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3),
32 svc_v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5),
33 connect_timeout: Millis(5_000),
34 _t: marker::PhantomData,
35 }
36 }
37}
38
39impl<Err, InitErr> Default
40 for MqttServer<
41 DefaultProtocolServer<Err, InitErr>,
42 DefaultProtocolServer<Err, InitErr>,
43 Err,
44 InitErr,
45 >
46{
47 fn default() -> Self {
48 MqttServer::new()
49 }
50}
51
52impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr> {
53 pub fn protocol_version_timeout(mut self, timeout: Seconds) -> Self {
61 self.connect_timeout = timeout.into();
62 self
63 }
64}
65
66impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
67where
68 Err: fmt::Debug,
69 V3: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
70 V5: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
71{
72 pub fn v3<St, H, P, M, Codec>(
74 self,
75 service: service::MqttServer<St, H, P, M, Codec>,
76 ) -> MqttServer<
77 impl ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
78 V5,
79 Err,
80 InitErr,
81 >
82 where
83 St: 'static,
84 H: ServiceFactory<
85 IoBoxed,
86 Response = (IoBoxed, Codec, St, Seconds),
87 Error = MqttError<Err>,
88 InitError = InitErr,
89 > + 'static,
90 P: ServiceFactory<
91 DispatchItem<Codec>,
92 St,
93 Response = Option<<Codec as Encoder>::Item>,
94 Error = MqttError<Err>,
95 InitError = MqttError<Err>,
96 > + 'static,
97 M: Middleware<P::Service>,
98 M::Service: Service<
99 DispatchItem<Codec>,
100 Response = Option<<Codec as Encoder>::Item>,
101 Error = MqttError<Err>,
102 > + 'static,
103 Codec: Encoder + Decoder + Clone + 'static,
104 {
105 MqttServer {
106 svc_v3: service,
107 svc_v5: self.svc_v5,
108 connect_timeout: self.connect_timeout,
109 _t: marker::PhantomData,
110 }
111 }
112
113 pub fn v5<St, H, P, M, Codec>(
115 self,
116 service: service::MqttServer<St, H, P, M, Codec>,
117 ) -> MqttServer<
118 V3,
119 impl ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
120 Err,
121 InitErr,
122 >
123 where
124 St: 'static,
125 H: ServiceFactory<
126 IoBoxed,
127 Response = (IoBoxed, Codec, St, Seconds),
128 Error = MqttError<Err>,
129 InitError = InitErr,
130 > + 'static,
131 P: ServiceFactory<
132 DispatchItem<Codec>,
133 St,
134 Response = Option<<Codec as Encoder>::Item>,
135 Error = MqttError<Err>,
136 InitError = MqttError<Err>,
137 > + 'static,
138 M: Middleware<P::Service>,
139 M::Service: Service<
140 DispatchItem<Codec>,
141 Response = Option<<Codec as Encoder>::Item>,
142 Error = MqttError<Err>,
143 > + 'static,
144 Codec: Encoder + Decoder + Clone + 'static,
145 {
146 MqttServer {
147 svc_v3: self.svc_v3,
148 svc_v5: service,
149 connect_timeout: self.connect_timeout,
150 _t: marker::PhantomData,
151 }
152 }
153}
154
155impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
156where
157 V3: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
158 V5: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>,
159{
160 async fn create_service(
161 &self,
162 ) -> Result<MqttServerImpl<V3::Service, V5::Service, Err>, InitErr> {
163 let (v3, v5) = join(self.svc_v3.create(()), self.svc_v5.create(())).await;
164 let v3 = v3?;
165 let v5 = v5?;
166 Ok(MqttServerImpl {
167 handlers: (v3, v5),
168 connect_timeout: self.connect_timeout,
169 _t: marker::PhantomData,
170 })
171 }
172}
173
174impl<V3, V5, Err, InitErr> ServiceFactory<IoBoxed> for MqttServer<V3, V5, Err, InitErr>
175where
176 V3: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>
177 + 'static,
178 V5: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>
179 + 'static,
180 Err: 'static,
181 InitErr: 'static,
182{
183 type Response = ();
184 type Error = MqttError<Err>;
185 type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
186 type InitError = InitErr;
187
188 async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
189 self.create_service().await
190 }
191}
192
193impl<F, V3, V5, Err, InitErr> ServiceFactory<Io<F>> for MqttServer<V3, V5, Err, InitErr>
194where
195 F: Filter,
196 V3: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>
197 + 'static,
198 V5: ServiceFactory<IoBoxed, Response = (), Error = MqttError<Err>, InitError = InitErr>
199 + 'static,
200 Err: 'static,
201 InitErr: 'static,
202{
203 type Response = ();
204 type Error = MqttError<Err>;
205 type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
206 type InitError = InitErr;
207
208 async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
209 self.create_service().await
210 }
211}
212
213pub struct MqttServerImpl<V3, V5, Err> {
215 handlers: (V3, V5),
216 connect_timeout: Millis,
217 _t: marker::PhantomData<Err>,
218}
219
220impl<V3, V5, Err> Service<IoBoxed> for MqttServerImpl<V3, V5, Err>
221where
222 V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
223 V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
224{
225 type Response = ();
226 type Error = MqttError<Err>;
227
228 #[inline]
229 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
230 let (ready1, ready2) =
231 join(ctx.ready(&self.handlers.0), ctx.ready(&self.handlers.1)).await;
232 ready1?;
233 ready2
234 }
235
236 #[inline]
237 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
238 self.handlers.0.poll(cx)?;
239 self.handlers.1.poll(cx)
240 }
241
242 #[inline]
243 async fn shutdown(&self) {
244 self.handlers.0.shutdown().await;
245 self.handlers.1.shutdown().await;
246 }
247
248 #[inline]
249 async fn call(
250 &self,
251 io: IoBoxed,
252 ctx: ServiceCtx<'_, Self>,
253 ) -> Result<Self::Response, Self::Error> {
254 let res = io
256 .decode(&VersionCodec)
257 .map_err(|e| MqttError::Handshake(HandshakeError::Protocol(e.into())))?;
258 if let Some(ver) = res {
259 match ver {
260 ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
261 ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
262 }
263 } else {
264 let fut = async {
265 match io.recv(&VersionCodec).await {
266 Ok(ver) => Ok(ver),
267 Err(Either::Left(e)) => {
268 Err(MqttError::Handshake(HandshakeError::Protocol(e.into())))
269 }
270 Err(Either::Right(e)) => {
271 Err(MqttError::Handshake(HandshakeError::Disconnected(Some(e))))
272 }
273 }
274 };
275
276 match select(&mut Deadline::new(self.connect_timeout), fut).await {
277 Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)),
278 Either::Right(Ok(Some(ver))) => match ver {
279 ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
280 ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
281 },
282 Either::Right(Ok(None)) => {
283 Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
284 }
285 Either::Right(Err(e)) => Err(e),
286 }
287 }
288 }
289}
290
291impl<F, V3, V5, Err> Service<Io<F>> for MqttServerImpl<V3, V5, Err>
292where
293 F: Filter,
294 V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
295 V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
296{
297 type Response = ();
298 type Error = MqttError<Err>;
299
300 #[inline]
301 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
302 Service::<IoBoxed>::ready(self, ctx).await
303 }
304
305 #[inline]
306 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
307 Service::<IoBoxed>::poll(self, cx)
308 }
309
310 #[inline]
311 async fn shutdown(&self) {
312 Service::<IoBoxed>::shutdown(self).await
313 }
314
315 #[inline]
316 async fn call(
317 &self,
318 io: Io<F>,
319 ctx: ServiceCtx<'_, Self>,
320 ) -> Result<Self::Response, Self::Error> {
321 Service::<IoBoxed>::call(self, IoBoxed::from(io), ctx).await
322 }
323}
324
325pub struct DefaultProtocolServer<Err, InitErr> {
326 ver: ProtocolVersion,
327 _t: marker::PhantomData<(Err, InitErr)>,
328}
329
330impl<Err, InitErr> DefaultProtocolServer<Err, InitErr> {
331 fn new(ver: ProtocolVersion) -> Self {
332 Self { ver, _t: marker::PhantomData }
333 }
334}
335
336impl<Err, InitErr> ServiceFactory<IoBoxed> for DefaultProtocolServer<Err, InitErr> {
337 type Response = ();
338 type Error = MqttError<Err>;
339 type Service = DefaultProtocolServer<Err, InitErr>;
340 type InitError = InitErr;
341
342 async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
343 Ok(DefaultProtocolServer { ver: self.ver, _t: marker::PhantomData })
344 }
345}
346
347impl<Err, InitErr> Service<IoBoxed> for DefaultProtocolServer<Err, InitErr> {
348 type Response = ();
349 type Error = MqttError<Err>;
350
351 async fn call(
352 &self,
353 _: IoBoxed,
354 _: ServiceCtx<'_, Self>,
355 ) -> Result<Self::Response, Self::Error> {
356 Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::new(
357 io::ErrorKind::Other,
358 format!("Protocol is not supported: {:?}", self.ver),
359 )))))
360 }
361}