web3_async_native_tls/
lib.rs

1#![warn(
2    missing_debug_implementations,
3    missing_docs,
4    rust_2018_idioms,
5    unreachable_pub
6)]
7
8//! Async TLS streams
9//!
10//! # Examples
11//!
12//! To connect as a client to a remote server:
13//!
14//! ```rust
15//! # #[cfg(feature = "runtime-async-std")]
16//! # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
17//! #
18//! use async_std::prelude::*;
19//! use async_std::net::TcpStream;
20//!
21//! let stream = TcpStream::connect("google.com:443").await?;
22//! let mut stream = async_native_tls::connect("google.com", stream).await?;
23//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
24//!
25//! let mut res = Vec::new();
26//! stream.read_to_end(&mut res).await?;
27//! println!("{}", String::from_utf8_lossy(&res));
28//! #
29//! # Ok(()) }) }
30//! # #[cfg(feature = "runtime-tokio")]
31//! # fn main() {}
32//! ```
33
34#[cfg(not(any(feature = "runtime-tokio", feature = "runtime-async-std")))]
35compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
36
37#[cfg(all(feature = "runtime-tokio", feature = "runtime-async-std"))]
38compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
39
40mod acceptor;
41mod connector;
42mod handshake;
43mod runtime;
44mod std_adapter;
45mod tls_stream;
46
47pub use accept::accept;
48pub use acceptor::{Error as AcceptError, TlsAcceptor};
49pub use connect::{connect, TlsConnector};
50pub use host::Host;
51pub use tls_stream::TlsStream;
52
53#[doc(inline)]
54pub use native_tls::{Certificate, Error, Identity, Protocol, Result};
55
56mod accept {
57    use crate::runtime::{AsyncRead, AsyncWrite};
58
59    use crate::TlsStream;
60
61    /// One of accept of an incoming connection.
62    ///
63    /// # Example
64    ///
65    /// ```no_run
66    /// # #[cfg(feature = "runtime-async-std")]
67    /// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
68    /// #
69    /// use async_std::prelude::*;
70    /// use async_std::net::TcpListener;
71    /// use async_std::fs::File;
72    ///
73    /// let listener = TcpListener::bind("0.0.0.0:8443").await?;
74    /// let (stream, _addr) = listener.accept().await?;
75    ///
76    /// let key = File::open("identity.pfx").await?;
77    /// let stream = async_native_tls::accept(key, "<password>", stream).await?;
78    /// // handle stream here
79    /// #
80    /// # Ok(()) }) }
81    /// # #[cfg(feature = "runtime-tokio")]
82    /// # fn main() {}
83    /// ```
84    pub async fn accept<R, S, T>(
85        file: R,
86        password: S,
87        stream: T,
88    ) -> Result<TlsStream<T>, crate::AcceptError>
89    where
90        R: AsyncRead + Unpin,
91        S: AsRef<str>,
92        T: AsyncRead + AsyncWrite + Unpin,
93    {
94        let acceptor = crate::TlsAcceptor::new(file, password).await?;
95        let stream = acceptor.accept(stream).await?;
96
97        Ok(stream)
98    }
99}
100
101mod host {
102    use url::Url;
103
104    /// The host part of a domain (without scheme, port and path).
105    ///
106    /// This is the argument to the [`connect`](crate::connect::connect) function. Strings and string slices are
107    /// converted into Hosts automatically, as is [Url](url::Url) with the `host-from-url` feature (enabled by default).
108    #[derive(Debug)]
109    pub struct Host(String);
110
111    impl Host {
112        /// The host as string. Consumes self.
113        #[allow(clippy::wrong_self_convention)]
114        pub fn as_string(self) -> String {
115            self.0
116        }
117    }
118
119    impl From<&str> for Host {
120        fn from(host: &str) -> Self {
121            Self(host.into())
122        }
123    }
124
125    impl From<String> for Host {
126        fn from(host: String) -> Self {
127            Self(host)
128        }
129    }
130
131    impl From<&String> for Host {
132        fn from(host: &String) -> Self {
133            Self(host.into())
134        }
135    }
136
137    impl From<Url> for Host {
138        fn from(url: Url) -> Self {
139            Self(
140                url.host_str()
141                    .expect("URL has to include a host part.")
142                    .into(),
143            )
144        }
145    }
146
147    impl From<&Url> for Host {
148        fn from(url: &Url) -> Self {
149            Self(
150                url.host_str()
151                    .expect("URL has to include a host part.")
152                    .into(),
153            )
154        }
155    }
156}
157
158mod connect {
159    use std::fmt::{self, Debug};
160
161    use crate::host::Host;
162    use crate::runtime::{AsyncRead, AsyncWrite};
163    use crate::TlsStream;
164    use crate::{Certificate, Identity, Protocol};
165
166    /// Connect a client to a remote server.
167    ///
168    /// # Examples
169    ///
170    /// ```
171    /// # #[cfg(feature = "runtime-async-std")]
172    /// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
173    /// #
174    /// use async_std::prelude::*;
175    /// use async_std::net::TcpStream;
176    ///
177    /// let stream = TcpStream::connect("google.com:443").await?;
178    /// let mut stream = async_native_tls::connect("google.com", stream).await?;
179    /// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
180    ///
181    /// let mut res = Vec::new();
182    /// stream.read_to_end(&mut res).await?;
183    /// println!("{}", String::from_utf8_lossy(&res));
184    /// #
185    /// # Ok(()) }) }
186    /// # #[cfg(feature = "runtime-tokio")]
187    /// # fn main() {}
188    /// ```
189    pub async fn connect<S>(host: impl Into<Host>, stream: S) -> native_tls::Result<TlsStream<S>>
190    where
191        S: AsyncRead + AsyncWrite + Unpin,
192    {
193        let stream = TlsConnector::new().connect(host, stream).await?;
194        Ok(stream)
195    }
196
197    /// Connect a client to a remote server.
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// # #[cfg(feature = "runtime-async-std")]
203    /// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
204    /// #
205    /// use async_std::prelude::*;
206    /// use async_std::net::TcpStream;
207    /// use async_native_tls::TlsConnector;
208    ///
209    /// let stream = TcpStream::connect("google.com:443").await?;
210    /// let mut stream = TlsConnector::new()
211    ///     .use_sni(true)
212    ///     .connect("google.com", stream)
213    ///     .await?;
214    /// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
215    ///
216    /// let mut res = Vec::new();
217    /// stream.read_to_end(&mut res).await?;
218    /// println!("{}", String::from_utf8_lossy(&res));
219    /// #
220    /// # Ok(()) }) }
221    /// # #[cfg(feature = "runtime-tokio")]
222    /// # fn main() {}
223    /// ```
224    pub struct TlsConnector {
225        builder: native_tls::TlsConnectorBuilder,
226    }
227
228    impl Default for TlsConnector {
229        fn default() -> Self {
230            TlsConnector::new()
231        }
232    }
233
234    impl TlsConnector {
235        /// Create a new instance.
236        pub fn new() -> Self {
237            Self {
238                builder: native_tls::TlsConnector::builder(),
239            }
240        }
241
242        /// Sets the identity to be used for client certificate authentication.
243        pub fn identity(mut self, identity: Identity) -> Self {
244            self.builder.identity(identity);
245            self
246        }
247
248        /// Sets the minimum supported protocol version.
249        ///
250        /// A value of `None` enables support for the oldest protocols supported by the
251        /// implementation. Defaults to `Some(Protocol::Tlsv10)`.
252        pub fn min_protocol_version(mut self, protocol: Option<Protocol>) -> Self {
253            self.builder.min_protocol_version(protocol);
254            self
255        }
256
257        /// Sets the maximum supported protocol version.
258        ///
259        /// A value of `None` enables support for the newest protocols supported by the
260        /// implementation. Defaults to `None`.
261        pub fn max_protocol_version(mut self, protocol: Option<Protocol>) -> Self {
262            self.builder.max_protocol_version(protocol);
263            self
264        }
265
266        /// Adds a certificate to the set of roots that the connector will trust.
267        ///
268        /// The connector will use the system's trust root by default. This method can be used to
269        /// add to that set when communicating with servers not trusted by the system. Defaults to
270        /// an empty set.
271        pub fn add_root_certificate(mut self, cert: Certificate) -> Self {
272            self.builder.add_root_certificate(cert);
273            self
274        }
275
276        /// Controls the use of certificate validation.
277        ///
278        /// Defaults to false.
279        ///
280        /// # Warning
281        ///
282        /// You should think very carefully before using this method. If invalid certificates are
283        /// trusted, any certificate for any site will be trusted for use. This includes expired
284        /// certificates. This introduces significant vulnerabilities, and should only be used as a
285        /// last resort.
286        pub fn danger_accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
287            self.builder
288                .danger_accept_invalid_certs(accept_invalid_certs);
289            self
290        }
291
292        /// Controls the use of Server Name Indication (SNI).
293        ///
294        /// Defaults to `true`.
295        pub fn use_sni(mut self, use_sni: bool) -> Self {
296            self.builder.use_sni(use_sni);
297            self
298        }
299
300        /// Controls the use of hostname verification.
301        ///
302        /// Defaults to `false`.
303        ///
304        /// # Warning
305        ///
306        /// You should think very carefully before using this method. If invalid hostnames are
307        /// trusted, any valid certificate for any site will be trusted for use. This introduces
308        /// significant vulnerabilities, and should only be used as a last resort.
309        pub fn danger_accept_invalid_hostnames(mut self, accept_invalid_hostnames: bool) -> Self {
310            self.builder
311                .danger_accept_invalid_hostnames(accept_invalid_hostnames);
312            self
313        }
314
315        /// Connect to a remote server.
316        ///
317        /// # Examples
318        ///
319        /// ```
320        /// # #[cfg(feature = "runtime-async-std")]
321        /// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
322        /// #
323        /// use async_std::prelude::*;
324        /// use async_std::net::TcpStream;
325        /// use async_native_tls::TlsConnector;
326        ///
327        /// let stream = TcpStream::connect("google.com:443").await?;
328        /// let mut stream = TlsConnector::new()
329        ///     .use_sni(true)
330        ///     .connect("google.com", stream)
331        ///     .await?;
332        /// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
333        ///
334        /// let mut res = Vec::new();
335        /// stream.read_to_end(&mut res).await?;
336        /// println!("{}", String::from_utf8_lossy(&res));
337        /// #
338        /// # Ok(()) }) }
339        /// # #[cfg(feature = "runtime-tokio")]
340        /// # fn main() {}
341        /// ```
342        pub async fn connect<S>(
343            &self,
344            host: impl Into<Host>,
345            stream: S,
346        ) -> native_tls::Result<TlsStream<S>>
347        where
348            S: AsyncRead + AsyncWrite + Unpin,
349        {
350            let host: Host = host.into();
351            let domain = host.as_string();
352            let connector = self.builder.build()?;
353            let connector = crate::connector::TlsConnector::from(connector);
354            let stream = connector.connect(&domain, stream).await?;
355            Ok(stream)
356        }
357    }
358
359    impl Debug for TlsConnector {
360        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361            f.debug_struct("TlsConnector").finish()
362        }
363    }
364
365    impl From<native_tls::TlsConnectorBuilder> for TlsConnector {
366        fn from(builder: native_tls::TlsConnectorBuilder) -> Self {
367            Self { builder }
368        }
369    }
370}