use crate::notice::NoticeTx;
use crate::NoticeError;
use crate::{framed::Network, Transport};
use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError};
use crate::{MqttOptions, Outgoing};
use crate::framed::AsyncReadWrite;
use crate::mqttbytes::v4::*;
use flume::{bounded, Receiver, Sender};
use tokio::net::{lookup_host, TcpSocket, TcpStream};
use tokio::select;
use tokio::time::{self, Instant, Sleep};
use std::collections::VecDeque;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::time::Duration;
#[cfg(unix)]
use {std::path::Path, tokio::net::UnixStream};
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
use crate::tls;
#[cfg(feature = "websocket")]
use {
crate::websockets::{split_url, validate_response_headers, UrlError},
async_tungstenite::tungstenite::client::IntoClientRequest,
ws_stream_tungstenite::WsStream,
};
#[cfg(feature = "proxy")]
use crate::proxy::ProxyError;
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error("Mqtt state: {0}")]
MqttState(#[from] StateError),
#[error("Network timeout")]
NetworkTimeout,
#[error("Flush timeout")]
FlushTimeout,
#[cfg(feature = "websocket")]
#[error("Websocket: {0}")]
Websocket(#[from] async_tungstenite::tungstenite::error::Error),
#[cfg(feature = "websocket")]
#[error("Websocket Connect: {0}")]
WsConnect(#[from] http::Error),
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
#[error("TLS: {0}")]
Tls(#[from] tls::Error),
#[error("I/O: {0}")]
Io(#[from] io::Error),
#[error("Connection refused, return code: `{0:?}`")]
ConnectionRefused(ConnectReturnCode),
#[error("Expected ConnAck packet, received: {0:?}")]
NotConnAck(Packet),
#[error("Requests done")]
RequestsDone,
#[cfg(feature = "websocket")]
#[error("Invalid Url: {0}")]
InvalidUrl(#[from] UrlError),
#[cfg(feature = "proxy")]
#[error("Proxy Connect: {0}")]
Proxy(#[from] ProxyError),
#[cfg(feature = "websocket")]
#[error("Websocket response validation error: ")]
ResponseValidation(#[from] crate::websockets::ValidationError),
}
pub struct EventLoop {
pub mqtt_options: MqttOptions,
pub state: MqttState,
requests_rx: Receiver<(NoticeTx, Request)>,
pub(crate) requests_tx: Sender<(NoticeTx, Request)>,
pub pending: VecDeque<(NoticeTx, Request)>,
pub network: Option<Network>,
keepalive_timeout: Option<Pin<Box<Sleep>>>,
pub network_options: NetworkOptions,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Event {
Incoming(Incoming),
Outgoing(Outgoing),
}
impl EventLoop {
pub fn new(mqtt_options: MqttOptions, cap: usize) -> EventLoop {
let (requests_tx, requests_rx) = bounded(cap);
let pending = VecDeque::new();
let max_inflight = mqtt_options.inflight;
let manual_acks = mqtt_options.manual_acks;
EventLoop {
mqtt_options,
state: MqttState::new(max_inflight, manual_acks),
requests_tx,
requests_rx,
pending,
network: None,
keepalive_timeout: None,
network_options: NetworkOptions::new(),
}
}
pub fn clean(&mut self) {
self.network = None;
self.keepalive_timeout = None;
self.pending.extend(self.state.clean());
let requests_in_channel = self.requests_rx.drain();
self.pending.extend(requests_in_channel);
}
pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
if self.network.is_none() {
let (network, connack) = match time::timeout(
Duration::from_secs(self.network_options.connection_timeout()),
connect(&self.mqtt_options, self.network_options.clone()),
)
.await
{
Ok(inner) => inner?,
Err(_) => return Err(ConnectionError::NetworkTimeout),
};
if !connack.session_present {
for (tx, request) in self.pending.drain(..) {
if let Request::Publish(_) = request {
tx.error(NoticeError::SessionReset)
}
}
}
self.network = Some(network);
if self.keepalive_timeout.is_none() && !self.mqtt_options.keep_alive.is_zero() {
self.keepalive_timeout = Some(Box::pin(time::sleep(self.mqtt_options.keep_alive)));
}
return Ok(Event::Incoming(Packet::ConnAck(connack)));
}
match self.select().await {
Ok(v) => Ok(v),
Err(e) => {
self.clean();
Err(e)
}
}
}
async fn select(&mut self) -> Result<Event, ConnectionError> {
let network = self.network.as_mut().unwrap();
let inflight_full = self.state.inflight >= self.mqtt_options.inflight;
let collision = self.state.collision.is_some();
let network_timeout = Duration::from_secs(self.network_options.connection_timeout());
if let Some(event) = self.state.events.pop_front() {
return Ok(event);
}
let mut no_sleep = Box::pin(time::sleep(Duration::ZERO));
select! {
o = network.readb(&mut self.state) => {
o?;
match time::timeout(network_timeout, network.flush()).await {
Ok(inner) => inner?,
Err(_)=> return Err(ConnectionError::FlushTimeout),
};
Ok(self.state.events.pop_front().unwrap())
},
o = Self::next_request(
&mut self.pending,
&self.requests_rx,
self.mqtt_options.pending_throttle
), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
Ok((tx, request)) => {
if let Some(outgoing) = self.state.handle_outgoing_packet(tx, request)? {
network.write(outgoing).await?;
}
match time::timeout(network_timeout, network.flush()).await {
Ok(inner) => inner?,
Err(_)=> return Err(ConnectionError::FlushTimeout),
};
Ok(self.state.events.pop_front().unwrap())
}
Err(_) => Err(ConnectionError::RequestsDone),
},
_ = self.keepalive_timeout.as_mut().unwrap_or(&mut no_sleep),
if self.keepalive_timeout.is_some() && !self.mqtt_options.keep_alive.is_zero() => {
let timeout = self.keepalive_timeout.as_mut().unwrap();
timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive);
let (tx, _) = NoticeTx::new();
if let Some(outgoing) = self.state.handle_outgoing_packet(tx, Request::PingReq(PingReq))? {
network.write(outgoing).await?;
}
match time::timeout(network_timeout, network.flush()).await {
Ok(inner) => inner?,
Err(_)=> return Err(ConnectionError::FlushTimeout),
};
Ok(self.state.events.pop_front().unwrap())
}
}
}
pub fn network_options(&self) -> NetworkOptions {
self.network_options.clone()
}
pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self {
self.network_options = network_options;
self
}
async fn next_request(
pending: &mut VecDeque<(NoticeTx, Request)>,
rx: &Receiver<(NoticeTx, Request)>,
pending_throttle: Duration,
) -> Result<(NoticeTx, Request), ConnectionError> {
if !pending.is_empty() {
time::sleep(pending_throttle).await;
Ok(pending.pop_front().unwrap())
} else {
match rx.recv_async().await {
Ok(r) => Ok(r),
Err(_) => Err(ConnectionError::RequestsDone),
}
}
}
}
async fn connect(
mqtt_options: &MqttOptions,
network_options: NetworkOptions,
) -> Result<(Network, ConnAck), ConnectionError> {
let mut network = network_connect(mqtt_options, network_options).await?;
let connack = mqtt_connect(mqtt_options, &mut network).await?;
Ok((network, connack))
}
pub(crate) async fn socket_connect(
host: String,
network_options: NetworkOptions,
) -> io::Result<TcpStream> {
let addrs = lookup_host(host).await?;
let mut last_err = None;
for addr in addrs {
let socket = match addr {
SocketAddr::V4(_) => TcpSocket::new_v4()?,
SocketAddr::V6(_) => TcpSocket::new_v6()?,
};
if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
socket.set_send_buffer_size(send_buff_size).unwrap();
}
if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
socket.set_recv_buffer_size(recv_buffer_size).unwrap();
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
{
if let Some(bind_device) = &network_options.bind_device {
socket.bind_device(Some(bind_device.as_bytes()))?;
}
}
match socket.connect(addr).await {
Ok(s) => return Ok(s),
Err(e) => {
last_err = Some(e);
}
};
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
)
}))
}
async fn network_connect(
options: &MqttOptions,
network_options: NetworkOptions,
) -> Result<Network, ConnectionError> {
#[cfg(unix)]
if matches!(options.transport(), Transport::Unix) {
let file = options.broker_addr.as_str();
let socket = UnixStream::connect(Path::new(file)).await?;
let network = Network::new(
socket,
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
);
return Ok(network);
}
let (domain, port) = match options.transport() {
#[cfg(feature = "websocket")]
Transport::Ws => split_url(&options.broker_addr)?,
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
Transport::Wss(_) => split_url(&options.broker_addr)?,
_ => options.broker_address(),
};
let tcp_stream: Box<dyn AsyncReadWrite> = {
#[cfg(feature = "proxy")]
match options.proxy() {
Some(proxy) => proxy.connect(&domain, port, network_options).await?,
None => {
let addr = format!("{domain}:{port}");
let tcp = socket_connect(addr, network_options).await?;
Box::new(tcp)
}
}
#[cfg(not(feature = "proxy"))]
{
let addr = format!("{domain}:{port}");
let tcp = socket_connect(addr, network_options).await?;
Box::new(tcp)
}
};
let network = match options.transport() {
Transport::Tcp => Network::new(
tcp_stream,
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
),
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
Transport::Tls(tls_config) => {
let socket =
tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
.await?;
Network::new(
socket,
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
)
}
#[cfg(unix)]
Transport::Unix => unreachable!(),
#[cfg(feature = "websocket")]
Transport::Ws => {
let mut request = options.broker_addr.as_str().into_client_request()?;
request
.headers_mut()
.insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
if let Some(request_modifier) = options.request_modifier() {
request = request_modifier(request).await;
}
let (socket, response) =
async_tungstenite::tokio::client_async(request, tcp_stream).await?;
validate_response_headers(response)?;
Network::new(
WsStream::new(socket),
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
)
}
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
Transport::Wss(tls_config) => {
let mut request = options.broker_addr.as_str().into_client_request()?;
request
.headers_mut()
.insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
if let Some(request_modifier) = options.request_modifier() {
request = request_modifier(request).await;
}
let connector = tls::rustls_connector(&tls_config).await?;
let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
request,
tcp_stream,
Some(connector),
)
.await?;
validate_response_headers(response)?;
Network::new(
WsStream::new(socket),
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
)
}
};
Ok(network)
}
async fn mqtt_connect(
options: &MqttOptions,
network: &mut Network,
) -> Result<ConnAck, ConnectionError> {
let keep_alive = options.keep_alive().as_secs() as u16;
let clean_session = options.clean_session();
let last_will = options.last_will();
let mut connect = Connect::new(options.client_id());
connect.keep_alive = keep_alive;
connect.clean_session = clean_session;
connect.last_will = last_will;
connect.login = options.credentials();
network.connect(connect).await?;
match network.read().await? {
Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => Ok(connack),
Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
packet => Err(ConnectionError::NotConnAck(packet)),
}
}