1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use ntex_io::{DispatchItem, DispatcherConfig, IoBoxed};
4use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack};
5use ntex_util::time::{timeout_checked, Millis, Seconds};
6
7use crate::error::{HandshakeError, MqttError, ProtocolError};
8use crate::{service, types::QoS, InFlightService};
9
10use super::control::{Control, ControlAck};
11use super::default::{DefaultControlService, DefaultPublishService};
12use super::handshake::{Handshake, HandshakeAck};
13use super::shared::{MqttShared, MqttSinkPool};
14use super::{codec as mqtt, dispatcher::factory, MqttSink, Publish, Session};
15
16pub struct MqttServer<St, H, C, P, M = Identity> {
44 handshake: H,
45 control: C,
46 publish: P,
47 middleware: M,
48 max_qos: QoS,
49 max_size: u32,
50 max_send: u16,
51 max_send_size: (u32, u32),
52 min_chunk_size: u32,
53 handle_qos_after_disconnect: Option<QoS>,
54 connect_timeout: Seconds,
55 config: DispatcherConfig,
56 pub(super) pool: Rc<MqttSinkPool>,
57 _t: PhantomData<St>,
58}
59
60impl<St, H>
61 MqttServer<
62 St,
63 H,
64 DefaultControlService<St, H::Error>,
65 DefaultPublishService<St, H::Error>,
66 InFlightService,
67 >
68where
69 St: 'static,
70 H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
71 H::Error: fmt::Debug,
72{
73 pub fn new<F>(handshake: F) -> Self
75 where
76 F: IntoServiceFactory<H, Handshake>,
77 {
78 let config = DispatcherConfig::default();
79 config.set_disconnect_timeout(Seconds(3));
80
81 MqttServer {
82 config,
83 handshake: handshake.into_factory(),
84 control: DefaultControlService::default(),
85 publish: DefaultPublishService::default(),
86 middleware: InFlightService::new(16, 65535),
87 max_qos: QoS::AtLeastOnce,
88 max_size: 0,
89 max_send: 16,
90 max_send_size: (65535, 512),
91 min_chunk_size: 32 * 1024,
92 handle_qos_after_disconnect: None,
93 connect_timeout: Seconds::ZERO,
94 pool: Default::default(),
95 _t: PhantomData,
96 }
97 }
98}
99
100impl<St, H, C, P> MqttServer<St, H, C, P, InFlightService> {
101 pub fn max_receive(mut self, val: u16) -> Self {
105 self.middleware = self.middleware.max_receive(val);
106 self
107 }
108
109 pub fn max_receive_size(mut self, val: usize) -> Self {
113 self.middleware = self.middleware.max_receive_size(val);
114 self
115 }
116}
117
118impl<St, H, C, P, M> MqttServer<St, H, C, P, M>
119where
120 St: 'static,
121 H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
122 C: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
123 P: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
124 H::Error:
125 From<C::Error> + From<C::InitError> + From<P::Error> + From<P::InitError> + fmt::Debug,
126{
127 pub fn connect_timeout(mut self, timeout: Seconds) -> Self {
135 self.connect_timeout = timeout;
136 self
137 }
138
139 pub fn disconnect_timeout(self, val: Seconds) -> Self {
148 self.config.set_disconnect_timeout(val);
149 self
150 }
151
152 pub fn frame_read_rate(self, timeout: Seconds, max_timeout: Seconds, rate: u16) -> Self {
160 self.config.set_frame_read_rate(timeout, max_timeout, rate);
161 self
162 }
163
164 pub fn max_qos(mut self, qos: QoS) -> Self {
169 self.max_qos = qos;
170 self
171 }
172
173 pub fn max_size(mut self, size: u32) -> Self {
178 self.max_size = size;
179 self
180 }
181
182 pub fn max_send(mut self, val: u16) -> Self {
186 self.max_send = val;
187 self
188 }
189
190 pub fn max_send_size(mut self, val: u32) -> Self {
194 self.max_send_size = (val, val / 10);
195 self
196 }
197
198 pub fn min_chunk_size(mut self, size: u32) -> Self {
205 self.min_chunk_size = size;
206 self
207 }
208
209 pub fn handle_qos_after_disconnect(mut self, max_handle_qos: Option<QoS>) -> Self {
231 self.handle_qos_after_disconnect = max_handle_qos;
232 self
233 }
234
235 pub fn control<F, Srv>(self, service: F) -> MqttServer<St, H, Srv, P, M>
240 where
241 F: IntoServiceFactory<Srv, Control<H::Error>, Session<St>>,
242 Srv: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
243 H::Error: From<Srv::Error> + From<Srv::InitError>,
244 {
245 MqttServer {
246 handshake: self.handshake,
247 publish: self.publish,
248 control: service.into_factory(),
249 config: self.config,
250 middleware: self.middleware,
251 max_qos: self.max_qos,
252 max_size: self.max_size,
253 max_send: self.max_send,
254 max_send_size: self.max_send_size,
255 min_chunk_size: self.min_chunk_size,
256 handle_qos_after_disconnect: self.handle_qos_after_disconnect,
257 connect_timeout: self.connect_timeout,
258 pool: self.pool,
259 _t: PhantomData,
260 }
261 }
262
263 pub fn publish<F, Srv>(self, publish: F) -> MqttServer<St, H, C, Srv, M>
265 where
266 F: IntoServiceFactory<Srv, Publish, Session<St>>,
267 Srv: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
268 H::Error: From<Srv::Error> + From<Srv::InitError> + fmt::Debug,
269 {
270 MqttServer {
271 handshake: self.handshake,
272 publish: publish.into_factory(),
273 control: self.control,
274 config: self.config,
275 middleware: self.middleware,
276 max_qos: self.max_qos,
277 max_size: self.max_size,
278 max_send: self.max_send,
279 max_send_size: self.max_send_size,
280 min_chunk_size: self.min_chunk_size,
281 handle_qos_after_disconnect: self.handle_qos_after_disconnect,
282 connect_timeout: self.connect_timeout,
283 pool: self.pool,
284 _t: PhantomData,
285 }
286 }
287
288 pub fn reset_middlewares(self) -> MqttServer<St, H, C, P, Identity> {
290 MqttServer {
291 middleware: Identity,
292 handshake: self.handshake,
293 publish: self.publish,
294 control: self.control,
295 config: self.config,
296 max_qos: self.max_qos,
297 max_size: self.max_size,
298 max_send: self.max_send,
299 max_send_size: self.max_send_size,
300 min_chunk_size: self.min_chunk_size,
301 handle_qos_after_disconnect: self.handle_qos_after_disconnect,
302 connect_timeout: self.connect_timeout,
303 pool: self.pool,
304 _t: PhantomData,
305 }
306 }
307
308 pub fn middleware<U>(self, mw: U) -> MqttServer<St, H, C, P, Stack<M, U>> {
316 MqttServer {
317 middleware: Stack::new(self.middleware, mw),
318 handshake: self.handshake,
319 publish: self.publish,
320 control: self.control,
321 config: self.config,
322 max_qos: self.max_qos,
323 max_size: self.max_size,
324 max_send: self.max_send,
325 max_send_size: self.max_send_size,
326 min_chunk_size: self.min_chunk_size,
327 handle_qos_after_disconnect: self.handle_qos_after_disconnect,
328 connect_timeout: self.connect_timeout,
329 pool: self.pool,
330 _t: PhantomData,
331 }
332 }
333}
334
335impl<St, H, C, P, M> MqttServer<St, H, C, P, M>
336where
337 St: 'static,
338 H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
339 C: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
340 P: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
341 H::Error:
342 From<C::Error> + From<C::InitError> + From<P::Error> + From<P::InitError> + fmt::Debug,
343{
344 pub fn finish(
346 self,
347 ) -> service::MqttServer<
348 Session<St>,
349 impl ServiceFactory<
350 IoBoxed,
351 Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds),
352 Error = MqttError<H::Error>,
353 InitError = H::InitError,
354 >,
355 impl ServiceFactory<
356 DispatchItem<Rc<MqttShared>>,
357 Session<St>,
358 Response = Option<mqtt::Encoded>,
359 Error = MqttError<H::Error>,
360 InitError = MqttError<H::Error>,
361 >,
362 M,
363 Rc<MqttShared>,
364 > {
365 service::MqttServer::new(
366 HandshakeFactory {
367 factory: self.handshake,
368 max_size: self.max_size,
369 max_send: self.max_send,
370 max_send_size: self.max_send_size,
371 min_chunk_size: self.min_chunk_size,
372 connect_timeout: self.connect_timeout,
373 pool: self.pool.clone(),
374 _t: PhantomData,
375 },
376 factory(self.publish, self.control, self.max_qos, self.handle_qos_after_disconnect),
377 self.middleware,
378 self.config,
379 )
380 }
381}
382
383struct HandshakeFactory<St, H> {
384 factory: H,
385 max_size: u32,
386 max_send: u16,
387 max_send_size: (u32, u32),
388 min_chunk_size: u32,
389 connect_timeout: Seconds,
390 pool: Rc<MqttSinkPool>,
391 _t: PhantomData<St>,
392}
393
394impl<St, H> ServiceFactory<IoBoxed> for HandshakeFactory<St, H>
395where
396 H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
397 H::Error: fmt::Debug,
398{
399 type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
400 type Error = MqttError<H::Error>;
401
402 type Service = HandshakeService<St, H::Service>;
403 type InitError = H::InitError;
404
405 async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
406 Ok(HandshakeService {
407 max_size: self.max_size,
408 max_send: self.max_send,
409 max_send_size: self.max_send_size,
410 min_chunk_size: self.min_chunk_size,
411 pool: self.pool.clone(),
412 service: self.factory.create(()).await?,
413 connect_timeout: self.connect_timeout.into(),
414 _t: PhantomData,
415 })
416 }
417}
418
419struct HandshakeService<St, H> {
420 service: H,
421 max_size: u32,
422 max_send: u16,
423 max_send_size: (u32, u32),
424 min_chunk_size: u32,
425 pool: Rc<MqttSinkPool>,
426 connect_timeout: Millis,
427 _t: PhantomData<St>,
428}
429
430impl<St, H> Service<IoBoxed> for HandshakeService<St, H>
431where
432 H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
433 H::Error: fmt::Debug,
434{
435 type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
436 type Error = MqttError<H::Error>;
437
438 ntex_service::forward_ready!(service, MqttError::Service);
439 ntex_service::forward_shutdown!(service);
440
441 async fn call(
442 &self,
443 io: IoBoxed,
444 ctx: ServiceCtx<'_, Self>,
445 ) -> Result<Self::Response, Self::Error> {
446 log::trace!("Starting mqtt v3 handshake");
447
448 let (h, l) = self.max_send_size;
449 io.memory_pool().set_write_params(h, l);
450
451 let codec = mqtt::Codec::default();
452 codec.set_max_size(self.max_size);
453 codec.set_min_chunk_size(self.min_chunk_size);
454 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, self.pool.clone()));
455
456 let packet = timeout_checked(self.connect_timeout, io.recv(&shared.codec))
458 .await
459 .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))?
460 .map_err(|err| {
461 log::trace!("Error is received during mqtt handshake: {:?}", err);
462 MqttError::Handshake(HandshakeError::from(err))
463 })?
464 .ok_or_else(|| {
465 log::trace!("Server mqtt is disconnected during handshake");
466 MqttError::Handshake(HandshakeError::Disconnected(None))
467 })?;
468
469 match packet {
470 mqtt::Decoded::Packet(mqtt::Packet::Connect(connect), size) => {
471 let ack = ctx
473 .call(&self.service, Handshake::new(connect, size, io, shared))
474 .await
475 .map_err(MqttError::Service)?;
476
477 match ack.session {
478 Some(session) => {
479 let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck {
480 session_present: ack.session_present,
481 return_code: mqtt::ConnectAckReason::ConnectionAccepted,
482 });
483
484 log::trace!("Sending success handshake ack: {:#?}", pkt);
485
486 ack.shared.set_cap(ack.max_send.unwrap_or(self.max_send) as usize);
487 ack.io.encode(mqtt::Encoded::Packet(pkt.into()), &ack.shared.codec)?;
488 Ok((
489 ack.io,
490 ack.shared.clone(),
491 Session::new(session, MqttSink::new(ack.shared)),
492 ack.keepalive,
493 ))
494 }
495 None => {
496 let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck {
497 session_present: false,
498 return_code: ack.return_code,
499 });
500
501 log::trace!("Sending failed handshake ack: {:#?}", pkt);
502 ack.io.encode(mqtt::Encoded::Packet(pkt.into()), &ack.shared.codec)?;
503 let _ = ack.io.shutdown().await;
504
505 Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
506 }
507 }
508 }
509 mqtt::Decoded::Packet(packet, _) => {
510 log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet);
511 Err(MqttError::Handshake(HandshakeError::Protocol(
512 ProtocolError::unexpected_packet(
513 packet.packet_type(),
514 "MQTT-3.1.0-1: Expected CONNECT packet",
515 ),
516 )))
517 }
518 mqtt::Decoded::Publish(..) => {
519 log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received PUBLISH");
520 Err(MqttError::Handshake(HandshakeError::Protocol(
521 ProtocolError::unexpected_packet(
522 crate::types::packet_type::PUBLISH_START,
523 "Expected CONNECT packet [MQTT-3.1.0-1]",
524 ),
525 )))
526 }
527 mqtt::Decoded::PayloadChunk(..) => unreachable!(),
528 }
529 }
530}