quic_rpc/transport/
quinn.rs

1//! QUIC transport implementation based on [quinn](https://crates.io/crates/quinn)
2use std::{
3    fmt, io,
4    marker::PhantomData,
5    net::SocketAddr,
6    pin::Pin,
7    result,
8    sync::Arc,
9    task::{Context, Poll},
10};
11
12use futures_lite::{Future, Stream, StreamExt};
13use futures_sink::Sink;
14use futures_util::FutureExt;
15use pin_project::pin_project;
16use serde::{de::DeserializeOwned, Serialize};
17use tokio::sync::oneshot;
18use tracing::{debug_span, Instrument};
19
20use super::{
21    util::{FramedPostcardRead, FramedPostcardWrite},
22    StreamTypes,
23};
24use crate::{
25    transport::{ConnectionErrors, Connector, Listener, LocalAddr},
26    RpcMessage,
27};
28
29const MAX_FRAME_LENGTH: usize = 1024 * 1024 * 16;
30
31#[derive(Debug)]
32struct ListenerInner {
33    endpoint: Option<quinn::Endpoint>,
34    task: Option<tokio::task::JoinHandle<()>>,
35    local_addr: [LocalAddr; 1],
36    receiver: flume::Receiver<SocketInner>,
37}
38
39impl Drop for ListenerInner {
40    fn drop(&mut self) {
41        tracing::debug!("Dropping listener");
42        if let Some(endpoint) = self.endpoint.take() {
43            endpoint.close(0u32.into(), b"Listener dropped");
44
45            if let Ok(handle) = tokio::runtime::Handle::try_current() {
46                // spawn a task to wait for the endpoint to notify peers that it is closing
47                let span = debug_span!("closing listener");
48                handle.spawn(
49                    async move {
50                        endpoint.wait_idle().await;
51                    }
52                    .instrument(span),
53                );
54            }
55        }
56        if let Some(task) = self.task.take() {
57            task.abort()
58        }
59    }
60}
61
62/// A listener using a quinn connection
63#[derive(Debug)]
64pub struct QuinnListener<In: RpcMessage, Out: RpcMessage> {
65    inner: Arc<ListenerInner>,
66    _p: PhantomData<(In, Out)>,
67}
68
69impl<In: RpcMessage, Out: RpcMessage> QuinnListener<In, Out> {
70    /// handles RPC requests from a connection
71    ///
72    /// to cleanly shutdown the handler, drop the receiver side of the sender.
73    async fn connection_handler(connection: quinn::Connection, sender: flume::Sender<SocketInner>) {
74        loop {
75            tracing::debug!("Awaiting incoming bidi substream on existing connection...");
76            let bidi_stream = match connection.accept_bi().await {
77                Ok(bidi_stream) => bidi_stream,
78                Err(quinn::ConnectionError::ApplicationClosed(e)) => {
79                    tracing::debug!("Peer closed the connection {:?}", e);
80                    break;
81                }
82                Err(e) => {
83                    tracing::debug!("Error accepting stream: {}", e);
84                    break;
85                }
86            };
87            tracing::debug!("Sending substream to be handled... {}", bidi_stream.0.id());
88            if sender.send_async(bidi_stream).await.is_err() {
89                tracing::debug!("Receiver dropped");
90                break;
91            }
92        }
93    }
94
95    async fn endpoint_handler(endpoint: quinn::Endpoint, sender: flume::Sender<SocketInner>) {
96        loop {
97            tracing::debug!("Waiting for incoming connection...");
98            let connecting = match endpoint.accept().await {
99                Some(connecting) => connecting,
100                None => break,
101            };
102            tracing::debug!("Awaiting connection from connect...");
103            let conection = match connecting.await {
104                Ok(conection) => conection,
105                Err(e) => {
106                    tracing::warn!("Error accepting connection: {}", e);
107                    continue;
108                }
109            };
110            tracing::debug!(
111                "Connection established from {:?}",
112                conection.remote_address()
113            );
114            tracing::debug!("Spawning connection handler...");
115            tokio::spawn(Self::connection_handler(conection, sender.clone()));
116        }
117    }
118
119    /// Create a new server channel, given a quinn endpoint.
120    ///
121    /// The endpoint must be a server endpoint.
122    ///
123    /// The server channel will take care of listening on the endpoint and spawning
124    /// handlers for new connections.
125    pub fn new(endpoint: quinn::Endpoint) -> io::Result<Self> {
126        let local_addr = endpoint.local_addr()?;
127        let (sender, receiver) = flume::bounded(16);
128        let task = tokio::spawn(Self::endpoint_handler(endpoint.clone(), sender));
129        Ok(Self {
130            inner: Arc::new(ListenerInner {
131                endpoint: Some(endpoint),
132                task: Some(task),
133                local_addr: [LocalAddr::Socket(local_addr)],
134                receiver,
135            }),
136            _p: PhantomData,
137        })
138    }
139
140    /// Create a new server channel, given just a source of incoming connections
141    ///
142    /// This is useful if you want to manage the quinn endpoint yourself,
143    /// use multiple endpoints, or use an endpoint for multiple protocols.
144    pub fn handle_connections(
145        incoming: flume::Receiver<quinn::Connection>,
146        local_addr: SocketAddr,
147    ) -> Self {
148        let (sender, receiver) = flume::bounded(16);
149        let task = tokio::spawn(async move {
150            // just grab all connections and spawn a handler for each one
151            while let Ok(connection) = incoming.recv_async().await {
152                tokio::spawn(Self::connection_handler(connection, sender.clone()));
153            }
154        });
155        Self {
156            inner: Arc::new(ListenerInner {
157                endpoint: None,
158                task: Some(task),
159                local_addr: [LocalAddr::Socket(local_addr)],
160                receiver,
161            }),
162            _p: PhantomData,
163        }
164    }
165
166    /// Create a new server channel, given just a source of incoming substreams
167    ///
168    /// This is useful if you want to manage the quinn endpoint yourself,
169    /// use multiple endpoints, or use an endpoint for multiple protocols.
170    pub fn handle_substreams(
171        receiver: flume::Receiver<SocketInner>,
172        local_addr: SocketAddr,
173    ) -> Self {
174        Self {
175            inner: Arc::new(ListenerInner {
176                endpoint: None,
177                task: None,
178                local_addr: [LocalAddr::Socket(local_addr)],
179                receiver,
180            }),
181            _p: PhantomData,
182        }
183    }
184}
185
186impl<In: RpcMessage, Out: RpcMessage> Clone for QuinnListener<In, Out> {
187    fn clone(&self) -> Self {
188        Self {
189            inner: self.inner.clone(),
190            _p: PhantomData,
191        }
192    }
193}
194
195impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for QuinnListener<In, Out> {
196    type SendError = io::Error;
197    type RecvError = io::Error;
198    type OpenError = quinn::ConnectionError;
199    type AcceptError = quinn::ConnectionError;
200}
201
202impl<In: RpcMessage, Out: RpcMessage> StreamTypes for QuinnListener<In, Out> {
203    type In = In;
204    type Out = Out;
205    type SendSink = self::SendSink<Out>;
206    type RecvStream = self::RecvStream<In>;
207}
208
209impl<In: RpcMessage, Out: RpcMessage> Listener for QuinnListener<In, Out> {
210    async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
211        let (send, recv) = self
212            .inner
213            .receiver
214            .recv_async()
215            .await
216            .map_err(|_| quinn::ConnectionError::LocallyClosed)?;
217        Ok((SendSink::new(send), RecvStream::new(recv)))
218    }
219
220    fn local_addr(&self) -> &[LocalAddr] {
221        &self.inner.local_addr
222    }
223}
224
225type SocketInner = (quinn::SendStream, quinn::RecvStream);
226
227#[derive(Debug)]
228struct ClientConnectionInner {
229    /// The quinn endpoint, we just keep a clone of this for information
230    endpoint: Option<quinn::Endpoint>,
231    /// The task that handles creating new connections
232    task: Option<tokio::task::JoinHandle<()>>,
233    /// The channel to receive new connections
234    sender: flume::Sender<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
235}
236
237impl Drop for ClientConnectionInner {
238    fn drop(&mut self) {
239        tracing::debug!("Dropping client connection");
240        if let Some(endpoint) = self.endpoint.take() {
241            endpoint.close(0u32.into(), b"client connection dropped");
242            if let Ok(handle) = tokio::runtime::Handle::try_current() {
243                // spawn a task to wait for the endpoint to notify peers that it is closing
244                let span = debug_span!("closing client endpoint");
245                handle.spawn(
246                    async move {
247                        endpoint.wait_idle().await;
248                    }
249                    .instrument(span),
250                );
251            }
252        }
253        // this should not be necessary, since the task would terminate when the receiver is dropped.
254        // but just to be on the safe side.
255        if let Some(task) = self.task.take() {
256            tracing::debug!("Aborting task");
257            task.abort();
258        }
259    }
260}
261
262/// A connection using a quinn connection
263pub struct QuinnConnector<In: RpcMessage, Out: RpcMessage> {
264    inner: Arc<ClientConnectionInner>,
265    _p: PhantomData<(In, Out)>,
266}
267
268impl<In: RpcMessage, Out: RpcMessage> QuinnConnector<In, Out> {
269    async fn single_connection_handler_inner(
270        connection: quinn::Connection,
271        requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
272    ) -> result::Result<(), flume::RecvError> {
273        loop {
274            tracing::debug!("Awaiting request for new bidi substream...");
275            let request = requests.recv_async().await?;
276            tracing::debug!("Got request for new bidi substream");
277            match connection.open_bi().await {
278                Ok(pair) => {
279                    tracing::debug!("Bidi substream opened");
280                    if request.send(Ok(pair)).is_err() {
281                        tracing::debug!("requester dropped");
282                    }
283                }
284                Err(e) => {
285                    tracing::warn!("error opening bidi substream: {}", e);
286                    if request.send(Err(e)).is_err() {
287                        tracing::debug!("requester dropped");
288                    }
289                }
290            }
291        }
292    }
293
294    async fn single_connection_handler(
295        connection: quinn::Connection,
296        requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
297    ) {
298        if Self::single_connection_handler_inner(connection, requests)
299            .await
300            .is_err()
301        {
302            tracing::info!("Single connection handler finished");
303        } else {
304            unreachable!()
305        }
306    }
307
308    /// Client connection handler.
309    ///
310    /// It will run until the send side of the channel is dropped.
311    /// All other errors are logged and handled internally.
312    /// It will try to keep a connection open at all times.
313    async fn reconnect_handler_inner(
314        endpoint: quinn::Endpoint,
315        addr: SocketAddr,
316        name: String,
317        requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
318    ) {
319        let reconnect = ReconnectHandler {
320            endpoint,
321            state: ConnectionState::NotConnected,
322            addr,
323            name,
324        };
325        tokio::pin!(reconnect);
326
327        let mut receiver = Receiver::new(&requests);
328
329        let mut pending_request: Option<
330            oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>,
331        > = None;
332        let mut connection = None;
333
334        enum Racer {
335            Reconnect(Result<quinn::Connection, ReconnectErr>),
336            Channel(Option<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>),
337        }
338
339        loop {
340            let mut conn_result = None;
341            let mut chann_result = None;
342            if !reconnect.connected() && pending_request.is_none() {
343                match futures_lite::future::race(
344                    reconnect.as_mut().map(Racer::Reconnect),
345                    receiver.next().map(Racer::Channel),
346                )
347                .await
348                {
349                    Racer::Reconnect(connection_result) => conn_result = Some(connection_result),
350                    Racer::Channel(channel_result) => {
351                        chann_result = Some(channel_result);
352                    }
353                }
354            } else if !reconnect.connected() {
355                // only need a new connection
356                conn_result = Some(reconnect.as_mut().await);
357            } else if pending_request.is_none() {
358                // there is a connection, just need a request
359                chann_result = Some(receiver.next().await);
360            }
361
362            if let Some(conn_result) = conn_result {
363                tracing::trace!("tick: connection result");
364                match conn_result {
365                    Ok(new_connection) => {
366                        connection = Some(new_connection);
367                    }
368                    Err(e) => {
369                        let connection_err = match e {
370                            ReconnectErr::Connect(e) => {
371                                // TODO(@divma): the type for now accepts only a
372                                // ConnectionError, not a ConnectError. I'm mapping this now to
373                                // some ConnectionError since before it was not even reported.
374                                // Maybe adjust the type?
375                                tracing::warn!(%e, "error calling connect");
376                                quinn::ConnectionError::Reset
377                            }
378                            ReconnectErr::Connection(e) => {
379                                tracing::warn!(%e, "failed to connect");
380                                e
381                            }
382                        };
383                        if let Some(request) = pending_request.take() {
384                            if request.send(Err(connection_err)).is_err() {
385                                tracing::debug!("requester dropped");
386                            }
387                        }
388                    }
389                }
390            }
391
392            if let Some(req) = chann_result {
393                tracing::trace!("tick: bidi request");
394                match req {
395                    Some(request) => pending_request = Some(request),
396                    None => {
397                        tracing::debug!("client dropped");
398                        if let Some(connection) = connection {
399                            connection.close(0u32.into(), b"requester dropped");
400                        }
401                        break;
402                    }
403                }
404            }
405
406            if let Some(connection) = connection.as_mut() {
407                if let Some(request) = pending_request.take() {
408                    match connection.open_bi().await {
409                        Ok(pair) => {
410                            tracing::debug!("Bidi substream opened");
411                            if request.send(Ok(pair)).is_err() {
412                                tracing::debug!("requester dropped");
413                            }
414                        }
415                        Err(e) => {
416                            tracing::warn!("error opening bidi substream: {}", e);
417                            tracing::warn!("recreating connection");
418                            // NOTE: the connection might be stale, so we recreate the
419                            // connection and set the request as pending instead of
420                            // sending the error as a response
421                            reconnect.set_not_connected();
422                            pending_request = Some(request);
423                        }
424                    }
425                }
426            }
427        }
428    }
429
430    async fn reconnect_handler(
431        endpoint: quinn::Endpoint,
432        addr: SocketAddr,
433        name: String,
434        requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
435    ) {
436        Self::reconnect_handler_inner(endpoint, addr, name, requests).await;
437        tracing::info!("Reconnect handler finished");
438    }
439
440    /// Create a new channel
441    pub fn from_connection(connection: quinn::Connection) -> Self {
442        let (sender, receiver) = flume::bounded(16);
443        let task = tokio::spawn(Self::single_connection_handler(connection, receiver));
444        Self {
445            inner: Arc::new(ClientConnectionInner {
446                endpoint: None,
447                task: Some(task),
448                sender,
449            }),
450            _p: PhantomData,
451        }
452    }
453
454    /// Create a new channel
455    pub fn new(endpoint: quinn::Endpoint, addr: SocketAddr, name: String) -> Self {
456        let (sender, receiver) = flume::bounded(16);
457        let task = tokio::spawn(Self::reconnect_handler(
458            endpoint.clone(),
459            addr,
460            name,
461            receiver,
462        ));
463        Self {
464            inner: Arc::new(ClientConnectionInner {
465                endpoint: Some(endpoint),
466                task: Some(task),
467                sender,
468            }),
469            _p: PhantomData,
470        }
471    }
472}
473
474struct ReconnectHandler {
475    endpoint: quinn::Endpoint,
476    state: ConnectionState,
477    addr: SocketAddr,
478    name: String,
479}
480
481impl ReconnectHandler {
482    pub fn set_not_connected(&mut self) {
483        self.state.set_not_connected()
484    }
485
486    pub fn connected(&self) -> bool {
487        matches!(self.state, ConnectionState::Connected(_))
488    }
489}
490
491enum ConnectionState {
492    /// There is no active connection. An attempt to connect will be made.
493    NotConnected,
494    /// Connecting to the remote.
495    Connecting(quinn::Connecting),
496    /// A connection is already established. In this state, no more connection attempts are made.
497    Connected(quinn::Connection),
498    /// Intermediate state while processing.
499    Poisoned,
500}
501
502impl ConnectionState {
503    pub fn poison(&mut self) -> ConnectionState {
504        std::mem::replace(self, ConnectionState::Poisoned)
505    }
506
507    pub fn set_not_connected(&mut self) {
508        *self = ConnectionState::NotConnected
509    }
510}
511
512enum ReconnectErr {
513    Connect(quinn::ConnectError),
514    Connection(quinn::ConnectionError),
515}
516
517impl Future for ReconnectHandler {
518    type Output = Result<quinn::Connection, ReconnectErr>;
519
520    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
521        match self.state.poison() {
522            ConnectionState::NotConnected => match self.endpoint.connect(self.addr, &self.name) {
523                Ok(connecting) => {
524                    self.state = ConnectionState::Connecting(connecting);
525                    self.poll(cx)
526                }
527                Err(e) => {
528                    self.state = ConnectionState::NotConnected;
529                    Poll::Ready(Err(ReconnectErr::Connect(e)))
530                }
531            },
532            ConnectionState::Connecting(mut connecting) => match Pin::new(&mut connecting).poll(cx)
533            {
534                Poll::Ready(res) => match res {
535                    Ok(connection) => {
536                        self.state = ConnectionState::Connected(connection.clone());
537                        Poll::Ready(Ok(connection))
538                    }
539                    Err(e) => {
540                        self.state = ConnectionState::NotConnected;
541                        Poll::Ready(Err(ReconnectErr::Connection(e)))
542                    }
543                },
544                Poll::Pending => {
545                    self.state = ConnectionState::Connecting(connecting);
546                    Poll::Pending
547                }
548            },
549            ConnectionState::Connected(connection) => {
550                self.state = ConnectionState::Connected(connection.clone());
551                Poll::Ready(Ok(connection))
552            }
553            ConnectionState::Poisoned => unreachable!("poisoned connection state"),
554        }
555    }
556}
557
558/// Wrapper over [`flume::Receiver`] that can be used with [`tokio::select`].
559///
560/// NOTE: from https://github.com/zesterer/flume/issues/104:
561/// > If RecvFut is dropped without being polled, the item is never received.
562enum Receiver<'a, T>
563where
564    Self: 'a,
565{
566    PreReceive(&'a flume::Receiver<T>),
567    Receiving(&'a flume::Receiver<T>, flume::r#async::RecvFut<'a, T>),
568    Poisoned,
569}
570
571impl<'a, T> Receiver<'a, T> {
572    fn new(recv: &'a flume::Receiver<T>) -> Self {
573        Receiver::PreReceive(recv)
574    }
575
576    fn poison(&mut self) -> Self {
577        std::mem::replace(self, Self::Poisoned)
578    }
579}
580
581impl<T> Stream for Receiver<'_, T> {
582    type Item = T;
583
584    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
585        match self.poison() {
586            Receiver::PreReceive(recv) => {
587                let fut = recv.recv_async();
588                *self = Receiver::Receiving(recv, fut);
589                self.poll_next(cx)
590            }
591            Receiver::Receiving(recv, mut fut) => match Pin::new(&mut fut).poll(cx) {
592                Poll::Ready(Ok(t)) => {
593                    *self = Receiver::PreReceive(recv);
594                    Poll::Ready(Some(t))
595                }
596                Poll::Ready(Err(flume::RecvError::Disconnected)) => {
597                    *self = Receiver::PreReceive(recv);
598                    Poll::Ready(None)
599                }
600                Poll::Pending => {
601                    *self = Receiver::Receiving(recv, fut);
602                    Poll::Pending
603                }
604            },
605            Receiver::Poisoned => unreachable!("poisoned receiver state"),
606        }
607    }
608}
609
610impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for QuinnConnector<In, Out> {
611    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
612        f.debug_struct("ClientChannel")
613            .field("inner", &self.inner)
614            .finish()
615    }
616}
617
618impl<In: RpcMessage, Out: RpcMessage> Clone for QuinnConnector<In, Out> {
619    fn clone(&self) -> Self {
620        Self {
621            inner: self.inner.clone(),
622            _p: PhantomData,
623        }
624    }
625}
626
627impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for QuinnConnector<In, Out> {
628    type SendError = io::Error;
629    type RecvError = io::Error;
630    type OpenError = quinn::ConnectionError;
631    type AcceptError = quinn::ConnectionError;
632}
633
634impl<In: RpcMessage, Out: RpcMessage> StreamTypes for QuinnConnector<In, Out> {
635    type In = In;
636    type Out = Out;
637    type SendSink = self::SendSink<Out>;
638    type RecvStream = self::RecvStream<In>;
639}
640
641impl<In: RpcMessage, Out: RpcMessage> Connector for QuinnConnector<In, Out> {
642    async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
643        let (sender, receiver) = oneshot::channel();
644        self.inner
645            .sender
646            .send_async(sender)
647            .await
648            .map_err(|_| quinn::ConnectionError::LocallyClosed)?;
649        let (send, recv) = receiver
650            .await
651            .map_err(|_| quinn::ConnectionError::LocallyClosed)??;
652        Ok((SendSink::new(send), RecvStream::new(recv)))
653    }
654}
655
656/// A sink that wraps a quinn SendStream with length delimiting and postcard
657///
658/// If you want to send bytes directly, use [SendSink::into_inner] to get the
659/// underlying [quinn::SendStream].
660#[pin_project]
661pub struct SendSink<Out>(#[pin] FramedPostcardWrite<quinn::SendStream, Out>);
662
663impl<Out> fmt::Debug for SendSink<Out> {
664    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
665        f.debug_struct("SendSink").finish()
666    }
667}
668
669impl<Out: Serialize> SendSink<Out> {
670    fn new(inner: quinn::SendStream) -> Self {
671        let inner = FramedPostcardWrite::new(inner, MAX_FRAME_LENGTH);
672        Self(inner)
673    }
674}
675
676impl<Out> SendSink<Out> {
677    /// Get the underlying [quinn::SendStream], which implements
678    /// [tokio::io::AsyncWrite] and can be used to send bytes directly.
679    pub fn into_inner(self) -> quinn::SendStream {
680        self.0.into_inner()
681    }
682}
683
684impl<Out: Serialize> Sink<Out> for SendSink<Out> {
685    type Error = io::Error;
686
687    fn poll_ready(
688        self: Pin<&mut Self>,
689        cx: &mut std::task::Context<'_>,
690    ) -> std::task::Poll<Result<(), Self::Error>> {
691        Pin::new(&mut self.project().0).poll_ready(cx)
692    }
693
694    fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
695        Pin::new(&mut self.project().0).start_send(item)
696    }
697
698    fn poll_flush(
699        self: Pin<&mut Self>,
700        cx: &mut std::task::Context<'_>,
701    ) -> std::task::Poll<Result<(), Self::Error>> {
702        Pin::new(&mut self.project().0).poll_flush(cx)
703    }
704
705    fn poll_close(
706        self: Pin<&mut Self>,
707        cx: &mut std::task::Context<'_>,
708    ) -> std::task::Poll<Result<(), Self::Error>> {
709        Pin::new(&mut self.project().0).poll_close(cx)
710    }
711}
712
713/// A stream that wraps a quinn RecvStream with length delimiting and postcard
714///
715/// If you want to receive bytes directly, use [RecvStream::into_inner] to get
716/// the underlying [quinn::RecvStream].
717#[pin_project]
718pub struct RecvStream<In>(#[pin] FramedPostcardRead<quinn::RecvStream, In>);
719
720impl<In> fmt::Debug for RecvStream<In> {
721    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
722        f.debug_struct("RecvStream").finish()
723    }
724}
725
726impl<In: DeserializeOwned> RecvStream<In> {
727    fn new(inner: quinn::RecvStream) -> Self {
728        let inner = FramedPostcardRead::new(inner, MAX_FRAME_LENGTH);
729        Self(inner)
730    }
731}
732
733impl<In> RecvStream<In> {
734    /// Get the underlying [quinn::RecvStream], which implements
735    /// [tokio::io::AsyncRead] and can be used to receive bytes directly.
736    pub fn into_inner(self) -> quinn::RecvStream {
737        self.0.into_inner()
738    }
739}
740
741impl<In: DeserializeOwned> Stream for RecvStream<In> {
742    type Item = result::Result<In, io::Error>;
743
744    fn poll_next(
745        self: Pin<&mut Self>,
746        cx: &mut std::task::Context<'_>,
747    ) -> std::task::Poll<Option<Self::Item>> {
748        Pin::new(&mut self.project().0).poll_next(cx)
749    }
750}
751
752/// Error for open. Currently just a quinn::ConnectionError
753pub type OpenError = quinn::ConnectionError;
754
755/// Error for accept. Currently just a quinn::ConnectionError
756pub type AcceptError = quinn::ConnectionError;
757
758/// CreateChannelError for quinn channels.
759#[derive(Debug, Clone)]
760pub enum CreateChannelError {
761    /// Something went wrong immediately when creating the quinn endpoint
762    Io(io::ErrorKind, String),
763    /// Error directly when calling connect on the quinn endpoint
764    Connect(quinn::ConnectError),
765    /// Error produced by the future returned by connect
766    Connection(quinn::ConnectionError),
767}
768
769impl From<io::Error> for CreateChannelError {
770    fn from(e: io::Error) -> Self {
771        CreateChannelError::Io(e.kind(), e.to_string())
772    }
773}
774
775impl From<quinn::ConnectionError> for CreateChannelError {
776    fn from(e: quinn::ConnectionError) -> Self {
777        CreateChannelError::Connection(e)
778    }
779}
780
781impl From<quinn::ConnectError> for CreateChannelError {
782    fn from(e: quinn::ConnectError) -> Self {
783        CreateChannelError::Connect(e)
784    }
785}
786
787impl fmt::Display for CreateChannelError {
788    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
789        fmt::Debug::fmt(self, f)
790    }
791}
792
793impl std::error::Error for CreateChannelError {}
794
795/// Get the handshake data from a quinn connection that uses rustls.
796pub fn get_handshake_data(
797    connection: &quinn::Connection,
798) -> Option<quinn::crypto::rustls::HandshakeData> {
799    let handshake_data = connection.handshake_data()?;
800    let tls_connection = handshake_data.downcast_ref::<quinn::crypto::rustls::HandshakeData>()?;
801    Some(quinn::crypto::rustls::HandshakeData {
802        protocol: tls_connection.protocol.clone(),
803        server_name: tls_connection.server_name.clone(),
804    })
805}
806
807#[cfg(feature = "test-utils")]
808mod quinn_setup_utils {
809    use std::{net::SocketAddr, sync::Arc};
810
811    use anyhow::Result;
812    use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint, ServerConfig};
813
814    /// Builds default quinn client config and trusts given certificates.
815    ///
816    /// ## Args
817    ///
818    /// - server_certs: a list of trusted certificates in DER format.
819    pub fn configure_client(server_certs: &[&[u8]]) -> Result<ClientConfig> {
820        let mut certs = rustls::RootCertStore::empty();
821        for cert in server_certs {
822            let cert = rustls::pki_types::CertificateDer::from(cert.to_vec());
823            certs.add(cert)?;
824        }
825
826        let crypto_client_config = rustls::ClientConfig::builder_with_provider(Arc::new(
827            rustls::crypto::ring::default_provider(),
828        ))
829        .with_protocol_versions(&[&rustls::version::TLS13])
830        .expect("valid versions")
831        .with_root_certificates(certs)
832        .with_no_client_auth();
833        let quic_client_config =
834            quinn::crypto::rustls::QuicClientConfig::try_from(crypto_client_config)?;
835
836        Ok(ClientConfig::new(Arc::new(quic_client_config)))
837    }
838
839    /// Constructs a QUIC endpoint configured for use a client only.
840    ///
841    /// ## Args
842    ///
843    /// - server_certs: list of trusted certificates.
844    pub fn make_client_endpoint(bind_addr: SocketAddr, server_certs: &[&[u8]]) -> Result<Endpoint> {
845        let client_cfg = configure_client(server_certs)?;
846        let mut endpoint = Endpoint::client(bind_addr)?;
847        endpoint.set_default_client_config(client_cfg);
848        Ok(endpoint)
849    }
850
851    /// Create a server endpoint with a self-signed certificate
852    ///
853    /// Returns the server endpoint and the certificate in DER format
854    pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<(Endpoint, Vec<u8>)> {
855        let (server_config, server_cert) = configure_server()?;
856        let endpoint = Endpoint::server(server_config, bind_addr)?;
857        Ok((endpoint, server_cert))
858    }
859
860    /// Create a quinn server config with a self-signed certificate
861    ///
862    /// Returns the server config and the certificate in DER format
863    pub fn configure_server() -> anyhow::Result<(ServerConfig, Vec<u8>)> {
864        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])?;
865        let cert_der = cert.cert.der();
866        let priv_key = rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
867        let cert_chain = vec![cert_der.clone()];
868
869        let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key.into())?;
870        Arc::get_mut(&mut server_config.transport)
871            .unwrap()
872            .max_concurrent_uni_streams(0_u8.into());
873
874        Ok((server_config, cert_der.to_vec()))
875    }
876
877    /// Constructs a QUIC endpoint that trusts all certificates.
878    ///
879    /// This is useful for testing and local connections, but should be used with care.
880    pub fn make_insecure_client_endpoint(bind_addr: SocketAddr) -> Result<Endpoint> {
881        let crypto = rustls::ClientConfig::builder()
882            .dangerous()
883            .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
884            .with_no_client_auth();
885
886        let client_cfg = QuicClientConfig::try_from(crypto)?;
887        let client_cfg = ClientConfig::new(Arc::new(client_cfg));
888        let mut endpoint = Endpoint::client(bind_addr)?;
889        endpoint.set_default_client_config(client_cfg);
890        Ok(endpoint)
891    }
892
893    #[derive(Debug)]
894    struct SkipServerVerification;
895
896    impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
897        fn verify_server_cert(
898            &self,
899            _end_entity: &rustls::pki_types::CertificateDer<'_>,
900            _intermediates: &[rustls::pki_types::CertificateDer<'_>],
901            _server_name: &rustls::pki_types::ServerName<'_>,
902            _ocsp_response: &[u8],
903            _now: rustls::pki_types::UnixTime,
904        ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
905            Ok(rustls::client::danger::ServerCertVerified::assertion())
906        }
907
908        fn verify_tls12_signature(
909            &self,
910            _message: &[u8],
911            _cert: &rustls::pki_types::CertificateDer<'_>,
912            _dss: &rustls::DigitallySignedStruct,
913        ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
914            Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
915        }
916
917        fn verify_tls13_signature(
918            &self,
919            _message: &[u8],
920            _cert: &rustls::pki_types::CertificateDer<'_>,
921            _dss: &rustls::DigitallySignedStruct,
922        ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
923            Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
924        }
925
926        fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
927            use rustls::SignatureScheme::*;
928            // list them all, we don't care.
929            vec![
930                RSA_PKCS1_SHA1,
931                ECDSA_SHA1_Legacy,
932                RSA_PKCS1_SHA256,
933                ECDSA_NISTP256_SHA256,
934                RSA_PKCS1_SHA384,
935                ECDSA_NISTP384_SHA384,
936                RSA_PKCS1_SHA512,
937                ECDSA_NISTP521_SHA512,
938                RSA_PSS_SHA256,
939                RSA_PSS_SHA384,
940                RSA_PSS_SHA512,
941                ED25519,
942                ED448,
943            ]
944        }
945    }
946}
947#[cfg(feature = "test-utils")]
948pub use quinn_setup_utils::*;