1use std::fmt;
2use std::io;
3use std::io::IoSlice;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use hyper::rt::{Read, ReadBufCursor, Write};
8
9use hyper_util::{
10 client::legacy::connect::{Connected, Connection},
11 rt::TokioIo,
12};
13pub use tokio_native_tls::TlsStream;
14
15pub enum MaybeHttpsStream<T> {
17 Http(T),
19 Https(TokioIo<TlsStream<TokioIo<T>>>),
21}
22
23impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> {
26 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27 match self {
28 MaybeHttpsStream::Http(s) => f.debug_tuple("Http").field(s).finish(),
29 MaybeHttpsStream::Https(s) => f.debug_tuple("Https").field(s).finish(),
30 }
31 }
32}
33
34impl<T> From<T> for MaybeHttpsStream<T> {
35 fn from(inner: T) -> Self {
36 MaybeHttpsStream::Http(inner)
37 }
38}
39
40impl<T> From<TlsStream<TokioIo<T>>> for MaybeHttpsStream<T> {
41 fn from(inner: TlsStream<TokioIo<T>>) -> Self {
42 MaybeHttpsStream::Https(TokioIo::new(inner))
43 }
44}
45
46impl<T> From<TokioIo<TlsStream<TokioIo<T>>>> for MaybeHttpsStream<T> {
47 fn from(inner: TokioIo<TlsStream<TokioIo<T>>>) -> Self {
48 MaybeHttpsStream::Https(inner)
49 }
50}
51
52impl<T: Read + Write + Unpin> Read for MaybeHttpsStream<T> {
53 #[inline]
54 fn poll_read(
55 self: Pin<&mut Self>,
56 cx: &mut Context,
57 buf: ReadBufCursor<'_>,
58 ) -> Poll<Result<(), io::Error>> {
59 match Pin::get_mut(self) {
60 MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf),
61 MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(cx, buf),
62 }
63 }
64}
65
66impl<T: Write + Read + Unpin> Write for MaybeHttpsStream<T> {
67 #[inline]
68 fn poll_write(
69 self: Pin<&mut Self>,
70 cx: &mut Context<'_>,
71 buf: &[u8],
72 ) -> Poll<Result<usize, io::Error>> {
73 match Pin::get_mut(self) {
74 MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(cx, buf),
75 MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(cx, buf),
76 }
77 }
78
79 fn poll_write_vectored(
80 self: Pin<&mut Self>,
81 cx: &mut Context<'_>,
82 bufs: &[IoSlice<'_>],
83 ) -> Poll<Result<usize, io::Error>> {
84 match Pin::get_mut(self) {
85 MaybeHttpsStream::Http(s) => Pin::new(s).poll_write_vectored(cx, bufs),
86 MaybeHttpsStream::Https(s) => Pin::new(s).poll_write_vectored(cx, bufs),
87 }
88 }
89
90 fn is_write_vectored(&self) -> bool {
91 match self {
92 MaybeHttpsStream::Http(s) => s.is_write_vectored(),
93 MaybeHttpsStream::Https(s) => s.is_write_vectored(),
94 }
95 }
96
97 #[inline]
98 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
99 match Pin::get_mut(self) {
100 MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(cx),
101 MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(cx),
102 }
103 }
104
105 #[inline]
106 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
107 match Pin::get_mut(self) {
108 MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(cx),
109 MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(cx),
110 }
111 }
112}
113
114impl<T: Write + Read + Connection + Unpin> Connection for MaybeHttpsStream<T> {
115 fn connected(&self) -> Connected {
116 match self {
117 MaybeHttpsStream::Http(s) => s.connected(),
118 MaybeHttpsStream::Https(s) => {
119 let c = s.inner().get_ref().get_ref().get_ref().inner().connected();
120 #[cfg(feature = "alpn")]
121 {
122 if negotiated_h2(s.inner().get_ref()) {
123 return c.negotiated_h2();
124 }
125 }
126 c
127 }
128 }
129 }
130}
131
132#[cfg(feature = "alpn")]
133fn negotiated_h2<T: std::io::Read + std::io::Write>(s: &native_tls::TlsStream<T>) -> bool {
134 s.negotiated_alpn()
135 .unwrap_or(None)
136 .map(|list| list == &b"h2"[..])
137 .unwrap_or(false)
138}