use std::{fmt, io, num::NonZeroU16};
use ntex_util::future::Either;
use crate::v5::codec::DisconnectReasonCode;
#[derive(Debug, thiserror::Error)]
pub enum MqttError<E> {
#[error("Service error")]
Service(E),
#[error("Mqtt handshake error: {}", _0)]
Handshake(#[from] HandshakeError<E>),
}
#[derive(Debug, thiserror::Error)]
pub enum HandshakeError<E> {
#[error("Handshake service error")]
Service(E),
#[error("Mqtt protocol error: {}", _0)]
Protocol(#[from] ProtocolError),
#[error("Handshake timeout")]
Timeout,
#[error("Peer is disconnected, error: {:?}", _0)]
Disconnected(Option<io::Error>),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ProtocolError {
#[error("Decoding error: {0:?}")]
Decode(#[from] DecodeError),
#[error("Encoding error: {0:?}")]
Encode(#[from] EncodeError),
#[error("Protocol violation: {0}")]
ProtocolViolation(#[from] ProtocolViolationError),
#[error("Keep Alive timeout")]
KeepAliveTimeout,
#[error("Read frame timeout")]
ReadTimeout,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
#[error(transparent)]
pub struct ProtocolViolationError {
inner: ViolationInner,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
enum ViolationInner {
#[error("{message}")]
Common { reason: DisconnectReasonCode, message: &'static str },
#[error("{message}; received packet with type `{packet_type:b}`")]
UnexpectedPacket { packet_type: u8, message: &'static str },
}
impl ProtocolViolationError {
pub(crate) fn reason(&self) -> DisconnectReasonCode {
match self.inner {
ViolationInner::Common { reason, .. } => reason,
ViolationInner::UnexpectedPacket { .. } => DisconnectReasonCode::ProtocolError,
}
}
}
impl ProtocolError {
pub(crate) fn violation(reason: DisconnectReasonCode, message: &'static str) -> Self {
Self::ProtocolViolation(ProtocolViolationError {
inner: ViolationInner::Common { reason, message },
})
}
pub fn generic_violation(message: &'static str) -> Self {
Self::violation(DisconnectReasonCode::ProtocolError, message)
}
pub(crate) fn unexpected_packet(packet_type: u8, message: &'static str) -> ProtocolError {
Self::ProtocolViolation(ProtocolViolationError {
inner: ViolationInner::UnexpectedPacket { packet_type, message },
})
}
pub(crate) fn packet_id_mismatch() -> Self {
Self::generic_violation(
"Packet id of PUBACK packet does not match expected next value according to sending order of PUBLISH packets [MQTT-4.6.0-2]"
)
}
}
impl<E> From<io::Error> for MqttError<E> {
fn from(err: io::Error) -> Self {
MqttError::Handshake(HandshakeError::Disconnected(Some(err)))
}
}
impl<E> From<Either<io::Error, io::Error>> for MqttError<E> {
fn from(err: Either<io::Error, io::Error>) -> Self {
MqttError::Handshake(HandshakeError::Disconnected(Some(err.into_inner())))
}
}
impl<E> From<EncodeError> for MqttError<E> {
fn from(err: EncodeError) -> Self {
MqttError::Handshake(HandshakeError::Protocol(ProtocolError::Encode(err)))
}
}
impl<E> From<Either<DecodeError, io::Error>> for HandshakeError<E> {
fn from(err: Either<DecodeError, io::Error>) -> Self {
match err {
Either::Left(err) => HandshakeError::Protocol(ProtocolError::Decode(err)),
Either::Right(err) => HandshakeError::Disconnected(Some(err)),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, thiserror::Error)]
pub enum DecodeError {
#[error("Invalid protocol")]
InvalidProtocol,
#[error("Invalid length")]
InvalidLength,
#[error("Malformed packet")]
MalformedPacket,
#[error("Unsupported protocol level")]
UnsupportedProtocolLevel,
#[error("Connect frame's reserved flag is set")]
ConnectReservedFlagSet,
#[error("ConnectAck frame's reserved flag is set")]
ConnAckReservedFlagSet,
#[error("Invalid client id")]
InvalidClientId,
#[error("Unsupported packet type")]
UnsupportedPacketType,
#[error("Packet id is required")]
PacketIdRequired,
#[error("Max size exceeded")]
MaxSizeExceeded,
#[error("utf8 error")]
Utf8Error,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, thiserror::Error)]
pub enum EncodeError {
#[error("Packet is bigger than peer's Maximum Packet Size")]
OverMaxPacketSize,
#[error("Invalid length")]
InvalidLength,
#[error("Malformed packet")]
MalformedPacket,
#[error("Packet id is required")]
PacketIdRequired,
#[error("Unsupported version")]
UnsupportedVersion,
}
#[derive(Debug, PartialEq, Eq, Copy, Clone, thiserror::Error)]
pub enum SendPacketError {
#[error("Encoding error {:?}", _0)]
Encode(#[from] EncodeError),
#[error("Provided packet id is in use")]
PacketIdInUse(NonZeroU16),
#[error("Peer is disconnected")]
Disconnected,
}
#[derive(Debug, thiserror::Error)]
pub enum ClientError<T: fmt::Debug> {
#[error("Connect ack failed: {:?}", _0)]
Ack(T),
#[error("Protocol error: {:?}", _0)]
Protocol(#[from] ProtocolError),
#[error("Handshake timeout")]
HandshakeTimeout,
#[error("Peer disconnected")]
Disconnected(Option<std::io::Error>),
#[error("Connect error: {}", _0)]
Connect(#[from] ntex_net::connect::ConnectError),
}
impl<T: fmt::Debug> From<EncodeError> for ClientError<T> {
fn from(err: EncodeError) -> Self {
ClientError::Protocol(ProtocolError::Encode(err))
}
}
impl<T: fmt::Debug> From<Either<DecodeError, std::io::Error>> for ClientError<T> {
fn from(err: Either<DecodeError, std::io::Error>) -> Self {
match err {
Either::Left(err) => ClientError::Protocol(ProtocolError::Decode(err)),
Either::Right(err) => ClientError::Disconnected(Some(err)),
}
}
}