async_native_tls/
tls_stream.rs

1use std::io::{self, Read, Write};
2use std::marker::Unpin;
3use std::pin::Pin;
4use std::ptr::null_mut;
5use std::task::{Context, Poll};
6
7use crate::runtime::{AsyncRead, AsyncWrite};
8use crate::std_adapter::StdAdapter;
9
10/// A stream managing a TLS session.
11///
12/// A wrapper around an underlying raw stream which implements the TLS or SSL
13/// protocol.
14///
15/// A `TlsStream<S>` represents a handshake that has been completed successfully
16/// and both the server and the client are ready for receiving and sending
17/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
18/// to a `TlsStream` are encrypted when passing through to `S`.
19#[derive(Debug)]
20pub struct TlsStream<S>(native_tls::TlsStream<StdAdapter<S>>);
21
22impl<S> TlsStream<S> {
23    pub(crate) fn new(stream: native_tls::TlsStream<StdAdapter<S>>) -> Self {
24        Self(stream)
25    }
26    fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
27    where
28        F: FnOnce(&mut native_tls::TlsStream<StdAdapter<S>>) -> R,
29        StdAdapter<S>: Read + Write,
30    {
31        self.0.get_mut().context = ctx as *mut _ as *mut ();
32        let g = Guard(self);
33        f(&mut (g.0).0)
34    }
35
36    /// Returns a shared reference to the inner stream.
37    pub fn get_ref(&self) -> &S
38    where
39        S: AsyncRead + AsyncWrite + Unpin,
40    {
41        &self.0.get_ref().inner
42    }
43
44    /// Returns a mutable reference to the inner stream.
45    pub fn get_mut(&mut self) -> &mut S
46    where
47        S: AsyncRead + AsyncWrite + Unpin,
48    {
49        &mut self.0.get_mut().inner
50    }
51
52    /// Returns the number of bytes that can be read without resulting in any network calls.
53    pub fn buffered_read_size(&self) -> crate::Result<usize>
54    where
55        S: AsyncRead + AsyncWrite + Unpin,
56    {
57        self.0.buffered_read_size()
58    }
59
60    /// Returns the peer's leaf certificate, if available.
61    pub fn peer_certificate(&self) -> crate::Result<Option<crate::Certificate>>
62    where
63        S: AsyncRead + AsyncWrite + Unpin,
64    {
65        self.0.peer_certificate()
66    }
67
68    /// Returns the tls-server-end-point channel binding data as defined in [RFC 5929](https://tools.ietf.org/html/rfc5929).
69    pub fn tls_server_end_point(&self) -> crate::Result<Option<Vec<u8>>>
70    where
71        S: AsyncRead + AsyncWrite + Unpin,
72    {
73        self.0.tls_server_end_point()
74    }
75}
76
77#[cfg(feature = "runtime-async-std")]
78impl<S> AsyncRead for TlsStream<S>
79where
80    S: AsyncRead + AsyncWrite + Unpin,
81{
82    fn poll_read(
83        mut self: Pin<&mut Self>,
84        ctx: &mut Context<'_>,
85        buf: &mut [u8],
86    ) -> Poll<io::Result<usize>> {
87        self.with_context(ctx, |s| cvt(s.read(buf)))
88    }
89}
90
91#[cfg(feature = "runtime-tokio")]
92impl<S> AsyncRead for TlsStream<S>
93where
94    S: AsyncRead + AsyncWrite + Unpin,
95{
96    fn poll_read(
97        mut self: Pin<&mut Self>,
98        ctx: &mut Context<'_>,
99        buf: &mut tokio::io::ReadBuf<'_>,
100    ) -> Poll<io::Result<()>> {
101        match self.with_context(ctx, |s| cvt(s.read(buf.initialize_unfilled()))) {
102            Poll::Ready(Ok(len)) => {
103                buf.advance(len);
104                Poll::Ready(Ok(()))
105            }
106            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
107            Poll::Pending => Poll::Pending,
108        }
109    }
110}
111
112impl<S> AsyncWrite for TlsStream<S>
113where
114    S: AsyncRead + AsyncWrite + Unpin,
115{
116    fn poll_write(
117        mut self: Pin<&mut Self>,
118        ctx: &mut Context<'_>,
119        buf: &[u8],
120    ) -> Poll<io::Result<usize>> {
121        self.with_context(ctx, |s| cvt(s.write(buf)))
122    }
123
124    fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
125        self.with_context(ctx, |s| cvt(s.flush()))
126    }
127
128    #[cfg(feature = "runtime-async-std")]
129    fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
130        match self.with_context(ctx, |s| s.shutdown()) {
131            Ok(()) => Poll::Ready(Ok(())),
132            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
133            Err(e) => Poll::Ready(Err(e)),
134        }
135    }
136
137    #[cfg(feature = "runtime-tokio")]
138    fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
139        match self.with_context(ctx, |s| s.shutdown()) {
140            Ok(()) => Poll::Ready(Ok(())),
141            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
142            Err(e) => Poll::Ready(Err(e)),
143        }
144    }
145}
146
147struct Guard<'a, S>(&'a mut TlsStream<S>)
148where
149    StdAdapter<S>: Read + Write;
150
151impl<S> Drop for Guard<'_, S>
152where
153    StdAdapter<S>: Read + Write,
154{
155    fn drop(&mut self) {
156        (self.0).0.get_mut().context = null_mut();
157    }
158}
159
160fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
161    match r {
162        Ok(v) => Poll::Ready(Ok(v)),
163        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
164        Err(e) => Poll::Ready(Err(e)),
165    }
166}