moq_native/
quic.rs

1use std::{net, sync::Arc, time};
2
3use anyhow::Context;
4use clap::Parser;
5use url::Url;
6
7use crate::tls;
8
9use futures::future::BoxFuture;
10use futures::stream::{FuturesUnordered, StreamExt};
11use futures::FutureExt;
12
13use web_transport::quinn as web_transport_quinn;
14
15#[derive(Parser, Clone)]
16pub struct Args {
17	/// Listen for UDP packets on the given address.
18	#[arg(long, default_value = "[::]:0")]
19	pub bind: net::SocketAddr,
20
21	#[command(flatten)]
22	pub tls: tls::Args,
23}
24
25impl Default for Args {
26	fn default() -> Self {
27		Self {
28			bind: "[::]:0".parse().unwrap(),
29			tls: Default::default(),
30		}
31	}
32}
33
34impl Args {
35	pub fn load(&self) -> anyhow::Result<Config> {
36		let tls = self.tls.load()?;
37		Ok(Config { bind: self.bind, tls })
38	}
39}
40
41pub struct Config {
42	pub bind: net::SocketAddr,
43	pub tls: tls::Config,
44}
45
46pub struct Endpoint {
47	pub client: Client,
48	pub server: Option<Server>,
49}
50
51impl Endpoint {
52	pub fn new(config: Config) -> anyhow::Result<Self> {
53		// Enable BBR congestion control
54		// TODO validate the implementation
55		let mut transport = quinn::TransportConfig::default();
56		transport.max_idle_timeout(Some(time::Duration::from_secs(9).try_into().unwrap()));
57		transport.keep_alive_interval(Some(time::Duration::from_secs(4))); // TODO make this smarter
58		transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
59		transport.mtu_discovery_config(None); // Disable MTU discovery
60		let transport = Arc::new(transport);
61
62		let mut server_config = None;
63
64		if let Some(mut config) = config.tls.server {
65			config.alpn_protocols = vec![web_transport::quinn::ALPN.to_vec(), moq_transfork::ALPN.to_vec()];
66			config.key_log = Arc::new(rustls::KeyLogFile::new());
67
68			let config: quinn::crypto::rustls::QuicServerConfig = config.try_into()?;
69			let mut config = quinn::ServerConfig::with_crypto(Arc::new(config));
70			config.transport_config(transport.clone());
71
72			server_config = Some(config);
73		}
74
75		// There's a bit more boilerplate to make a generic endpoint.
76		let runtime = quinn::default_runtime().context("no async runtime")?;
77		let endpoint_config = quinn::EndpointConfig::default();
78		let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
79
80		// Create the generic QUIC endpoint.
81		let quic = quinn::Endpoint::new(endpoint_config, server_config.clone(), socket, runtime)
82			.context("failed to create QUIC endpoint")?;
83
84		let server = server_config.is_some().then(|| Server {
85			quic: quic.clone(),
86			accept: Default::default(),
87		});
88
89		let client = Client {
90			quic,
91			config: config.tls.client,
92			transport,
93		};
94
95		Ok(Self { client, server })
96	}
97}
98
99pub struct Server {
100	quic: quinn::Endpoint,
101	accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Session>>>,
102}
103
104impl Server {
105	pub async fn accept(&mut self) -> Option<web_transport_quinn::Session> {
106		loop {
107			tokio::select! {
108				res = self.quic.accept() => {
109					let conn = res?;
110					self.accept.push(Self::accept_session(conn).boxed());
111				}
112				Some(res) = self.accept.next() => {
113					if let Ok(session) = res {
114						return Some(session)
115					}
116				}
117			}
118		}
119	}
120
121	async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Session> {
122		let mut conn = conn.accept()?;
123
124		let handshake = conn
125			.handshake_data()
126			.await?
127			.downcast::<quinn::crypto::rustls::HandshakeData>()
128			.unwrap();
129
130		let alpn = handshake.protocol.context("missing ALPN")?;
131		let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
132		let host = handshake.server_name.unwrap_or_default();
133
134		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
135
136		// Wait for the QUIC connection to be established.
137		let conn = conn.await.context("failed to establish QUIC connection")?;
138
139		let span = tracing::Span::current();
140		span.record("id", conn.stable_id()); // TODO can we get this earlier?
141
142		let session = match alpn.as_bytes() {
143			web_transport::quinn::ALPN => {
144				// Wait for the CONNECT request.
145				let request = web_transport::quinn::Request::accept(conn)
146					.await
147					.context("failed to receive WebTransport request")?;
148
149				// Accept the CONNECT request.
150				request
151					.ok()
152					.await
153					.context("failed to respond to WebTransport request")?
154			}
155			// A bit of a hack to pretend like we're a WebTransport session
156			moq_transfork::ALPN => conn.into(),
157			_ => anyhow::bail!("unsupported ALPN: {}", alpn),
158		};
159
160		Ok(session)
161	}
162
163	pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
164		self.quic.local_addr().context("failed to get local address")
165	}
166}
167
168#[derive(Clone)]
169pub struct Client {
170	quic: quinn::Endpoint,
171	config: rustls::ClientConfig,
172	transport: Arc<quinn::TransportConfig>,
173}
174
175impl Client {
176	pub async fn connect(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
177		let mut config = self.config.clone();
178
179		let host = url.host().context("invalid DNS name")?.to_string();
180		let port = url.port().unwrap_or(443);
181
182		// Look up the DNS entry.
183		let ip = tokio::net::lookup_host((host.clone(), port))
184			.await
185			.context("failed DNS lookup")?
186			.next()
187			.context("no DNS entries")?;
188
189		if url.scheme() == "http" {
190			// Perform a HTTP request to fetch the certificate fingerprint.
191			let mut fingerprint = url.clone();
192			fingerprint.set_path("/fingerprint");
193
194			tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
195
196			let resp = reqwest::get(fingerprint.as_str())
197				.await
198				.context("failed to fetch fingerprint")?
199				.error_for_status()
200				.context("fingerprint request failed")?;
201
202			let fingerprint = resp.text().await.context("failed to read fingerprint")?;
203			let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
204
205			let verifier = tls::FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
206			config.dangerous().set_certificate_verifier(Arc::new(verifier));
207
208			url.set_scheme("https").expect("failed to set scheme");
209		}
210
211		let alpn = match url.scheme() {
212			"https" => web_transport::quinn::ALPN,
213			"moqf" => moq_transfork::ALPN,
214			_ => anyhow::bail!("url scheme must be 'http', 'https', or 'moqf'"),
215		};
216
217		// TODO support connecting to both ALPNs at the same time
218		config.alpn_protocols = vec![alpn.to_vec()];
219		config.key_log = Arc::new(rustls::KeyLogFile::new());
220
221		let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
222		let mut config = quinn::ClientConfig::new(Arc::new(config));
223		config.transport_config(self.transport.clone());
224
225		tracing::debug!(%url, %ip, alpn = %String::from_utf8_lossy(alpn), "connecting");
226
227		let connection = self.quic.connect_with(config, ip, &host)?.await?;
228		tracing::Span::current().record("id", connection.stable_id());
229
230		let session = match url.scheme() {
231			"https" => web_transport::quinn::Session::connect(connection, &url).await?,
232			"moqf" => connection.into(),
233			_ => unreachable!(),
234		};
235
236		Ok(session)
237	}
238}