1use anyhow::Context;
2use clap::Parser;
3use ring::digest::{digest, SHA256};
4use rustls::crypto::ring::sign::any_supported_type;
5use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName, UnixTime};
6use rustls::server::{ClientHello, ResolvesServerCert};
7use rustls::sign::CertifiedKey;
8use rustls::RootCertStore;
9use std::fs;
10use std::io::{self, Cursor, Read};
11use std::path;
12use std::sync::Arc;
13
14#[derive(Parser, Clone, Default)]
15#[group(id = "tls")]
16pub struct Args {
17 #[arg(long = "tls-cert", value_delimiter = ',')]
23 pub cert: Vec<path::PathBuf>,
24
25 #[arg(long = "tls-key", value_delimiter = ',')]
29 pub key: Vec<path::PathBuf>,
30
31 #[arg(long = "tls-root", value_delimiter = ',')]
36 pub root: Vec<path::PathBuf>,
37
38 #[arg(long = "tls-disable-verify")]
42 pub disable_verify: bool,
43
44 #[arg(long = "tls-self-sign", value_delimiter = ',')]
49 pub self_sign: Vec<String>,
50}
51
52#[derive(Clone)]
53pub struct Config {
54 pub client: rustls::ClientConfig,
55 pub server: Option<rustls::ServerConfig>,
56 pub fingerprints: Vec<String>,
57}
58
59impl Args {
60 pub fn load(&self) -> anyhow::Result<Config> {
61 let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
62 let mut serve = ServeCerts::default();
63
64 anyhow::ensure!(
66 self.cert.len() == self.key.len(),
67 "--tls-cert and --tls-key counts differ"
68 );
69 for (chain, key) in self.cert.iter().zip(self.key.iter()) {
70 serve.load(chain, key)?;
71 }
72
73 if !self.self_sign.is_empty() {
74 serve.generate(&self.self_sign)?;
75 }
76
77 let mut roots = RootCertStore::empty();
79
80 if self.root.is_empty() {
81 let native = rustls_native_certs::load_native_certs();
82
83 for err in native.errors {
85 tracing::warn!(?err, "failed to load root cert");
86 }
87
88 for cert in native.certs {
90 roots.add(cert).context("failed to add root cert")?;
91 }
92 } else {
93 for root in &self.root {
95 let root = fs::File::open(root).context("failed to open root cert file")?;
96 let mut root = io::BufReader::new(root);
97
98 let root = rustls_pemfile::certs(&mut root)
99 .next()
100 .context("no roots found")?
101 .context("failed to read root cert")?;
102
103 roots.add(root).context("failed to add root cert")?;
104 }
105 }
106
107 let mut client = rustls::ClientConfig::builder_with_provider(provider.clone())
109 .with_protocol_versions(&[&rustls::version::TLS13])?
110 .with_root_certificates(roots)
111 .with_no_client_auth();
112
113 if self.disable_verify {
115 tracing::warn!("TLS server certificate verification is disabled");
116
117 let noop = NoCertificateVerification(provider.clone());
118 client.dangerous().set_certificate_verifier(Arc::new(noop));
119 }
120
121 let fingerprints = serve.fingerprints();
122
123 let server = if !serve.list.is_empty() {
125 Some(
126 rustls::ServerConfig::builder_with_provider(provider)
127 .with_protocol_versions(&[&rustls::version::TLS13])?
128 .with_no_client_auth()
129 .with_cert_resolver(Arc::new(serve)),
130 )
131 } else {
132 None
133 };
134
135 Ok(Config {
136 server,
137 client,
138 fingerprints,
139 })
140 }
141}
142
143#[derive(Default, Debug)]
144struct ServeCerts {
145 list: Vec<Arc<CertifiedKey>>,
146}
147
148impl ServeCerts {
149 pub fn load(&mut self, chain: &path::PathBuf, key: &path::PathBuf) -> anyhow::Result<()> {
151 let chain = fs::File::open(chain).context("failed to open cert file")?;
153 let mut chain = io::BufReader::new(chain);
154
155 let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
156 .collect::<Result<_, _>>()
157 .context("failed to read certs")?;
158
159 anyhow::ensure!(!chain.is_empty(), "could not find certificate");
160
161 let mut keys = fs::File::open(key).context("failed to open key file")?;
163
164 let mut buf = Vec::new();
166 keys.read_to_end(&mut buf)?;
167
168 let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
169 let key = rustls::crypto::ring::sign::any_supported_type(&key)?;
170
171 let certified = Arc::new(CertifiedKey::new(chain, key));
172 self.list.push(certified);
173
174 Ok(())
175 }
176
177 pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
178 let key_pair = rcgen::KeyPair::generate()?;
179
180 let mut params = rcgen::CertificateParams::new(hostnames)?;
181
182 params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
185 params.not_after = params.not_before + time::Duration::days(14);
186
187 let cert = params.self_signed(&key_pair)?;
189
190 let key = PrivatePkcs8KeyDer::from(key_pair.serialized_der());
192 let key = any_supported_type(&key.into())?;
193
194 let certified = CertifiedKey::new(vec![cert.into()], key);
196 self.list.push(Arc::new(certified));
197
198 Ok(())
199 }
200
201 pub fn fingerprints(&self) -> Vec<String> {
203 self.list
204 .iter()
205 .map(|ck| {
206 let fingerprint = digest(&SHA256, ck.cert[0].as_ref());
207 let fingerprint = hex::encode(fingerprint.as_ref());
208 fingerprint
209 })
210 .collect()
211 }
212}
213
214impl ResolvesServerCert for ServeCerts {
215 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
216 if let Some(name) = client_hello.server_name() {
217 if let Ok(dns_name) = webpki::DnsNameRef::try_from_ascii_str(name) {
218 for ck in &self.list {
219 let leaf = ck.end_entity_cert().expect("missing certificate");
222 let parsed = webpki::EndEntityCert::try_from(leaf.as_ref()).expect("failed to parse certificate");
223
224 if parsed.verify_is_valid_for_dns_name(dns_name).is_ok() {
225 return Some(ck.clone());
226 }
227 }
228 }
229 }
230
231 self.list.last().cloned()
233 }
234}
235
236#[derive(Debug)]
237pub struct NoCertificateVerification(Arc<rustls::crypto::CryptoProvider>);
238
239impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
240 fn verify_server_cert(
241 &self,
242 _end_entity: &CertificateDer<'_>,
243 _intermediates: &[CertificateDer<'_>],
244 _server_name: &ServerName<'_>,
245 _ocsp: &[u8],
246 _now: UnixTime,
247 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
248 Ok(rustls::client::danger::ServerCertVerified::assertion())
249 }
250
251 fn verify_tls12_signature(
252 &self,
253 message: &[u8],
254 cert: &CertificateDer<'_>,
255 dss: &rustls::DigitallySignedStruct,
256 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
257 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
258 }
259
260 fn verify_tls13_signature(
261 &self,
262 message: &[u8],
263 cert: &CertificateDer<'_>,
264 dss: &rustls::DigitallySignedStruct,
265 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
266 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
267 }
268
269 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
270 self.0.signature_verification_algorithms.supported_schemes()
271 }
272}
273
274#[derive(Debug)]
276pub struct FingerprintVerifier {
277 provider: Arc<rustls::crypto::CryptoProvider>,
278 fingerprint: Vec<u8>,
279}
280
281impl FingerprintVerifier {
282 pub fn new(provider: Arc<rustls::crypto::CryptoProvider>, fingerprint: Vec<u8>) -> Self {
283 Self { provider, fingerprint }
284 }
285}
286
287impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
288 fn verify_server_cert(
289 &self,
290 end_entity: &CertificateDer<'_>,
291 _intermediates: &[CertificateDer<'_>],
292 _server_name: &ServerName<'_>,
293 _ocsp: &[u8],
294 _now: UnixTime,
295 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
296 let fingerprint = digest(&SHA256, end_entity);
297 if fingerprint.as_ref() == self.fingerprint.as_slice() {
298 Ok(rustls::client::danger::ServerCertVerified::assertion())
299 } else {
300 Err(rustls::Error::General("fingerprint mismatch".into()))
301 }
302 }
303
304 fn verify_tls12_signature(
305 &self,
306 message: &[u8],
307 cert: &CertificateDer<'_>,
308 dss: &rustls::DigitallySignedStruct,
309 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
310 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
311 }
312
313 fn verify_tls13_signature(
314 &self,
315 message: &[u8],
316 cert: &CertificateDer<'_>,
317 dss: &rustls::DigitallySignedStruct,
318 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
319 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
320 }
321
322 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
323 self.provider.signature_verification_algorithms.supported_schemes()
324 }
325}