hickory_proto/native_tls/
tls_stream.rs1use std::io;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::{future::Future, marker::PhantomData};
14
15use futures_util::TryFutureExt;
16use native_tls::Protocol::Tlsv12;
17use native_tls::{Certificate, Identity, TlsConnector};
18use tokio_native_tls::{TlsConnector as TokioTlsConnector, TlsStream as TokioTlsStream};
19
20use crate::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
21use crate::tcp::TcpStream;
22use crate::tcp::{Connect, DnsTcpStream};
23use crate::xfer::{BufDnsStreamHandle, StreamReceiver};
24
25pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;
27
28fn tls_new(certs: Vec<Certificate>, pkcs12: Option<Identity>) -> io::Result<TlsConnector> {
29 let mut builder = TlsConnector::builder();
30 builder.min_protocol_version(Some(Tlsv12));
31
32 for cert in certs {
33 builder.add_root_certificate(cert);
34 }
35
36 if let Some(pkcs12) = pkcs12 {
37 builder.identity(pkcs12);
38 }
39 builder
40 .build()
41 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))
42}
43
44pub fn tls_from_stream<S: DnsTcpStream>(
48 stream: TokioTlsStream<AsyncIoStdAsTokio<S>>,
49 peer_addr: SocketAddr,
50) -> (TlsStream<S>, BufDnsStreamHandle) {
51 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
52
53 let stream = TcpStream::from_stream_with_receiver(
54 AsyncIoTokioAsStd(stream),
55 peer_addr,
56 outbound_messages,
57 );
58
59 (stream, message_sender)
60}
61
62#[derive(Default)]
64pub struct TlsStreamBuilder<S> {
65 ca_chain: Vec<Certificate>,
66 identity: Option<Identity>,
67 bind_addr: Option<SocketAddr>,
68 marker: PhantomData<S>,
69}
70
71impl<S: DnsTcpStream> TlsStreamBuilder<S> {
72 pub fn new() -> Self {
74 Self {
75 ca_chain: vec![],
76 identity: None,
77 bind_addr: None,
78 marker: PhantomData,
79 }
80 }
81
82 pub fn add_ca(&mut self, ca: Certificate) {
86 self.ca_chain.push(ca);
87 }
88
89 pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
91 self.bind_addr = Some(bind_addr);
92 }
93
94 #[allow(clippy::type_complexity)]
101 pub fn build_with_future<F>(
102 self,
103 future: F,
104 name_server: SocketAddr,
105 dns_name: String,
106 ) -> (
107 Pin<Box<dyn Future<Output = Result<TlsStream<S>, io::Error>> + Send>>,
109 BufDnsStreamHandle,
110 )
111 where
112 S: DnsTcpStream,
113 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
114 {
115 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
116
117 let stream = self.inner_build(future, name_server, dns_name, outbound_messages);
118 (Box::pin(stream), message_sender)
119 }
120
121 async fn inner_build<F>(
122 self,
123 future: F,
124 name_server: SocketAddr,
125 dns_name: String,
126 outbound_messages: StreamReceiver,
127 ) -> Result<TlsStream<S>, io::Error>
128 where
129 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
130 {
131 use crate::native_tls::tls_stream;
132 let tcp_stream = future.await;
133
134 let ca_chain = self.ca_chain.clone();
135 let identity = self.identity;
136
137 let tcp_stream = match tcp_stream {
139 Ok(tcp_stream) => AsyncIoStdAsTokio(tcp_stream),
140 Err(err) => return Err(err),
141 };
142
143 let tls_connector = tls_stream::tls_new(ca_chain, identity)
146 .map(TokioTlsConnector::from)
147 .map_err(|e| {
148 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
149 })?;
150
151 let tls_connected = tls_connector
152 .connect(&dns_name, tcp_stream)
153 .map_err(|e| {
154 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
155 })
156 .await?;
157
158 Ok(TcpStream::from_stream_with_receiver(
159 AsyncIoTokioAsStd(tls_connected),
160 name_server,
161 outbound_messages,
162 ))
163 }
164}
165
166impl<S: Connect> TlsStreamBuilder<S> {
167 #[allow(clippy::type_complexity)]
193 pub fn build(
194 self,
195 name_server: SocketAddr,
196 dns_name: String,
197 ) -> (
198 Pin<Box<dyn Future<Output = Result<TlsStream<S>, io::Error>> + Send>>,
200 BufDnsStreamHandle,
201 ) {
202 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
203 let conn = S::connect_with_bind(name_server, self.bind_addr);
204 let stream = self.inner_build(conn, name_server, dns_name, outbound_messages);
205 (Box::pin(stream), message_sender)
206 }
207}