use crate::mediator;
use crate::wrappers::*;
use crate::Error;
use crate::MakeTransport;
use futures_core::{
future::Future,
ready,
stream::TryStream,
task::{Context, Poll},
};
use futures_sink::Sink;
use pin_project::pin_project;
use std::collections::VecDeque;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::{atomic, Arc};
use std::{error, fmt};
use tower_service::Service;
pub struct Maker<NT, Request> {
t_maker: NT,
_req: PhantomData<fn(Request)>,
}
impl<NT, Request> fmt::Debug for Maker<NT, Request>
where
NT: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Maker")
.field("t_maker", &self.t_maker)
.finish()
}
}
impl<NT, Request> Maker<NT, Request> {
pub fn new(t: NT) -> Self {
Maker {
t_maker: t,
_req: PhantomData,
}
}
}
#[derive(Debug)]
pub enum SpawnError<E> {
SpawnFailed,
Inner(E),
}
impl<NT, Target, Request> Service<Target> for Maker<NT, Request>
where
NT: MakeTransport<Target, Request>,
NT::Transport: 'static + Send,
Request: 'static + Send,
NT::Item: 'static + Send,
NT::SinkError: 'static + Send + Sync,
NT::Error: 'static + Send + Sync,
NT::Future: 'static + Send,
{
type Error = SpawnError<NT::MakeError>;
type Response = Client<NT::Transport, Error<NT::Transport, Request>, Request>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&mut self, target: Target) -> Self::Future {
let maker = self.t_maker.make_transport(target);
Box::pin(async move { Ok(Client::new(maker.await.map_err(SpawnError::Inner)?)) })
}
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.t_maker.poll_ready(cx).map_err(SpawnError::Inner)
}
}
impl<NT, Request> tower_load::Load for Maker<NT, Request> {
type Metric = u8;
fn load(&self) -> Self::Metric {
0
}
}
pub struct Client<T, E, Request>
where
T: Sink<Request> + TryStream,
{
mediator: mediator::Sender<ClientRequest<T, Request>>,
in_flight: Arc<atomic::AtomicUsize>,
_error: PhantomData<fn(E)>,
}
impl<T, E, Request> fmt::Debug for Client<T, E, Request>
where
T: Sink<Request> + TryStream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Client")
.field("mediator", &self.mediator)
.field("in_flight", &self.in_flight)
.finish()
}
}
struct Pending<Item> {
tx: tokio::sync::oneshot::Sender<ClientResponse<Item>>,
span: tracing::Span,
}
#[pin_project]
struct ClientInner<T, E, Request>
where
T: Sink<Request> + TryStream,
{
mediator: mediator::Receiver<ClientRequest<T, Request>>,
responses: VecDeque<Pending<T::Ok>>,
#[pin]
transport: T,
in_flight: Arc<atomic::AtomicUsize>,
finish: bool,
rx_only: bool,
#[allow(unused)]
error: PhantomData<fn(E)>,
}
impl<T, E, Request> Client<T, E, Request>
where
T: Sink<Request> + TryStream + Send + 'static,
E: From<Error<T, Request>>,
E: 'static + Send,
Request: 'static + Send,
T::Ok: 'static + Send,
{
pub fn new(transport: T) -> Self where {
Self::with_error_handler(transport, |_| {})
}
pub fn with_error_handler<F>(transport: T, on_service_error: F) -> Self
where
F: FnOnce(E) + Send + 'static,
{
let (tx, rx) = mediator::new();
let in_flight = Arc::new(atomic::AtomicUsize::new(0));
tokio::spawn({
let c = ClientInner {
mediator: rx,
responses: Default::default(),
transport,
in_flight: in_flight.clone(),
error: PhantomData::<fn(E)>,
finish: false,
rx_only: false,
};
async move {
if let Err(e) = c.await {
on_service_error(e);
}
}
});
Client {
mediator: tx,
in_flight,
_error: PhantomData,
}
}
}
impl<T, E, Request> Future for ClientInner<T, E, Request>
where
T: Sink<Request> + TryStream,
E: From<Error<T, Request>>,
E: 'static + Send,
Request: 'static + Send,
T::Ok: 'static + Send,
{
type Output = Result<(), E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut transport: Pin<_> = this.transport;
let mut i = 0;
if !*this.finish {
while let Poll::Ready(r) = transport.as_mut().poll_ready(cx) {
if let Err(e) = r {
return Poll::Ready(Err(E::from(Error::from_sink_error(e))));
}
match this.mediator.try_recv(cx) {
Poll::Ready(Some(ClientRequest {
req,
span: _span,
res,
})) => {
let guard = _span.enter();
tracing::trace!("request received by worker; sending to Sink");
transport
.as_mut()
.start_send(req)
.map_err(Error::from_sink_error)?;
tracing::trace!("request sent");
drop(guard);
this.responses.push_back(Pending {
tx: res,
span: _span,
});
this.in_flight.fetch_add(1, atomic::Ordering::AcqRel);
i += 1;
if i == crate::YIELD_EVERY {
cx.waker().wake_by_ref();
break;
}
}
Poll::Ready(None) => {
*this.finish = true;
break;
}
Poll::Pending => {
break;
}
}
}
}
if this.in_flight.load(atomic::Ordering::Acquire) != 0 && !*this.rx_only {
if *this.finish {
let r = transport
.as_mut()
.poll_close(cx)
.map_err(Error::from_sink_error)?;
if r.is_ready() {
*this.rx_only = true;
}
} else {
let _ = transport
.as_mut()
.poll_flush(cx)
.map_err(Error::from_sink_error)?;
}
}
while this.in_flight.load(atomic::Ordering::Acquire) != 0 {
match ready!(transport.as_mut().try_poll_next(cx))
.transpose()
.map_err(Error::from_stream_error)?
{
Some(r) => {
let pending = this.responses.pop_front().ok_or(Error::Desynchronized)?;
tracing::trace!(parent: &pending.span, "response arrived; forwarding");
let sender = pending.tx;
let _ = sender.send(ClientResponse {
response: r,
span: pending.span,
});
this.in_flight.fetch_sub(1, atomic::Ordering::AcqRel);
}
None => {
return Poll::Ready(Err(E::from(Error::BrokenTransportRecv(None))));
}
}
}
if *this.finish && this.in_flight.load(atomic::Ordering::Acquire) == 0 {
if *this.rx_only {
} else {
ready!(transport.poll_close(cx)).map_err(Error::from_sink_error)?;
}
return Poll::Ready(Ok(()));
}
Poll::Pending
}
}
impl<T, E, Request> Service<Request> for Client<T, E, Request>
where
T: Sink<Request> + TryStream,
E: From<Error<T, Request>>,
E: 'static + Send,
Request: 'static + Send,
T: 'static,
T::Ok: 'static + Send,
{
type Response = T::Ok;
type Error = E;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
Poll::Ready(ready!(self.mediator.poll_ready(cx)).map_err(|_| E::from(Error::ClientDropped)))
}
fn call(&mut self, req: Request) -> Self::Future {
let (tx, rx) = tokio::sync::oneshot::channel();
let span = tracing::Span::current();
tracing::trace!("issuing request");
let req = ClientRequest { req, span, res: tx };
let r = self.mediator.try_send(req);
Box::pin(async move {
match r {
Ok(()) => match rx.await {
Ok(r) => {
tracing::trace!(parent: &r.span, "response returned");
Ok(r.response)
}
Err(_) => Err(E::from(Error::ClientDropped)),
},
Err(_) => Err(E::from(Error::TransportFull)),
}
})
}
}
impl<T, E, Request> tower_load::Load for Client<T, E, Request>
where
T: Sink<Request> + TryStream,
{
type Metric = usize;
fn load(&self) -> Self::Metric {
self.in_flight.load(atomic::Ordering::Acquire)
}
}
impl<T> fmt::Display for SpawnError<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
SpawnError::SpawnFailed => write!(f, "error spawning multiplex client"),
SpawnError::Inner(ref te) => {
write!(f, "error making new multiplex transport: {:?}", te)
}
}
}
}
impl<T> error::Error for SpawnError<T>
where
T: error::Error,
{
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
SpawnError::SpawnFailed => None,
SpawnError::Inner(ref te) => Some(te),
}
}
fn description(&self) -> &str {
"error creating new multiplex client"
}
}