use futures_lite::Stream;
use futures_sink::Sink;
use crate::{
transport::{Connection, ConnectionErrors, LocalAddr, ServerEndpoint},
RpcMessage, Service,
};
use core::fmt;
use std::{error, fmt::Display, pin::Pin, result, task::Poll};
use super::ConnectionCommon;
#[derive(Debug)]
pub enum RecvError {}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
pub struct SendSink<T: RpcMessage>(flume::r#async::SendSink<'static, T>);
impl<T: RpcMessage> fmt::Debug for SendSink<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SendSink").finish()
}
}
impl<T: RpcMessage> Sink<T> for SendSink<T> {
type Error = self::SendError;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_ready(cx)
.map_err(|_| SendError::ReceiverDropped)
}
fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
Pin::new(&mut self.0)
.start_send(item)
.map_err(|_| SendError::ReceiverDropped)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_flush(cx)
.map_err(|_| SendError::ReceiverDropped)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_close(cx)
.map_err(|_| SendError::ReceiverDropped)
}
}
pub struct RecvStream<T: RpcMessage>(flume::r#async::RecvStream<'static, T>);
impl<T: RpcMessage> fmt::Debug for RecvStream<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RecvStream").finish()
}
}
impl<T: RpcMessage> Stream for RecvStream<T> {
type Item = result::Result<T, self::RecvError>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.0).poll_next(cx) {
Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl error::Error for RecvError {}
pub struct FlumeServerEndpoint<S: Service> {
#[allow(clippy::type_complexity)]
stream: flume::Receiver<(SendSink<S::Res>, RecvStream<S::Req>)>,
}
impl<S: Service> Clone for FlumeServerEndpoint<S> {
fn clone(&self) -> Self {
Self {
stream: self.stream.clone(),
}
}
}
impl<S: Service> fmt::Debug for FlumeServerEndpoint<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FlumeServerEndpoint")
.field("stream", &self.stream)
.finish()
}
}
impl<S: Service> ConnectionErrors for FlumeServerEndpoint<S> {
type SendError = self::SendError;
type RecvError = self::RecvError;
type OpenError = self::AcceptBiError;
}
impl<S: Service> ConnectionCommon<S::Req, S::Res> for FlumeServerEndpoint<S> {
type SendSink = SendSink<S::Res>;
type RecvStream = RecvStream<S::Req>;
}
impl<S: Service> ServerEndpoint<S::Req, S::Res> for FlumeServerEndpoint<S> {
async fn accept_bi(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptBiError> {
self.stream
.recv_async()
.await
.map_err(|_| AcceptBiError::RemoteDropped)
}
fn local_addr(&self) -> &[LocalAddr] {
&[LocalAddr::Mem]
}
}
impl<S: Service> ConnectionErrors for FlumeConnection<S> {
type SendError = self::SendError;
type RecvError = self::RecvError;
type OpenError = self::OpenBiError;
}
impl<S: Service> ConnectionCommon<S::Res, S::Req> for FlumeConnection<S> {
type SendSink = SendSink<S::Req>;
type RecvStream = RecvStream<S::Res>;
}
impl<S: Service> Connection<S::Res, S::Req> for FlumeConnection<S> {
async fn open_bi(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
let (local_send, remote_recv) = flume::bounded::<S::Req>(128);
let (remote_send, local_recv) = flume::bounded::<S::Res>(128);
let remote_chan = (
SendSink(remote_send.into_sink()),
RecvStream(remote_recv.into_stream()),
);
let local_chan = (
SendSink(local_send.into_sink()),
RecvStream(local_recv.into_stream()),
);
self.sink
.send_async(remote_chan)
.await
.map_err(|_| OpenBiError::RemoteDropped)?;
Ok(local_chan)
}
}
pub struct FlumeConnection<S: Service> {
#[allow(clippy::type_complexity)]
sink: flume::Sender<(SendSink<S::Res>, RecvStream<S::Req>)>,
}
impl<S: Service> Clone for FlumeConnection<S> {
fn clone(&self) -> Self {
Self {
sink: self.sink.clone(),
}
}
}
impl<S: Service> fmt::Debug for FlumeConnection<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FlumeClientChannel")
.field("sink", &self.sink)
.finish()
}
}
#[derive(Debug)]
pub enum AcceptBiError {
RemoteDropped,
}
impl fmt::Display for AcceptBiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl error::Error for AcceptBiError {}
#[derive(Debug)]
pub enum SendError {
ReceiverDropped,
}
impl Display for SendError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl std::error::Error for SendError {}
#[derive(Debug)]
pub enum OpenBiError {
RemoteDropped,
}
impl Display for OpenBiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl std::error::Error for OpenBiError {}
#[derive(Debug, Clone, Copy)]
pub enum CreateChannelError {}
impl Display for CreateChannelError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl std::error::Error for CreateChannelError {}
pub fn connection<S: Service>(buffer: usize) -> (FlumeServerEndpoint<S>, FlumeConnection<S>) {
let (sink, stream) = flume::bounded(buffer);
(FlumeServerEndpoint { stream }, FlumeConnection { sink })
}