1#![allow(dead_code)]
2
3use std::path::PathBuf;
4
5use crate::error::Error;
6use crate::net::socket::WithSocket;
7use crate::net::Socket;
8
9#[cfg(feature = "_tls-rustls")]
10mod tls_rustls;
11
12#[cfg(feature = "_tls-native-tls")]
13mod tls_native_tls;
14
15mod util;
16
17#[derive(Clone, Debug)]
19pub enum CertificateInput {
20 Inline(Vec<u8>),
22 File(PathBuf),
24}
25
26impl From<String> for CertificateInput {
27 fn from(value: String) -> Self {
28 let trimmed = value.trim();
29 if trimmed.starts_with("-----BEGIN CERTIFICATE-----")
31 && trimmed.contains("-----END CERTIFICATE-----")
32 {
33 CertificateInput::Inline(value.as_bytes().to_vec())
34 } else {
35 CertificateInput::File(PathBuf::from(value))
36 }
37 }
38}
39
40impl CertificateInput {
41 async fn data(&self) -> Result<Vec<u8>, std::io::Error> {
42 use crate::fs;
43 match self {
44 CertificateInput::Inline(v) => Ok(v.clone()),
45 CertificateInput::File(path) => fs::read(path).await,
46 }
47 }
48}
49
50impl std::fmt::Display for CertificateInput {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 CertificateInput::Inline(v) => write!(f, "{}", String::from_utf8_lossy(v.as_slice())),
54 CertificateInput::File(path) => write!(f, "file: {}", path.display()),
55 }
56 }
57}
58
59pub struct TlsConfig<'a> {
60 pub accept_invalid_certs: bool,
61 pub accept_invalid_hostnames: bool,
62 pub hostname: &'a str,
63 pub root_cert_path: Option<&'a CertificateInput>,
64 pub client_cert_path: Option<&'a CertificateInput>,
65 pub client_key_path: Option<&'a CertificateInput>,
66}
67
68pub async fn handshake<S, Ws>(
69 socket: S,
70 config: TlsConfig<'_>,
71 with_socket: Ws,
72) -> crate::Result<Ws::Output>
73where
74 S: Socket,
75 Ws: WithSocket,
76{
77 #[cfg(feature = "_tls-native-tls")]
78 return Ok(with_socket
79 .with_socket(tls_native_tls::handshake(socket, config).await?)
80 .await);
81
82 #[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
83 return Ok(with_socket
84 .with_socket(tls_rustls::handshake(socket, config).await?)
85 .await);
86
87 #[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
88 {
89 drop((socket, config, with_socket));
90 panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled")
91 }
92}
93
94pub fn available() -> bool {
95 cfg!(any(feature = "_tls-native-tls", feature = "_tls-rustls"))
96}
97
98pub fn error_if_unavailable() -> crate::Result<()> {
99 if !available() {
100 return Err(Error::tls(
101 "TLS upgrade required by connect options \
102 but SQLx was built without TLS support enabled",
103 ));
104 }
105
106 Ok(())
107}