use std::sync::Arc;
use tokio::net::lookup_host;
use url::Url;
use quinn::{crypto::rustls::QuicClientConfig, rustls};
use rustls::{client::danger::ServerCertVerifier, pki_types::CertificateDer};
use crate::{ClientError, Session, ALPN};
pub enum CongestionControl {
Default,
Throughput,
LowLatency,
}
pub struct ClientBuilder {
provider: Arc<rustls::crypto::CryptoProvider>,
congestion_controller:
Option<Arc<dyn quinn::congestion::ControllerFactory + Send + Sync + 'static>>,
}
impl ClientBuilder {
pub fn new() -> Self {
Self {
provider: Arc::new(rustls::crypto::aws_lc_rs::default_provider()),
congestion_controller: None,
}
}
pub fn with_congestion_control(mut self, algorithm: CongestionControl) -> Self {
self.congestion_controller = match algorithm {
CongestionControl::LowLatency => {
Some(Arc::new(quinn::congestion::BbrConfig::default()))
}
CongestionControl::Throughput => {
Some(Arc::new(quinn::congestion::CubicConfig::default()))
}
CongestionControl::Default => None,
};
self
}
pub fn with_system_roots(self) -> Result<Client, ClientError> {
let mut roots = rustls::RootCertStore::empty();
let native = rustls_native_certs::load_native_certs();
for err in native.errors {
log::warn!("failed to load root cert: {:?}", err);
}
for cert in native.certs {
if let Err(err) = roots.add(cert) {
log::warn!("failed to add root cert: {:?}", err);
}
}
let crypto = self
.builder()
.with_root_certificates(roots)
.with_no_client_auth();
self.build(crypto)
}
pub fn with_server_certificates(
self,
certs: Vec<CertificateDer>,
) -> Result<Client, ClientError> {
let hashes = certs.iter().map(|cert| {
aws_lc_rs::digest::digest(&aws_lc_rs::digest::SHA256, cert)
.as_ref()
.to_vec()
});
self.with_server_certificate_hashes(hashes.collect())
}
pub fn with_server_certificate_hashes(
self,
hashes: Vec<Vec<u8>>,
) -> Result<Client, ClientError> {
let fingerprints = Arc::new(ServerFingerprints {
provider: self.provider.clone(),
fingerprints: hashes,
});
let crypto = self
.builder()
.dangerous()
.with_custom_certificate_verifier(fingerprints.clone())
.with_no_client_auth();
self.build(crypto)
}
pub unsafe fn with_no_certificate_verification(self) -> Result<Client, ClientError> {
let noop = NoCertificateVerification(self.provider.clone());
let crypto = self
.builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(noop))
.with_no_client_auth();
self.build(crypto)
}
fn builder(&self) -> rustls::ConfigBuilder<rustls::ClientConfig, rustls::WantsVerifier> {
rustls::ClientConfig::builder_with_provider(self.provider.clone())
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap()
}
fn build(self, mut crypto: rustls::ClientConfig) -> Result<Client, ClientError> {
crypto.alpn_protocols = vec![ALPN.to_vec()];
let client_config = QuicClientConfig::try_from(crypto).unwrap();
let mut client_config = quinn::ClientConfig::new(Arc::new(client_config));
let mut transport = quinn::TransportConfig::default();
if let Some(cc) = &self.congestion_controller {
transport.congestion_controller_factory(cc.clone());
}
client_config.transport_config(transport.into());
let client = quinn::Endpoint::client("[::]:0".parse().unwrap()).unwrap();
Ok(Client {
endpoint: client,
config: client_config,
})
}
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct Client {
endpoint: quinn::Endpoint,
config: quinn::ClientConfig,
}
impl Client {
pub fn new(endpoint: quinn::Endpoint, config: quinn::ClientConfig) -> Self {
Self { endpoint, config }
}
pub async fn connect(&self, url: &Url) -> Result<Session, ClientError> {
let host = url
.host()
.ok_or_else(|| ClientError::InvalidDnsName("".to_string()))?
.to_string();
let port = url.port().unwrap_or(443);
let mut remotes = match lookup_host((host.clone(), port)).await {
Ok(remotes) => remotes,
Err(_) => return Err(ClientError::InvalidDnsName(host)),
};
let remote = match remotes.next() {
Some(remote) => remote,
None => return Err(ClientError::InvalidDnsName(host)),
};
let conn = self
.endpoint
.connect_with(self.config.clone(), remote, &host)?;
let conn = conn.await?;
Session::connect(conn, url).await
}
}
impl Default for Client {
fn default() -> Self {
ClientBuilder::new().with_system_roots().unwrap()
}
}
#[derive(Debug)]
struct ServerFingerprints {
provider: Arc<rustls::crypto::CryptoProvider>,
fingerprints: Vec<Vec<u8>>,
}
impl ServerCertVerifier for ServerFingerprints {
fn verify_server_cert(
&self,
end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
let cert_hash = aws_lc_rs::digest::digest(&aws_lc_rs::digest::SHA256, end_entity);
if self
.fingerprints
.iter()
.any(|fingerprint| fingerprint == cert_hash.as_ref())
{
return Ok(rustls::client::danger::ServerCertVerified::assertion());
}
Err(rustls::Error::InvalidCertificate(
rustls::CertificateError::UnknownIssuer,
))
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.provider
.signature_verification_algorithms
.supported_schemes()
}
}
#[derive(Debug)]
pub struct NoCertificateVerification(Arc<rustls::crypto::CryptoProvider>);
impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}