use std::{
convert::Infallible, error, fmt, io, marker::PhantomData, net::SocketAddr, pin::Pin, result,
sync::Arc, task::Poll,
};
use bytes::Bytes;
use flume::{Receiver, Sender};
use futures_lite::{Stream, StreamExt};
use futures_sink::Sink;
use hyper::{
client::{connect::Connect, HttpConnector, ResponseFuture},
server::conn::{AddrIncoming, AddrStream},
service::{make_service_fn, service_fn},
Body, Client, Request, Response, Server, StatusCode, Uri,
};
use tokio::{sync::mpsc, task::JoinHandle};
use tracing::{debug, event, trace, Level};
use crate::{
transport::{ConnectionErrors, Connector, Listener, LocalAddr, StreamTypes},
RpcMessage,
};
struct HyperConnectionInner {
client: Box<dyn Requester>,
config: Arc<ChannelConfig>,
uri: Uri,
}
pub struct HyperConnector<In: RpcMessage, Out: RpcMessage> {
inner: Arc<HyperConnectionInner>,
_p: PhantomData<(In, Out)>,
}
impl<In: RpcMessage, Out: RpcMessage> Clone for HyperConnector<In, Out> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_p: PhantomData,
}
}
}
trait Requester: Send + Sync + 'static {
fn request(&self, req: Request<Body>) -> ResponseFuture;
}
impl<C: Connect + Clone + Send + Sync + 'static> Requester for Client<C, Body> {
fn request(&self, req: Request<Body>) -> ResponseFuture {
self.request(req)
}
}
impl<In: RpcMessage, Out: RpcMessage> HyperConnector<In, Out> {
pub fn new(uri: Uri) -> Self {
Self::with_config(uri, ChannelConfig::default())
}
pub fn with_config(uri: Uri, config: ChannelConfig) -> Self {
let mut connector = HttpConnector::new();
connector.set_nodelay(true);
Self::with_connector(connector, uri, Arc::new(config))
}
pub fn with_connector<C: Connect + Clone + Send + Sync + 'static>(
connector: C,
uri: Uri,
config: Arc<ChannelConfig>,
) -> Self {
let client = Client::builder()
.http2_only(true)
.http2_initial_connection_window_size(Some(config.max_frame_size))
.http2_initial_stream_window_size(Some(config.max_frame_size))
.http2_max_frame_size(Some(config.max_frame_size))
.http2_max_send_buf_size(config.max_frame_size.try_into().unwrap())
.build(connector);
Self {
inner: Arc::new(HyperConnectionInner {
client: Box::new(client),
uri,
config,
}),
_p: PhantomData,
}
}
}
impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for HyperConnector<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientChannel")
.field("uri", &self.inner.uri)
.field("config", &self.inner.config)
.finish()
}
}
type InternalChannel<In> = (
Receiver<result::Result<In, RecvError>>,
Sender<io::Result<Bytes>>,
);
#[derive(Debug, Clone)]
pub enum ChannelConfigError {
InvalidMaxFrameSize(u32),
InvalidMaxPayloadSize(usize),
}
impl fmt::Display for ChannelConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self, f)
}
}
impl error::Error for ChannelConfigError {}
#[derive(Debug, Clone)]
pub struct ChannelConfig {
max_frame_size: u32,
max_payload_size: usize,
}
impl ChannelConfig {
pub fn max_frame_size(mut self, value: u32) -> result::Result<Self, ChannelConfigError> {
if !(0x4000..=0xFFFFFF).contains(&value) {
return Err(ChannelConfigError::InvalidMaxFrameSize(value));
}
self.max_frame_size = value;
Ok(self)
}
pub fn max_payload_size(mut self, value: usize) -> result::Result<Self, ChannelConfigError> {
if !(4096..1024 * 1024 * 16).contains(&value) {
return Err(ChannelConfigError::InvalidMaxPayloadSize(value));
}
self.max_payload_size = value;
Ok(self)
}
}
impl Default for ChannelConfig {
fn default() -> Self {
Self {
max_frame_size: 0xFFFFFF,
max_payload_size: 0xFFFFFF,
}
}
}
#[derive(Debug)]
pub struct HyperListener<In: RpcMessage, Out: RpcMessage> {
channel: Receiver<InternalChannel<In>>,
config: Arc<ChannelConfig>,
stop_tx: mpsc::Sender<()>,
local_addr: [LocalAddr; 1],
_p: PhantomData<(In, Out)>,
}
impl<In: RpcMessage, Out: RpcMessage> HyperListener<In, Out> {
pub fn serve(addr: &SocketAddr) -> hyper::Result<Self> {
Self::serve_with_config(addr, Default::default())
}
pub fn serve_with_config(addr: &SocketAddr, config: ChannelConfig) -> hyper::Result<Self> {
let (accept_tx, accept_rx) = flume::bounded(32);
let service = make_service_fn(move |socket: &AddrStream| {
let remote_addr = socket.remote_addr();
event!(Level::TRACE, "Connection from {:?}", remote_addr);
let accept_tx = accept_tx.clone();
async move {
let one_req_service = service_fn(move |req: Request<Body>| {
Self::handle_one_http2_request(req, accept_tx.clone())
});
Ok::<_, Infallible>(one_req_service)
}
});
let mut incoming = AddrIncoming::bind(addr)?;
incoming.set_nodelay(true);
let server = Server::builder(incoming)
.http2_only(true)
.http2_initial_connection_window_size(Some(config.max_frame_size))
.http2_initial_stream_window_size(Some(config.max_frame_size))
.http2_max_frame_size(Some(config.max_frame_size))
.http2_max_send_buf_size(config.max_frame_size.try_into().unwrap())
.serve(service);
let local_addr = server.local_addr();
let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1);
let server = server.with_graceful_shutdown(async move {
stop_rx.recv().await;
});
tokio::spawn(server);
Ok(Self {
channel: accept_rx,
config: Arc::new(config),
stop_tx,
local_addr: [LocalAddr::Socket(local_addr)],
_p: PhantomData,
})
}
async fn handle_one_http2_request(
req: Request<Body>,
accept_tx: Sender<InternalChannel<In>>,
) -> Result<Response<Body>, String> {
let (req_tx, req_rx) = flume::bounded::<result::Result<In, RecvError>>(32);
let (res_tx, res_rx) = flume::bounded::<io::Result<Bytes>>(32);
accept_tx
.send_async((req_rx, res_tx))
.await
.map_err(|_e| "unable to send")?;
spawn_recv_forwarder(req.into_body(), req_tx);
let response = Response::builder()
.status(StatusCode::OK)
.body(Body::wrap_stream(res_rx.into_stream()))
.map_err(|_| "unable to set body")?;
Ok(response)
}
}
fn try_get_length_prefixed(buf: &[u8]) -> Option<&[u8]> {
if buf.len() < 4 {
return None;
}
let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if buf.len() < 4 + len {
return None;
}
Some(&buf[4..4 + len])
}
async fn try_forward_all<In: RpcMessage>(
buffer: &[u8],
req_tx: &Sender<Result<In, RecvError>>,
) -> result::Result<usize, ()> {
let mut sent = 0;
while let Some(msg) = try_get_length_prefixed(&buffer[sent..]) {
sent += msg.len() + 4;
let item = postcard::from_bytes::<In>(msg).map_err(RecvError::DeserializeError);
if let Err(_cause) = req_tx.send_async(item).await {
trace!("Flume receiver dropped");
return Err(());
}
}
Ok(sent)
}
fn spawn_recv_forwarder<In: RpcMessage>(
req: Body,
req_tx: Sender<result::Result<In, RecvError>>,
) -> JoinHandle<result::Result<(), ()>> {
tokio::spawn(async move {
let mut stream = req;
let mut buf = Vec::new();
while let Some(chunk) = stream.next().await {
match chunk.as_ref() {
Ok(chunk) => {
event!(Level::TRACE, "Server got {} bytes", chunk.len());
if buf.is_empty() {
let sent = try_forward_all(chunk, &req_tx).await?;
buf.extend_from_slice(&chunk[sent..]);
} else {
buf.extend_from_slice(chunk);
}
}
Err(cause) => {
debug!("Network error: {}", cause);
break;
}
};
let sent = try_forward_all(&buf, &req_tx).await?;
buf.drain(..sent);
}
Ok(())
})
}
impl<In: RpcMessage, Out: RpcMessage> Clone for HyperListener<In, Out> {
fn clone(&self) -> Self {
Self {
channel: self.channel.clone(),
stop_tx: self.stop_tx.clone(),
local_addr: self.local_addr.clone(),
config: self.config.clone(),
_p: PhantomData,
}
}
}
pub struct RecvStream<Res: RpcMessage> {
recv: flume::r#async::RecvStream<'static, result::Result<Res, RecvError>>,
}
impl<Res: RpcMessage> RecvStream<Res> {
pub fn new(recv: flume::Receiver<result::Result<Res, RecvError>>) -> Self {
Self {
recv: recv.into_stream(),
}
}
}
impl<In: RpcMessage> Clone for RecvStream<In> {
fn clone(&self) -> Self {
Self {
recv: self.recv.clone(),
}
}
}
impl<Res: RpcMessage> Stream for RecvStream<Res> {
type Item = Result<Res, RecvError>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.recv).poll_next(cx)
}
}
pub struct SendSink<Out: RpcMessage> {
sink: flume::r#async::SendSink<'static, io::Result<Bytes>>,
config: Arc<ChannelConfig>,
_p: PhantomData<Out>,
}
impl<Out: RpcMessage> SendSink<Out> {
fn new(sender: flume::Sender<io::Result<Bytes>>, config: Arc<ChannelConfig>) -> Self {
Self {
sink: sender.into_sink(),
config,
_p: PhantomData,
}
}
fn serialize(&self, item: Out) -> Result<Bytes, SendError> {
let mut data = Vec::with_capacity(1024);
data.extend_from_slice(&[0u8; 4]);
let mut data = postcard::to_extend(&item, data).map_err(SendError::SerializeError)?;
let len = data.len() - 4;
if len > self.config.max_payload_size {
return Err(SendError::SizeError(len));
}
let len: u32 = len.try_into().expect("max_payload_size fits into u32");
data[0..4].copy_from_slice(&len.to_be_bytes());
Ok(data.into())
}
pub fn into_inner(self) -> flume::r#async::SendSink<'static, io::Result<Bytes>> {
self.sink
}
}
impl<Out: RpcMessage> Sink<Out> for SendSink<Out> {
type Error = SendError;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink)
.poll_ready(cx)
.map_err(|_| SendError::ReceiverDropped)
}
fn start_send(mut self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
let (send, res) = match self.serialize(item) {
Ok(data) => (Ok(data), Ok(())),
Err(cause) => (
Err(io::Error::new(io::ErrorKind::Other, cause.to_string())),
Err(cause),
),
};
Pin::new(&mut self.sink)
.start_send(send)
.map_err(|_| SendError::ReceiverDropped)?;
res
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink)
.poll_flush(cx)
.map_err(|_| SendError::ReceiverDropped)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink)
.poll_close(cx)
.map_err(|_| SendError::ReceiverDropped)
}
}
#[derive(Debug)]
pub enum SendError {
SerializeError(postcard::Error),
SizeError(usize),
ReceiverDropped,
}
impl fmt::Display for SendError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self, f)
}
}
impl error::Error for SendError {}
#[derive(Debug)]
pub enum RecvError {
DeserializeError(postcard::Error),
NetworkError(hyper::Error),
}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self, f)
}
}
impl error::Error for RecvError {}
#[derive(Debug)]
pub enum OpenError {
HyperHttp(hyper::http::Error),
Hyper(hyper::Error),
RemoteDropped,
}
impl fmt::Display for OpenError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl std::error::Error for OpenError {}
#[derive(Debug)]
pub enum AcceptError {
Hyper(hyper::http::Error),
RemoteDropped,
}
impl fmt::Display for AcceptError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl error::Error for AcceptError {}
impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for HyperConnector<In, Out> {
type SendError = self::SendError;
type RecvError = self::RecvError;
type OpenError = OpenError;
type AcceptError = AcceptError;
}
impl<In: RpcMessage, Out: RpcMessage> StreamTypes for HyperConnector<In, Out> {
type In = In;
type Out = Out;
type RecvStream = self::RecvStream<In>;
type SendSink = self::SendSink<Out>;
}
impl<In: RpcMessage, Out: RpcMessage> Connector for HyperConnector<In, Out> {
async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
let (out_tx, out_rx) = flume::bounded::<io::Result<Bytes>>(32);
let req: Request<Body> = Request::post(&self.inner.uri)
.body(Body::wrap_stream(out_rx.into_stream()))
.map_err(OpenError::HyperHttp)?;
let res = self
.inner
.client
.request(req)
.await
.map_err(OpenError::Hyper)?;
let (in_tx, in_rx) = flume::bounded::<result::Result<In, RecvError>>(32);
spawn_recv_forwarder(res.into_body(), in_tx);
let out_tx = self::SendSink::new(out_tx, self.inner.config.clone());
let in_rx = self::RecvStream::new(in_rx);
Ok((out_tx, in_rx))
}
}
impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for HyperListener<In, Out> {
type SendError = self::SendError;
type RecvError = self::RecvError;
type OpenError = AcceptError;
type AcceptError = AcceptError;
}
impl<In: RpcMessage, Out: RpcMessage> StreamTypes for HyperListener<In, Out> {
type In = In;
type Out = Out;
type RecvStream = self::RecvStream<In>;
type SendSink = self::SendSink<Out>;
}
impl<In: RpcMessage, Out: RpcMessage> Listener for HyperListener<In, Out> {
fn local_addr(&self) -> &[LocalAddr] {
&self.local_addr
}
async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
let (recv, send) = self
.channel
.recv_async()
.await
.map_err(|_| AcceptError::RemoteDropped)?;
Ok((
SendSink::new(send, self.config.clone()),
RecvStream::new(recv),
))
}
}