hyper_tls/
stream.rs

1use std::fmt;
2use std::io;
3use std::io::IoSlice;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use hyper::rt::{Read, ReadBufCursor, Write};
8
9use hyper_util::{
10    client::legacy::connect::{Connected, Connection},
11    rt::TokioIo,
12};
13pub use tokio_native_tls::TlsStream;
14
15/// A stream that might be protected with TLS.
16pub enum MaybeHttpsStream<T> {
17    /// A stream over plain text.
18    Http(T),
19    /// A stream protected with TLS.
20    Https(TokioIo<TlsStream<TokioIo<T>>>),
21}
22
23// ===== impl MaybeHttpsStream =====
24
25impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> {
26    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27        match self {
28            MaybeHttpsStream::Http(s) => f.debug_tuple("Http").field(s).finish(),
29            MaybeHttpsStream::Https(s) => f.debug_tuple("Https").field(s).finish(),
30        }
31    }
32}
33
34impl<T> From<T> for MaybeHttpsStream<T> {
35    fn from(inner: T) -> Self {
36        MaybeHttpsStream::Http(inner)
37    }
38}
39
40impl<T> From<TlsStream<TokioIo<T>>> for MaybeHttpsStream<T> {
41    fn from(inner: TlsStream<TokioIo<T>>) -> Self {
42        MaybeHttpsStream::Https(TokioIo::new(inner))
43    }
44}
45
46impl<T> From<TokioIo<TlsStream<TokioIo<T>>>> for MaybeHttpsStream<T> {
47    fn from(inner: TokioIo<TlsStream<TokioIo<T>>>) -> Self {
48        MaybeHttpsStream::Https(inner)
49    }
50}
51
52impl<T: Read + Write + Unpin> Read for MaybeHttpsStream<T> {
53    #[inline]
54    fn poll_read(
55        self: Pin<&mut Self>,
56        cx: &mut Context,
57        buf: ReadBufCursor<'_>,
58    ) -> Poll<Result<(), io::Error>> {
59        match Pin::get_mut(self) {
60            MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf),
61            MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(cx, buf),
62        }
63    }
64}
65
66impl<T: Write + Read + Unpin> Write for MaybeHttpsStream<T> {
67    #[inline]
68    fn poll_write(
69        self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &[u8],
72    ) -> Poll<Result<usize, io::Error>> {
73        match Pin::get_mut(self) {
74            MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(cx, buf),
75            MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(cx, buf),
76        }
77    }
78
79    fn poll_write_vectored(
80        self: Pin<&mut Self>,
81        cx: &mut Context<'_>,
82        bufs: &[IoSlice<'_>],
83    ) -> Poll<Result<usize, io::Error>> {
84        match Pin::get_mut(self) {
85            MaybeHttpsStream::Http(s) => Pin::new(s).poll_write_vectored(cx, bufs),
86            MaybeHttpsStream::Https(s) => Pin::new(s).poll_write_vectored(cx, bufs),
87        }
88    }
89
90    fn is_write_vectored(&self) -> bool {
91        match self {
92            MaybeHttpsStream::Http(s) => s.is_write_vectored(),
93            MaybeHttpsStream::Https(s) => s.is_write_vectored(),
94        }
95    }
96
97    #[inline]
98    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
99        match Pin::get_mut(self) {
100            MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(cx),
101            MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(cx),
102        }
103    }
104
105    #[inline]
106    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
107        match Pin::get_mut(self) {
108            MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(cx),
109            MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(cx),
110        }
111    }
112}
113
114impl<T: Write + Read + Connection + Unpin> Connection for MaybeHttpsStream<T> {
115    fn connected(&self) -> Connected {
116        match self {
117            MaybeHttpsStream::Http(s) => s.connected(),
118            MaybeHttpsStream::Https(s) => {
119                let c = s.inner().get_ref().get_ref().get_ref().inner().connected();
120                #[cfg(feature = "alpn")]
121                {
122                    if negotiated_h2(s.inner().get_ref()) {
123                        return c.negotiated_h2();
124                    }
125                }
126                c
127            }
128        }
129    }
130}
131
132#[cfg(feature = "alpn")]
133fn negotiated_h2<T: std::io::Read + std::io::Write>(s: &native_tls::TlsStream<T>) -> bool {
134    s.negotiated_alpn()
135        .unwrap_or(None)
136        .map(|list| list == &b"h2"[..])
137        .unwrap_or(false)
138}