1use std::io::{self, BufRead as _};
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7#[cfg(feature = "early-data")]
8use std::task::Waker;
9use std::task::{Context, Poll};
10
11use rustls::ClientConnection;
12use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
13
14use crate::common::{IoSession, Stream, TlsState};
15
16#[derive(Debug)]
19pub struct TlsStream<IO> {
20 pub(crate) io: IO,
21 pub(crate) session: ClientConnection,
22 pub(crate) state: TlsState,
23
24 #[cfg(feature = "early-data")]
25 pub(crate) early_waker: Option<Waker>,
26}
27
28impl<IO> TlsStream<IO> {
29 #[inline]
30 pub fn get_ref(&self) -> (&IO, &ClientConnection) {
31 (&self.io, &self.session)
32 }
33
34 #[inline]
35 pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
36 (&mut self.io, &mut self.session)
37 }
38
39 #[inline]
40 pub fn into_inner(self) -> (IO, ClientConnection) {
41 (self.io, self.session)
42 }
43}
44
45#[cfg(unix)]
46impl<S> AsRawFd for TlsStream<S>
47where
48 S: AsRawFd,
49{
50 fn as_raw_fd(&self) -> RawFd {
51 self.get_ref().0.as_raw_fd()
52 }
53}
54
55#[cfg(windows)]
56impl<S> AsRawSocket for TlsStream<S>
57where
58 S: AsRawSocket,
59{
60 fn as_raw_socket(&self) -> RawSocket {
61 self.get_ref().0.as_raw_socket()
62 }
63}
64
65impl<IO> IoSession for TlsStream<IO> {
66 type Io = IO;
67 type Session = ClientConnection;
68
69 #[inline]
70 fn skip_handshake(&self) -> bool {
71 self.state.is_early_data()
72 }
73
74 #[inline]
75 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
76 (&mut self.state, &mut self.io, &mut self.session)
77 }
78
79 #[inline]
80 fn into_io(self) -> Self::Io {
81 self.io
82 }
83}
84
85#[cfg(feature = "early-data")]
86impl<IO> TlsStream<IO>
87where
88 IO: AsyncRead + AsyncWrite + Unpin,
89{
90 fn poll_early_data(&mut self, cx: &mut Context<'_>) {
91 if self
98 .early_waker
99 .as_ref()
100 .filter(|waker| cx.waker().will_wake(waker))
101 .is_none()
102 {
103 self.early_waker = Some(cx.waker().clone());
104 }
105 }
106}
107
108impl<IO> AsyncRead for TlsStream<IO>
109where
110 IO: AsyncRead + AsyncWrite + Unpin,
111{
112 fn poll_read(
113 mut self: Pin<&mut Self>,
114 cx: &mut Context<'_>,
115 buf: &mut ReadBuf<'_>,
116 ) -> Poll<io::Result<()>> {
117 let data = ready!(self.as_mut().poll_fill_buf(cx))?;
118 let len = data.len().min(buf.remaining());
119 buf.put_slice(&data[..len]);
120 self.consume(len);
121 Poll::Ready(Ok(()))
122 }
123}
124
125impl<IO> AsyncBufRead for TlsStream<IO>
126where
127 IO: AsyncRead + AsyncWrite + Unpin,
128{
129 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
130 match self.state {
131 #[cfg(feature = "early-data")]
132 TlsState::EarlyData(..) => {
133 self.get_mut().poll_early_data(cx);
134 Poll::Pending
135 }
136 TlsState::Stream | TlsState::WriteShutdown => {
137 let this = self.get_mut();
138 let stream =
139 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
140
141 match stream.poll_fill_buf(cx) {
142 Poll::Ready(Ok(buf)) => {
143 if buf.is_empty() {
144 this.state.shutdown_read();
145 }
146
147 Poll::Ready(Ok(buf))
148 }
149 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
150 this.state.shutdown_read();
151 Poll::Ready(Err(err))
152 }
153 output => output,
154 }
155 }
156 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])),
157 }
158 }
159
160 fn consume(mut self: Pin<&mut Self>, amt: usize) {
161 self.session.reader().consume(amt);
162 }
163}
164
165impl<IO> AsyncWrite for TlsStream<IO>
166where
167 IO: AsyncRead + AsyncWrite + Unpin,
168{
169 fn poll_write(
172 self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 buf: &[u8],
175 ) -> Poll<io::Result<usize>> {
176 let this = self.get_mut();
177 let mut stream =
178 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
179
180 #[cfg(feature = "early-data")]
181 {
182 let bufs = [io::IoSlice::new(buf)];
183 let written = ready!(poll_handle_early_data(
184 &mut this.state,
185 &mut stream,
186 &mut this.early_waker,
187 cx,
188 &bufs
189 ))?;
190 if written != 0 {
191 return Poll::Ready(Ok(written));
192 }
193 }
194
195 stream.as_mut_pin().poll_write(cx, buf)
196 }
197
198 fn poll_write_vectored(
201 self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 bufs: &[io::IoSlice<'_>],
204 ) -> Poll<io::Result<usize>> {
205 let this = self.get_mut();
206 let mut stream =
207 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
208
209 #[cfg(feature = "early-data")]
210 {
211 let written = ready!(poll_handle_early_data(
212 &mut this.state,
213 &mut stream,
214 &mut this.early_waker,
215 cx,
216 bufs
217 ))?;
218 if written != 0 {
219 return Poll::Ready(Ok(written));
220 }
221 }
222
223 stream.as_mut_pin().poll_write_vectored(cx, bufs)
224 }
225
226 #[inline]
227 fn is_write_vectored(&self) -> bool {
228 true
229 }
230
231 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
232 let this = self.get_mut();
233 let mut stream =
234 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
235
236 #[cfg(feature = "early-data")]
237 ready!(poll_handle_early_data(
238 &mut this.state,
239 &mut stream,
240 &mut this.early_waker,
241 cx,
242 &[]
243 ))?;
244
245 stream.as_mut_pin().poll_flush(cx)
246 }
247
248 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
249 #[cfg(feature = "early-data")]
250 {
251 if matches!(self.state, TlsState::EarlyData(..)) {
253 ready!(self.as_mut().poll_flush(cx))?;
254 }
255 }
256
257 if self.state.writeable() {
258 self.session.send_close_notify();
259 self.state.shutdown_write();
260 }
261
262 let this = self.get_mut();
263 let mut stream =
264 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
265 stream.as_mut_pin().poll_shutdown(cx)
266 }
267}
268
269#[cfg(feature = "early-data")]
270fn poll_handle_early_data<IO>(
271 state: &mut TlsState,
272 stream: &mut Stream<IO, ClientConnection>,
273 early_waker: &mut Option<Waker>,
274 cx: &mut Context<'_>,
275 bufs: &[io::IoSlice<'_>],
276) -> Poll<io::Result<usize>>
277where
278 IO: AsyncRead + AsyncWrite + Unpin,
279{
280 if let TlsState::EarlyData(pos, data) = state {
281 use std::io::Write;
282
283 if let Some(mut early_data) = stream.session.early_data() {
285 let mut written = 0;
286
287 for buf in bufs {
288 if buf.is_empty() {
289 continue;
290 }
291
292 let len = match early_data.write(buf) {
293 Ok(0) => break,
294 Ok(n) => n,
295 Err(err) => return Poll::Ready(Err(err)),
296 };
297
298 written += len;
299 data.extend_from_slice(&buf[..len]);
300
301 if len < buf.len() {
302 break;
303 }
304 }
305
306 if written != 0 {
307 return Poll::Ready(Ok(written));
308 }
309 }
310
311 while stream.session.is_handshaking() {
313 ready!(stream.handshake(cx))?;
314 }
315
316 if !stream.session.is_early_data_accepted() {
318 while *pos < data.len() {
319 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
320 *pos += len;
321 }
322 }
323
324 *state = TlsState::Stream;
326
327 if let Some(waker) = early_waker.take() {
328 waker.wake();
329 }
330 }
331
332 Poll::Ready(Ok(0))
333}