use std::{
error, fmt,
fmt::Debug,
pin::Pin,
task::{Context, Poll},
};
use futures_lite::Stream;
use futures_sink::Sink;
use pin_project::pin_project;
use super::{ConnectionErrors, Connector, Listener, LocalAddr, StreamTypes};
#[derive(Debug, Clone)]
pub struct CombinedConnector<A, B> {
pub a: Option<A>,
pub b: Option<B>,
}
impl<A: Connector, B: Connector<In = A::In, Out = A::Out>> CombinedConnector<A, B> {
pub fn new(a: Option<A>, b: Option<B>) -> Self {
Self { a, b }
}
}
#[derive(Debug, Clone)]
pub struct CombinedListener<A, B> {
pub a: Option<A>,
pub b: Option<B>,
local_addr: Vec<LocalAddr>,
}
impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> CombinedListener<A, B> {
pub fn new(a: Option<A>, b: Option<B>) -> Self {
let mut local_addr = Vec::with_capacity(2);
if let Some(a) = &a {
local_addr.extend(a.local_addr().iter().cloned())
};
if let Some(b) = &b {
local_addr.extend(b.local_addr().iter().cloned())
};
Self { a, b, local_addr }
}
pub fn into_inner(self) -> (Option<A>, Option<B>) {
(self.a, self.b)
}
}
#[pin_project(project = SendSinkProj)]
pub enum SendSink<A: StreamTypes, B: StreamTypes> {
A(#[pin] A::SendSink),
B(#[pin] B::SendSink),
}
impl<A: StreamTypes, B: StreamTypes<In = A::In, Out = A::Out>> Sink<A::Out> for SendSink<A, B> {
type Error = self::SendError<A, B>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.project() {
SendSinkProj::A(sink) => sink.poll_ready(cx).map_err(Self::Error::A),
SendSinkProj::B(sink) => sink.poll_ready(cx).map_err(Self::Error::B),
}
}
fn start_send(self: Pin<&mut Self>, item: A::Out) -> Result<(), Self::Error> {
match self.project() {
SendSinkProj::A(sink) => sink.start_send(item).map_err(Self::Error::A),
SendSinkProj::B(sink) => sink.start_send(item).map_err(Self::Error::B),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.project() {
SendSinkProj::A(sink) => sink.poll_flush(cx).map_err(Self::Error::A),
SendSinkProj::B(sink) => sink.poll_flush(cx).map_err(Self::Error::B),
}
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.project() {
SendSinkProj::A(sink) => sink.poll_close(cx).map_err(Self::Error::A),
SendSinkProj::B(sink) => sink.poll_close(cx).map_err(Self::Error::B),
}
}
}
#[pin_project(project = ResStreamProj)]
pub enum RecvStream<A: StreamTypes, B: StreamTypes> {
A(#[pin] A::RecvStream),
B(#[pin] B::RecvStream),
}
impl<A: StreamTypes, B: StreamTypes<In = A::In, Out = A::Out>> Stream for RecvStream<A, B> {
type Item = Result<A::In, RecvError<A, B>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.project() {
ResStreamProj::A(stream) => stream.poll_next(cx).map_err(RecvError::<A, B>::A),
ResStreamProj::B(stream) => stream.poll_next(cx).map_err(RecvError::<A, B>::B),
}
}
}
#[derive(Debug)]
pub enum SendError<A: ConnectionErrors, B: ConnectionErrors> {
A(A::SendError),
B(B::SendError),
}
impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for SendError<A, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for SendError<A, B> {}
#[derive(Debug)]
pub enum RecvError<A: ConnectionErrors, B: ConnectionErrors> {
A(A::RecvError),
B(B::RecvError),
}
impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for RecvError<A, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for RecvError<A, B> {}
#[derive(Debug)]
pub enum OpenError<A: ConnectionErrors, B: ConnectionErrors> {
A(A::OpenError),
B(B::OpenError),
NoChannel,
}
impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for OpenError<A, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for OpenError<A, B> {}
#[derive(Debug)]
pub enum AcceptError<A: ConnectionErrors, B: ConnectionErrors> {
A(A::AcceptError),
B(B::AcceptError),
}
impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for AcceptError<A, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for AcceptError<A, B> {}
impl<A: ConnectionErrors, B: ConnectionErrors> ConnectionErrors for CombinedConnector<A, B> {
type SendError = self::SendError<A, B>;
type RecvError = self::RecvError<A, B>;
type OpenError = self::OpenError<A, B>;
type AcceptError = self::AcceptError<A, B>;
}
impl<A: Connector, B: Connector<In = A::In, Out = A::Out>> StreamTypes for CombinedConnector<A, B> {
type In = A::In;
type Out = A::Out;
type RecvStream = self::RecvStream<A, B>;
type SendSink = self::SendSink<A, B>;
}
impl<A: Connector, B: Connector<In = A::In, Out = A::Out>> Connector for CombinedConnector<A, B> {
async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
let this = self.clone();
if let Some(a) = this.a {
let (send, recv) = a.open().await.map_err(OpenError::A)?;
Ok((SendSink::A(send), RecvStream::A(recv)))
} else if let Some(b) = this.b {
let (send, recv) = b.open().await.map_err(OpenError::B)?;
Ok((SendSink::B(send), RecvStream::B(recv)))
} else {
Err(OpenError::NoChannel)
}
}
}
impl<A: ConnectionErrors, B: ConnectionErrors> ConnectionErrors for CombinedListener<A, B> {
type SendError = self::SendError<A, B>;
type RecvError = self::RecvError<A, B>;
type OpenError = self::OpenError<A, B>;
type AcceptError = self::AcceptError<A, B>;
}
impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> StreamTypes for CombinedListener<A, B> {
type In = A::In;
type Out = A::Out;
type RecvStream = self::RecvStream<A, B>;
type SendSink = self::SendSink<A, B>;
}
impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> Listener for CombinedListener<A, B> {
async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> {
let a_fut = async {
if let Some(a) = &self.a {
let (send, recv) = a.accept().await.map_err(AcceptError::A)?;
Ok((SendSink::A(send), RecvStream::A(recv)))
} else {
std::future::pending().await
}
};
let b_fut = async {
if let Some(b) = &self.b {
let (send, recv) = b.accept().await.map_err(AcceptError::B)?;
Ok((SendSink::B(send), RecvStream::B(recv)))
} else {
std::future::pending().await
}
};
async move {
tokio::select! {
res = a_fut => res,
res = b_fut => res,
}
}
.await
}
fn local_addr(&self) -> &[LocalAddr] {
&self.local_addr
}
}
#[cfg(test)]
#[cfg(feature = "flume-transport")]
mod tests {
use crate::transport::{
combined::{self, OpenError},
flume, Connector,
};
#[tokio::test]
async fn open_empty_channel() {
let channel = combined::CombinedConnector::<
flume::FlumeConnector<(), ()>,
flume::FlumeConnector<(), ()>,
>::new(None, None);
let res = channel.open().await;
assert!(matches!(res, Err(OpenError::NoChannel)));
}
}