use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::{future::Future, marker::PhantomData};
use futures_util::{future, TryFutureExt};
use openssl::pkcs12::ParsedPkcs12_2;
use openssl::pkey::{PKey, Private};
use openssl::ssl::{ConnectConfiguration, SslConnector, SslContextBuilder, SslMethod, SslOptions};
use openssl::stack::Stack;
use openssl::x509::store::X509StoreBuilder;
use openssl::x509::X509;
use tokio_openssl::{self, SslStream as TokioTlsStream};
use crate::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
use crate::tcp::TcpStream;
use crate::tcp::{Connect, DnsTcpStream};
use crate::xfer::BufDnsStreamHandle;
pub(crate) trait TlsIdentityExt {
fn identity(&mut self, pkcs12: &ParsedPkcs12_2) -> io::Result<()> {
self.identity_parts(
pkcs12.cert.as_ref(),
pkcs12.pkey.as_ref(),
pkcs12.ca.as_ref(),
)
}
fn identity_parts(
&mut self,
cert: Option<&X509>,
pkey: Option<&PKey<Private>>,
chain: Option<&Stack<X509>>,
) -> io::Result<()>;
}
impl TlsIdentityExt for SslContextBuilder {
fn identity_parts(
&mut self,
cert: Option<&X509>,
pkey: Option<&PKey<Private>>,
chain: Option<&Stack<X509>>,
) -> io::Result<()> {
if let Some(cert) = cert {
self.set_certificate(cert)?;
}
if let Some(pkey) = pkey {
self.set_private_key(pkey)?;
}
self.check_private_key()?;
if let Some(chain) = chain {
for cert in chain {
self.add_extra_chain_cert(cert.to_owned())?;
}
}
Ok(())
}
}
pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<S>>>;
pub(crate) type CompatTlsStream<S> = TlsStream<AsyncIoStdAsTokio<S>>;
fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12_2>) -> io::Result<SslConnector> {
let mut tls = SslConnector::builder(SslMethod::tls())
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
{
let openssl_ctx_builder = &mut tls;
openssl_ctx_builder.set_options(
SslOptions::NO_SSLV2
| SslOptions::NO_SSLV3
| SslOptions::NO_TLSV1
| SslOptions::NO_TLSV1_1,
);
let mut store = X509StoreBuilder::new().map_err(|e| {
io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
})?;
for cert in certs {
store.add_cert(cert).map_err(|e| {
io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
})?;
}
openssl_ctx_builder
.set_verify_cert_store(store.build())
.map_err(|e| {
io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
})?;
if let Some(pkcs12) = pkcs12 {
openssl_ctx_builder.identity(&pkcs12)?;
}
}
Ok(tls.build())
}
pub fn tls_stream_from_existing_tls_stream<S: DnsTcpStream>(
stream: AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>,
peer_addr: SocketAddr,
) -> (CompatTlsStream<S>, BufDnsStreamHandle) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}
async fn connect_tls<S, F>(
future: F,
tls_config: ConnectConfiguration,
dns_name: String,
) -> Result<TokioTlsStream<AsyncIoStdAsTokio<S>>, io::Error>
where
S: DnsTcpStream,
F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
{
let tcp = future
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
let mut stream = tls_config
.into_ssl(&dns_name)
.and_then(|ssl| TokioTlsStream::new(ssl, AsyncIoStdAsTokio(tcp)))
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("tls error: {e}")))?;
Pin::new(&mut stream)
.connect()
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
Ok(stream)
}
#[derive(Default)]
pub struct TlsStreamBuilder<S> {
ca_chain: Vec<X509>,
identity: Option<ParsedPkcs12_2>,
bind_addr: Option<SocketAddr>,
marker: PhantomData<S>,
}
impl<S: DnsTcpStream> TlsStreamBuilder<S> {
pub fn new() -> Self {
Self {
ca_chain: vec![],
identity: None,
bind_addr: None,
marker: PhantomData,
}
}
pub fn add_ca(&mut self, ca: X509) {
self.ca_chain.push(ca);
}
#[cfg(feature = "mtls")]
pub fn identity(&mut self, pkcs12: ParsedPkcs12) {
self.identity = Some(pkcs12);
}
pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
self.bind_addr = Some(bind_addr);
}
#[allow(clippy::type_complexity)]
pub fn build_with_future<F>(
self,
future: F,
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
BufDnsStreamHandle,
)
where
F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
{
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
let tls_config = match new(self.ca_chain, self.identity) {
Ok(c) => c,
Err(e) => {
return (
Box::pin(future::err(e).map_err(|e| {
io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
})),
message_sender,
)
}
};
let tls_config = match tls_config.configure() {
Ok(c) => c,
Err(e) => {
return (
Box::pin(future::err(e).map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls config error: {e}"),
)
})),
message_sender,
)
}
};
let stream = Box::pin(connect_tls(future, tls_config, dns_name).map_ok(move |s| {
TcpStream::from_stream_with_receiver(
AsyncIoTokioAsStd(s),
name_server,
outbound_messages,
)
}));
(stream, message_sender)
}
}
impl<S: Connect> TlsStreamBuilder<S> {
#[allow(clippy::type_complexity)]
pub fn build(
self,
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
BufDnsStreamHandle,
) {
let future = S::connect_with_bind(name_server, self.bind_addr);
self.build_with_future(future, name_server, dns_name)
}
}