use std::time::Duration;
use flume::TryRecvError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SendError {
#[error("The channel is closed.")]
Disconnected,
#[error("The channel is full.")]
Full,
}
#[derive(Debug, Error, Copy, Clone, PartialEq, Eq)]
pub enum RecvError {
#[error("A timeout occured when attempting to receive a message.")]
Timeout,
#[error("All sender were dropped an no message are pending in the channel.")]
Disconnected,
}
impl From<flume::RecvTimeoutError> for RecvError {
fn from(flume_err: flume::RecvTimeoutError) -> Self {
match flume_err {
flume::RecvTimeoutError::Timeout => Self::Timeout,
flume::RecvTimeoutError::Disconnected => Self::Disconnected,
}
}
}
impl<T> From<flume::SendError<T>> for SendError {
fn from(_send_error: flume::SendError<T>) -> Self {
SendError::Disconnected
}
}
impl<T> From<flume::TrySendError<T>> for SendError {
fn from(try_send_error: flume::TrySendError<T>) -> Self {
match try_send_error {
flume::TrySendError::Full(_) => SendError::Full,
flume::TrySendError::Disconnected(_) => SendError::Disconnected,
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum Priority {
High,
Low,
}
#[derive(Clone, Copy, Debug)]
pub enum QueueCapacity {
Bounded(usize),
Unbounded,
}
impl QueueCapacity {
pub(crate) fn create_channel<M>(&self) -> (flume::Sender<M>, flume::Receiver<M>) {
match *self {
QueueCapacity::Bounded(cap) => flume::bounded(cap),
QueueCapacity::Unbounded => flume::unbounded(),
}
}
}
pub fn channel<T>(queue_capacity: QueueCapacity) -> (Sender<T>, Receiver<T>) {
let (high_priority_tx, high_priority_rx) = flume::unbounded();
let (low_priority_tx, low_priority_rx) = queue_capacity.create_channel();
let receiver = Receiver {
low_priority_rx,
high_priority_rx,
_high_priority_tx: high_priority_tx.clone(),
pending: None,
};
let sender = Sender {
low_priority_tx,
high_priority_tx,
};
(sender, receiver)
}
pub struct Sender<T> {
low_priority_tx: flume::Sender<T>,
high_priority_tx: flume::Sender<T>,
}
impl<T> Sender<T> {
fn channel(&self, priority: Priority) -> &flume::Sender<T> {
match priority {
Priority::High => &self.high_priority_tx,
Priority::Low => &self.low_priority_tx,
}
}
pub async fn send(&self, msg: T, priority: Priority) -> Result<(), SendError> {
self.channel(priority).send_async(msg).await?;
Ok(())
}
}
pub struct Receiver<T> {
low_priority_rx: flume::Receiver<T>,
high_priority_rx: flume::Receiver<T>,
_high_priority_tx: flume::Sender<T>,
pending: Option<T>,
}
impl<T> Receiver<T> {
fn try_recv_high_priority_message(&self) -> Option<T> {
match self.high_priority_rx.try_recv() {
Ok(msg) => Some(msg),
Err(TryRecvError::Disconnected) => {
unreachable!(
"This can never happen, as the high priority Sender is owned by the Receiver."
);
}
Err(TryRecvError::Empty) => None,
}
}
pub async fn recv_high_priority_timeout(&mut self, duration: Duration) -> Result<T, RecvError> {
tokio::select! {
high_priority_msg_res = self.high_priority_rx.recv_async() => {
match high_priority_msg_res {
Ok(high_priority_msg) => { Ok(high_priority_msg) },
Err(_) => { unreachable!("The Receiver owns the high priority Sender to avoid any disconnection.") }, }
}
_ = tokio::time::sleep(duration) => {
Err(RecvError::Timeout)
}
}
}
pub async fn recv_timeout(&mut self, duration: Duration) -> Result<T, RecvError> {
if let Some(msg) = self.try_recv_high_priority_message() {
return Ok(msg);
}
if let Some(pending_msg) = self.pending.take() {
return Ok(pending_msg);
}
tokio::select! {
high_priority_msg_res = self.high_priority_rx.recv_async() => {
match high_priority_msg_res {
Ok(high_priority_msg) => {
Ok(high_priority_msg)
},
Err(_) => {
unreachable!("The Receiver owns the high priority Sender to avoid any disconnection.")
},
}
}
low_priority_msg_res = self.low_priority_rx.recv_async() => {
match low_priority_msg_res {
Ok(low_priority_msg) => {
if let Some(high_priority_msg) = self.try_recv_high_priority_message() {
self.pending = Some(low_priority_msg);
Ok(high_priority_msg)
} else {
Ok(low_priority_msg)
}
},
Err(flume::RecvError::Disconnected) => {
if let Some(high_priority_msg) = self.try_recv_high_priority_message() {
Ok(high_priority_msg)
} else {
Err(RecvError::Disconnected)
}
}
}
}
_ = tokio::time::sleep(duration) => {
Err(RecvError::Timeout)
}
}
}
pub fn drain_low_priority(&self) -> Vec<T> {
let mut messages = Vec::new();
while let Ok(msg) = self.low_priority_rx.try_recv() {
messages.push(msg);
}
messages
}
}
#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
use super::*;
const TEST_TIMEOUT: Duration = Duration::from_millis(100);
#[tokio::test]
async fn test_recv_timeout_prority() -> anyhow::Result<()> {
let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
sender.send(1, Priority::Low).await?;
sender.send(2, Priority::High).await?;
assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(2));
assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(1));
assert_eq!(
receiver.recv_timeout(TEST_TIMEOUT).await,
Err(RecvError::Timeout)
);
Ok(())
}
#[tokio::test]
async fn test_recv_high_priority_timeout() -> anyhow::Result<()> {
let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
sender.send(1, Priority::Low).await?;
assert_eq!(
receiver.recv_high_priority_timeout(TEST_TIMEOUT).await,
Err(RecvError::Timeout)
);
Ok(())
}
#[tokio::test]
async fn test_recv_high_priority_ignore_disconnection() -> anyhow::Result<()> {
let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
std::mem::drop(sender);
assert_eq!(
receiver.recv_high_priority_timeout(TEST_TIMEOUT).await,
Err(RecvError::Timeout)
);
Ok(())
}
#[tokio::test]
async fn test_recv_disconnect() -> anyhow::Result<()> {
let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
std::mem::drop(sender);
assert_eq!(
receiver.recv_timeout(TEST_TIMEOUT).await,
Err(RecvError::Disconnected)
);
Ok(())
}
#[tokio::test]
async fn test_recv_timeout_simple() -> anyhow::Result<()> {
let (_sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
let start_time = Instant::now();
assert_eq!(
receiver.recv_timeout(TEST_TIMEOUT).await,
Err(RecvError::Timeout)
);
let elapsed = start_time.elapsed();
assert!(elapsed < crate::HEARTBEAT);
Ok(())
}
#[tokio::test]
async fn test_try_recv_prority_corner_case() -> anyhow::Result<()> {
let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
tokio::task::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
sender.send(1, Priority::High).await?;
sender.send(2, Priority::Low).await?;
Result::<(), SendError>::Ok(())
});
assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(1));
assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(2));
assert_eq!(
receiver.recv_timeout(TEST_TIMEOUT).await,
Err(RecvError::Disconnected)
);
Ok(())
}
}