use std::{collections::HashMap, future::Future, net::SocketAddr, pin::Pin, sync::Arc};
use anyhow::{bail, ensure, Context as _, Result};
use bytes::Bytes;
use derive_more::Debug;
use futures_lite::FutureExt;
use http::{header::CONNECTION, response::Builder as ResponseBuilder};
use hyper::{
body::Incoming,
header::{HeaderValue, UPGRADE},
service::Service,
upgrade::Upgraded,
HeaderMap, Method, Request, Response, StatusCode,
};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls_acme::AcmeAcceptor;
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
use tracing::{debug, debug_span, error, info, info_span, warn, Instrument};
use tungstenite::handshake::derive_accept_key;
use crate::relay::{
http::{Protocol, LEGACY_RELAY_PATH, RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION},
server::{
actor::{ClientConnHandler, ServerActorTask},
streams::MaybeTlsStream,
},
};
type BytesBody = http_body_util::Full<hyper::body::Bytes>;
type HyperError = Box<dyn std::error::Error + Send + Sync>;
type HyperResult<T> = std::result::Result<T, HyperError>;
type HyperHandler = Box<
dyn Fn(Request<Incoming>, ResponseBuilder) -> HyperResult<Response<BytesBody>>
+ Send
+ Sync
+ 'static,
>;
fn body_empty() -> BytesBody {
http_body_util::Full::new(hyper::body::Bytes::new())
}
fn body_full(content: impl Into<hyper::body::Bytes>) -> BytesBody {
http_body_util::Full::new(content.into())
}
fn downcast_upgrade(upgraded: Upgraded) -> Result<(MaybeTlsStream, Bytes)> {
match upgraded.downcast::<hyper_util::rt::TokioIo<MaybeTlsStream>>() {
Ok(parts) => Ok((parts.io.into_inner(), parts.read_buf)),
Err(_) => {
bail!("could not downcast the upgraded connection to MaybeTlsStream")
}
}
}
async fn relay_connection_handler(
protocol: Protocol,
conn_handler: &ClientConnHandler,
upgraded: Upgraded,
) -> Result<()> {
debug!(?protocol, "relay_connection upgraded");
let (io, read_buf) = downcast_upgrade(upgraded)?;
ensure!(
read_buf.is_empty(),
"can not deal with buffered data yet: {:?}",
read_buf
);
conn_handler.accept(protocol, io).await
}
#[derive(Debug)]
pub struct Server {
addr: SocketAddr,
http_server_task: AbortOnDropHandle<()>,
cancel_server_loop: CancellationToken,
}
impl Server {
pub fn handle(&self) -> ServerHandle {
ServerHandle {
cancel_token: self.cancel_server_loop.clone(),
}
}
pub fn shutdown(&self) {
self.cancel_server_loop.cancel();
}
pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
&mut self.http_server_task
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
}
#[derive(Debug, Clone)]
pub struct ServerHandle {
cancel_token: CancellationToken,
}
impl ServerHandle {
pub fn shutdown(&self) {
self.cancel_token.cancel()
}
}
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub acceptor: TlsAcceptor,
}
#[derive(derive_more::Debug)]
pub struct ServerBuilder {
addr: SocketAddr,
tls_config: Option<TlsConfig>,
handlers: Handlers,
headers: HeaderMap,
#[debug("{}", not_found_fn.as_ref().map_or("None", |_| "Some(Box<Fn(ResponseBuilder) -> Result<Response<Body>> + Send + Sync + 'static>)"))]
not_found_fn: Option<HyperHandler>,
}
impl ServerBuilder {
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
tls_config: None,
handlers: Default::default(),
headers: HeaderMap::new(),
not_found_fn: None,
}
}
pub fn tls_config(mut self, config: Option<TlsConfig>) -> Self {
self.tls_config = config;
self
}
pub fn request_handler(
mut self,
method: Method,
uri_path: &'static str,
handler: HyperHandler,
) -> Self {
self.handlers.insert((method, uri_path), handler);
self
}
#[allow(unused)]
pub fn not_found_handler(mut self, handler: HyperHandler) -> Self {
self.not_found_fn = Some(handler);
self
}
pub fn headers(mut self, headers: HeaderMap) -> Self {
for (k, v) in headers.iter() {
self.headers.insert(k.clone(), v.clone());
}
self
}
pub async fn spawn(self) -> Result<Server> {
let relay_server = ServerActorTask::new();
let relay_handler = relay_server.client_conn_handler(self.headers.clone());
let h = self.headers.clone();
let not_found_fn = match self.not_found_fn {
Some(f) => f,
None => Box::new(move |_req: Request<Incoming>, mut res: ResponseBuilder| {
for (k, v) in h.iter() {
res = res.header(k.clone(), v.clone());
}
let body = body_full("Not Found");
let r = res.status(StatusCode::NOT_FOUND).body(body)?;
HyperResult::Ok(r)
}),
};
let service = RelayService::new(self.handlers, relay_handler, not_found_fn, self.headers);
let server_state = ServerState {
addr: self.addr,
tls_config: self.tls_config,
server: relay_server,
service,
};
server_state.serve().await
}
}
#[derive(Debug)]
struct ServerState {
addr: SocketAddr,
tls_config: Option<TlsConfig>,
server: ServerActorTask,
service: RelayService,
}
impl ServerState {
async fn serve(self) -> Result<Server> {
let ServerState {
addr,
tls_config,
server,
service,
} = self;
let listener = TcpListener::bind(&addr)
.await
.with_context(|| format!("failed to bind server socket to {addr}"))?;
let cancel_server_loop = CancellationToken::new();
let addr = listener.local_addr()?;
let http_str = tls_config.as_ref().map_or("HTTP/WS", |_| "HTTPS/WSS");
info!("[{http_str}] relay: serving on {addr}");
let cancel = cancel_server_loop.clone();
let task = tokio::task::spawn(async move {
let mut set = tokio::task::JoinSet::new();
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
break;
}
res = listener.accept() => match res {
Ok((stream, peer_addr)) => {
debug!("[{http_str}] relay: Connection opened from {peer_addr}");
let tls_config = tls_config.clone();
let service = service.clone();
set.spawn(async move {
if let Err(error) = service
.handle_connection(stream, tls_config)
.await
{
match error.downcast_ref::<std::io::Error>() {
Some(io_error) if io_error.kind() == std::io::ErrorKind::UnexpectedEof => {
debug!(reason=?error, "[{http_str}] relay: peer disconnected");
},
_ => {
error!(?error, "[{http_str}] relay: failed to handle connection");
}
}
}
}.instrument(info_span!("conn", peer = %peer_addr)));
}
Err(err) => {
error!("[{http_str}] relay: failed to accept connection: {err}");
}
}
}
}
server.close().await;
set.shutdown().await;
debug!("[{http_str}] relay: server has been shutdown.");
}.instrument(info_span!("relay-http-serve")));
Ok(Server {
addr,
http_server_task: AbortOnDropHandle::new(task),
cancel_server_loop,
})
}
}
impl Service<Request<Incoming>> for ClientConnHandler {
type Response = Response<BytesBody>;
type Error = hyper::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, mut req: Request<Incoming>) -> Self::Future {
let closure_conn_handler = self.clone();
let mut builder = Response::builder();
for (key, value) in self.default_headers.iter() {
builder = builder.header(key, value);
}
async move {
{
let Some(protocol) = req.headers().get(UPGRADE).and_then(Protocol::parse_header)
else {
return Ok(builder
.status(StatusCode::BAD_REQUEST)
.body(body_empty())
.expect("valid body"));
};
let websocket_headers = if protocol == Protocol::Websocket {
let Some(key) = req.headers().get("Sec-WebSocket-Key").cloned() else {
warn!("missing header Sec-WebSocket-Key for websocket relay protocol");
return Ok(builder
.status(StatusCode::BAD_REQUEST)
.body(body_empty())
.expect("valid body"));
};
let Some(version) = req.headers().get("Sec-WebSocket-Version").cloned() else {
warn!("missing header Sec-WebSocket-Version for websocket relay protocol");
return Ok(builder
.status(StatusCode::BAD_REQUEST)
.body(body_empty())
.expect("valid body"));
};
if version.as_bytes() != SUPPORTED_WEBSOCKET_VERSION.as_bytes() {
warn!("invalid header Sec-WebSocket-Version: {:?}", version);
return Ok(builder
.status(StatusCode::BAD_REQUEST)
.header("Sec-WebSocket-Version", SUPPORTED_WEBSOCKET_VERSION)
.body(body_empty())
.expect("valid body"));
}
Some((key, version))
} else {
None
};
debug!("upgrading protocol: {:?}", protocol);
tokio::task::spawn(
async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
if let Err(e) = relay_connection_handler(
protocol,
&closure_conn_handler,
upgraded,
)
.await
{
warn!(
"upgrade to \"{}\": io error: {:?}",
e,
protocol.upgrade_header()
);
} else {
debug!("upgrade to \"{}\" success", protocol.upgrade_header());
};
}
Err(e) => warn!("upgrade error: {:?}", e),
}
}
.instrument(debug_span!("handler")),
);
builder = builder
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(UPGRADE, HeaderValue::from_static(protocol.upgrade_header()));
if let Some((key, _version)) = websocket_headers {
Ok(builder
.header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes()))
.header(CONNECTION, "upgrade")
.body(body_full("switching to websocket protocol"))
.expect("valid body"))
} else {
Ok(builder.body(body_empty()).expect("valid body"))
}
}
}
.boxed()
}
}
impl Service<Request<Incoming>> for RelayService {
type Response = Response<BytesBody>;
type Error = HyperError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
if matches!(
(req.method(), req.uri().path()),
(&hyper::Method::GET, LEGACY_RELAY_PATH | RELAY_PATH)
) {
let h = self.0.relay_handler.clone();
return Box::pin(async move { h.call(req).await.map_err(Into::into) });
}
let uri = req.uri().clone();
if let Some(res) = self.0.handlers.get(&(req.method().clone(), uri.path())) {
let f = res(req, self.0.default_response());
return Box::pin(async move { f });
}
let res = (self.0.not_found_fn)(req, self.0.default_response());
Box::pin(async move { res })
}
}
#[derive(Clone, Debug)]
struct RelayService(Arc<Inner>);
#[derive(derive_more::Debug)]
struct Inner {
pub relay_handler: ClientConnHandler,
#[debug("Box<Fn(ResponseBuilder) -> Result<Response<BytesBody>> + Send + Sync + 'static>")]
pub not_found_fn: HyperHandler,
pub handlers: Handlers,
pub headers: HeaderMap,
}
impl Inner {
fn default_response(&self) -> ResponseBuilder {
let mut response = Response::builder();
for (key, value) in self.headers.iter() {
response = response.header(key.clone(), value.clone());
}
response
}
}
#[derive(Clone, derive_more::Debug)]
pub enum TlsAcceptor {
LetsEncrypt(#[debug("tokio_rustls_acme::AcmeAcceptor")] AcmeAcceptor),
Manual(#[debug("tokio_rustls::TlsAcceptor")] tokio_rustls::TlsAcceptor),
}
impl RelayService {
fn new(
handlers: Handlers,
relay_handler: ClientConnHandler,
not_found_fn: HyperHandler,
headers: HeaderMap,
) -> Self {
Self(Arc::new(Inner {
relay_handler,
handlers,
not_found_fn,
headers,
}))
}
async fn handle_connection(
self,
stream: TcpStream,
tls_config: Option<TlsConfig>,
) -> Result<()> {
match tls_config {
Some(tls_config) => self.tls_serve_connection(stream, tls_config).await,
None => {
debug!("HTTP: serve connection");
self.serve_connection(MaybeTlsStream::Plain(stream)).await
}
}
}
async fn tls_serve_connection(self, stream: TcpStream, tls_config: TlsConfig) -> Result<()> {
let TlsConfig { acceptor, config } = tls_config;
match acceptor {
TlsAcceptor::LetsEncrypt(a) => match a.accept(stream).await? {
None => {
info!("TLS[acme]: received TLS-ALPN-01 validation request");
}
Some(start_handshake) => {
debug!("TLS[acme]: start handshake");
let tls_stream = start_handshake
.into_stream(config)
.await
.context("TLS[acme] handshake")?;
self.serve_connection(MaybeTlsStream::Tls(tls_stream))
.await
.context("TLS[acme] serve connection")?;
}
},
TlsAcceptor::Manual(a) => {
debug!("TLS[manual]: accept");
let tls_stream = a.accept(stream).await.context("TLS[manual] accept")?;
self.serve_connection(MaybeTlsStream::Tls(tls_stream))
.await
.context("TLS[manual] serve connection")?;
}
}
Ok(())
}
async fn serve_connection<I>(self, io: I) -> Result<()>
where
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + 'static,
{
hyper::server::conn::http1::Builder::new()
.serve_connection(hyper_util::rt::TokioIo::new(io), self)
.with_upgrades()
.await?;
Ok(())
}
}
#[derive(Default)]
struct Handlers(HashMap<(Method, &'static str), HyperHandler>);
impl std::fmt::Debug for Handlers {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = self.0.keys().fold(String::new(), |curr, next| {
let (method, uri) = next;
format!("{curr}\n({method},{uri}): Box<Fn(ResponseBuilder) -> Result<Response<Body>> + Send + Sync + 'static>")
});
write!(f, "HashMap<{s}>")
}
}
impl std::ops::Deref for Handlers {
type Target = HashMap<(Method, &'static str), HyperHandler>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for Handlers {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use anyhow::Result;
use bytes::Bytes;
use reqwest::Url;
use tokio::{sync::mpsc, task::JoinHandle};
use tracing::{info, info_span, Instrument};
use tracing_subscriber::{prelude::*, EnvFilter};
use super::*;
use crate::{
key::{PublicKey, SecretKey},
relay::client::{conn::ReceivedMessage, Client, ClientBuilder},
};
pub(crate) fn make_tls_config() -> TlsConfig {
let subject_alt_names = vec!["localhost".to_string()];
let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
let rustls_certificate =
rustls::pki_types::CertificateDer::from(cert.serialize_der().unwrap());
let rustls_key =
rustls::pki_types::PrivatePkcs8KeyDer::from(cert.get_key_pair().serialize_der());
let rustls_key = rustls::pki_types::PrivateKeyDer::from(rustls_key);
let config = rustls::ServerConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.expect("protocols supported by ring")
.with_no_client_auth()
.with_single_cert(vec![(rustls_certificate)], rustls_key)
.expect("cert is right");
let config = Arc::new(config);
let acceptor = tokio_rustls::TlsAcceptor::from(config.clone());
TlsConfig {
config,
acceptor: TlsAcceptor::Manual(acceptor),
}
}
#[tokio::test]
async fn test_http_clients_and_server() -> Result<()> {
let _guard = iroh_test::logging::setup();
let a_key = SecretKey::generate();
let b_key = SecretKey::generate();
let server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
.spawn()
.await?;
let addr = server.addr();
let port = addr.port();
let addr = {
if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
ipv4_addr
} else {
anyhow::bail!("cannot get ipv4 addr from socket addr {addr:?}");
}
};
info!("addr: {addr}:{port}");
let relay_addr: Url = format!("http://{addr}:{port}").parse().unwrap();
let (a_key, mut a_recv, client_a_task, client_a) = {
let span = info_span!("client-a");
let _guard = span.enter();
create_test_client(a_key, relay_addr.clone())
};
info!("created client {a_key:?}");
let (b_key, mut b_recv, client_b_task, client_b) = {
let span = info_span!("client-b");
let _guard = span.enter();
create_test_client(b_key, relay_addr)
};
info!("created client {b_key:?}");
info!("ping a");
client_a.ping().await?;
info!("ping b");
client_b.ping().await?;
info!("sending message from a to b");
let msg = Bytes::from_static(b"hi there, client b!");
client_a.send(b_key, msg.clone()).await?;
info!("waiting for message from a on b");
let (got_key, got_msg) = b_recv.recv().await.expect("expected message from client_a");
assert_eq!(a_key, got_key);
assert_eq!(msg, got_msg);
info!("sending message from b to a");
let msg = Bytes::from_static(b"right back at ya, client b!");
client_b.send(a_key, msg.clone()).await?;
info!("waiting for message b on a");
let (got_key, got_msg) = a_recv.recv().await.expect("expected message from client_b");
assert_eq!(b_key, got_key);
assert_eq!(msg, got_msg);
client_a.close().await?;
client_a_task.abort();
client_b.close().await?;
client_b_task.abort();
server.shutdown();
Ok(())
}
fn create_test_client(
key: SecretKey,
server_url: Url,
) -> (
PublicKey,
mpsc::Receiver<(PublicKey, Bytes)>,
JoinHandle<()>,
Client,
) {
let client = ClientBuilder::new(server_url).insecure_skip_cert_verify(true);
let dns_resolver = crate::dns::default_resolver();
let (client, mut client_reader) = client.build(key.clone(), dns_resolver.clone());
let public_key = key.public();
let (received_msg_s, received_msg_r) = tokio::sync::mpsc::channel(10);
let client_reader_task = tokio::spawn(
async move {
loop {
info!("waiting for message on {:?}", key.public());
match client_reader.recv().await {
None => {
info!("client received nothing");
return;
}
Some(Err(e)) => {
info!("client {:?} `recv` error {e}", key.public());
return;
}
Some(Ok(msg)) => {
info!("got message on {:?}: {msg:?}", key.public());
if let ReceivedMessage::ReceivedPacket { source, data } = msg {
received_msg_s
.send((source, data))
.await
.unwrap_or_else(|err| {
panic!(
"client {:?}, error sending message over channel: {:?}",
key.public(),
err
)
});
}
}
}
}
}
.instrument(info_span!("test-client-reader")),
);
(public_key, received_msg_r, client_reader_task, client)
}
#[tokio::test]
async fn test_https_clients_and_server() -> Result<()> {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
.with(EnvFilter::from_default_env())
.try_init()
.ok();
let a_key = SecretKey::generate();
let b_key = SecretKey::generate();
let tls_config = make_tls_config();
let mut server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
.tls_config(Some(tls_config))
.spawn()
.await?;
let addr = server.addr();
let port = addr.port();
let addr = {
if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
ipv4_addr
} else {
anyhow::bail!("cannot get ipv4 addr from socket addr {addr:?}");
}
};
info!("Relay listening on: {addr}:{port}");
let url: Url = format!("https://localhost:{port}").parse().unwrap();
let (a_key, mut a_recv, client_a_task, client_a) = create_test_client(a_key, url.clone());
info!("created client {a_key:?}");
let (b_key, mut b_recv, client_b_task, client_b) = create_test_client(b_key, url);
info!("created client {b_key:?}");
client_a.ping().await?;
client_b.ping().await?;
info!("sending message from a to b");
let msg = Bytes::from_static(b"hi there, client b!");
client_a.send(b_key, msg.clone()).await?;
info!("waiting for message from a on b");
let (got_key, got_msg) = b_recv.recv().await.expect("expected message from client_a");
assert_eq!(a_key, got_key);
assert_eq!(msg, got_msg);
info!("sending message from b to a");
let msg = Bytes::from_static(b"right back at ya, client b!");
client_b.send(a_key, msg.clone()).await?;
info!("waiting for message b on a");
let (got_key, got_msg) = a_recv.recv().await.expect("expected message from client_b");
assert_eq!(b_key, got_key);
assert_eq!(msg, got_msg);
server.shutdown();
server.task_handle().await?;
client_a.close().await?;
client_a_task.abort();
client_b.close().await?;
client_b_task.abort();
Ok(())
}
}