async_native_tls/
lib.rs

1#![warn(
2    missing_debug_implementations,
3    missing_docs,
4    rust_2018_idioms,
5    unreachable_pub
6)]
7
8//! Async TLS streams
9//!
10//! # Examples
11//!
12//! To connect as a client to a remote server:
13//!
14//! ```rust
15//! # #[cfg(feature = "runtime-async-std")]
16//! # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
17//! #
18//! use async_std::prelude::*;
19//! use async_std::net::TcpStream;
20//!
21//! let stream = TcpStream::connect("google.com:443").await?;
22//! let mut stream = async_native_tls::connect("google.com", stream).await?;
23//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
24//!
25//! let mut res = Vec::new();
26//! stream.read_to_end(&mut res).await?;
27//! println!("{}", String::from_utf8_lossy(&res));
28//! #
29//! # Ok(()) }) }
30//! # #[cfg(feature = "runtime-tokio")]
31//! # fn main() {}
32//! ```
33
34#[cfg(not(any(feature = "runtime-tokio", feature = "runtime-async-std")))]
35compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
36
37#[cfg(all(feature = "runtime-tokio", feature = "runtime-async-std"))]
38compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
39
40mod acceptor;
41mod connector;
42mod handshake;
43mod runtime;
44mod std_adapter;
45mod tls_stream;
46
47pub use accept::accept;
48pub use acceptor::{Error as AcceptError, TlsAcceptor};
49pub use connect::{connect, TlsConnector};
50pub use host::Host;
51pub use tls_stream::TlsStream;
52
53#[doc(inline)]
54pub use native_tls::{Certificate, Error, Identity, Protocol, Result};
55
56mod accept {
57    use crate::runtime::{AsyncRead, AsyncWrite};
58
59    use crate::TlsStream;
60
61    /// One of accept of an incoming connection.
62    ///
63    /// # Example
64    ///
65    /// ```no_run
66    /// # #[cfg(feature = "runtime-async-std")]
67    /// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
68    /// #
69    /// use async_std::prelude::*;
70    /// use async_std::net::TcpListener;
71    /// use async_std::fs::File;
72    ///
73    /// let listener = TcpListener::bind("0.0.0.0:8443").await?;
74    /// let (stream, _addr) = listener.accept().await?;
75    ///
76    /// let key = File::open("identity.pfx").await?;
77    /// let stream = async_native_tls::accept(key, "<password>", stream).await?;
78    /// // handle stream here
79    /// #
80    /// # Ok(()) }) }
81    /// # #[cfg(feature = "runtime-tokio")]
82    /// # fn main() {}
83    /// ```
84    pub async fn accept<R, S, T>(
85        file: R,
86        password: S,
87        stream: T,
88    ) -> Result<TlsStream<T>, crate::AcceptError>
89    where
90        R: AsyncRead + Unpin,
91        S: AsRef<str>,
92        T: AsyncRead + AsyncWrite + Unpin,
93    {
94        let acceptor = crate::TlsAcceptor::new(file, password).await?;
95        let stream = acceptor.accept(stream).await?;
96
97        Ok(stream)
98    }
99}
100
101mod host {
102    use url::Url;
103
104    /// The host part of a domain (without scheme, port and path).
105    ///
106    /// This is the argument to the [`connect`](crate::connect::connect) function. Strings and string slices are
107    /// converted into Hosts automatically, as is [Url](url::Url) with the `host-from-url` feature (enabled by default).
108    #[derive(Debug)]
109    pub struct Host(String);
110
111    impl Host {
112        /// The host as string. Consumes self.
113        #[allow(clippy::wrong_self_convention)]
114        pub fn as_string(self) -> String {
115            self.0
116        }
117    }
118
119    impl From<&str> for Host {
120        fn from(host: &str) -> Self {
121            Self(host.into())
122        }
123    }
124
125    impl From<String> for Host {
126        fn from(host: String) -> Self {
127            Self(host)
128        }
129    }
130
131    impl From<&String> for Host {
132        fn from(host: &String) -> Self {
133            Self(host.into())
134        }
135    }
136
137    impl From<Url> for Host {
138        fn from(url: Url) -> Self {
139            Self(
140                url.host_str()
141                    .expect("URL has to include a host part.")
142                    .into(),
143            )
144        }
145    }
146
147    impl From<&Url> for Host {
148        fn from(url: &Url) -> Self {
149            Self(
150                url.host_str()
151                    .expect("URL has to include a host part.")
152                    .into(),
153            )
154        }
155    }
156}
157
158mod connect {
159    use std::fmt::{self, Debug};
160
161    use crate::host::Host;
162    use crate::runtime::{AsyncRead, AsyncWrite};
163    use crate::TlsStream;
164    use crate::{Certificate, Identity, Protocol};
165
166    /// Connect a client to a remote server.
167    ///
168    /// # Examples
169    ///
170    /// ```
171    /// # #[cfg(feature = "runtime-async-std")]
172    /// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
173    /// #
174    /// use async_std::prelude::*;
175    /// use async_std::net::TcpStream;
176    ///
177    /// let stream = TcpStream::connect("google.com:443").await?;
178    /// let mut stream = async_native_tls::connect("google.com", stream).await?;
179    /// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
180    ///
181    /// let mut res = Vec::new();
182    /// stream.read_to_end(&mut res).await?;
183    /// println!("{}", String::from_utf8_lossy(&res));
184    /// #
185    /// # Ok(()) }) }
186    /// # #[cfg(feature = "runtime-tokio")]
187    /// # fn main() {}
188    /// ```
189    pub async fn connect<S>(host: impl Into<Host>, stream: S) -> native_tls::Result<TlsStream<S>>
190    where
191        S: AsyncRead + AsyncWrite + Unpin,
192    {
193        let stream = TlsConnector::new().connect(host, stream).await?;
194        Ok(stream)
195    }
196
197    /// Connect a client to a remote server.
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// # #[cfg(feature = "runtime-async-std")]
203    /// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
204    /// #
205    /// use async_std::prelude::*;
206    /// use async_std::net::TcpStream;
207    /// use async_native_tls::TlsConnector;
208    ///
209    /// let stream = TcpStream::connect("google.com:443").await?;
210    /// let mut stream = TlsConnector::new()
211    ///     .use_sni(true)
212    ///     .connect("google.com", stream)
213    ///     .await?;
214    /// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
215    ///
216    /// let mut res = Vec::new();
217    /// stream.read_to_end(&mut res).await?;
218    /// println!("{}", String::from_utf8_lossy(&res));
219    /// #
220    /// # Ok(()) }) }
221    /// # #[cfg(feature = "runtime-tokio")]
222    /// # fn main() {}
223    /// ```
224    pub struct TlsConnector {
225        builder: native_tls::TlsConnectorBuilder,
226    }
227
228    impl Default for TlsConnector {
229        fn default() -> Self {
230            TlsConnector::new()
231        }
232    }
233
234    impl TlsConnector {
235        /// Create a new instance.
236        pub fn new() -> Self {
237            Self {
238                builder: native_tls::TlsConnector::builder(),
239            }
240        }
241
242        /// Sets the identity to be used for client certificate authentication.
243        pub fn identity(mut self, identity: Identity) -> Self {
244            self.builder.identity(identity);
245            self
246        }
247
248        /// Sets the minimum supported protocol version.
249        ///
250        /// A value of `None` enables support for the oldest protocols supported by the
251        /// implementation. Defaults to `Some(Protocol::Tlsv10)`.
252        pub fn min_protocol_version(mut self, protocol: Option<Protocol>) -> Self {
253            self.builder.min_protocol_version(protocol);
254            self
255        }
256
257        /// Sets the maximum supported protocol version.
258        ///
259        /// A value of `None` enables support for the newest protocols supported by the
260        /// implementation. Defaults to `None`.
261        pub fn max_protocol_version(mut self, protocol: Option<Protocol>) -> Self {
262            self.builder.max_protocol_version(protocol);
263            self
264        }
265
266        /// Adds a certificate to the set of roots that the connector will trust.
267        ///
268        /// The connector will use the system's trust root by default. This method can be used to
269        /// add to that set when communicating with servers not trusted by the system. Defaults to
270        /// an empty set.
271        pub fn add_root_certificate(mut self, cert: Certificate) -> Self {
272            self.builder.add_root_certificate(cert);
273            self
274        }
275
276        /// Request specific protocols through ALPN (Application-Layer Protocol Negotiation).
277        ///
278        /// Defaults to none
279        pub fn request_alpns(mut self, protocols: &[&str]) -> Self {
280            self.builder.request_alpns(protocols);
281            self
282        }
283
284        /// Controls the use of certificate validation.
285        ///
286        /// Defaults to false.
287        ///
288        /// # Warning
289        ///
290        /// You should think very carefully before using this method. If invalid certificates are
291        /// trusted, any certificate for any site will be trusted for use. This includes expired
292        /// certificates. This introduces significant vulnerabilities, and should only be used as a
293        /// last resort.
294        pub fn danger_accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
295            self.builder
296                .danger_accept_invalid_certs(accept_invalid_certs);
297            self
298        }
299
300        /// Controls the use of Server Name Indication (SNI).
301        ///
302        /// Defaults to `true`.
303        pub fn use_sni(mut self, use_sni: bool) -> Self {
304            self.builder.use_sni(use_sni);
305            self
306        }
307
308        /// Controls the use of hostname verification.
309        ///
310        /// Defaults to `false`.
311        ///
312        /// # Warning
313        ///
314        /// You should think very carefully before using this method. If invalid hostnames are
315        /// trusted, any valid certificate for any site will be trusted for use. This introduces
316        /// significant vulnerabilities, and should only be used as a last resort.
317        pub fn danger_accept_invalid_hostnames(mut self, accept_invalid_hostnames: bool) -> Self {
318            self.builder
319                .danger_accept_invalid_hostnames(accept_invalid_hostnames);
320            self
321        }
322
323        /// Connect to a remote server.
324        ///
325        /// # Examples
326        ///
327        /// ```
328        /// # #[cfg(feature = "runtime-async-std")]
329        /// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
330        /// #
331        /// use async_std::prelude::*;
332        /// use async_std::net::TcpStream;
333        /// use async_native_tls::TlsConnector;
334        ///
335        /// let stream = TcpStream::connect("google.com:443").await?;
336        /// let mut stream = TlsConnector::new()
337        ///     .use_sni(true)
338        ///     .connect("google.com", stream)
339        ///     .await?;
340        /// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
341        ///
342        /// let mut res = Vec::new();
343        /// stream.read_to_end(&mut res).await?;
344        /// println!("{}", String::from_utf8_lossy(&res));
345        /// #
346        /// # Ok(()) }) }
347        /// # #[cfg(feature = "runtime-tokio")]
348        /// # fn main() {}
349        /// ```
350        pub async fn connect<S>(
351            &self,
352            host: impl Into<Host>,
353            stream: S,
354        ) -> native_tls::Result<TlsStream<S>>
355        where
356            S: AsyncRead + AsyncWrite + Unpin,
357        {
358            let host: Host = host.into();
359            let domain = host.as_string();
360            let connector = self.builder.build()?;
361            let connector = crate::connector::TlsConnector::from(connector);
362            let stream = connector.connect(&domain, stream).await?;
363            Ok(stream)
364        }
365    }
366
367    impl Debug for TlsConnector {
368        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369            f.debug_struct("TlsConnector").finish()
370        }
371    }
372
373    impl From<native_tls::TlsConnectorBuilder> for TlsConnector {
374        fn from(builder: native_tls::TlsConnectorBuilder) -> Self {
375            Self { builder }
376        }
377    }
378}