moq_native/
tls.rs

1use anyhow::Context;
2use clap::Parser;
3use ring::digest::{digest, SHA256};
4use rustls::crypto::ring::sign::any_supported_type;
5use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName, UnixTime};
6use rustls::server::{ClientHello, ResolvesServerCert};
7use rustls::sign::CertifiedKey;
8use rustls::RootCertStore;
9use std::fs;
10use std::io::{self, Cursor, Read};
11use std::path;
12use std::sync::Arc;
13
14#[derive(Parser, Clone, Default)]
15#[group(id = "tls")]
16pub struct Args {
17	/// Use the certificates at this path, encoded as PEM.
18	///
19	/// You can use this option multiple times for multiple certificates.
20	/// The first match for the provided SNI will be used, otherwise the last cert will be used.
21	/// You also need to provide the private key multiple times via `key``.
22	#[arg(long = "tls-cert", value_delimiter = ',')]
23	pub cert: Vec<path::PathBuf>,
24
25	/// Use the private key at this path, encoded as PEM.
26	///
27	/// There must be a key for every certificate provided via `cert`.
28	#[arg(long = "tls-key", value_delimiter = ',')]
29	pub key: Vec<path::PathBuf>,
30
31	/// Use the TLS root at this path, encoded as PEM.
32	///
33	/// This value can be provided multiple times for multiple roots.
34	/// If this is empty, system roots will be used instead
35	#[arg(long = "tls-root", value_delimiter = ',')]
36	pub root: Vec<path::PathBuf>,
37
38	/// Danger: Disable TLS certificate verification.
39	///
40	/// Fine for local development and between relays, but should be used in caution in production.
41	#[arg(long = "tls-disable-verify")]
42	pub disable_verify: bool,
43
44	/// Generate a self-signed certificate for the provided hostnames (comma separated).
45	///
46	/// This is useful for local development and testing.
47	/// This can be combined with the `/fingerprint` endpoint for clients to fetch the fingerprint.
48	#[arg(long = "tls-self-sign", value_delimiter = ',')]
49	pub self_sign: Vec<String>,
50}
51
52#[derive(Clone)]
53pub struct Config {
54	pub client: rustls::ClientConfig,
55	pub server: Option<rustls::ServerConfig>,
56	pub fingerprints: Vec<String>,
57}
58
59impl Args {
60	pub fn load(&self) -> anyhow::Result<Config> {
61		let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
62		let mut serve = ServeCerts::default();
63
64		// Load the certificate and key files based on their index.
65		anyhow::ensure!(
66			self.cert.len() == self.key.len(),
67			"--tls-cert and --tls-key counts differ"
68		);
69		for (chain, key) in self.cert.iter().zip(self.key.iter()) {
70			serve.load(chain, key)?;
71		}
72
73		if !self.self_sign.is_empty() {
74			serve.generate(&self.self_sign)?;
75		}
76
77		// Create a list of acceptable root certificates.
78		let mut roots = RootCertStore::empty();
79
80		if self.root.is_empty() {
81			let native = rustls_native_certs::load_native_certs();
82
83			// Log any errors that occurred while loading the native root certificates.
84			for err in native.errors {
85				tracing::warn!(?err, "failed to load root cert");
86			}
87
88			// Add the platform's native root certificates.
89			for cert in native.certs {
90				roots.add(cert).context("failed to add root cert")?;
91			}
92		} else {
93			// Add the specified root certificates.
94			for root in &self.root {
95				let root = fs::File::open(root).context("failed to open root cert file")?;
96				let mut root = io::BufReader::new(root);
97
98				let root = rustls_pemfile::certs(&mut root)
99					.next()
100					.context("no roots found")?
101					.context("failed to read root cert")?;
102
103				roots.add(root).context("failed to add root cert")?;
104			}
105		}
106
107		// Create the TLS configuration we'll use as a client (relay -> relay)
108		let mut client = rustls::ClientConfig::builder_with_provider(provider.clone())
109			.with_protocol_versions(&[&rustls::version::TLS13])?
110			.with_root_certificates(roots)
111			.with_no_client_auth();
112
113		// Allow disabling TLS verification altogether.
114		if self.disable_verify {
115			tracing::warn!("TLS server certificate verification is disabled");
116
117			let noop = NoCertificateVerification(provider.clone());
118			client.dangerous().set_certificate_verifier(Arc::new(noop));
119		}
120
121		let fingerprints = serve.fingerprints();
122
123		// Create the TLS configuration we'll use as a server (relay <- browser)
124		let server = if !serve.list.is_empty() {
125			Some(
126				rustls::ServerConfig::builder_with_provider(provider)
127					.with_protocol_versions(&[&rustls::version::TLS13])?
128					.with_no_client_auth()
129					.with_cert_resolver(Arc::new(serve)),
130			)
131		} else {
132			None
133		};
134
135		Ok(Config {
136			server,
137			client,
138			fingerprints,
139		})
140	}
141}
142
143#[derive(Default, Debug)]
144struct ServeCerts {
145	list: Vec<Arc<CertifiedKey>>,
146}
147
148impl ServeCerts {
149	// Load a certificate and cooresponding key from a file
150	pub fn load(&mut self, chain: &path::PathBuf, key: &path::PathBuf) -> anyhow::Result<()> {
151		// Read the PEM certificate chain
152		let chain = fs::File::open(chain).context("failed to open cert file")?;
153		let mut chain = io::BufReader::new(chain);
154
155		let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
156			.collect::<Result<_, _>>()
157			.context("failed to read certs")?;
158
159		anyhow::ensure!(!chain.is_empty(), "could not find certificate");
160
161		// Read the PEM private key
162		let mut keys = fs::File::open(key).context("failed to open key file")?;
163
164		// Read the keys into a Vec so we can parse it twice.
165		let mut buf = Vec::new();
166		keys.read_to_end(&mut buf)?;
167
168		let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
169		let key = rustls::crypto::ring::sign::any_supported_type(&key)?;
170
171		let certified = Arc::new(CertifiedKey::new(chain, key));
172		self.list.push(certified);
173
174		Ok(())
175	}
176
177	pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
178		let key_pair = rcgen::KeyPair::generate()?;
179
180		let mut params = rcgen::CertificateParams::new(hostnames)?;
181
182		// Make the certificate valid for two weeks, starting yesterday (in case of clock drift).
183		// WebTransport certificates MUST be valid for two weeks at most.
184		params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
185		params.not_after = params.not_before + time::Duration::days(14);
186
187		// Generate the certificate
188		let cert = params.self_signed(&key_pair)?;
189
190		// Convert the rcgen type to the rustls type.
191		let key = PrivatePkcs8KeyDer::from(key_pair.serialized_der());
192		let key = any_supported_type(&key.into())?;
193
194		// Create a rustls::sign::CertifiedKey
195		let certified = CertifiedKey::new(vec![cert.into()], key);
196		self.list.push(Arc::new(certified));
197
198		Ok(())
199	}
200
201	// Return the SHA256 fingerprint of our certificates.
202	pub fn fingerprints(&self) -> Vec<String> {
203		self.list
204			.iter()
205			.map(|ck| {
206				let fingerprint = digest(&SHA256, ck.cert[0].as_ref());
207				let fingerprint = hex::encode(fingerprint.as_ref());
208				fingerprint
209			})
210			.collect()
211	}
212}
213
214impl ResolvesServerCert for ServeCerts {
215	fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
216		if let Some(name) = client_hello.server_name() {
217			if let Ok(dns_name) = webpki::DnsNameRef::try_from_ascii_str(name) {
218				for ck in &self.list {
219					// TODO I gave up on caching the parsed result because of lifetime hell.
220					// If this shows up on benchmarks, somebody should fix it.
221					let leaf = ck.end_entity_cert().expect("missing certificate");
222					let parsed = webpki::EndEntityCert::try_from(leaf.as_ref()).expect("failed to parse certificate");
223
224					if parsed.verify_is_valid_for_dns_name(dns_name).is_ok() {
225						return Some(ck.clone());
226					}
227				}
228			}
229		}
230
231		// Default to the last certificate if we couldn't find one.
232		self.list.last().cloned()
233	}
234}
235
236#[derive(Debug)]
237pub struct NoCertificateVerification(Arc<rustls::crypto::CryptoProvider>);
238
239impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
240	fn verify_server_cert(
241		&self,
242		_end_entity: &CertificateDer<'_>,
243		_intermediates: &[CertificateDer<'_>],
244		_server_name: &ServerName<'_>,
245		_ocsp: &[u8],
246		_now: UnixTime,
247	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
248		Ok(rustls::client::danger::ServerCertVerified::assertion())
249	}
250
251	fn verify_tls12_signature(
252		&self,
253		message: &[u8],
254		cert: &CertificateDer<'_>,
255		dss: &rustls::DigitallySignedStruct,
256	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
257		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
258	}
259
260	fn verify_tls13_signature(
261		&self,
262		message: &[u8],
263		cert: &CertificateDer<'_>,
264		dss: &rustls::DigitallySignedStruct,
265	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
266		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
267	}
268
269	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
270		self.0.signature_verification_algorithms.supported_schemes()
271	}
272}
273
274// Verify the certificate matches a provided fingerprint.
275#[derive(Debug)]
276pub struct FingerprintVerifier {
277	provider: Arc<rustls::crypto::CryptoProvider>,
278	fingerprint: Vec<u8>,
279}
280
281impl FingerprintVerifier {
282	pub fn new(provider: Arc<rustls::crypto::CryptoProvider>, fingerprint: Vec<u8>) -> Self {
283		Self { provider, fingerprint }
284	}
285}
286
287impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
288	fn verify_server_cert(
289		&self,
290		end_entity: &CertificateDer<'_>,
291		_intermediates: &[CertificateDer<'_>],
292		_server_name: &ServerName<'_>,
293		_ocsp: &[u8],
294		_now: UnixTime,
295	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
296		let fingerprint = digest(&SHA256, end_entity);
297		if fingerprint.as_ref() == self.fingerprint.as_slice() {
298			Ok(rustls::client::danger::ServerCertVerified::assertion())
299		} else {
300			Err(rustls::Error::General("fingerprint mismatch".into()))
301		}
302	}
303
304	fn verify_tls12_signature(
305		&self,
306		message: &[u8],
307		cert: &CertificateDer<'_>,
308		dss: &rustls::DigitallySignedStruct,
309	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
310		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
311	}
312
313	fn verify_tls13_signature(
314		&self,
315		message: &[u8],
316		cert: &CertificateDer<'_>,
317		dss: &rustls::DigitallySignedStruct,
318	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
319		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
320	}
321
322	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
323		self.provider.signature_verification_algorithms.supported_schemes()
324	}
325}