1use std::{
3 fmt, io,
4 marker::PhantomData,
5 net::SocketAddr,
6 pin::Pin,
7 result,
8 sync::Arc,
9 task::{Context, Poll},
10};
11
12use futures_lite::{Future, Stream, StreamExt};
13use futures_sink::Sink;
14use futures_util::FutureExt;
15use pin_project::pin_project;
16use serde::{de::DeserializeOwned, Serialize};
17use tokio::sync::oneshot;
18use tracing::{debug_span, Instrument};
19
20use super::{
21 util::{FramedPostcardRead, FramedPostcardWrite},
22 StreamTypes,
23};
24use crate::{
25 transport::{ConnectionErrors, Connector, Listener, LocalAddr},
26 RpcMessage,
27};
28
29const MAX_FRAME_LENGTH: usize = 1024 * 1024 * 16;
30
31#[derive(Debug)]
32struct ListenerInner {
33 endpoint: Option<quinn::Endpoint>,
34 task: Option<tokio::task::JoinHandle<()>>,
35 local_addr: [LocalAddr; 1],
36 receiver: flume::Receiver<SocketInner>,
37}
38
39impl Drop for ListenerInner {
40 fn drop(&mut self) {
41 tracing::debug!("Dropping listener");
42 if let Some(endpoint) = self.endpoint.take() {
43 endpoint.close(0u32.into(), b"Listener dropped");
44
45 if let Ok(handle) = tokio::runtime::Handle::try_current() {
46 let span = debug_span!("closing listener");
48 handle.spawn(
49 async move {
50 endpoint.wait_idle().await;
51 }
52 .instrument(span),
53 );
54 }
55 }
56 if let Some(task) = self.task.take() {
57 task.abort()
58 }
59 }
60}
61
62#[derive(Debug)]
64pub struct QuinnListener<In: RpcMessage, Out: RpcMessage> {
65 inner: Arc<ListenerInner>,
66 _p: PhantomData<(In, Out)>,
67}
68
69impl<In: RpcMessage, Out: RpcMessage> QuinnListener<In, Out> {
70 async fn connection_handler(connection: quinn::Connection, sender: flume::Sender<SocketInner>) {
74 loop {
75 tracing::debug!("Awaiting incoming bidi substream on existing connection...");
76 let bidi_stream = match connection.accept_bi().await {
77 Ok(bidi_stream) => bidi_stream,
78 Err(quinn::ConnectionError::ApplicationClosed(e)) => {
79 tracing::debug!("Peer closed the connection {:?}", e);
80 break;
81 }
82 Err(e) => {
83 tracing::debug!("Error accepting stream: {}", e);
84 break;
85 }
86 };
87 tracing::debug!("Sending substream to be handled... {}", bidi_stream.0.id());
88 if sender.send_async(bidi_stream).await.is_err() {
89 tracing::debug!("Receiver dropped");
90 break;
91 }
92 }
93 }
94
95 async fn endpoint_handler(endpoint: quinn::Endpoint, sender: flume::Sender<SocketInner>) {
96 loop {
97 tracing::debug!("Waiting for incoming connection...");
98 let connecting = match endpoint.accept().await {
99 Some(connecting) => connecting,
100 None => break,
101 };
102 tracing::debug!("Awaiting connection from connect...");
103 let conection = match connecting.await {
104 Ok(conection) => conection,
105 Err(e) => {
106 tracing::warn!("Error accepting connection: {}", e);
107 continue;
108 }
109 };
110 tracing::debug!(
111 "Connection established from {:?}",
112 conection.remote_address()
113 );
114 tracing::debug!("Spawning connection handler...");
115 tokio::spawn(Self::connection_handler(conection, sender.clone()));
116 }
117 }
118
119 pub fn new(endpoint: quinn::Endpoint) -> io::Result<Self> {
126 let local_addr = endpoint.local_addr()?;
127 let (sender, receiver) = flume::bounded(16);
128 let task = tokio::spawn(Self::endpoint_handler(endpoint.clone(), sender));
129 Ok(Self {
130 inner: Arc::new(ListenerInner {
131 endpoint: Some(endpoint),
132 task: Some(task),
133 local_addr: [LocalAddr::Socket(local_addr)],
134 receiver,
135 }),
136 _p: PhantomData,
137 })
138 }
139
140 pub fn handle_connections(
145 incoming: flume::Receiver<quinn::Connection>,
146 local_addr: SocketAddr,
147 ) -> Self {
148 let (sender, receiver) = flume::bounded(16);
149 let task = tokio::spawn(async move {
150 while let Ok(connection) = incoming.recv_async().await {
152 tokio::spawn(Self::connection_handler(connection, sender.clone()));
153 }
154 });
155 Self {
156 inner: Arc::new(ListenerInner {
157 endpoint: None,
158 task: Some(task),
159 local_addr: [LocalAddr::Socket(local_addr)],
160 receiver,
161 }),
162 _p: PhantomData,
163 }
164 }
165
166 pub fn handle_substreams(
171 receiver: flume::Receiver<SocketInner>,
172 local_addr: SocketAddr,
173 ) -> Self {
174 Self {
175 inner: Arc::new(ListenerInner {
176 endpoint: None,
177 task: None,
178 local_addr: [LocalAddr::Socket(local_addr)],
179 receiver,
180 }),
181 _p: PhantomData,
182 }
183 }
184}
185
186impl<In: RpcMessage, Out: RpcMessage> Clone for QuinnListener<In, Out> {
187 fn clone(&self) -> Self {
188 Self {
189 inner: self.inner.clone(),
190 _p: PhantomData,
191 }
192 }
193}
194
195impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for QuinnListener<In, Out> {
196 type SendError = io::Error;
197 type RecvError = io::Error;
198 type OpenError = quinn::ConnectionError;
199 type AcceptError = quinn::ConnectionError;
200}
201
202impl<In: RpcMessage, Out: RpcMessage> StreamTypes for QuinnListener<In, Out> {
203 type In = In;
204 type Out = Out;
205 type SendSink = self::SendSink<Out>;
206 type RecvStream = self::RecvStream<In>;
207}
208
209impl<In: RpcMessage, Out: RpcMessage> Listener for QuinnListener<In, Out> {
210 async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
211 let (send, recv) = self
212 .inner
213 .receiver
214 .recv_async()
215 .await
216 .map_err(|_| quinn::ConnectionError::LocallyClosed)?;
217 Ok((SendSink::new(send), RecvStream::new(recv)))
218 }
219
220 fn local_addr(&self) -> &[LocalAddr] {
221 &self.inner.local_addr
222 }
223}
224
225type SocketInner = (quinn::SendStream, quinn::RecvStream);
226
227#[derive(Debug)]
228struct ClientConnectionInner {
229 endpoint: Option<quinn::Endpoint>,
231 task: Option<tokio::task::JoinHandle<()>>,
233 sender: flume::Sender<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
235}
236
237impl Drop for ClientConnectionInner {
238 fn drop(&mut self) {
239 tracing::debug!("Dropping client connection");
240 if let Some(endpoint) = self.endpoint.take() {
241 endpoint.close(0u32.into(), b"client connection dropped");
242 if let Ok(handle) = tokio::runtime::Handle::try_current() {
243 let span = debug_span!("closing client endpoint");
245 handle.spawn(
246 async move {
247 endpoint.wait_idle().await;
248 }
249 .instrument(span),
250 );
251 }
252 }
253 if let Some(task) = self.task.take() {
256 tracing::debug!("Aborting task");
257 task.abort();
258 }
259 }
260}
261
262pub struct QuinnConnector<In: RpcMessage, Out: RpcMessage> {
264 inner: Arc<ClientConnectionInner>,
265 _p: PhantomData<(In, Out)>,
266}
267
268impl<In: RpcMessage, Out: RpcMessage> QuinnConnector<In, Out> {
269 async fn single_connection_handler_inner(
270 connection: quinn::Connection,
271 requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
272 ) -> result::Result<(), flume::RecvError> {
273 loop {
274 tracing::debug!("Awaiting request for new bidi substream...");
275 let request = requests.recv_async().await?;
276 tracing::debug!("Got request for new bidi substream");
277 match connection.open_bi().await {
278 Ok(pair) => {
279 tracing::debug!("Bidi substream opened");
280 if request.send(Ok(pair)).is_err() {
281 tracing::debug!("requester dropped");
282 }
283 }
284 Err(e) => {
285 tracing::warn!("error opening bidi substream: {}", e);
286 if request.send(Err(e)).is_err() {
287 tracing::debug!("requester dropped");
288 }
289 }
290 }
291 }
292 }
293
294 async fn single_connection_handler(
295 connection: quinn::Connection,
296 requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
297 ) {
298 if Self::single_connection_handler_inner(connection, requests)
299 .await
300 .is_err()
301 {
302 tracing::info!("Single connection handler finished");
303 } else {
304 unreachable!()
305 }
306 }
307
308 async fn reconnect_handler_inner(
314 endpoint: quinn::Endpoint,
315 addr: SocketAddr,
316 name: String,
317 requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
318 ) {
319 let reconnect = ReconnectHandler {
320 endpoint,
321 state: ConnectionState::NotConnected,
322 addr,
323 name,
324 };
325 tokio::pin!(reconnect);
326
327 let mut receiver = Receiver::new(&requests);
328
329 let mut pending_request: Option<
330 oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>,
331 > = None;
332 let mut connection = None;
333
334 enum Racer {
335 Reconnect(Result<quinn::Connection, ReconnectErr>),
336 Channel(Option<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>),
337 }
338
339 loop {
340 let mut conn_result = None;
341 let mut chann_result = None;
342 if !reconnect.connected() && pending_request.is_none() {
343 match futures_lite::future::race(
344 reconnect.as_mut().map(Racer::Reconnect),
345 receiver.next().map(Racer::Channel),
346 )
347 .await
348 {
349 Racer::Reconnect(connection_result) => conn_result = Some(connection_result),
350 Racer::Channel(channel_result) => {
351 chann_result = Some(channel_result);
352 }
353 }
354 } else if !reconnect.connected() {
355 conn_result = Some(reconnect.as_mut().await);
357 } else if pending_request.is_none() {
358 chann_result = Some(receiver.next().await);
360 }
361
362 if let Some(conn_result) = conn_result {
363 tracing::trace!("tick: connection result");
364 match conn_result {
365 Ok(new_connection) => {
366 connection = Some(new_connection);
367 }
368 Err(e) => {
369 let connection_err = match e {
370 ReconnectErr::Connect(e) => {
371 tracing::warn!(%e, "error calling connect");
376 quinn::ConnectionError::Reset
377 }
378 ReconnectErr::Connection(e) => {
379 tracing::warn!(%e, "failed to connect");
380 e
381 }
382 };
383 if let Some(request) = pending_request.take() {
384 if request.send(Err(connection_err)).is_err() {
385 tracing::debug!("requester dropped");
386 }
387 }
388 }
389 }
390 }
391
392 if let Some(req) = chann_result {
393 tracing::trace!("tick: bidi request");
394 match req {
395 Some(request) => pending_request = Some(request),
396 None => {
397 tracing::debug!("client dropped");
398 if let Some(connection) = connection {
399 connection.close(0u32.into(), b"requester dropped");
400 }
401 break;
402 }
403 }
404 }
405
406 if let Some(connection) = connection.as_mut() {
407 if let Some(request) = pending_request.take() {
408 match connection.open_bi().await {
409 Ok(pair) => {
410 tracing::debug!("Bidi substream opened");
411 if request.send(Ok(pair)).is_err() {
412 tracing::debug!("requester dropped");
413 }
414 }
415 Err(e) => {
416 tracing::warn!("error opening bidi substream: {}", e);
417 tracing::warn!("recreating connection");
418 reconnect.set_not_connected();
422 pending_request = Some(request);
423 }
424 }
425 }
426 }
427 }
428 }
429
430 async fn reconnect_handler(
431 endpoint: quinn::Endpoint,
432 addr: SocketAddr,
433 name: String,
434 requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
435 ) {
436 Self::reconnect_handler_inner(endpoint, addr, name, requests).await;
437 tracing::info!("Reconnect handler finished");
438 }
439
440 pub fn from_connection(connection: quinn::Connection) -> Self {
442 let (sender, receiver) = flume::bounded(16);
443 let task = tokio::spawn(Self::single_connection_handler(connection, receiver));
444 Self {
445 inner: Arc::new(ClientConnectionInner {
446 endpoint: None,
447 task: Some(task),
448 sender,
449 }),
450 _p: PhantomData,
451 }
452 }
453
454 pub fn new(endpoint: quinn::Endpoint, addr: SocketAddr, name: String) -> Self {
456 let (sender, receiver) = flume::bounded(16);
457 let task = tokio::spawn(Self::reconnect_handler(
458 endpoint.clone(),
459 addr,
460 name,
461 receiver,
462 ));
463 Self {
464 inner: Arc::new(ClientConnectionInner {
465 endpoint: Some(endpoint),
466 task: Some(task),
467 sender,
468 }),
469 _p: PhantomData,
470 }
471 }
472}
473
474struct ReconnectHandler {
475 endpoint: quinn::Endpoint,
476 state: ConnectionState,
477 addr: SocketAddr,
478 name: String,
479}
480
481impl ReconnectHandler {
482 pub fn set_not_connected(&mut self) {
483 self.state.set_not_connected()
484 }
485
486 pub fn connected(&self) -> bool {
487 matches!(self.state, ConnectionState::Connected(_))
488 }
489}
490
491enum ConnectionState {
492 NotConnected,
494 Connecting(quinn::Connecting),
496 Connected(quinn::Connection),
498 Poisoned,
500}
501
502impl ConnectionState {
503 pub fn poison(&mut self) -> ConnectionState {
504 std::mem::replace(self, ConnectionState::Poisoned)
505 }
506
507 pub fn set_not_connected(&mut self) {
508 *self = ConnectionState::NotConnected
509 }
510}
511
512enum ReconnectErr {
513 Connect(quinn::ConnectError),
514 Connection(quinn::ConnectionError),
515}
516
517impl Future for ReconnectHandler {
518 type Output = Result<quinn::Connection, ReconnectErr>;
519
520 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
521 match self.state.poison() {
522 ConnectionState::NotConnected => match self.endpoint.connect(self.addr, &self.name) {
523 Ok(connecting) => {
524 self.state = ConnectionState::Connecting(connecting);
525 self.poll(cx)
526 }
527 Err(e) => {
528 self.state = ConnectionState::NotConnected;
529 Poll::Ready(Err(ReconnectErr::Connect(e)))
530 }
531 },
532 ConnectionState::Connecting(mut connecting) => match Pin::new(&mut connecting).poll(cx)
533 {
534 Poll::Ready(res) => match res {
535 Ok(connection) => {
536 self.state = ConnectionState::Connected(connection.clone());
537 Poll::Ready(Ok(connection))
538 }
539 Err(e) => {
540 self.state = ConnectionState::NotConnected;
541 Poll::Ready(Err(ReconnectErr::Connection(e)))
542 }
543 },
544 Poll::Pending => {
545 self.state = ConnectionState::Connecting(connecting);
546 Poll::Pending
547 }
548 },
549 ConnectionState::Connected(connection) => {
550 self.state = ConnectionState::Connected(connection.clone());
551 Poll::Ready(Ok(connection))
552 }
553 ConnectionState::Poisoned => unreachable!("poisoned connection state"),
554 }
555 }
556}
557
558enum Receiver<'a, T>
563where
564 Self: 'a,
565{
566 PreReceive(&'a flume::Receiver<T>),
567 Receiving(&'a flume::Receiver<T>, flume::r#async::RecvFut<'a, T>),
568 Poisoned,
569}
570
571impl<'a, T> Receiver<'a, T> {
572 fn new(recv: &'a flume::Receiver<T>) -> Self {
573 Receiver::PreReceive(recv)
574 }
575
576 fn poison(&mut self) -> Self {
577 std::mem::replace(self, Self::Poisoned)
578 }
579}
580
581impl<T> Stream for Receiver<'_, T> {
582 type Item = T;
583
584 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
585 match self.poison() {
586 Receiver::PreReceive(recv) => {
587 let fut = recv.recv_async();
588 *self = Receiver::Receiving(recv, fut);
589 self.poll_next(cx)
590 }
591 Receiver::Receiving(recv, mut fut) => match Pin::new(&mut fut).poll(cx) {
592 Poll::Ready(Ok(t)) => {
593 *self = Receiver::PreReceive(recv);
594 Poll::Ready(Some(t))
595 }
596 Poll::Ready(Err(flume::RecvError::Disconnected)) => {
597 *self = Receiver::PreReceive(recv);
598 Poll::Ready(None)
599 }
600 Poll::Pending => {
601 *self = Receiver::Receiving(recv, fut);
602 Poll::Pending
603 }
604 },
605 Receiver::Poisoned => unreachable!("poisoned receiver state"),
606 }
607 }
608}
609
610impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for QuinnConnector<In, Out> {
611 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
612 f.debug_struct("ClientChannel")
613 .field("inner", &self.inner)
614 .finish()
615 }
616}
617
618impl<In: RpcMessage, Out: RpcMessage> Clone for QuinnConnector<In, Out> {
619 fn clone(&self) -> Self {
620 Self {
621 inner: self.inner.clone(),
622 _p: PhantomData,
623 }
624 }
625}
626
627impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for QuinnConnector<In, Out> {
628 type SendError = io::Error;
629 type RecvError = io::Error;
630 type OpenError = quinn::ConnectionError;
631 type AcceptError = quinn::ConnectionError;
632}
633
634impl<In: RpcMessage, Out: RpcMessage> StreamTypes for QuinnConnector<In, Out> {
635 type In = In;
636 type Out = Out;
637 type SendSink = self::SendSink<Out>;
638 type RecvStream = self::RecvStream<In>;
639}
640
641impl<In: RpcMessage, Out: RpcMessage> Connector for QuinnConnector<In, Out> {
642 async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
643 let (sender, receiver) = oneshot::channel();
644 self.inner
645 .sender
646 .send_async(sender)
647 .await
648 .map_err(|_| quinn::ConnectionError::LocallyClosed)?;
649 let (send, recv) = receiver
650 .await
651 .map_err(|_| quinn::ConnectionError::LocallyClosed)??;
652 Ok((SendSink::new(send), RecvStream::new(recv)))
653 }
654}
655
656#[pin_project]
661pub struct SendSink<Out>(#[pin] FramedPostcardWrite<quinn::SendStream, Out>);
662
663impl<Out> fmt::Debug for SendSink<Out> {
664 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
665 f.debug_struct("SendSink").finish()
666 }
667}
668
669impl<Out: Serialize> SendSink<Out> {
670 fn new(inner: quinn::SendStream) -> Self {
671 let inner = FramedPostcardWrite::new(inner, MAX_FRAME_LENGTH);
672 Self(inner)
673 }
674}
675
676impl<Out> SendSink<Out> {
677 pub fn into_inner(self) -> quinn::SendStream {
680 self.0.into_inner()
681 }
682}
683
684impl<Out: Serialize> Sink<Out> for SendSink<Out> {
685 type Error = io::Error;
686
687 fn poll_ready(
688 self: Pin<&mut Self>,
689 cx: &mut std::task::Context<'_>,
690 ) -> std::task::Poll<Result<(), Self::Error>> {
691 Pin::new(&mut self.project().0).poll_ready(cx)
692 }
693
694 fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
695 Pin::new(&mut self.project().0).start_send(item)
696 }
697
698 fn poll_flush(
699 self: Pin<&mut Self>,
700 cx: &mut std::task::Context<'_>,
701 ) -> std::task::Poll<Result<(), Self::Error>> {
702 Pin::new(&mut self.project().0).poll_flush(cx)
703 }
704
705 fn poll_close(
706 self: Pin<&mut Self>,
707 cx: &mut std::task::Context<'_>,
708 ) -> std::task::Poll<Result<(), Self::Error>> {
709 Pin::new(&mut self.project().0).poll_close(cx)
710 }
711}
712
713#[pin_project]
718pub struct RecvStream<In>(#[pin] FramedPostcardRead<quinn::RecvStream, In>);
719
720impl<In> fmt::Debug for RecvStream<In> {
721 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
722 f.debug_struct("RecvStream").finish()
723 }
724}
725
726impl<In: DeserializeOwned> RecvStream<In> {
727 fn new(inner: quinn::RecvStream) -> Self {
728 let inner = FramedPostcardRead::new(inner, MAX_FRAME_LENGTH);
729 Self(inner)
730 }
731}
732
733impl<In> RecvStream<In> {
734 pub fn into_inner(self) -> quinn::RecvStream {
737 self.0.into_inner()
738 }
739}
740
741impl<In: DeserializeOwned> Stream for RecvStream<In> {
742 type Item = result::Result<In, io::Error>;
743
744 fn poll_next(
745 self: Pin<&mut Self>,
746 cx: &mut std::task::Context<'_>,
747 ) -> std::task::Poll<Option<Self::Item>> {
748 Pin::new(&mut self.project().0).poll_next(cx)
749 }
750}
751
752pub type OpenError = quinn::ConnectionError;
754
755pub type AcceptError = quinn::ConnectionError;
757
758#[derive(Debug, Clone)]
760pub enum CreateChannelError {
761 Io(io::ErrorKind, String),
763 Connect(quinn::ConnectError),
765 Connection(quinn::ConnectionError),
767}
768
769impl From<io::Error> for CreateChannelError {
770 fn from(e: io::Error) -> Self {
771 CreateChannelError::Io(e.kind(), e.to_string())
772 }
773}
774
775impl From<quinn::ConnectionError> for CreateChannelError {
776 fn from(e: quinn::ConnectionError) -> Self {
777 CreateChannelError::Connection(e)
778 }
779}
780
781impl From<quinn::ConnectError> for CreateChannelError {
782 fn from(e: quinn::ConnectError) -> Self {
783 CreateChannelError::Connect(e)
784 }
785}
786
787impl fmt::Display for CreateChannelError {
788 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
789 fmt::Debug::fmt(self, f)
790 }
791}
792
793impl std::error::Error for CreateChannelError {}
794
795pub fn get_handshake_data(
797 connection: &quinn::Connection,
798) -> Option<quinn::crypto::rustls::HandshakeData> {
799 let handshake_data = connection.handshake_data()?;
800 let tls_connection = handshake_data.downcast_ref::<quinn::crypto::rustls::HandshakeData>()?;
801 Some(quinn::crypto::rustls::HandshakeData {
802 protocol: tls_connection.protocol.clone(),
803 server_name: tls_connection.server_name.clone(),
804 })
805}
806
807#[cfg(feature = "test-utils")]
808mod quinn_setup_utils {
809 use std::{net::SocketAddr, sync::Arc};
810
811 use anyhow::Result;
812 use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint, ServerConfig};
813
814 pub fn configure_client(server_certs: &[&[u8]]) -> Result<ClientConfig> {
820 let mut certs = rustls::RootCertStore::empty();
821 for cert in server_certs {
822 let cert = rustls::pki_types::CertificateDer::from(cert.to_vec());
823 certs.add(cert)?;
824 }
825
826 let crypto_client_config = rustls::ClientConfig::builder_with_provider(Arc::new(
827 rustls::crypto::ring::default_provider(),
828 ))
829 .with_protocol_versions(&[&rustls::version::TLS13])
830 .expect("valid versions")
831 .with_root_certificates(certs)
832 .with_no_client_auth();
833 let quic_client_config =
834 quinn::crypto::rustls::QuicClientConfig::try_from(crypto_client_config)?;
835
836 Ok(ClientConfig::new(Arc::new(quic_client_config)))
837 }
838
839 pub fn make_client_endpoint(bind_addr: SocketAddr, server_certs: &[&[u8]]) -> Result<Endpoint> {
845 let client_cfg = configure_client(server_certs)?;
846 let mut endpoint = Endpoint::client(bind_addr)?;
847 endpoint.set_default_client_config(client_cfg);
848 Ok(endpoint)
849 }
850
851 pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<(Endpoint, Vec<u8>)> {
855 let (server_config, server_cert) = configure_server()?;
856 let endpoint = Endpoint::server(server_config, bind_addr)?;
857 Ok((endpoint, server_cert))
858 }
859
860 pub fn configure_server() -> anyhow::Result<(ServerConfig, Vec<u8>)> {
864 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])?;
865 let cert_der = cert.cert.der();
866 let priv_key = rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
867 let cert_chain = vec![cert_der.clone()];
868
869 let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key.into())?;
870 Arc::get_mut(&mut server_config.transport)
871 .unwrap()
872 .max_concurrent_uni_streams(0_u8.into());
873
874 Ok((server_config, cert_der.to_vec()))
875 }
876
877 pub fn make_insecure_client_endpoint(bind_addr: SocketAddr) -> Result<Endpoint> {
881 let crypto = rustls::ClientConfig::builder()
882 .dangerous()
883 .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
884 .with_no_client_auth();
885
886 let client_cfg = QuicClientConfig::try_from(crypto)?;
887 let client_cfg = ClientConfig::new(Arc::new(client_cfg));
888 let mut endpoint = Endpoint::client(bind_addr)?;
889 endpoint.set_default_client_config(client_cfg);
890 Ok(endpoint)
891 }
892
893 #[derive(Debug)]
894 struct SkipServerVerification;
895
896 impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
897 fn verify_server_cert(
898 &self,
899 _end_entity: &rustls::pki_types::CertificateDer<'_>,
900 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
901 _server_name: &rustls::pki_types::ServerName<'_>,
902 _ocsp_response: &[u8],
903 _now: rustls::pki_types::UnixTime,
904 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
905 Ok(rustls::client::danger::ServerCertVerified::assertion())
906 }
907
908 fn verify_tls12_signature(
909 &self,
910 _message: &[u8],
911 _cert: &rustls::pki_types::CertificateDer<'_>,
912 _dss: &rustls::DigitallySignedStruct,
913 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
914 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
915 }
916
917 fn verify_tls13_signature(
918 &self,
919 _message: &[u8],
920 _cert: &rustls::pki_types::CertificateDer<'_>,
921 _dss: &rustls::DigitallySignedStruct,
922 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
923 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
924 }
925
926 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
927 use rustls::SignatureScheme::*;
928 vec![
930 RSA_PKCS1_SHA1,
931 ECDSA_SHA1_Legacy,
932 RSA_PKCS1_SHA256,
933 ECDSA_NISTP256_SHA256,
934 RSA_PKCS1_SHA384,
935 ECDSA_NISTP384_SHA384,
936 RSA_PKCS1_SHA512,
937 ECDSA_NISTP521_SHA512,
938 RSA_PSS_SHA256,
939 RSA_PSS_SHA384,
940 RSA_PSS_SHA512,
941 ED25519,
942 ED448,
943 ]
944 }
945 }
946}
947#[cfg(feature = "test-utils")]
948pub use quinn_setup_utils::*;