quic_rpc/transport/
hyper.rs

1//! http2 transport using [hyper]
2//!
3//! [hyper]: https://crates.io/crates/hyper/
4use std::{
5    convert::Infallible, error, fmt, io, marker::PhantomData, net::SocketAddr, pin::Pin, result,
6    sync::Arc, task::Poll,
7};
8
9use bytes::Bytes;
10use flume::{Receiver, Sender};
11use futures_lite::{Stream, StreamExt};
12use futures_sink::Sink;
13use hyper::{
14    client::{connect::Connect, HttpConnector, ResponseFuture},
15    server::conn::{AddrIncoming, AddrStream},
16    service::{make_service_fn, service_fn},
17    Body, Client, Request, Response, Server, StatusCode, Uri,
18};
19use tokio::{sync::mpsc, task::JoinHandle};
20use tracing::{debug, event, trace, Level};
21
22use crate::{
23    transport::{ConnectionErrors, Connector, Listener, LocalAddr, StreamTypes},
24    RpcMessage,
25};
26
27struct HyperConnectionInner {
28    client: Box<dyn Requester>,
29    config: Arc<ChannelConfig>,
30    uri: Uri,
31}
32
33/// Hyper based connection to a server
34pub struct HyperConnector<In: RpcMessage, Out: RpcMessage> {
35    inner: Arc<HyperConnectionInner>,
36    _p: PhantomData<(In, Out)>,
37}
38
39impl<In: RpcMessage, Out: RpcMessage> Clone for HyperConnector<In, Out> {
40    fn clone(&self) -> Self {
41        Self {
42            inner: self.inner.clone(),
43            _p: PhantomData,
44        }
45    }
46}
47
48/// Trait so we don't have to drag around the hyper internals
49trait Requester: Send + Sync + 'static {
50    fn request(&self, req: Request<Body>) -> ResponseFuture;
51}
52
53impl<C: Connect + Clone + Send + Sync + 'static> Requester for Client<C, Body> {
54    fn request(&self, req: Request<Body>) -> ResponseFuture {
55        self.request(req)
56    }
57}
58
59impl<In: RpcMessage, Out: RpcMessage> HyperConnector<In, Out> {
60    /// create a client given an uri and the default configuration
61    pub fn new(uri: Uri) -> Self {
62        Self::with_config(uri, ChannelConfig::default())
63    }
64
65    /// create a client given an uri and a custom configuration
66    pub fn with_config(uri: Uri, config: ChannelConfig) -> Self {
67        let mut connector = HttpConnector::new();
68        connector.set_nodelay(true);
69        Self::with_connector(connector, uri, Arc::new(config))
70    }
71
72    /// create a client given an uri and a custom configuration
73    pub fn with_connector<C: Connect + Clone + Send + Sync + 'static>(
74        connector: C,
75        uri: Uri,
76        config: Arc<ChannelConfig>,
77    ) -> Self {
78        let client = Client::builder()
79            .http2_only(true)
80            .http2_initial_connection_window_size(Some(config.max_frame_size))
81            .http2_initial_stream_window_size(Some(config.max_frame_size))
82            .http2_max_frame_size(Some(config.max_frame_size))
83            .http2_max_send_buf_size(config.max_frame_size.try_into().unwrap())
84            .build(connector);
85        Self {
86            inner: Arc::new(HyperConnectionInner {
87                client: Box::new(client),
88                uri,
89                config,
90            }),
91            _p: PhantomData,
92        }
93    }
94}
95
96impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for HyperConnector<In, Out> {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.debug_struct("ClientChannel")
99            .field("uri", &self.inner.uri)
100            .field("config", &self.inner.config)
101            .finish()
102    }
103}
104
105/// A flume sender and receiver tuple.
106type InternalChannel<In> = (
107    Receiver<result::Result<In, RecvError>>,
108    Sender<io::Result<Bytes>>,
109);
110
111/// Error when setting a channel configuration
112#[derive(Debug, Clone)]
113pub enum ChannelConfigError {
114    /// The maximum frame size is invalid
115    InvalidMaxFrameSize(u32),
116    /// The maximum payload size is invalid
117    InvalidMaxPayloadSize(usize),
118}
119
120impl fmt::Display for ChannelConfigError {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        fmt::Debug::fmt(&self, f)
123    }
124}
125
126impl error::Error for ChannelConfigError {}
127
128/// Channel configuration
129///
130/// These settings apply to both client and server channels.
131#[derive(Debug, Clone)]
132pub struct ChannelConfig {
133    /// The maximum frame size to use.
134    max_frame_size: u32,
135    max_payload_size: usize,
136}
137
138impl ChannelConfig {
139    /// Set the maximum frame size.
140    pub fn max_frame_size(mut self, value: u32) -> result::Result<Self, ChannelConfigError> {
141        if !(0x4000..=0xFFFFFF).contains(&value) {
142            return Err(ChannelConfigError::InvalidMaxFrameSize(value));
143        }
144        self.max_frame_size = value;
145        Ok(self)
146    }
147
148    /// Set the maximum payload size.
149    pub fn max_payload_size(mut self, value: usize) -> result::Result<Self, ChannelConfigError> {
150        if !(4096..1024 * 1024 * 16).contains(&value) {
151            return Err(ChannelConfigError::InvalidMaxPayloadSize(value));
152        }
153        self.max_payload_size = value;
154        Ok(self)
155    }
156}
157
158impl Default for ChannelConfig {
159    fn default() -> Self {
160        Self {
161            max_frame_size: 0xFFFFFF,
162            max_payload_size: 0xFFFFFF,
163        }
164    }
165}
166
167/// A listener using a hyper server
168///
169/// Each request made by the any client connection this channel will yield a `(recv, send)`
170/// pair which allows receiving the request and sending the response.  Both these are
171/// channels themselves to support streaming requests and responses.
172///
173/// Creating this spawns a tokio task which runs the server, once dropped this task is shut
174/// down: no new connections will be accepted and existing channels will stop.
175#[derive(Debug)]
176pub struct HyperListener<In: RpcMessage, Out: RpcMessage> {
177    /// The channel.
178    channel: Receiver<InternalChannel<In>>,
179    /// The configuration.
180    config: Arc<ChannelConfig>,
181    /// The sender to stop the server.
182    ///
183    /// We never send anything over this really, simply dropping it makes the receiver
184    /// complete and will shut down the hyper server.
185    stop_tx: mpsc::Sender<()>,
186    /// The local address this server is bound to.
187    ///
188    /// This is useful when the listen address uses a random port, `:0`, to find out which
189    /// port was bound by the kernel.
190    local_addr: [LocalAddr; 1],
191    /// Phantom data for service
192    _p: PhantomData<(In, Out)>,
193}
194
195impl<In: RpcMessage, Out: RpcMessage> HyperListener<In, Out> {
196    /// Creates a server listening on the [`SocketAddr`], with the default configuration.
197    pub fn serve(addr: &SocketAddr) -> hyper::Result<Self> {
198        Self::serve_with_config(addr, Default::default())
199    }
200
201    /// Creates a server listening on the [`SocketAddr`] with a custom configuration.
202    pub fn serve_with_config(addr: &SocketAddr, config: ChannelConfig) -> hyper::Result<Self> {
203        let (accept_tx, accept_rx) = flume::bounded(32);
204
205        // The hyper "MakeService" which is called for each connection that is made to the
206        // server.  It creates another Service which handles a single request.
207        let service = make_service_fn(move |socket: &AddrStream| {
208            let remote_addr = socket.remote_addr();
209            event!(Level::TRACE, "Connection from {:?}", remote_addr);
210
211            // Need a new accept_tx to move to the future on every call of this FnMut.
212            let accept_tx = accept_tx.clone();
213            async move {
214                let one_req_service = service_fn(move |req: Request<Body>| {
215                    // This closure is an FnMut as well, so clone accept_tx once more.
216                    Self::handle_one_http2_request(req, accept_tx.clone())
217                });
218                Ok::<_, Infallible>(one_req_service)
219            }
220        });
221
222        let mut incoming = AddrIncoming::bind(addr)?;
223        incoming.set_nodelay(true);
224        let server = Server::builder(incoming)
225            .http2_only(true)
226            .http2_initial_connection_window_size(Some(config.max_frame_size))
227            .http2_initial_stream_window_size(Some(config.max_frame_size))
228            .http2_max_frame_size(Some(config.max_frame_size))
229            .http2_max_send_buf_size(config.max_frame_size.try_into().unwrap())
230            .serve(service);
231        let local_addr = server.local_addr();
232
233        let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1);
234        let server = server.with_graceful_shutdown(async move {
235            // If the sender is dropped this will also gracefully terminate the server.
236            stop_rx.recv().await;
237        });
238        tokio::spawn(server);
239
240        Ok(Self {
241            channel: accept_rx,
242            config: Arc::new(config),
243            stop_tx,
244            local_addr: [LocalAddr::Socket(local_addr)],
245            _p: PhantomData,
246        })
247    }
248
249    /// Handles a single HTTP2 request.
250    ///
251    /// This creates the channels to communicate the (optionally streaming) request and
252    /// response and sends them to the [`ServerChannel`].
253    async fn handle_one_http2_request(
254        req: Request<Body>,
255        accept_tx: Sender<InternalChannel<In>>,
256    ) -> Result<Response<Body>, String> {
257        let (req_tx, req_rx) = flume::bounded::<result::Result<In, RecvError>>(32);
258        let (res_tx, res_rx) = flume::bounded::<io::Result<Bytes>>(32);
259        accept_tx
260            .send_async((req_rx, res_tx))
261            .await
262            .map_err(|_e| "unable to send")?;
263
264        spawn_recv_forwarder(req.into_body(), req_tx);
265        // Create a response with the response body channel as the response body
266        let response = Response::builder()
267            .status(StatusCode::OK)
268            .body(Body::wrap_stream(res_rx.into_stream()))
269            .map_err(|_| "unable to set body")?;
270        Ok(response)
271    }
272}
273
274fn try_get_length_prefixed(buf: &[u8]) -> Option<&[u8]> {
275    if buf.len() < 4 {
276        return None;
277    }
278    let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
279    if buf.len() < 4 + len {
280        return None;
281    }
282    Some(&buf[4..4 + len])
283}
284
285/// Try forward all frames as deserialized messages from the buffer to the sender.
286///
287/// On success, returns the number of forwarded bytes.
288/// On forward error, returns the unit error.
289///
290/// Deserialization errors don't cause an error, they will be sent.
291/// On error the number of consumed bytes is not returned. There is nothing to do but
292/// to stop the forwarder since there is nowhere to forward to anymore.
293async fn try_forward_all<In: RpcMessage>(
294    buffer: &[u8],
295    req_tx: &Sender<Result<In, RecvError>>,
296) -> result::Result<usize, ()> {
297    let mut sent = 0;
298    while let Some(msg) = try_get_length_prefixed(&buffer[sent..]) {
299        sent += msg.len() + 4;
300        let item = postcard::from_bytes::<In>(msg).map_err(RecvError::DeserializeError);
301        if let Err(_cause) = req_tx.send_async(item).await {
302            // The receiver is gone, so we can't send any more data.
303            //
304            // This is a normal way for an interaction to end, when the server side is done processing
305            // the request and drops the receiver.
306            //
307            // don't log the cause. It does not contain any useful information.
308            trace!("Flume receiver dropped");
309            return Err(());
310        }
311    }
312    Ok(sent)
313}
314
315/// Spawns a task which forwards requests from the network to a flume channel.
316///
317/// This task will read chunks from the network, split them into length prefixed
318/// frames, deserialize those frames, and send the result to the flume channel.
319///
320/// If there is a network error or the flume channel closes or the request
321/// stream is simply ended this task will terminate.
322///
323/// So it is fine to ignore the returned [`JoinHandle`].
324///
325/// The HTTP2 request comes from *req* and the data is sent to `req_tx`.
326fn spawn_recv_forwarder<In: RpcMessage>(
327    req: Body,
328    req_tx: Sender<result::Result<In, RecvError>>,
329) -> JoinHandle<result::Result<(), ()>> {
330    tokio::spawn(async move {
331        let mut stream = req;
332        let mut buf = Vec::new();
333
334        while let Some(chunk) = stream.next().await {
335            match chunk.as_ref() {
336                Ok(chunk) => {
337                    event!(Level::TRACE, "Server got {} bytes", chunk.len());
338                    if buf.is_empty() {
339                        // try to forward directly from buffer
340                        let sent = try_forward_all(chunk, &req_tx).await?;
341                        // add just the rest, if any
342                        buf.extend_from_slice(&chunk[sent..]);
343                    } else {
344                        // no choice but to add it all
345                        buf.extend_from_slice(chunk);
346                    }
347                }
348                Err(cause) => {
349                    // Indicates that the connection has been closed on the client side.
350                    // This is a normal occurrence, e.g. when the client has raced the RPC
351                    // call with something else and has droppped the future.
352                    debug!("Network error: {}", cause);
353                    break;
354                }
355            };
356            let sent = try_forward_all(&buf, &req_tx).await?;
357            // remove the forwarded bytes.
358            // Frequently this will be the entire buffer, so no memcpy but just set the size to 0
359            buf.drain(..sent);
360        }
361        Ok(())
362    })
363}
364
365// This does not want or need RpcMessage to be clone but still want to clone the
366// ServerChannel and it's containing channels itself.  The derive macro can't cope with this
367// so this needs to be written by hand.
368impl<In: RpcMessage, Out: RpcMessage> Clone for HyperListener<In, Out> {
369    fn clone(&self) -> Self {
370        Self {
371            channel: self.channel.clone(),
372            stop_tx: self.stop_tx.clone(),
373            local_addr: self.local_addr.clone(),
374            config: self.config.clone(),
375            _p: PhantomData,
376        }
377    }
378}
379
380/// Receive stream for hyper channels.
381///
382/// This is a newtype wrapper around a [`flume::async::RecvStream`] of deserialized
383/// messages.
384pub struct RecvStream<Res: RpcMessage> {
385    recv: flume::r#async::RecvStream<'static, result::Result<Res, RecvError>>,
386}
387
388impl<Res: RpcMessage> RecvStream<Res> {
389    /// Creates a new [`RecvStream`] from a [`flume::Receiver`].
390    pub fn new(recv: flume::Receiver<result::Result<Res, RecvError>>) -> Self {
391        Self {
392            recv: recv.into_stream(),
393        }
394    }
395
396    // we can not write into_inner, since all we got is a stream of already
397    // framed and deserialize messages. Might want to change that...
398}
399
400impl<In: RpcMessage> Clone for RecvStream<In> {
401    fn clone(&self) -> Self {
402        Self {
403            recv: self.recv.clone(),
404        }
405    }
406}
407
408impl<Res: RpcMessage> Stream for RecvStream<Res> {
409    type Item = Result<Res, RecvError>;
410
411    fn poll_next(
412        mut self: Pin<&mut Self>,
413        cx: &mut std::task::Context<'_>,
414    ) -> Poll<Option<Self::Item>> {
415        Pin::new(&mut self.recv).poll_next(cx)
416    }
417}
418
419/// Send sink for hyper channels
420pub struct SendSink<Out: RpcMessage> {
421    sink: flume::r#async::SendSink<'static, io::Result<Bytes>>,
422    config: Arc<ChannelConfig>,
423    _p: PhantomData<Out>,
424}
425
426impl<Out: RpcMessage> SendSink<Out> {
427    fn new(sender: flume::Sender<io::Result<Bytes>>, config: Arc<ChannelConfig>) -> Self {
428        Self {
429            sink: sender.into_sink(),
430            config,
431            _p: PhantomData,
432        }
433    }
434    fn serialize(&self, item: Out) -> Result<Bytes, SendError> {
435        let mut data = Vec::with_capacity(1024);
436        data.extend_from_slice(&[0u8; 4]);
437        let mut data = postcard::to_extend(&item, data).map_err(SendError::SerializeError)?;
438        let len = data.len() - 4;
439        if len > self.config.max_payload_size {
440            return Err(SendError::SizeError(len));
441        }
442        let len: u32 = len.try_into().expect("max_payload_size fits into u32");
443        data[0..4].copy_from_slice(&len.to_be_bytes());
444        Ok(data.into())
445    }
446
447    /// Consumes the [`SendSink`] and returns the underlying [`flume::async::SendSink`].
448    ///
449    /// This is useful if you want to send raw [bytes::Bytes] without framing
450    /// directly to the channel.
451    pub fn into_inner(self) -> flume::r#async::SendSink<'static, io::Result<Bytes>> {
452        self.sink
453    }
454}
455
456impl<Out: RpcMessage> Sink<Out> for SendSink<Out> {
457    type Error = SendError;
458
459    fn poll_ready(
460        mut self: Pin<&mut Self>,
461        cx: &mut std::task::Context<'_>,
462    ) -> Poll<Result<(), Self::Error>> {
463        Pin::new(&mut self.sink)
464            .poll_ready(cx)
465            .map_err(|_| SendError::ReceiverDropped)
466    }
467
468    fn start_send(mut self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
469        // figure out what to send and what to return
470        let (send, res) = match self.serialize(item) {
471            Ok(data) => (Ok(data), Ok(())),
472            Err(cause) => (
473                Err(io::Error::new(io::ErrorKind::Other, cause.to_string())),
474                Err(cause),
475            ),
476        };
477        // attempt sending
478        Pin::new(&mut self.sink)
479            .start_send(send)
480            .map_err(|_| SendError::ReceiverDropped)?;
481        res
482    }
483
484    fn poll_flush(
485        mut self: Pin<&mut Self>,
486        cx: &mut std::task::Context<'_>,
487    ) -> Poll<Result<(), Self::Error>> {
488        Pin::new(&mut self.sink)
489            .poll_flush(cx)
490            .map_err(|_| SendError::ReceiverDropped)
491    }
492
493    fn poll_close(
494        mut self: Pin<&mut Self>,
495        cx: &mut std::task::Context<'_>,
496    ) -> Poll<Result<(), Self::Error>> {
497        Pin::new(&mut self.sink)
498            .poll_close(cx)
499            .map_err(|_| SendError::ReceiverDropped)
500    }
501}
502
503/// Send error for hyper channels.
504#[derive(Debug)]
505pub enum SendError {
506    /// Error when postcard serializing the message.
507    SerializeError(postcard::Error),
508    /// The message is too large to be sent.
509    SizeError(usize),
510    /// The connection has been closed.
511    ReceiverDropped,
512}
513
514impl fmt::Display for SendError {
515    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
516        fmt::Debug::fmt(&self, f)
517    }
518}
519
520impl error::Error for SendError {}
521
522/// Receive error for hyper channels.
523#[derive(Debug)]
524pub enum RecvError {
525    /// Error when postcard deserializing the message.
526    DeserializeError(postcard::Error),
527    /// Hyper network error.
528    NetworkError(hyper::Error),
529}
530
531impl fmt::Display for RecvError {
532    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
533        fmt::Debug::fmt(&self, f)
534    }
535}
536
537impl error::Error for RecvError {}
538
539/// OpenError for hyper channels.
540#[derive(Debug)]
541pub enum OpenError {
542    /// Hyper http error
543    HyperHttp(hyper::http::Error),
544    /// Generic hyper error
545    Hyper(hyper::Error),
546    /// The remote side of the channel was dropped
547    RemoteDropped,
548}
549
550impl fmt::Display for OpenError {
551    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552        fmt::Debug::fmt(self, f)
553    }
554}
555
556impl std::error::Error for OpenError {}
557
558/// AcceptError for hyper channels.
559///
560/// There is not much that can go wrong with hyper channels.
561#[derive(Debug)]
562pub enum AcceptError {
563    /// Hyper error
564    Hyper(hyper::http::Error),
565    /// The remote side of the channel was dropped
566    RemoteDropped,
567}
568
569impl fmt::Display for AcceptError {
570    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
571        fmt::Debug::fmt(self, f)
572    }
573}
574
575impl error::Error for AcceptError {}
576
577impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for HyperConnector<In, Out> {
578    type SendError = self::SendError;
579
580    type RecvError = self::RecvError;
581
582    type OpenError = OpenError;
583
584    type AcceptError = AcceptError;
585}
586
587impl<In: RpcMessage, Out: RpcMessage> StreamTypes for HyperConnector<In, Out> {
588    type In = In;
589    type Out = Out;
590    type RecvStream = self::RecvStream<In>;
591    type SendSink = self::SendSink<Out>;
592}
593
594impl<In: RpcMessage, Out: RpcMessage> Connector for HyperConnector<In, Out> {
595    async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
596        let (out_tx, out_rx) = flume::bounded::<io::Result<Bytes>>(32);
597        let req: Request<Body> = Request::post(&self.inner.uri)
598            .body(Body::wrap_stream(out_rx.into_stream()))
599            .map_err(OpenError::HyperHttp)?;
600        let res = self
601            .inner
602            .client
603            .request(req)
604            .await
605            .map_err(OpenError::Hyper)?;
606        let (in_tx, in_rx) = flume::bounded::<result::Result<In, RecvError>>(32);
607        spawn_recv_forwarder(res.into_body(), in_tx);
608
609        let out_tx = self::SendSink::new(out_tx, self.inner.config.clone());
610        let in_rx = self::RecvStream::new(in_rx);
611        Ok((out_tx, in_rx))
612    }
613}
614
615impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for HyperListener<In, Out> {
616    type SendError = self::SendError;
617    type RecvError = self::RecvError;
618    type OpenError = AcceptError;
619    type AcceptError = AcceptError;
620}
621
622impl<In: RpcMessage, Out: RpcMessage> StreamTypes for HyperListener<In, Out> {
623    type In = In;
624    type Out = Out;
625    type RecvStream = self::RecvStream<In>;
626    type SendSink = self::SendSink<Out>;
627}
628
629impl<In: RpcMessage, Out: RpcMessage> Listener for HyperListener<In, Out> {
630    fn local_addr(&self) -> &[LocalAddr] {
631        &self.local_addr
632    }
633
634    async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
635        let (recv, send) = self
636            .channel
637            .recv_async()
638            .await
639            .map_err(|_| AcceptError::RemoteDropped)?;
640        Ok((
641            SendSink::new(send, self.config.clone()),
642            RecvStream::new(recv),
643        ))
644    }
645}