tokio_tungstenite/
stream.rs

1//! Convenience wrapper for streams to switch between plain TCP and TLS at runtime.
2//!
3//!  There is no dependency on actual TLS implementations. Everything like
4//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
5//! `Read + Write` traits.
6use std::{
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12
13/// A stream that might be protected with TLS.
14#[non_exhaustive]
15#[derive(Debug)]
16pub enum MaybeTlsStream<S> {
17    /// Unencrypted socket stream.
18    Plain(S),
19    /// Encrypted socket stream using `native-tls`.
20    #[cfg(feature = "native-tls")]
21    NativeTls(tokio_native_tls::TlsStream<S>),
22    /// Encrypted socket stream using `rustls`.
23    #[cfg(feature = "__rustls-tls")]
24    Rustls(tokio_rustls::client::TlsStream<S>),
25}
26
27impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
28    fn poll_read(
29        self: Pin<&mut Self>,
30        cx: &mut Context<'_>,
31        buf: &mut ReadBuf<'_>,
32    ) -> Poll<std::io::Result<()>> {
33        match self.get_mut() {
34            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
35            #[cfg(feature = "native-tls")]
36            MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
37            #[cfg(feature = "__rustls-tls")]
38            MaybeTlsStream::Rustls(s) => Pin::new(s).poll_read(cx, buf),
39        }
40    }
41}
42
43impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
44    fn poll_write(
45        self: Pin<&mut Self>,
46        cx: &mut Context<'_>,
47        buf: &[u8],
48    ) -> Poll<Result<usize, std::io::Error>> {
49        match self.get_mut() {
50            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
51            #[cfg(feature = "native-tls")]
52            MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
53            #[cfg(feature = "__rustls-tls")]
54            MaybeTlsStream::Rustls(s) => Pin::new(s).poll_write(cx, buf),
55        }
56    }
57
58    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
59        match self.get_mut() {
60            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
61            #[cfg(feature = "native-tls")]
62            MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_flush(cx),
63            #[cfg(feature = "__rustls-tls")]
64            MaybeTlsStream::Rustls(s) => Pin::new(s).poll_flush(cx),
65        }
66    }
67
68    fn poll_shutdown(
69        self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71    ) -> Poll<Result<(), std::io::Error>> {
72        match self.get_mut() {
73            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
74            #[cfg(feature = "native-tls")]
75            MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_shutdown(cx),
76            #[cfg(feature = "__rustls-tls")]
77            MaybeTlsStream::Rustls(s) => Pin::new(s).poll_shutdown(cx),
78        }
79    }
80}