async_native_tls/
tls_stream.rs1use std::io::{self, Read, Write};
2use std::marker::Unpin;
3use std::pin::Pin;
4use std::ptr::null_mut;
5use std::task::{Context, Poll};
6
7use crate::runtime::{AsyncRead, AsyncWrite};
8use crate::std_adapter::StdAdapter;
9
10#[derive(Debug)]
20pub struct TlsStream<S>(native_tls::TlsStream<StdAdapter<S>>);
21
22impl<S> TlsStream<S> {
23 pub(crate) fn new(stream: native_tls::TlsStream<StdAdapter<S>>) -> Self {
24 Self(stream)
25 }
26 fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
27 where
28 F: FnOnce(&mut native_tls::TlsStream<StdAdapter<S>>) -> R,
29 StdAdapter<S>: Read + Write,
30 {
31 self.0.get_mut().context = ctx as *mut _ as *mut ();
32 let g = Guard(self);
33 f(&mut (g.0).0)
34 }
35
36 pub fn get_ref(&self) -> &S
38 where
39 S: AsyncRead + AsyncWrite + Unpin,
40 {
41 &self.0.get_ref().inner
42 }
43
44 pub fn get_mut(&mut self) -> &mut S
46 where
47 S: AsyncRead + AsyncWrite + Unpin,
48 {
49 &mut self.0.get_mut().inner
50 }
51
52 pub fn buffered_read_size(&self) -> crate::Result<usize>
54 where
55 S: AsyncRead + AsyncWrite + Unpin,
56 {
57 self.0.buffered_read_size()
58 }
59
60 pub fn peer_certificate(&self) -> crate::Result<Option<crate::Certificate>>
62 where
63 S: AsyncRead + AsyncWrite + Unpin,
64 {
65 self.0.peer_certificate()
66 }
67
68 pub fn tls_server_end_point(&self) -> crate::Result<Option<Vec<u8>>>
70 where
71 S: AsyncRead + AsyncWrite + Unpin,
72 {
73 self.0.tls_server_end_point()
74 }
75}
76
77#[cfg(feature = "runtime-async-std")]
78impl<S> AsyncRead for TlsStream<S>
79where
80 S: AsyncRead + AsyncWrite + Unpin,
81{
82 fn poll_read(
83 mut self: Pin<&mut Self>,
84 ctx: &mut Context<'_>,
85 buf: &mut [u8],
86 ) -> Poll<io::Result<usize>> {
87 self.with_context(ctx, |s| cvt(s.read(buf)))
88 }
89}
90
91#[cfg(feature = "runtime-tokio")]
92impl<S> AsyncRead for TlsStream<S>
93where
94 S: AsyncRead + AsyncWrite + Unpin,
95{
96 fn poll_read(
97 mut self: Pin<&mut Self>,
98 ctx: &mut Context<'_>,
99 buf: &mut tokio::io::ReadBuf<'_>,
100 ) -> Poll<io::Result<()>> {
101 match self.with_context(ctx, |s| cvt(s.read(buf.initialize_unfilled()))) {
102 Poll::Ready(Ok(len)) => {
103 buf.advance(len);
104 Poll::Ready(Ok(()))
105 }
106 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
107 Poll::Pending => Poll::Pending,
108 }
109 }
110}
111
112impl<S> AsyncWrite for TlsStream<S>
113where
114 S: AsyncRead + AsyncWrite + Unpin,
115{
116 fn poll_write(
117 mut self: Pin<&mut Self>,
118 ctx: &mut Context<'_>,
119 buf: &[u8],
120 ) -> Poll<io::Result<usize>> {
121 self.with_context(ctx, |s| cvt(s.write(buf)))
122 }
123
124 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
125 self.with_context(ctx, |s| cvt(s.flush()))
126 }
127
128 #[cfg(feature = "runtime-async-std")]
129 fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
130 match self.with_context(ctx, |s| s.shutdown()) {
131 Ok(()) => Poll::Ready(Ok(())),
132 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
133 Err(e) => Poll::Ready(Err(e)),
134 }
135 }
136
137 #[cfg(feature = "runtime-tokio")]
138 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
139 match self.with_context(ctx, |s| s.shutdown()) {
140 Ok(()) => Poll::Ready(Ok(())),
141 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
142 Err(e) => Poll::Ready(Err(e)),
143 }
144 }
145}
146
147struct Guard<'a, S>(&'a mut TlsStream<S>)
148where
149 StdAdapter<S>: Read + Write;
150
151impl<S> Drop for Guard<'_, S>
152where
153 StdAdapter<S>: Read + Write,
154{
155 fn drop(&mut self) {
156 (self.0).0.get_mut().context = null_mut();
157 }
158}
159
160fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
161 match r {
162 Ok(v) => Poll::Ready(Ok(v)),
163 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
164 Err(e) => Poll::Ready(Err(e)),
165 }
166}