use core::fmt;
use std::{error, fmt::Display, marker::PhantomData, pin::Pin, result, task::Poll};
use futures_lite::{Future, Stream};
use futures_sink::Sink;
use super::StreamTypes;
use crate::{
transport::{ConnectionErrors, Connector, Listener, LocalAddr},
RpcMessage,
};
#[derive(Debug)]
pub enum RecvError {}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
pub struct SendSink<T: RpcMessage>(pub(crate) flume::r#async::SendSink<'static, T>);
impl<T: RpcMessage> fmt::Debug for SendSink<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SendSink").finish()
}
}
impl<T: RpcMessage> Sink<T> for SendSink<T> {
type Error = self::SendError;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_ready(cx)
.map_err(|_| SendError::ReceiverDropped)
}
fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
Pin::new(&mut self.0)
.start_send(item)
.map_err(|_| SendError::ReceiverDropped)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_flush(cx)
.map_err(|_| SendError::ReceiverDropped)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_close(cx)
.map_err(|_| SendError::ReceiverDropped)
}
}
pub struct RecvStream<T: RpcMessage>(pub(crate) flume::r#async::RecvStream<'static, T>);
impl<T: RpcMessage> fmt::Debug for RecvStream<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RecvStream").finish()
}
}
impl<T: RpcMessage> Stream for RecvStream<T> {
type Item = result::Result<T, self::RecvError>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.0).poll_next(cx) {
Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl error::Error for RecvError {}
pub struct FlumeListener<In: RpcMessage, Out: RpcMessage> {
#[allow(clippy::type_complexity)]
stream: flume::Receiver<(SendSink<Out>, RecvStream<In>)>,
}
impl<In: RpcMessage, Out: RpcMessage> Clone for FlumeListener<In, Out> {
fn clone(&self) -> Self {
Self {
stream: self.stream.clone(),
}
}
}
impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for FlumeListener<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FlumeListener")
.field("stream", &self.stream)
.finish()
}
}
impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for FlumeListener<In, Out> {
type SendError = self::SendError;
type RecvError = self::RecvError;
type OpenError = self::OpenError;
type AcceptError = self::AcceptError;
}
type Socket<In, Out> = (self::SendSink<Out>, self::RecvStream<In>);
pub struct OpenFuture<In: RpcMessage, Out: RpcMessage> {
inner: flume::r#async::SendFut<'static, Socket<Out, In>>,
res: Option<Socket<In, Out>>,
}
impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for OpenFuture<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OpenFuture").finish()
}
}
impl<In: RpcMessage, Out: RpcMessage> OpenFuture<In, Out> {
fn new(inner: flume::r#async::SendFut<'static, Socket<Out, In>>, res: Socket<In, Out>) -> Self {
Self {
inner,
res: Some(res),
}
}
}
impl<In: RpcMessage, Out: RpcMessage> Future for OpenFuture<In, Out> {
type Output = result::Result<Socket<In, Out>, self::OpenError>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
match Pin::new(&mut self.inner).poll(cx) {
Poll::Ready(Ok(())) => self
.res
.take()
.map(|x| Poll::Ready(Ok(x)))
.unwrap_or(Poll::Pending),
Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenError::RemoteDropped)),
Poll::Pending => Poll::Pending,
}
}
}
pub struct AcceptFuture<In: RpcMessage, Out: RpcMessage> {
wrapped: flume::r#async::RecvFut<'static, (SendSink<Out>, RecvStream<In>)>,
_p: PhantomData<(In, Out)>,
}
impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for AcceptFuture<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AcceptFuture").finish()
}
}
impl<In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<In, Out> {
type Output = result::Result<(SendSink<Out>, RecvStream<In>), AcceptError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.wrapped).poll(cx) {
Poll::Ready(Ok((send, recv))) => Poll::Ready(Ok((send, recv))),
Poll::Ready(Err(_)) => Poll::Ready(Err(AcceptError::RemoteDropped)),
Poll::Pending => Poll::Pending,
}
}
}
impl<In: RpcMessage, Out: RpcMessage> StreamTypes for FlumeListener<In, Out> {
type In = In;
type Out = Out;
type SendSink = SendSink<Out>;
type RecvStream = RecvStream<In>;
}
impl<In: RpcMessage, Out: RpcMessage> Listener for FlumeListener<In, Out> {
#[allow(refining_impl_trait)]
fn accept(&self) -> AcceptFuture<In, Out> {
AcceptFuture {
wrapped: self.stream.clone().into_recv_async(),
_p: PhantomData,
}
}
fn local_addr(&self) -> &[LocalAddr] {
&[LocalAddr::Mem]
}
}
impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for FlumeConnector<In, Out> {
type SendError = self::SendError;
type RecvError = self::RecvError;
type OpenError = self::OpenError;
type AcceptError = self::AcceptError;
}
impl<In: RpcMessage, Out: RpcMessage> StreamTypes for FlumeConnector<In, Out> {
type In = In;
type Out = Out;
type SendSink = SendSink<Out>;
type RecvStream = RecvStream<In>;
}
impl<In: RpcMessage, Out: RpcMessage> Connector for FlumeConnector<In, Out> {
#[allow(refining_impl_trait)]
fn open(&self) -> OpenFuture<In, Out> {
let (local_send, remote_recv) = flume::bounded::<Out>(128);
let (remote_send, local_recv) = flume::bounded::<In>(128);
let remote_chan = (
SendSink(remote_send.into_sink()),
RecvStream(remote_recv.into_stream()),
);
let local_chan = (
SendSink(local_send.into_sink()),
RecvStream(local_recv.into_stream()),
);
OpenFuture::new(self.sink.clone().into_send_async(remote_chan), local_chan)
}
}
pub struct FlumeConnector<In: RpcMessage, Out: RpcMessage> {
#[allow(clippy::type_complexity)]
sink: flume::Sender<(SendSink<In>, RecvStream<Out>)>,
}
impl<In: RpcMessage, Out: RpcMessage> Clone for FlumeConnector<In, Out> {
fn clone(&self) -> Self {
Self {
sink: self.sink.clone(),
}
}
}
impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for FlumeConnector<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FlumeClientChannel")
.field("sink", &self.sink)
.finish()
}
}
#[derive(Debug)]
pub enum AcceptError {
RemoteDropped,
}
impl fmt::Display for AcceptError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl error::Error for AcceptError {}
#[derive(Debug)]
pub enum SendError {
ReceiverDropped,
}
impl Display for SendError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl std::error::Error for SendError {}
#[derive(Debug)]
pub enum OpenError {
RemoteDropped,
}
impl Display for OpenError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl std::error::Error for OpenError {}
#[derive(Debug, Clone, Copy)]
pub enum CreateChannelError {}
impl Display for CreateChannelError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl std::error::Error for CreateChannelError {}
pub fn channel<Req: RpcMessage, Res: RpcMessage>(
buffer: usize,
) -> (FlumeListener<Req, Res>, FlumeConnector<Res, Req>) {
let (sink, stream) = flume::bounded(buffer);
(FlumeListener { stream }, FlumeConnector { sink })
}