use std::rc::Rc;
use ntex::connect::{self, Address, Connect, Connector};
use ntex::io::{DispatcherConfig, IoBoxed};
use ntex::service::{IntoService, Pipeline, Service};
use ntex::time::{timeout_checked, Seconds};
use ntex::util::{ByteString, Bytes, PoolId};
use super::{codec, connection::Client, error::ClientError, error::ProtocolError};
use crate::v3::shared::{MqttShared, MqttSinkPool};
pub struct MqttConnector<A, T> {
address: A,
connector: Pipeline<T>,
pkt: codec::Connect,
max_send: usize,
max_receive: usize,
max_packet_size: u32,
handshake_timeout: Seconds,
config: DispatcherConfig,
pool: Rc<MqttSinkPool>,
}
impl<A> MqttConnector<A, ()>
where
A: Address + Clone,
{
#[allow(clippy::new_ret_no_self)]
pub fn new(address: A) -> MqttConnector<A, Connector<A>> {
let config = DispatcherConfig::default();
config.set_disconnect_timeout(Seconds(3)).set_keepalive_timeout(Seconds(0));
MqttConnector {
address,
config,
pkt: codec::Connect::default(),
connector: Pipeline::new(Connector::default()),
max_send: 16,
max_receive: 16,
max_packet_size: 64 * 1024,
handshake_timeout: Seconds::ZERO,
pool: Rc::new(MqttSinkPool::default()),
}
}
}
impl<A, T> MqttConnector<A, T>
where
A: Address + Clone,
{
#[inline]
pub fn client_id<U>(mut self, client_id: U) -> Self
where
ByteString: From<U>,
{
self.pkt.client_id = client_id.into();
self
}
#[inline]
pub fn clean_session(mut self) -> Self {
self.pkt.clean_session = true;
self
}
#[inline]
pub fn keep_alive(mut self, val: Seconds) -> Self {
self.pkt.keep_alive = val.seconds() as u16;
self
}
#[inline]
pub fn last_will(mut self, val: codec::LastWill) -> Self {
self.pkt.last_will = Some(val);
self
}
#[inline]
pub fn username<U>(mut self, val: U) -> Self
where
ByteString: From<U>,
{
self.pkt.username = Some(val.into());
self
}
#[inline]
pub fn password(mut self, val: Bytes) -> Self {
self.pkt.password = Some(val);
self
}
#[inline]
pub fn max_send(mut self, val: u16) -> Self {
self.max_send = val as usize;
self
}
#[inline]
pub fn max_receive(mut self, val: u16) -> Self {
self.max_receive = val as usize;
self
}
#[inline]
pub fn max_packet_size(mut self, val: u32) -> Self {
self.max_packet_size = val;
self
}
#[inline]
pub fn packet<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut codec::Connect),
{
f(&mut self.pkt);
self
}
pub fn handshake_timeout(mut self, timeout: Seconds) -> Self {
self.handshake_timeout = timeout;
self
}
pub fn disconnect_timeout(self, timeout: Seconds) -> Self {
self.config.set_disconnect_timeout(timeout);
self
}
pub fn memory_pool(self, id: PoolId) -> Self {
self.pool.pool.set(id.pool_ref());
self
}
pub fn connector<U, F>(self, connector: F) -> MqttConnector<A, U>
where
F: IntoService<U, Connect<A>>,
U: Service<Connect<A>, Error = connect::ConnectError>,
IoBoxed: From<U::Response>,
{
MqttConnector {
connector: Pipeline::new(connector.into_service()),
pkt: self.pkt,
address: self.address,
config: self.config,
max_send: self.max_send,
max_receive: self.max_receive,
max_packet_size: self.max_packet_size,
handshake_timeout: self.handshake_timeout,
pool: self.pool,
}
}
}
impl<A, T> MqttConnector<A, T>
where
A: Address + Clone,
T: Service<Connect<A>, Error = connect::ConnectError>,
IoBoxed: From<T::Response>,
{
pub async fn connect(&self) -> Result<Client, ClientError<codec::ConnectAck>> {
match timeout_checked(self.handshake_timeout, self._connect()).await {
Ok(res) => res.map_err(From::from),
Err(_) => Err(ClientError::HandshakeTimeout),
}
}
async fn _connect(&self) -> Result<Client, ClientError<codec::ConnectAck>> {
let io: IoBoxed = self.connector.call(Connect::new(self.address.clone())).await?.into();
let pkt = self.pkt.clone();
let max_send = self.max_send;
let max_receive = self.max_receive;
let keepalive_timeout = pkt.keep_alive;
let config = self.config.clone();
let pool = self.pool.clone();
let codec = codec::Codec::new();
codec.set_max_size(self.max_packet_size);
io.encode(pkt.into(), &codec)?;
let packet = io.recv(&codec).await.map_err(ClientError::from)?.ok_or_else(|| {
log::trace!("Mqtt server is disconnected during handshake");
ClientError::Disconnected(None)
})?;
let shared = Rc::new(MqttShared::new(io.get_ref(), codec, true, pool));
match packet {
(codec::Packet::ConnectAck(pkt), _) => {
log::trace!("Connect ack response from server: session: present: {:?}, return code: {:?}", pkt.session_present, pkt.return_code);
if pkt.return_code == codec::ConnectAckReason::ConnectionAccepted {
shared.set_cap(max_send);
Ok(Client::new(
io,
shared,
pkt.session_present,
Seconds(keepalive_timeout),
max_receive,
config,
))
} else {
Err(ClientError::Ack(pkt))
}
}
(p, _) => Err(ProtocolError::unexpected_packet(
p.packet_type(),
"Expected CONNACK packet",
)
.into()),
}
}
}