web3_async_native_tls/
acceptor.rs

1use std::fmt;
2use std::marker::Unpin;
3
4use crate::handshake::handshake;
5use crate::runtime::{AsyncRead, AsyncReadExt, AsyncWrite};
6use crate::TlsStream;
7
8/// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept`
9/// method.
10///
11/// # Example
12///
13/// ```no_run
14/// # #[cfg(feature = "runtime-async-std")]
15/// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> { async_std::task::block_on(async {
16/// #
17/// use async_std::prelude::*;
18/// use async_std::net::TcpListener;
19/// use async_std::fs::File;
20/// use async_native_tls::TlsAcceptor;
21///
22/// let key = File::open("tests/identity.pfx").await?;
23/// let acceptor = TlsAcceptor::new(key, "hello").await?;
24/// let listener = TcpListener::bind("127.0.0.1:8443").await?;
25/// let mut incoming = listener.incoming();
26///
27/// while let Some(stream) = incoming.next().await {
28///     let acceptor = acceptor.clone();
29///     let stream = stream?;
30///     async_std::task::spawn(async move {
31///         let stream = acceptor.accept(stream).await.unwrap();
32///         // handle stream here
33///     });
34/// }
35/// #
36/// # Ok(()) }) }
37/// # #[cfg(feature = "runtime-tokio")]
38/// # fn main() {}
39/// ```
40#[derive(Clone)]
41pub struct TlsAcceptor(native_tls::TlsAcceptor);
42
43/// An error returned from creating an acceptor.
44#[derive(thiserror::Error, Debug)]
45pub enum Error {
46    /// NativeTls error.
47    #[error("NativeTls({})", 0)]
48    NativeTls(#[from] native_tls::Error),
49    /// Io error.
50    #[error("Io({})", 0)]
51    Io(#[from] std::io::Error),
52}
53
54impl TlsAcceptor {
55    /// Create a new TlsAcceptor based on an identity file and matching password.
56    pub async fn new<R, S>(mut file: R, password: S) -> Result<Self, Error>
57    where
58        R: AsyncRead + Unpin,
59        S: AsRef<str>,
60    {
61        let mut identity = vec![];
62        file.read_to_end(&mut identity).await?;
63
64        let identity = native_tls::Identity::from_pkcs12(&identity, password.as_ref())?;
65        Ok(TlsAcceptor(native_tls::TlsAcceptor::new(identity)?))
66    }
67
68    /// Accepts a new client connection with the provided stream.
69    ///
70    /// This function will internally call `TlsAcceptor::accept` to connect
71    /// the stream and returns a future representing the resolution of the
72    /// connection operation. The returned future will resolve to either
73    /// `TlsStream<S>` or `Error` depending if it's successful or not.
74    ///
75    /// This is typically used after a new socket has been accepted from a
76    /// `TcpListener`. That socket is then passed to this function to perform
77    /// the server half of accepting a client connection.
78    pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, native_tls::Error>
79    where
80        S: AsyncRead + AsyncWrite + Unpin,
81    {
82        let stream = handshake(move |s| self.0.accept(s), stream).await?;
83        Ok(stream)
84    }
85}
86
87impl fmt::Debug for TlsAcceptor {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        f.debug_struct("TlsAcceptor").finish()
90    }
91}
92
93impl From<native_tls::TlsAcceptor> for TlsAcceptor {
94    fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor {
95        TlsAcceptor(inner)
96    }
97}
98
99#[cfg(all(test, feature = "runtime-async-std"))]
100mod tests {
101    use super::*;
102    use crate::runtime::AsyncWriteExt;
103    use crate::TlsConnector;
104    use async_std::fs::File;
105    use async_std::net::{TcpListener, TcpStream};
106    use async_std::stream::StreamExt;
107
108    #[async_std::test]
109    async fn test_acceptor() {
110        let key = File::open("tests/identity.pfx").await.unwrap();
111        let acceptor = TlsAcceptor::new(key, "hello").await.unwrap();
112        let listener = TcpListener::bind("127.0.0.1:8443").await.unwrap();
113        async_std::task::spawn(async move {
114            let mut incoming = listener.incoming();
115
116            while let Some(stream) = incoming.next().await {
117                let acceptor = acceptor.clone();
118                let stream = stream.unwrap();
119                async_std::task::spawn(async move {
120                    let mut stream = acceptor.accept(stream).await.unwrap();
121                    stream.write_all(b"hello").await.unwrap();
122                });
123            }
124        });
125
126        let stream = TcpStream::connect("127.0.01:8443").await.unwrap();
127        let connector = TlsConnector::new().danger_accept_invalid_certs(true);
128
129        let mut stream = connector.connect("127.0.0.1", stream).await.unwrap();
130        let mut res = Vec::new();
131        stream.read_to_end(&mut res).await.unwrap();
132        assert_eq!(res, b"hello");
133    }
134}