1use std::{
4 collections::BTreeSet,
5 fmt,
6 future::Future,
7 io,
8 iter::once,
9 marker::PhantomData,
10 net::SocketAddr,
11 pin::{pin, Pin},
12 sync::Arc,
13 task::{Context, Poll},
14};
15
16use flume::TryRecvError;
17use futures_lite::Stream;
18use futures_sink::Sink;
19use iroh::{endpoint::Connection, NodeAddr, NodeId};
20use pin_project::pin_project;
21use serde::{de::DeserializeOwned, Serialize};
22use tokio::{sync::oneshot, task::yield_now};
23use tracing::{debug_span, Instrument};
24
25use super::{
26 util::{FramedPostcardRead, FramedPostcardWrite},
27 StreamTypes,
28};
29use crate::{
30 transport::{ConnectionErrors, Connector, Listener, LocalAddr},
31 RpcMessage,
32};
33
34const MAX_FRAME_LENGTH: usize = 1024 * 1024 * 16;
35
36#[derive(Debug)]
37struct ListenerInner {
38 endpoint: Option<iroh::Endpoint>,
39 task: Option<tokio::task::JoinHandle<()>>,
40 local_addr: Vec<LocalAddr>,
41 receiver: flume::Receiver<SocketInner>,
42}
43
44impl Drop for ListenerInner {
45 fn drop(&mut self) {
46 tracing::debug!("Dropping server endpoint");
47 if let Some(endpoint) = self.endpoint.take() {
48 if let Ok(handle) = tokio::runtime::Handle::try_current() {
49 let span = debug_span!("closing listener");
51 handle.spawn(
52 async move {
53 endpoint.close().await;
56 }
57 .instrument(span),
58 );
59 }
60 }
61 if let Some(task) = self.task.take() {
62 task.abort()
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
70pub enum AccessControl {
71 Unrestricted,
73 Allowed(Vec<NodeId>),
75}
76
77#[derive(Debug)]
79pub struct IrohListener<In: RpcMessage, Out: RpcMessage> {
80 inner: Arc<ListenerInner>,
81 _p: PhantomData<(In, Out)>,
82}
83
84impl<In: RpcMessage, Out: RpcMessage> IrohListener<In, Out> {
85 async fn connection_handler(connection: Connection, sender: flume::Sender<SocketInner>) {
89 loop {
90 tracing::debug!("Awaiting incoming bidi substream on existing connection...");
91 let bidi_stream = match connection.accept_bi().await {
92 Ok(bidi_stream) => bidi_stream,
93 Err(quinn::ConnectionError::ApplicationClosed(e)) => {
94 tracing::debug!(?e, "Peer closed the connection");
95 break;
96 }
97 Err(e) => {
98 tracing::debug!(?e, "Error accepting stream");
99 break;
100 }
101 };
102 tracing::debug!("Sending substream to be handled... {}", bidi_stream.0.id());
103 if sender.send_async(bidi_stream).await.is_err() {
104 tracing::debug!("Receiver dropped");
105 break;
106 }
107 }
108 }
109
110 async fn endpoint_handler(
111 endpoint: iroh::Endpoint,
112 sender: flume::Sender<SocketInner>,
113 allowed_node_ids: BTreeSet<NodeId>,
114 ) {
115 loop {
116 tracing::debug!("Waiting for incoming connection...");
117 let connecting = match endpoint.accept().await {
118 Some(connecting) => connecting,
119 None => break,
120 };
121
122 tracing::debug!("Awaiting connection from connect...");
123 let connection = match connecting.await {
124 Ok(connection) => connection,
125 Err(e) => {
126 tracing::warn!(?e, "Error accepting connection");
127 continue;
128 }
129 };
130
131 if !allowed_node_ids.is_empty() {
136 let Ok(client_node_id) = connection.remote_node_id().map_err(|e| {
137 tracing::error!(
138 ?e,
139 "Failed to extract iroh node id from incoming connection",
140 )
141 }) else {
142 connection.close(0u32.into(), b"failed to extract iroh node id");
143 continue;
144 };
145
146 if !allowed_node_ids.contains(&client_node_id) {
147 connection.close(0u32.into(), b"forbidden node id");
148 continue;
149 }
150 }
151
152 tracing::debug!(
153 "Connection established from {:?}",
154 connection.remote_node_id()
155 );
156
157 tracing::debug!("Spawning connection handler...");
158 tokio::spawn(Self::connection_handler(connection, sender.clone()));
159 }
160 }
161
162 pub fn new(endpoint: iroh::Endpoint) -> io::Result<Self> {
167 Self::new_with_access_control(endpoint, AccessControl::Unrestricted)
168 }
169
170 pub fn new_with_access_control(
175 endpoint: iroh::Endpoint,
176 access_control: AccessControl,
177 ) -> io::Result<Self> {
178 let allowed_node_ids = match access_control {
179 AccessControl::Unrestricted => BTreeSet::new(),
180 AccessControl::Allowed(list) if list.is_empty() => {
181 return Err(io::Error::other(
182 "Empty list of allowed nodes, \
183 endpoint would reject all connections",
184 ));
185 }
186 AccessControl::Allowed(list) => BTreeSet::from_iter(list),
187 };
188
189 let (ipv4_socket_addr, maybe_ipv6_socket_addr) = endpoint.bound_sockets();
190 let (sender, receiver) = flume::bounded(16);
191 let task = tokio::spawn(Self::endpoint_handler(
192 endpoint.clone(),
193 sender,
194 allowed_node_ids,
195 ));
196
197 Ok(Self {
198 inner: Arc::new(ListenerInner {
199 endpoint: Some(endpoint),
200 task: Some(task),
201 local_addr: once(LocalAddr::Socket(ipv4_socket_addr))
202 .chain(maybe_ipv6_socket_addr.map(LocalAddr::Socket))
203 .collect(),
204 receiver,
205 }),
206 _p: PhantomData,
207 })
208 }
209
210 pub fn handle_connections(
215 incoming: flume::Receiver<Connection>,
216 local_addr: SocketAddr,
217 ) -> Self {
218 let (sender, receiver) = flume::bounded(16);
219 let task = tokio::spawn(async move {
220 while let Ok(connection) = incoming.recv_async().await {
222 tokio::spawn(Self::connection_handler(connection, sender.clone()));
223 }
224 });
225 Self {
226 inner: Arc::new(ListenerInner {
227 endpoint: None,
228 task: Some(task),
229 local_addr: vec![LocalAddr::Socket(local_addr)],
230 receiver,
231 }),
232 _p: PhantomData,
233 }
234 }
235
236 pub fn handle_substreams(
241 receiver: flume::Receiver<SocketInner>,
242 local_addr: SocketAddr,
243 ) -> Self {
244 Self {
245 inner: Arc::new(ListenerInner {
246 endpoint: None,
247 task: None,
248 local_addr: vec![LocalAddr::Socket(local_addr)],
249 receiver,
250 }),
251 _p: PhantomData,
252 }
253 }
254}
255
256impl<In: RpcMessage, Out: RpcMessage> Clone for IrohListener<In, Out> {
257 fn clone(&self) -> Self {
258 Self {
259 inner: self.inner.clone(),
260 _p: PhantomData,
261 }
262 }
263}
264
265impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for IrohListener<In, Out> {
266 type SendError = io::Error;
267 type RecvError = io::Error;
268 type OpenError = quinn::ConnectionError;
269 type AcceptError = quinn::ConnectionError;
270}
271
272impl<In: RpcMessage, Out: RpcMessage> StreamTypes for IrohListener<In, Out> {
273 type In = In;
274 type Out = Out;
275 type SendSink = SendSink<Out>;
276 type RecvStream = RecvStream<In>;
277}
278
279impl<In: RpcMessage, Out: RpcMessage> Listener for IrohListener<In, Out> {
280 async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
281 let (send, recv) = self
282 .inner
283 .receiver
284 .recv_async()
285 .await
286 .map_err(|_| quinn::ConnectionError::LocallyClosed)?;
287
288 Ok((SendSink::new(send), RecvStream::new(recv)))
289 }
290
291 fn local_addr(&self) -> &[LocalAddr] {
292 &self.inner.local_addr
293 }
294}
295
296type SocketInner = (quinn::SendStream, quinn::RecvStream);
297
298#[derive(Debug)]
299struct ClientConnectionInner {
300 endpoint: Option<iroh::Endpoint>,
302 task: Option<tokio::task::JoinHandle<()>>,
304 requests_tx: flume::Sender<oneshot::Sender<anyhow::Result<SocketInner>>>,
306}
307
308impl Drop for ClientConnectionInner {
309 fn drop(&mut self) {
310 tracing::debug!("Dropping client connection");
311 if let Some(endpoint) = self.endpoint.take() {
312 if let Ok(handle) = tokio::runtime::Handle::try_current() {
313 let span = debug_span!("closing client endpoint");
315 handle.spawn(
316 async move {
317 endpoint.close().await;
318 }
319 .instrument(span),
320 );
321 }
322 }
323 if let Some(task) = self.task.take() {
326 tracing::debug!("Aborting task");
327 task.abort();
328 }
329 }
330}
331
332pub struct IrohConnector<In: RpcMessage, Out: RpcMessage> {
334 inner: Arc<ClientConnectionInner>,
335 _p: PhantomData<(In, Out)>,
336}
337
338impl<In: RpcMessage, Out: RpcMessage> IrohConnector<In, Out> {
339 async fn single_connection_handler(
340 connection: Connection,
341 requests_rx: flume::Receiver<oneshot::Sender<anyhow::Result<SocketInner>>>,
342 ) {
343 loop {
344 tracing::debug!("Awaiting request for new bidi substream...");
345 let Ok(request_tx) = requests_rx.recv_async().await else {
346 tracing::info!("Single connection handler finished");
347 return;
348 };
349
350 tracing::debug!("Got request for new bidi substream");
351 match connection.open_bi().await {
352 Ok(pair) => {
353 tracing::debug!("Bidi substream opened");
354 if request_tx.send(Ok(pair)).is_err() {
355 tracing::debug!("requester dropped");
356 }
357 }
358 Err(e) => {
359 tracing::warn!(?e, "error opening bidi substream");
360 if request_tx
361 .send(anyhow::Context::context(
362 Err(e),
363 "error opening bidi substream",
364 ))
365 .is_err()
366 {
367 tracing::debug!("requester dropped");
368 }
369 }
370 }
371 }
372 }
373
374 async fn reconnect_handler_inner(
380 endpoint: iroh::Endpoint,
381 node_addr: NodeAddr,
382 alpn: Vec<u8>,
383 requests_rx: flume::Receiver<oneshot::Sender<anyhow::Result<SocketInner>>>,
384 ) {
385 let mut reconnect = pin!(ReconnectHandler {
386 endpoint,
387 state: ConnectionState::NotConnected,
388 node_addr,
389 alpn,
390 });
391
392 let mut pending_request: Option<oneshot::Sender<anyhow::Result<SocketInner>>> = None;
393 let mut connection: Option<Connection> = None;
394
395 loop {
396 if pending_request.is_none() {
398 pending_request = match requests_rx.try_recv() {
399 Ok(req) => Some(req),
400 Err(TryRecvError::Empty) => None,
401 Err(TryRecvError::Disconnected) => {
402 tracing::debug!("client dropped");
403 if let Some(connection) = connection {
404 connection.close(0u32.into(), b"requester dropped");
405 }
406 break;
407 }
408 };
409 }
410
411 if !reconnect.connected() {
413 tracing::trace!("tick: connection result");
414 match reconnect.as_mut().await {
415 Ok(new_connection) => {
416 connection = Some(new_connection);
417 }
418 Err(e) => {
419 if let Some(request_ack_tx) = pending_request.take() {
421 if request_ack_tx.send(Err(e)).is_err() {
422 tracing::debug!("requester dropped");
423 }
424 }
425
426 yield_now().await;
430 }
431 }
432 } else if pending_request.is_none() {
434 let Ok(req) = requests_rx.recv_async().await else {
435 tracing::debug!("client dropped");
436 if let Some(connection) = connection {
437 connection.close(0u32.into(), b"requester dropped");
438 }
439 break;
440 };
441
442 tracing::trace!("tick: bidi request");
443 pending_request = Some(req);
444 }
445
446 if let Some(connection) = connection.as_mut() {
448 if let Some(request) = pending_request.take() {
449 match connection.open_bi().await {
450 Ok(pair) => {
451 tracing::debug!("Bidi substream opened");
452 if request.send(Ok(pair)).is_err() {
453 tracing::debug!("requester dropped");
454 }
455 }
456 Err(e) => {
457 tracing::warn!(?e, "error opening bidi substream");
458 tracing::warn!("recreating connection");
459 reconnect.set_not_connected();
463 pending_request = Some(request);
464 }
465 }
466 }
467 }
468 }
469 }
470
471 async fn reconnect_handler(
472 endpoint: iroh::Endpoint,
473 addr: NodeAddr,
474 alpn: Vec<u8>,
475 requests_rx: flume::Receiver<oneshot::Sender<anyhow::Result<SocketInner>>>,
476 ) {
477 Self::reconnect_handler_inner(endpoint, addr, alpn, requests_rx).await;
478 tracing::info!("Reconnect handler finished");
479 }
480
481 pub fn from_connection(connection: Connection) -> Self {
483 let (requests_tx, requests_rx) = flume::bounded(16);
484 let task = tokio::spawn(Self::single_connection_handler(connection, requests_rx));
485 Self {
486 inner: Arc::new(ClientConnectionInner {
487 endpoint: None,
488 task: Some(task),
489 requests_tx,
490 }),
491 _p: PhantomData,
492 }
493 }
494
495 pub fn new(endpoint: iroh::Endpoint, node_addr: impl Into<NodeAddr>, alpn: Vec<u8>) -> Self {
497 let (requests_tx, requests_rx) = flume::bounded(16);
498 let task = tokio::spawn(Self::reconnect_handler(
499 endpoint.clone(),
500 node_addr.into(),
501 alpn,
502 requests_rx,
503 ));
504 Self {
505 inner: Arc::new(ClientConnectionInner {
506 endpoint: Some(endpoint),
507 task: Some(task),
508 requests_tx,
509 }),
510 _p: PhantomData,
511 }
512 }
513}
514
515struct ReconnectHandler {
516 endpoint: iroh::Endpoint,
517 state: ConnectionState,
518 node_addr: NodeAddr,
519 alpn: Vec<u8>,
520}
521
522impl ReconnectHandler {
523 pub fn set_not_connected(&mut self) {
524 self.state.set_not_connected()
525 }
526
527 pub fn connected(&self) -> bool {
528 matches!(self.state, ConnectionState::Connected(_))
529 }
530}
531
532enum ConnectionState {
533 NotConnected,
535 Connecting(Pin<Box<dyn Future<Output = anyhow::Result<Connection>> + Send>>),
537 Connected(Connection),
539 Poisoned,
541}
542
543impl ConnectionState {
544 pub fn poison(&mut self) -> Self {
545 std::mem::replace(self, Self::Poisoned)
546 }
547
548 pub fn set_not_connected(&mut self) {
549 *self = Self::NotConnected
550 }
551}
552
553impl Future for ReconnectHandler {
554 type Output = anyhow::Result<Connection>;
555
556 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
557 match self.state.poison() {
558 ConnectionState::NotConnected => {
559 self.state = ConnectionState::Connecting(Box::pin({
560 let endpoint = self.endpoint.clone();
561 let node_addr = self.node_addr.clone();
562 let alpn = self.alpn.clone();
563 async move { endpoint.connect(node_addr, &alpn).await }
564 }));
565 self.poll(cx)
566 }
567
568 ConnectionState::Connecting(mut connecting) => match connecting.as_mut().poll(cx) {
569 Poll::Ready(res) => match res {
570 Ok(connection) => {
571 self.state = ConnectionState::Connected(connection.clone());
572 Poll::Ready(Ok(connection))
573 }
574 Err(e) => {
575 self.state = ConnectionState::NotConnected;
576 Poll::Ready(Err(e))
577 }
578 },
579 Poll::Pending => {
580 self.state = ConnectionState::Connecting(connecting);
581 Poll::Pending
582 }
583 },
584
585 ConnectionState::Connected(connection) => {
586 self.state = ConnectionState::Connected(connection.clone());
587 Poll::Ready(Ok(connection))
588 }
589
590 ConnectionState::Poisoned => unreachable!("poisoned connection state"),
591 }
592 }
593}
594
595impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for IrohConnector<In, Out> {
596 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
597 f.debug_struct("ClientChannel")
598 .field("inner", &self.inner)
599 .finish()
600 }
601}
602
603impl<In: RpcMessage, Out: RpcMessage> Clone for IrohConnector<In, Out> {
604 fn clone(&self) -> Self {
605 Self {
606 inner: self.inner.clone(),
607 _p: PhantomData,
608 }
609 }
610}
611
612impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for IrohConnector<In, Out> {
613 type SendError = io::Error;
614 type RecvError = io::Error;
615 type OpenError = anyhow::Error;
616 type AcceptError = anyhow::Error;
617}
618
619impl<In: RpcMessage, Out: RpcMessage> StreamTypes for IrohConnector<In, Out> {
620 type In = In;
621 type Out = Out;
622 type SendSink = SendSink<Out>;
623 type RecvStream = RecvStream<In>;
624}
625
626impl<In: RpcMessage, Out: RpcMessage> Connector for IrohConnector<In, Out> {
627 async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
628 let (request_ack_tx, request_ack_rx) = oneshot::channel();
629
630 self.inner
631 .requests_tx
632 .send_async(request_ack_tx)
633 .await
634 .map_err(|_| quinn::ConnectionError::LocallyClosed)?;
635
636 let (send, recv) = request_ack_rx
637 .await
638 .map_err(|_| quinn::ConnectionError::LocallyClosed)??;
639
640 Ok((SendSink::new(send), RecvStream::new(recv)))
641 }
642}
643
644#[pin_project]
649pub struct SendSink<Out>(#[pin] FramedPostcardWrite<quinn::SendStream, Out>);
650
651impl<Out> fmt::Debug for SendSink<Out> {
652 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
653 f.debug_struct("SendSink").finish()
654 }
655}
656
657impl<Out: Serialize> SendSink<Out> {
658 fn new(inner: quinn::SendStream) -> Self {
659 let inner = FramedPostcardWrite::new(inner, MAX_FRAME_LENGTH);
660 Self(inner)
661 }
662}
663
664impl<Out> SendSink<Out> {
665 pub fn into_inner(self) -> quinn::SendStream {
668 self.0.into_inner()
669 }
670}
671
672impl<Out: Serialize> Sink<Out> for SendSink<Out> {
673 type Error = io::Error;
674
675 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
676 Pin::new(&mut self.project().0).poll_ready(cx)
677 }
678
679 fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
680 Pin::new(&mut self.project().0).start_send(item)
681 }
682
683 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
684 Pin::new(&mut self.project().0).poll_flush(cx)
685 }
686
687 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
688 Pin::new(&mut self.project().0).poll_close(cx)
689 }
690}
691
692#[pin_project]
697pub struct RecvStream<In>(#[pin] FramedPostcardRead<quinn::RecvStream, In>);
698
699impl<In> fmt::Debug for RecvStream<In> {
700 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
701 f.debug_struct("RecvStream").finish()
702 }
703}
704
705impl<In: DeserializeOwned> RecvStream<In> {
706 fn new(inner: quinn::RecvStream) -> Self {
707 let inner = FramedPostcardRead::new(inner, MAX_FRAME_LENGTH);
708 Self(inner)
709 }
710}
711
712impl<In> RecvStream<In> {
713 pub fn into_inner(self) -> quinn::RecvStream {
716 self.0.into_inner()
717 }
718}
719
720impl<In: DeserializeOwned> Stream for RecvStream<In> {
721 type Item = Result<In, io::Error>;
722
723 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
724 Pin::new(&mut self.project().0).poll_next(cx)
725 }
726}
727
728pub type OpenBiError = anyhow::Error;
730
731pub type AcceptError = quinn::ConnectionError;