use super::ChannelSendError;
use crate::intrusive_double_linked_list::ListNode;
use core::marker::PhantomData;
use core::pin::Pin;
use futures_core::future::{FusedFuture, Future};
use futures_core::task::{Context, Poll, Waker};
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum CloseStatus {
NewlyClosed,
AlreadyClosed,
}
impl CloseStatus {
pub fn is_newly_closed(self) -> bool {
match self {
Self::NewlyClosed => true,
_ => false,
}
}
pub fn is_already_closed(self) -> bool {
match self {
Self::AlreadyClosed => true,
_ => false,
}
}
}
#[derive(PartialEq, Debug)]
pub enum RecvPollState {
Unregistered,
Registered,
Notified,
}
#[derive(Debug)]
pub struct RecvWaitQueueEntry {
pub task: Option<Waker>,
pub state: RecvPollState,
}
impl RecvWaitQueueEntry {
pub fn new() -> RecvWaitQueueEntry {
RecvWaitQueueEntry {
task: None,
state: RecvPollState::Unregistered,
}
}
}
#[derive(PartialEq, Debug)]
pub enum SendPollState {
Unregistered,
Registered,
SendComplete,
}
pub struct SendWaitQueueEntry<T> {
pub task: Option<Waker>,
pub state: SendPollState,
pub value: Option<T>,
}
impl<T> core::fmt::Debug for SendWaitQueueEntry<T> {
fn fmt(
&self,
fmt: &mut core::fmt::Formatter<'_>,
) -> core::result::Result<(), core::fmt::Error> {
fmt.debug_struct("SendWaitQueueEntry")
.field("task", &self.task)
.field("state", &self.state)
.finish()
}
}
impl<T> SendWaitQueueEntry<T> {
pub fn new(value: T) -> SendWaitQueueEntry<T> {
SendWaitQueueEntry {
task: None,
state: SendPollState::Unregistered,
value: Some(value),
}
}
}
pub trait ChannelSendAccess<T> {
unsafe fn send_or_register(
&self,
wait_node: &mut ListNode<SendWaitQueueEntry<T>>,
cx: &mut Context<'_>,
) -> (Poll<()>, Option<T>);
fn remove_send_waiter(
&self,
wait_node: &mut ListNode<SendWaitQueueEntry<T>>,
);
}
pub trait ChannelReceiveAccess<T> {
unsafe fn receive_or_register(
&self,
wait_node: &mut ListNode<RecvWaitQueueEntry>,
cx: &mut Context<'_>,
) -> Poll<Option<T>>;
fn remove_receive_waiter(
&self,
wait_node: &mut ListNode<RecvWaitQueueEntry>,
);
}
#[must_use = "futures do nothing unless polled"]
pub struct ChannelReceiveFuture<'a, MutexType, T> {
pub(crate) channel: Option<&'a dyn ChannelReceiveAccess<T>>,
pub(crate) wait_node: ListNode<RecvWaitQueueEntry>,
pub(crate) _phantom: PhantomData<MutexType>,
}
unsafe impl<'a, MutexType: Sync, T: Send> Send
for ChannelReceiveFuture<'a, MutexType, T>
{
}
impl<'a, MutexType, T> core::fmt::Debug
for ChannelReceiveFuture<'a, MutexType, T>
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct("ChannelReceiveFuture").finish()
}
}
impl<'a, MutexType, T> Future for ChannelReceiveFuture<'a, MutexType, T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
let mut_self: &mut ChannelReceiveFuture<MutexType, T> =
unsafe { Pin::get_unchecked_mut(self) };
let channel = mut_self
.channel
.expect("polled ChannelReceiveFuture after completion");
let poll_res =
unsafe { channel.receive_or_register(&mut mut_self.wait_node, cx) };
if poll_res.is_ready() {
mut_self.channel = None;
}
poll_res
}
}
impl<'a, MutexType, T> FusedFuture for ChannelReceiveFuture<'a, MutexType, T> {
fn is_terminated(&self) -> bool {
self.channel.is_none()
}
}
impl<'a, MutexType, T> Drop for ChannelReceiveFuture<'a, MutexType, T> {
fn drop(&mut self) {
if let Some(channel) = self.channel {
channel.remove_receive_waiter(&mut self.wait_node);
}
}
}
#[must_use = "futures do nothing unless polled"]
pub struct ChannelSendFuture<'a, MutexType, T> {
pub(crate) channel: Option<&'a dyn ChannelSendAccess<T>>,
pub(crate) wait_node: ListNode<SendWaitQueueEntry<T>>,
pub(crate) _phantom: PhantomData<MutexType>,
}
unsafe impl<'a, MutexType: Sync, T: Send> Send
for ChannelSendFuture<'a, MutexType, T>
{
}
impl<'a, MutexType, T> core::fmt::Debug
for ChannelSendFuture<'a, MutexType, T>
{
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct("ChannelSendFuture").finish()
}
}
impl<'a, MutexType, T> ChannelSendFuture<'a, MutexType, T> {
pub fn cancel(&mut self) -> Option<T> {
let channel = self.channel.take();
match channel {
None => None,
Some(channel) => {
channel.remove_send_waiter(&mut self.wait_node);
self.wait_node.value.take()
}
}
}
}
impl<'a, MutexType, T> Future for ChannelSendFuture<'a, MutexType, T> {
type Output = Result<(), ChannelSendError<T>>;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelSendError<T>>> {
let mut_self: &mut ChannelSendFuture<MutexType, T> =
unsafe { Pin::get_unchecked_mut(self) };
let channel = mut_self
.channel
.expect("polled ChannelSendFuture after completion");
let send_res =
unsafe { channel.send_or_register(&mut mut_self.wait_node, cx) };
match send_res.0 {
Poll::Ready(()) => {
mut_self.channel = None;
match send_res.1 {
Some(v) => {
Poll::Ready(Err(ChannelSendError(v)))
}
None => Poll::Ready(Ok(())),
}
}
Poll::Pending => Poll::Pending,
}
}
}
impl<'a, MutexType, T> FusedFuture for ChannelSendFuture<'a, MutexType, T> {
fn is_terminated(&self) -> bool {
self.channel.is_none()
}
}
impl<'a, MutexType, T> Drop for ChannelSendFuture<'a, MutexType, T> {
fn drop(&mut self) {
if let Some(channel) = self.channel {
channel.remove_send_waiter(&mut self.wait_node);
}
}
}
#[cfg(feature = "alloc")]
mod if_alloc {
use super::*;
pub mod shared {
use super::*;
#[must_use = "futures do nothing unless polled"]
pub struct ChannelReceiveFuture<MutexType, T> {
pub(crate) channel:
Option<alloc::sync::Arc<dyn ChannelReceiveAccess<T>>>,
pub(crate) wait_node: ListNode<RecvWaitQueueEntry>,
pub(crate) _phantom: PhantomData<MutexType>,
}
unsafe impl<MutexType: Sync, T: Send> Send
for ChannelReceiveFuture<MutexType, T>
{
}
impl<MutexType, T> core::fmt::Debug for ChannelReceiveFuture<MutexType, T> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct("ChannelReceiveFuture").finish()
}
}
impl<MutexType, T> Future for ChannelReceiveFuture<MutexType, T> {
type Output = Option<T>;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<T>> {
let mut_self: &mut ChannelReceiveFuture<MutexType, T> =
unsafe { Pin::get_unchecked_mut(self) };
let channel = mut_self
.channel
.take()
.expect("polled ChannelReceiveFuture after completion");
let poll_res = unsafe {
channel.receive_or_register(&mut mut_self.wait_node, cx)
};
if poll_res.is_ready() {
mut_self.channel = None;
} else {
mut_self.channel = Some(channel)
}
poll_res
}
}
impl<MutexType, T> FusedFuture for ChannelReceiveFuture<MutexType, T> {
fn is_terminated(&self) -> bool {
self.channel.is_none()
}
}
impl<MutexType, T> Drop for ChannelReceiveFuture<MutexType, T> {
fn drop(&mut self) {
if let Some(channel) = &self.channel {
channel.remove_receive_waiter(&mut self.wait_node);
}
}
}
#[must_use = "futures do nothing unless polled"]
pub struct ChannelSendFuture<MutexType, T> {
pub(crate) channel:
Option<alloc::sync::Arc<dyn ChannelSendAccess<T>>>,
pub(crate) wait_node: ListNode<SendWaitQueueEntry<T>>,
pub(crate) _phantom: PhantomData<MutexType>,
}
impl<MutexType, T> ChannelSendFuture<MutexType, T> {
pub fn cancel(&mut self) -> Option<T> {
let channel = self.channel.take();
match channel {
None => None,
Some(channel) => {
channel.remove_send_waiter(&mut self.wait_node);
self.wait_node.value.take()
}
}
}
}
unsafe impl<MutexType: Sync, T: Send> Send for ChannelSendFuture<MutexType, T> {}
impl<MutexType, T> core::fmt::Debug for ChannelSendFuture<MutexType, T> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct("ChannelSendFuture").finish()
}
}
impl<MutexType, T> Future for ChannelSendFuture<MutexType, T> {
type Output = Result<(), ChannelSendError<T>>;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelSendError<T>>> {
let mut_self: &mut ChannelSendFuture<MutexType, T> =
unsafe { Pin::get_unchecked_mut(self) };
let channel = mut_self
.channel
.take()
.expect("polled ChannelSendFuture after completion");
let send_res = unsafe {
channel.send_or_register(&mut mut_self.wait_node, cx)
};
match send_res.0 {
Poll::Ready(()) => {
match send_res.1 {
Some(v) => {
Poll::Ready(Err(ChannelSendError(v)))
}
None => Poll::Ready(Ok(())),
}
}
Poll::Pending => {
mut_self.channel = Some(channel);
Poll::Pending
}
}
}
}
impl<MutexType, T> FusedFuture for ChannelSendFuture<MutexType, T> {
fn is_terminated(&self) -> bool {
self.channel.is_none()
}
}
impl<MutexType, T> Drop for ChannelSendFuture<MutexType, T> {
fn drop(&mut self) {
if let Some(channel) = &self.channel {
channel.remove_send_waiter(&mut self.wait_node);
}
}
}
}
}
#[cfg(feature = "alloc")]
pub use self::if_alloc::*;