tokio_rustls/
client.rs

1use std::io::{self, BufRead as _};
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7#[cfg(feature = "early-data")]
8use std::task::Waker;
9use std::task::{Context, Poll};
10
11use rustls::ClientConnection;
12use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
13
14use crate::common::{IoSession, Stream, TlsState};
15
16/// A wrapper around an underlying raw stream which implements the TLS or SSL
17/// protocol.
18#[derive(Debug)]
19pub struct TlsStream<IO> {
20    pub(crate) io: IO,
21    pub(crate) session: ClientConnection,
22    pub(crate) state: TlsState,
23
24    #[cfg(feature = "early-data")]
25    pub(crate) early_waker: Option<Waker>,
26}
27
28impl<IO> TlsStream<IO> {
29    #[inline]
30    pub fn get_ref(&self) -> (&IO, &ClientConnection) {
31        (&self.io, &self.session)
32    }
33
34    #[inline]
35    pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
36        (&mut self.io, &mut self.session)
37    }
38
39    #[inline]
40    pub fn into_inner(self) -> (IO, ClientConnection) {
41        (self.io, self.session)
42    }
43}
44
45#[cfg(unix)]
46impl<S> AsRawFd for TlsStream<S>
47where
48    S: AsRawFd,
49{
50    fn as_raw_fd(&self) -> RawFd {
51        self.get_ref().0.as_raw_fd()
52    }
53}
54
55#[cfg(windows)]
56impl<S> AsRawSocket for TlsStream<S>
57where
58    S: AsRawSocket,
59{
60    fn as_raw_socket(&self) -> RawSocket {
61        self.get_ref().0.as_raw_socket()
62    }
63}
64
65impl<IO> IoSession for TlsStream<IO> {
66    type Io = IO;
67    type Session = ClientConnection;
68
69    #[inline]
70    fn skip_handshake(&self) -> bool {
71        self.state.is_early_data()
72    }
73
74    #[inline]
75    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
76        (&mut self.state, &mut self.io, &mut self.session)
77    }
78
79    #[inline]
80    fn into_io(self) -> Self::Io {
81        self.io
82    }
83}
84
85#[cfg(feature = "early-data")]
86impl<IO> TlsStream<IO>
87where
88    IO: AsyncRead + AsyncWrite + Unpin,
89{
90    fn poll_early_data(&mut self, cx: &mut Context<'_>) {
91        // In the EarlyData state, we have not really established a Tls connection.
92        // Before writing data through `AsyncWrite` and completing the tls handshake,
93        // we ignore read readiness and return to pending.
94        //
95        // In order to avoid event loss,
96        // we need to register a waker and wake it up after tls is connected.
97        if self
98            .early_waker
99            .as_ref()
100            .filter(|waker| cx.waker().will_wake(waker))
101            .is_none()
102        {
103            self.early_waker = Some(cx.waker().clone());
104        }
105    }
106}
107
108impl<IO> AsyncRead for TlsStream<IO>
109where
110    IO: AsyncRead + AsyncWrite + Unpin,
111{
112    fn poll_read(
113        mut self: Pin<&mut Self>,
114        cx: &mut Context<'_>,
115        buf: &mut ReadBuf<'_>,
116    ) -> Poll<io::Result<()>> {
117        let data = ready!(self.as_mut().poll_fill_buf(cx))?;
118        let len = data.len().min(buf.remaining());
119        buf.put_slice(&data[..len]);
120        self.consume(len);
121        Poll::Ready(Ok(()))
122    }
123}
124
125impl<IO> AsyncBufRead for TlsStream<IO>
126where
127    IO: AsyncRead + AsyncWrite + Unpin,
128{
129    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
130        match self.state {
131            #[cfg(feature = "early-data")]
132            TlsState::EarlyData(..) => {
133                self.get_mut().poll_early_data(cx);
134                Poll::Pending
135            }
136            TlsState::Stream | TlsState::WriteShutdown => {
137                let this = self.get_mut();
138                let stream =
139                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
140
141                match stream.poll_fill_buf(cx) {
142                    Poll::Ready(Ok(buf)) => {
143                        if buf.is_empty() {
144                            this.state.shutdown_read();
145                        }
146
147                        Poll::Ready(Ok(buf))
148                    }
149                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
150                        this.state.shutdown_read();
151                        Poll::Ready(Err(err))
152                    }
153                    output => output,
154                }
155            }
156            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])),
157        }
158    }
159
160    fn consume(mut self: Pin<&mut Self>, amt: usize) {
161        self.session.reader().consume(amt);
162    }
163}
164
165impl<IO> AsyncWrite for TlsStream<IO>
166where
167    IO: AsyncRead + AsyncWrite + Unpin,
168{
169    /// Note: that it does not guarantee the final data to be sent.
170    /// To be cautious, you must manually call `flush`.
171    fn poll_write(
172        self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174        buf: &[u8],
175    ) -> Poll<io::Result<usize>> {
176        let this = self.get_mut();
177        let mut stream =
178            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
179
180        #[cfg(feature = "early-data")]
181        {
182            let bufs = [io::IoSlice::new(buf)];
183            let written = ready!(poll_handle_early_data(
184                &mut this.state,
185                &mut stream,
186                &mut this.early_waker,
187                cx,
188                &bufs
189            ))?;
190            if written != 0 {
191                return Poll::Ready(Ok(written));
192            }
193        }
194
195        stream.as_mut_pin().poll_write(cx, buf)
196    }
197
198    /// Note: that it does not guarantee the final data to be sent.
199    /// To be cautious, you must manually call `flush`.
200    fn poll_write_vectored(
201        self: Pin<&mut Self>,
202        cx: &mut Context<'_>,
203        bufs: &[io::IoSlice<'_>],
204    ) -> Poll<io::Result<usize>> {
205        let this = self.get_mut();
206        let mut stream =
207            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
208
209        #[cfg(feature = "early-data")]
210        {
211            let written = ready!(poll_handle_early_data(
212                &mut this.state,
213                &mut stream,
214                &mut this.early_waker,
215                cx,
216                bufs
217            ))?;
218            if written != 0 {
219                return Poll::Ready(Ok(written));
220            }
221        }
222
223        stream.as_mut_pin().poll_write_vectored(cx, bufs)
224    }
225
226    #[inline]
227    fn is_write_vectored(&self) -> bool {
228        true
229    }
230
231    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
232        let this = self.get_mut();
233        let mut stream =
234            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
235
236        #[cfg(feature = "early-data")]
237        ready!(poll_handle_early_data(
238            &mut this.state,
239            &mut stream,
240            &mut this.early_waker,
241            cx,
242            &[]
243        ))?;
244
245        stream.as_mut_pin().poll_flush(cx)
246    }
247
248    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
249        #[cfg(feature = "early-data")]
250        {
251            // complete handshake
252            if matches!(self.state, TlsState::EarlyData(..)) {
253                ready!(self.as_mut().poll_flush(cx))?;
254            }
255        }
256
257        if self.state.writeable() {
258            self.session.send_close_notify();
259            self.state.shutdown_write();
260        }
261
262        let this = self.get_mut();
263        let mut stream =
264            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
265        stream.as_mut_pin().poll_shutdown(cx)
266    }
267}
268
269#[cfg(feature = "early-data")]
270fn poll_handle_early_data<IO>(
271    state: &mut TlsState,
272    stream: &mut Stream<IO, ClientConnection>,
273    early_waker: &mut Option<Waker>,
274    cx: &mut Context<'_>,
275    bufs: &[io::IoSlice<'_>],
276) -> Poll<io::Result<usize>>
277where
278    IO: AsyncRead + AsyncWrite + Unpin,
279{
280    if let TlsState::EarlyData(pos, data) = state {
281        use std::io::Write;
282
283        // write early data
284        if let Some(mut early_data) = stream.session.early_data() {
285            let mut written = 0;
286
287            for buf in bufs {
288                if buf.is_empty() {
289                    continue;
290                }
291
292                let len = match early_data.write(buf) {
293                    Ok(0) => break,
294                    Ok(n) => n,
295                    Err(err) => return Poll::Ready(Err(err)),
296                };
297
298                written += len;
299                data.extend_from_slice(&buf[..len]);
300
301                if len < buf.len() {
302                    break;
303                }
304            }
305
306            if written != 0 {
307                return Poll::Ready(Ok(written));
308            }
309        }
310
311        // complete handshake
312        while stream.session.is_handshaking() {
313            ready!(stream.handshake(cx))?;
314        }
315
316        // write early data (fallback)
317        if !stream.session.is_early_data_accepted() {
318            while *pos < data.len() {
319                let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
320                *pos += len;
321            }
322        }
323
324        // end
325        *state = TlsState::Stream;
326
327        if let Some(waker) = early_waker.take() {
328            waker.wake();
329        }
330    }
331
332    Poll::Ready(Ok(0))
333}