use std::{
collections::BTreeSet,
fmt,
future::Future,
io,
iter::once,
marker::PhantomData,
net::SocketAddr,
pin::{pin, Pin},
sync::Arc,
task::{Context, Poll},
};
use flume::TryRecvError;
use futures_lite::Stream;
use futures_sink::Sink;
use iroh_net::{NodeAddr, NodeId};
use pin_project::pin_project;
use quinn::Connection;
use serde::{de::DeserializeOwned, Serialize};
use tokio::{sync::oneshot, task::yield_now};
use tracing::{debug_span, Instrument};
use super::{
util::{FramedBincodeRead, FramedBincodeWrite},
StreamTypes,
};
use crate::{
transport::{ConnectionErrors, Connector, Listener, LocalAddr},
RpcMessage,
};
const MAX_FRAME_LENGTH: usize = 1024 * 1024 * 16;
#[derive(Debug)]
struct ListenerInner {
endpoint: Option<iroh_net::Endpoint>,
task: Option<tokio::task::JoinHandle<()>>,
local_addr: Vec<LocalAddr>,
receiver: flume::Receiver<SocketInner>,
}
impl Drop for ListenerInner {
fn drop(&mut self) {
tracing::debug!("Dropping server endpoint");
if let Some(endpoint) = self.endpoint.take() {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let span = debug_span!("closing listener");
handle.spawn(
async move {
if let Err(e) = endpoint.close(0u32.into(), b"Listener dropped").await {
tracing::warn!(?e, "error closing listener");
}
}
.instrument(span),
);
}
}
if let Some(task) = self.task.take() {
task.abort()
}
}
}
#[derive(Debug, Clone)]
pub enum AccessControl {
Unrestricted,
Allowed(Vec<NodeId>),
}
#[derive(Debug)]
pub struct IrohNetListener<In: RpcMessage, Out: RpcMessage> {
inner: Arc<ListenerInner>,
_p: PhantomData<(In, Out)>,
}
impl<In: RpcMessage, Out: RpcMessage> IrohNetListener<In, Out> {
async fn connection_handler(connection: quinn::Connection, sender: flume::Sender<SocketInner>) {
loop {
tracing::debug!("Awaiting incoming bidi substream on existing connection...");
let bidi_stream = match connection.accept_bi().await {
Ok(bidi_stream) => bidi_stream,
Err(quinn::ConnectionError::ApplicationClosed(e)) => {
tracing::debug!(?e, "Peer closed the connection");
break;
}
Err(e) => {
tracing::debug!(?e, "Error accepting stream");
break;
}
};
tracing::debug!("Sending substream to be handled... {}", bidi_stream.0.id());
if sender.send_async(bidi_stream).await.is_err() {
tracing::debug!("Receiver dropped");
break;
}
}
}
async fn endpoint_handler(
endpoint: iroh_net::Endpoint,
sender: flume::Sender<SocketInner>,
allowed_node_ids: BTreeSet<NodeId>,
) {
loop {
tracing::debug!("Waiting for incoming connection...");
let connecting = match endpoint.accept().await {
Some(connecting) => connecting,
None => break,
};
tracing::debug!("Awaiting connection from connect...");
let connection = match connecting.await {
Ok(connection) => connection,
Err(e) => {
tracing::warn!(?e, "Error accepting connection");
continue;
}
};
if !allowed_node_ids.is_empty() {
let Ok(client_node_id) = iroh_net::endpoint::get_remote_node_id(&connection)
.map_err(|e| {
tracing::error!(
?e,
"Failed to extract iroh-net node id from incoming connection from {:?}",
connection.remote_address()
)
})
else {
connection.close(0u32.into(), b"failed to extract iroh-net node id");
continue;
};
if !allowed_node_ids.contains(&client_node_id) {
connection.close(0u32.into(), b"forbidden node id");
continue;
}
}
tracing::debug!(
"Connection established from {:?}",
connection.remote_address()
);
tracing::debug!("Spawning connection handler...");
tokio::spawn(Self::connection_handler(connection, sender.clone()));
}
}
pub fn new(endpoint: iroh_net::Endpoint) -> io::Result<Self> {
Self::new_with_access_control(endpoint, AccessControl::Unrestricted)
}
pub fn new_with_access_control(
endpoint: iroh_net::Endpoint,
access_control: AccessControl,
) -> io::Result<Self> {
let allowed_node_ids = match access_control {
AccessControl::Unrestricted => BTreeSet::new(),
AccessControl::Allowed(list) if list.is_empty() => {
return Err(io::Error::other(
"Empty list of allowed nodes, \
endpoint would reject all connections",
));
}
AccessControl::Allowed(list) => BTreeSet::from_iter(list),
};
let (ipv4_socket_addr, maybe_ipv6_socket_addr) = endpoint.bound_sockets();
let (sender, receiver) = flume::bounded(16);
let task = tokio::spawn(Self::endpoint_handler(
endpoint.clone(),
sender,
allowed_node_ids,
));
Ok(Self {
inner: Arc::new(ListenerInner {
endpoint: Some(endpoint),
task: Some(task),
local_addr: once(LocalAddr::Socket(ipv4_socket_addr))
.chain(maybe_ipv6_socket_addr.map(LocalAddr::Socket))
.collect(),
receiver,
}),
_p: PhantomData,
})
}
pub fn handle_connections(
incoming: flume::Receiver<quinn::Connection>,
local_addr: SocketAddr,
) -> Self {
let (sender, receiver) = flume::bounded(16);
let task = tokio::spawn(async move {
while let Ok(connection) = incoming.recv_async().await {
tokio::spawn(Self::connection_handler(connection, sender.clone()));
}
});
Self {
inner: Arc::new(ListenerInner {
endpoint: None,
task: Some(task),
local_addr: vec![LocalAddr::Socket(local_addr)],
receiver,
}),
_p: PhantomData,
}
}
pub fn handle_substreams(
receiver: flume::Receiver<SocketInner>,
local_addr: SocketAddr,
) -> Self {
Self {
inner: Arc::new(ListenerInner {
endpoint: None,
task: None,
local_addr: vec![LocalAddr::Socket(local_addr)],
receiver,
}),
_p: PhantomData,
}
}
}
impl<In: RpcMessage, Out: RpcMessage> Clone for IrohNetListener<In, Out> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_p: PhantomData,
}
}
}
impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for IrohNetListener<In, Out> {
type SendError = io::Error;
type RecvError = io::Error;
type OpenError = quinn::ConnectionError;
type AcceptError = quinn::ConnectionError;
}
impl<In: RpcMessage, Out: RpcMessage> StreamTypes for IrohNetListener<In, Out> {
type In = In;
type Out = Out;
type SendSink = SendSink<Out>;
type RecvStream = RecvStream<In>;
}
impl<In: RpcMessage, Out: RpcMessage> Listener for IrohNetListener<In, Out> {
async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
let (send, recv) = self
.inner
.receiver
.recv_async()
.await
.map_err(|_| quinn::ConnectionError::LocallyClosed)?;
Ok((SendSink::new(send), RecvStream::new(recv)))
}
fn local_addr(&self) -> &[LocalAddr] {
&self.inner.local_addr
}
}
type SocketInner = (quinn::SendStream, quinn::RecvStream);
#[derive(Debug)]
struct ClientConnectionInner {
endpoint: Option<iroh_net::Endpoint>,
task: Option<tokio::task::JoinHandle<()>>,
requests_tx: flume::Sender<oneshot::Sender<anyhow::Result<SocketInner>>>,
}
impl Drop for ClientConnectionInner {
fn drop(&mut self) {
tracing::debug!("Dropping client connection");
if let Some(endpoint) = self.endpoint.take() {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let span = debug_span!("closing client endpoint");
handle.spawn(
async move {
if let Err(e) = endpoint
.close(0u32.into(), b"client connection dropped")
.await
{
tracing::warn!(?e, "error closing client endpoint");
}
}
.instrument(span),
);
}
}
if let Some(task) = self.task.take() {
tracing::debug!("Aborting task");
task.abort();
}
}
}
pub struct IrohNetConnector<In: RpcMessage, Out: RpcMessage> {
inner: Arc<ClientConnectionInner>,
_p: PhantomData<(In, Out)>,
}
impl<In: RpcMessage, Out: RpcMessage> IrohNetConnector<In, Out> {
async fn single_connection_handler(
connection: quinn::Connection,
requests_rx: flume::Receiver<oneshot::Sender<anyhow::Result<SocketInner>>>,
) {
loop {
tracing::debug!("Awaiting request for new bidi substream...");
let Ok(request_tx) = requests_rx.recv_async().await else {
tracing::info!("Single connection handler finished");
return;
};
tracing::debug!("Got request for new bidi substream");
match connection.open_bi().await {
Ok(pair) => {
tracing::debug!("Bidi substream opened");
if request_tx.send(Ok(pair)).is_err() {
tracing::debug!("requester dropped");
}
}
Err(e) => {
tracing::warn!(?e, "error opening bidi substream");
if request_tx
.send(anyhow::Context::context(
Err(e),
"error opening bidi substream",
))
.is_err()
{
tracing::debug!("requester dropped");
}
}
}
}
}
async fn reconnect_handler_inner(
endpoint: iroh_net::Endpoint,
node_addr: NodeAddr,
alpn: Vec<u8>,
requests_rx: flume::Receiver<oneshot::Sender<anyhow::Result<SocketInner>>>,
) {
let mut reconnect = pin!(ReconnectHandler {
endpoint,
state: ConnectionState::NotConnected,
node_addr,
alpn,
});
let mut pending_request: Option<oneshot::Sender<anyhow::Result<SocketInner>>> = None;
let mut connection: Option<Connection> = None;
loop {
if pending_request.is_none() {
pending_request = match requests_rx.try_recv() {
Ok(req) => Some(req),
Err(TryRecvError::Empty) => None,
Err(TryRecvError::Disconnected) => {
tracing::debug!("client dropped");
if let Some(connection) = connection {
connection.close(0u32.into(), b"requester dropped");
}
break;
}
};
}
if !reconnect.connected() {
tracing::trace!("tick: connection result");
match reconnect.as_mut().await {
Ok(new_connection) => {
connection = Some(new_connection);
}
Err(e) => {
if let Some(request_ack_tx) = pending_request.take() {
if request_ack_tx.send(Err(e)).is_err() {
tracing::debug!("requester dropped");
}
}
yield_now().await;
}
}
} else if pending_request.is_none() {
let Ok(req) = requests_rx.recv_async().await else {
tracing::debug!("client dropped");
if let Some(connection) = connection {
connection.close(0u32.into(), b"requester dropped");
}
break;
};
tracing::trace!("tick: bidi request");
pending_request = Some(req);
}
if let Some(connection) = connection.as_mut() {
if let Some(request) = pending_request.take() {
match connection.open_bi().await {
Ok(pair) => {
tracing::debug!("Bidi substream opened");
if request.send(Ok(pair)).is_err() {
tracing::debug!("requester dropped");
}
}
Err(e) => {
tracing::warn!(?e, "error opening bidi substream");
tracing::warn!("recreating connection");
reconnect.set_not_connected();
pending_request = Some(request);
}
}
}
}
}
}
async fn reconnect_handler(
endpoint: iroh_net::Endpoint,
addr: NodeAddr,
alpn: Vec<u8>,
requests_rx: flume::Receiver<oneshot::Sender<anyhow::Result<SocketInner>>>,
) {
Self::reconnect_handler_inner(endpoint, addr, alpn, requests_rx).await;
tracing::info!("Reconnect handler finished");
}
pub fn from_connection(connection: quinn::Connection) -> Self {
let (requests_tx, requests_rx) = flume::bounded(16);
let task = tokio::spawn(Self::single_connection_handler(connection, requests_rx));
Self {
inner: Arc::new(ClientConnectionInner {
endpoint: None,
task: Some(task),
requests_tx,
}),
_p: PhantomData,
}
}
pub fn new(
endpoint: iroh_net::Endpoint,
node_addr: impl Into<NodeAddr>,
alpn: Vec<u8>,
) -> Self {
let (requests_tx, requests_rx) = flume::bounded(16);
let task = tokio::spawn(Self::reconnect_handler(
endpoint.clone(),
node_addr.into(),
alpn,
requests_rx,
));
Self {
inner: Arc::new(ClientConnectionInner {
endpoint: Some(endpoint),
task: Some(task),
requests_tx,
}),
_p: PhantomData,
}
}
}
struct ReconnectHandler {
endpoint: iroh_net::Endpoint,
state: ConnectionState,
node_addr: NodeAddr,
alpn: Vec<u8>,
}
impl ReconnectHandler {
pub fn set_not_connected(&mut self) {
self.state.set_not_connected()
}
pub fn connected(&self) -> bool {
matches!(self.state, ConnectionState::Connected(_))
}
}
enum ConnectionState {
NotConnected,
Connecting(Pin<Box<dyn Future<Output = anyhow::Result<quinn::Connection>> + Send>>),
Connected(quinn::Connection),
Poisoned,
}
impl ConnectionState {
pub fn poison(&mut self) -> Self {
std::mem::replace(self, Self::Poisoned)
}
pub fn set_not_connected(&mut self) {
*self = Self::NotConnected
}
}
impl Future for ReconnectHandler {
type Output = anyhow::Result<quinn::Connection>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state.poison() {
ConnectionState::NotConnected => {
self.state = ConnectionState::Connecting(Box::pin({
let endpoint = self.endpoint.clone();
let node_addr = self.node_addr.clone();
let alpn = self.alpn.clone();
async move { endpoint.connect(node_addr, &alpn).await }
}));
self.poll(cx)
}
ConnectionState::Connecting(mut connecting) => match connecting.as_mut().poll(cx) {
Poll::Ready(res) => match res {
Ok(connection) => {
self.state = ConnectionState::Connected(connection.clone());
Poll::Ready(Ok(connection))
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(e))
}
},
Poll::Pending => {
self.state = ConnectionState::Connecting(connecting);
Poll::Pending
}
},
ConnectionState::Connected(connection) => {
self.state = ConnectionState::Connected(connection.clone());
Poll::Ready(Ok(connection))
}
ConnectionState::Poisoned => unreachable!("poisoned connection state"),
}
}
}
impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for IrohNetConnector<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientChannel")
.field("inner", &self.inner)
.finish()
}
}
impl<In: RpcMessage, Out: RpcMessage> Clone for IrohNetConnector<In, Out> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_p: PhantomData,
}
}
}
impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for IrohNetConnector<In, Out> {
type SendError = io::Error;
type RecvError = io::Error;
type OpenError = anyhow::Error;
type AcceptError = anyhow::Error;
}
impl<In: RpcMessage, Out: RpcMessage> StreamTypes for IrohNetConnector<In, Out> {
type In = In;
type Out = Out;
type SendSink = SendSink<Out>;
type RecvStream = RecvStream<In>;
}
impl<In: RpcMessage, Out: RpcMessage> Connector for IrohNetConnector<In, Out> {
async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
let (request_ack_tx, request_ack_rx) = oneshot::channel();
self.inner
.requests_tx
.send_async(request_ack_tx)
.await
.map_err(|_| quinn::ConnectionError::LocallyClosed)?;
let (send, recv) = request_ack_rx
.await
.map_err(|_| quinn::ConnectionError::LocallyClosed)??;
Ok((SendSink::new(send), RecvStream::new(recv)))
}
}
#[pin_project]
pub struct SendSink<Out>(#[pin] FramedBincodeWrite<quinn::SendStream, Out>);
impl<Out> fmt::Debug for SendSink<Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SendSink").finish()
}
}
impl<Out: Serialize> SendSink<Out> {
fn new(inner: quinn::SendStream) -> Self {
let inner = FramedBincodeWrite::new(inner, MAX_FRAME_LENGTH);
Self(inner)
}
}
impl<Out> SendSink<Out> {
pub fn into_inner(self) -> quinn::SendStream {
self.0.into_inner()
}
}
impl<Out: Serialize> Sink<Out> for SendSink<Out> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.project().0).poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
Pin::new(&mut self.project().0).start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.project().0).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.project().0).poll_close(cx)
}
}
#[pin_project]
pub struct RecvStream<In>(#[pin] FramedBincodeRead<quinn::RecvStream, In>);
impl<In> fmt::Debug for RecvStream<In> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RecvStream").finish()
}
}
impl<In: DeserializeOwned> RecvStream<In> {
fn new(inner: quinn::RecvStream) -> Self {
let inner = FramedBincodeRead::new(inner, MAX_FRAME_LENGTH);
Self(inner)
}
}
impl<In> RecvStream<In> {
pub fn into_inner(self) -> quinn::RecvStream {
self.0.into_inner()
}
}
impl<In: DeserializeOwned> Stream for RecvStream<In> {
type Item = Result<In, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.project().0).poll_next(cx)
}
}
pub type OpenBiError = anyhow::Error;
pub type AcceptError = quinn::ConnectionError;