use std::{
error,
fmt::{self, Debug},
marker::PhantomData,
pin::Pin,
result,
sync::Arc,
task::{self, Poll},
};
use futures_lite::{Future, Stream, StreamExt};
use futures_util::{SinkExt, TryStreamExt};
use pin_project::pin_project;
use tokio::{sync::oneshot, task::JoinSet};
use tokio_util::task::AbortOnDropHandle;
use tracing::{error, warn};
use crate::{
transport::{
self,
boxed::BoxableListener,
mapped::{ErrorOrMapError, MappedRecvStream, MappedSendSink, MappedStreamTypes},
ConnectionErrors, StreamTypes,
},
Listener, RpcMessage, Service,
};
pub trait ChannelTypes<S: Service>: transport::StreamTypes<In = S::Req, Out = S::Res> {}
impl<T: transport::StreamTypes<In = S::Req, Out = S::Res>, S: Service> ChannelTypes<S> for T {}
pub type BoxedChannelTypes<S> = crate::transport::boxed::BoxedStreamTypes<
<S as crate::Service>::Req,
<S as crate::Service>::Res,
>;
pub type BoxedListener<S> =
crate::transport::boxed::BoxedListener<<S as crate::Service>::Req, <S as crate::Service>::Res>;
#[derive(Debug)]
pub struct RpcServer<S, C = BoxedListener<S>> {
source: C,
_p: PhantomData<S>,
}
impl<S, C: Clone> Clone for RpcServer<S, C> {
fn clone(&self) -> Self {
Self {
source: self.source.clone(),
_p: PhantomData,
}
}
}
impl<S: Service, C: Listener<S>> RpcServer<S, C> {
pub fn new(source: C) -> Self {
Self {
source,
_p: PhantomData,
}
}
pub fn boxed(self) -> RpcServer<S, BoxedListener<S>>
where
C: BoxableListener<S::Req, S::Res>,
{
RpcServer::new(self.source.boxed())
}
}
#[derive(Debug)]
pub struct RpcChannel<S: Service, C: ChannelTypes<S> = BoxedChannelTypes<S>> {
pub send: C::SendSink,
pub recv: C::RecvStream,
pub(crate) _p: PhantomData<S>,
}
impl<S, C> RpcChannel<S, C>
where
S: Service,
C: StreamTypes<In = S::Req, Out = S::Res>,
{
pub fn new(send: C::SendSink, recv: C::RecvStream) -> Self {
Self {
send,
recv,
_p: PhantomData,
}
}
pub fn boxed(self) -> RpcChannel<S, BoxedChannelTypes<S>>
where
C::SendError: Into<anyhow::Error> + Send + Sync + 'static,
C::RecvError: Into<anyhow::Error> + Send + Sync + 'static,
{
let send =
transport::boxed::SendSink::boxed(Box::new(self.send.sink_map_err(|e| e.into())));
let recv = transport::boxed::RecvStream::boxed(Box::new(self.recv.map_err(|e| e.into())));
RpcChannel::new(send, recv)
}
pub fn map<SNext>(self) -> RpcChannel<SNext, MappedStreamTypes<SNext::Req, SNext::Res, C>>
where
SNext: Service,
SNext::Req: TryFrom<S::Req>,
S::Res: From<SNext::Res>,
{
RpcChannel::new(
MappedSendSink::new(self.send),
MappedRecvStream::new(self.recv),
)
}
}
pub struct Accepting<S: Service, C: Listener<S>> {
send: C::SendSink,
recv: C::RecvStream,
_p: PhantomData<S>,
}
impl<S: Service, C: Listener<S>> Accepting<S, C> {
pub async fn read_first(self) -> result::Result<(S::Req, RpcChannel<S, C>), RpcServerError<C>> {
let Accepting { send, mut recv, .. } = self;
let request: S::Req = recv
.next()
.await
.ok_or(RpcServerError::EarlyClose)?
.map_err(RpcServerError::RecvError)?;
Ok((request, RpcChannel::<S, C>::new(send, recv)))
}
}
impl<S: Service, C: Listener<S>> RpcServer<S, C> {
pub async fn accept(&self) -> result::Result<Accepting<S, C>, RpcServerError<C>> {
let (send, recv) = self.source.accept().await.map_err(RpcServerError::Accept)?;
Ok(Accepting {
send,
recv,
_p: PhantomData,
})
}
pub fn into_inner(self) -> C {
self.source
}
pub async fn accept_loop<Fun, Fut, E>(self, handler: Fun)
where
S: Service,
C: Listener<S>,
Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: Into<anyhow::Error> + 'static,
{
let handler = Arc::new(handler);
let mut tasks = JoinSet::new();
loop {
tokio::select! {
Some(res) = tasks.join_next(), if !tasks.is_empty() => {
if let Err(e) = res {
if e.is_panic() {
error!("Panic handling RPC request: {e}");
}
}
}
req = self.accept() => {
let req = match req {
Ok(req) => req,
Err(e) => {
warn!("Error accepting RPC request: {e}");
continue;
}
};
let handler = handler.clone();
tasks.spawn(async move {
let (req, chan) = match req.read_first().await {
Ok((req, chan)) => (req, chan),
Err(e) => {
warn!("Error reading first message: {e}");
return;
}
};
if let Err(cause) = handler(req, chan).await {
warn!("Error handling RPC request: {}", cause.into());
}
});
}
}
}
}
pub fn spawn_accept_loop<Fun, Fut, E>(self, handler: Fun) -> AbortOnDropHandle<()>
where
S: Service,
C: Listener<S>,
Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: Into<anyhow::Error> + 'static,
{
AbortOnDropHandle::new(tokio::spawn(self.accept_loop(handler)))
}
}
impl<S: Service, C: Listener<S>> AsRef<C> for RpcServer<S, C> {
fn as_ref(&self) -> &C {
&self.source
}
}
#[pin_project]
#[derive(Debug)]
pub struct UpdateStream<C, T>(
#[pin] C::RecvStream,
Option<oneshot::Sender<RpcServerError<C>>>,
PhantomData<T>,
)
where
C: StreamTypes;
impl<C, T> UpdateStream<C, T>
where
C: StreamTypes,
T: TryFrom<C::In>,
{
pub(crate) fn new(recv: C::RecvStream) -> (Self, UnwrapToPending<RpcServerError<C>>) {
let (error_send, error_recv) = oneshot::channel();
let error_recv = UnwrapToPending(error_recv);
(Self(recv, Some(error_send), PhantomData), error_recv)
}
}
impl<C, T> Stream for UpdateStream<C, T>
where
C: StreamTypes,
T: TryFrom<C::In>,
{
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
match Pin::new(&mut this.0).poll_next(cx) {
Poll::Ready(Some(msg)) => match msg {
Ok(msg) => {
let msg = T::try_from(msg).map_err(|_cause| ());
match msg {
Ok(msg) => Poll::Ready(Some(msg)),
Err(_cause) => {
if let Some(tx) = this.1.take() {
let _ = tx.send(RpcServerError::UnexpectedUpdateMessage);
}
Poll::Pending
}
}
}
Err(cause) => {
if let Some(tx) = this.1.take() {
let _ = tx.send(RpcServerError::RecvError(cause));
}
Poll::Pending
}
},
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
pub enum RpcServerError<C: ConnectionErrors> {
Accept(C::AcceptError),
EarlyClose,
UnexpectedStartMessage,
RecvError(C::RecvError),
SendError(C::SendError),
UnexpectedUpdateMessage,
}
impl<In: RpcMessage, Out: RpcMessage, C: ConnectionErrors>
RpcServerError<MappedStreamTypes<In, Out, C>>
{
pub fn map_back(self) -> RpcServerError<C> {
match self {
RpcServerError::EarlyClose => RpcServerError::EarlyClose,
RpcServerError::UnexpectedStartMessage => RpcServerError::UnexpectedStartMessage,
RpcServerError::UnexpectedUpdateMessage => RpcServerError::UnexpectedUpdateMessage,
RpcServerError::SendError(x) => RpcServerError::SendError(x),
RpcServerError::Accept(x) => RpcServerError::Accept(x),
RpcServerError::RecvError(ErrorOrMapError::Inner(x)) => RpcServerError::RecvError(x),
RpcServerError::RecvError(ErrorOrMapError::Conversion) => {
RpcServerError::UnexpectedUpdateMessage
}
}
}
}
impl<C: ConnectionErrors> RpcServerError<C> {
pub fn errors_into<T>(self) -> RpcServerError<T>
where
T: ConnectionErrors,
C::SendError: Into<T::SendError>,
C::RecvError: Into<T::RecvError>,
C::AcceptError: Into<T::AcceptError>,
{
match self {
RpcServerError::EarlyClose => RpcServerError::EarlyClose,
RpcServerError::UnexpectedStartMessage => RpcServerError::UnexpectedStartMessage,
RpcServerError::UnexpectedUpdateMessage => RpcServerError::UnexpectedUpdateMessage,
RpcServerError::SendError(x) => RpcServerError::SendError(x.into()),
RpcServerError::Accept(x) => RpcServerError::Accept(x.into()),
RpcServerError::RecvError(x) => RpcServerError::RecvError(x.into()),
}
}
}
impl<C: ConnectionErrors> fmt::Debug for RpcServerError<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Accept(arg0) => f.debug_tuple("Open").field(arg0).finish(),
Self::EarlyClose => write!(f, "EarlyClose"),
Self::RecvError(arg0) => f.debug_tuple("RecvError").field(arg0).finish(),
Self::SendError(arg0) => f.debug_tuple("SendError").field(arg0).finish(),
Self::UnexpectedStartMessage => f.debug_tuple("UnexpectedStartMessage").finish(),
Self::UnexpectedUpdateMessage => f.debug_tuple("UnexpectedStartMessage").finish(),
}
}
}
impl<C: ConnectionErrors> fmt::Display for RpcServerError<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt::Debug::fmt(&self, f)
}
}
impl<C: ConnectionErrors> error::Error for RpcServerError<C> {}
pub(crate) struct UnwrapToPending<T>(oneshot::Receiver<T>);
impl<T> Future for UnwrapToPending<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.0).poll(cx) {
Poll::Ready(Ok(x)) => Poll::Ready(x),
Poll::Ready(Err(_)) => Poll::Pending,
Poll::Pending => Poll::Pending,
}
}
}
pub(crate) async fn race2<T, A: Future<Output = T>, B: Future<Output = T>>(f1: A, f2: B) -> T {
tokio::select! {
x = f1 => x,
x = f2 => x,
}
}
pub async fn run_server_loop<S, C, T, F, Fut>(
_service_type: S,
conn: C,
target: T,
mut handler: F,
) -> Result<(), RpcServerError<C>>
where
S: Service,
C: Listener<S>,
T: Clone + Send + 'static,
F: FnMut(RpcChannel<S, C>, S::Req, T) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), RpcServerError<C>>> + Send + 'static,
{
let server: RpcServer<S, C> = RpcServer::<S, C>::new(conn);
loop {
let (req, chan) = server.accept().await?.read_first().await?;
let target = target.clone();
handler(chan, req, target).await?;
}
}