web3_async_native_tls/
acceptor.rs1use std::fmt;
2use std::marker::Unpin;
3
4use crate::handshake::handshake;
5use crate::runtime::{AsyncRead, AsyncReadExt, AsyncWrite};
6use crate::TlsStream;
7
8#[derive(Clone)]
41pub struct TlsAcceptor(native_tls::TlsAcceptor);
42
43#[derive(thiserror::Error, Debug)]
45pub enum Error {
46 #[error("NativeTls({})", 0)]
48 NativeTls(#[from] native_tls::Error),
49 #[error("Io({})", 0)]
51 Io(#[from] std::io::Error),
52}
53
54impl TlsAcceptor {
55 pub async fn new<R, S>(mut file: R, password: S) -> Result<Self, Error>
57 where
58 R: AsyncRead + Unpin,
59 S: AsRef<str>,
60 {
61 let mut identity = vec![];
62 file.read_to_end(&mut identity).await?;
63
64 let identity = native_tls::Identity::from_pkcs12(&identity, password.as_ref())?;
65 Ok(TlsAcceptor(native_tls::TlsAcceptor::new(identity)?))
66 }
67
68 pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, native_tls::Error>
79 where
80 S: AsyncRead + AsyncWrite + Unpin,
81 {
82 let stream = handshake(move |s| self.0.accept(s), stream).await?;
83 Ok(stream)
84 }
85}
86
87impl fmt::Debug for TlsAcceptor {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("TlsAcceptor").finish()
90 }
91}
92
93impl From<native_tls::TlsAcceptor> for TlsAcceptor {
94 fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor {
95 TlsAcceptor(inner)
96 }
97}
98
99#[cfg(all(test, feature = "runtime-async-std"))]
100mod tests {
101 use super::*;
102 use crate::runtime::AsyncWriteExt;
103 use crate::TlsConnector;
104 use async_std::fs::File;
105 use async_std::net::{TcpListener, TcpStream};
106 use async_std::stream::StreamExt;
107
108 #[async_std::test]
109 async fn test_acceptor() {
110 let key = File::open("tests/identity.pfx").await.unwrap();
111 let acceptor = TlsAcceptor::new(key, "hello").await.unwrap();
112 let listener = TcpListener::bind("127.0.0.1:8443").await.unwrap();
113 async_std::task::spawn(async move {
114 let mut incoming = listener.incoming();
115
116 while let Some(stream) = incoming.next().await {
117 let acceptor = acceptor.clone();
118 let stream = stream.unwrap();
119 async_std::task::spawn(async move {
120 let mut stream = acceptor.accept(stream).await.unwrap();
121 stream.write_all(b"hello").await.unwrap();
122 });
123 }
124 });
125
126 let stream = TcpStream::connect("127.0.01:8443").await.unwrap();
127 let connector = TlsConnector::new().danger_accept_invalid_certs(true);
128
129 let mut stream = connector.connect("127.0.0.1", stream).await.unwrap();
130 let mut res = Vec::new();
131 stream.read_to_end(&mut res).await.unwrap();
132 assert_eq!(res, b"hello");
133 }
134}