1use std::future::Future;
40use std::io;
41#[cfg(unix)]
42use std::os::unix::io::{AsRawFd, RawFd};
43#[cfg(windows)]
44use std::os::windows::io::{AsRawSocket, RawSocket};
45use std::pin::Pin;
46use std::sync::Arc;
47use std::task::{Context, Poll};
48
49pub use rustls;
50
51use rustls::pki_types::ServerName;
52use rustls::server::AcceptedAlert;
53use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
54use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
55
56macro_rules! ready {
57 ( $e:expr ) => {
58 match $e {
59 std::task::Poll::Ready(t) => t,
60 std::task::Poll::Pending => return std::task::Poll::Pending,
61 }
62 };
63}
64
65pub mod client;
66mod common;
67use common::{MidHandshake, TlsState};
68pub mod server;
69
70#[derive(Clone)]
72pub struct TlsConnector {
73 inner: Arc<ClientConfig>,
74 #[cfg(feature = "early-data")]
75 early_data: bool,
76}
77
78#[derive(Clone)]
80pub struct TlsAcceptor {
81 inner: Arc<ServerConfig>,
82}
83
84impl From<Arc<ClientConfig>> for TlsConnector {
85 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
86 TlsConnector {
87 inner,
88 #[cfg(feature = "early-data")]
89 early_data: false,
90 }
91 }
92}
93
94impl From<Arc<ServerConfig>> for TlsAcceptor {
95 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
96 TlsAcceptor { inner }
97 }
98}
99
100impl TlsConnector {
101 #[cfg(feature = "early-data")]
106 pub fn early_data(mut self, flag: bool) -> TlsConnector {
107 self.early_data = flag;
108 self
109 }
110
111 #[inline]
112 pub fn connect<IO>(&self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
113 where
114 IO: AsyncRead + AsyncWrite + Unpin,
115 {
116 self.connect_with(domain, stream, |_| ())
117 }
118
119 pub fn connect_with<IO, F>(&self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
120 where
121 IO: AsyncRead + AsyncWrite + Unpin,
122 F: FnOnce(&mut ClientConnection),
123 {
124 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
125 Ok(session) => session,
126 Err(error) => {
127 return Connect(MidHandshake::Error {
128 io: stream,
129 error: io::Error::new(io::ErrorKind::Other, error),
132 });
133 }
134 };
135 f(&mut session);
136
137 Connect(MidHandshake::Handshaking(client::TlsStream {
138 io: stream,
139
140 #[cfg(not(feature = "early-data"))]
141 state: TlsState::Stream,
142
143 #[cfg(feature = "early-data")]
144 state: if self.early_data && session.early_data().is_some() {
145 TlsState::EarlyData(0, Vec::new())
146 } else {
147 TlsState::Stream
148 },
149
150 #[cfg(feature = "early-data")]
151 early_waker: None,
152
153 session,
154 }))
155 }
156
157 pub fn config(&self) -> &Arc<ClientConfig> {
159 &self.inner
160 }
161}
162
163impl TlsAcceptor {
164 #[inline]
165 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
166 where
167 IO: AsyncRead + AsyncWrite + Unpin,
168 {
169 self.accept_with(stream, |_| ())
170 }
171
172 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
173 where
174 IO: AsyncRead + AsyncWrite + Unpin,
175 F: FnOnce(&mut ServerConnection),
176 {
177 let mut session = match ServerConnection::new(self.inner.clone()) {
178 Ok(session) => session,
179 Err(error) => {
180 return Accept(MidHandshake::Error {
181 io: stream,
182 error: io::Error::new(io::ErrorKind::Other, error),
185 });
186 }
187 };
188 f(&mut session);
189
190 Accept(MidHandshake::Handshaking(server::TlsStream {
191 session,
192 io: stream,
193 state: TlsState::Stream,
194 }))
195 }
196
197 pub fn config(&self) -> &Arc<ServerConfig> {
199 &self.inner
200 }
201}
202
203pub struct LazyConfigAcceptor<IO> {
204 acceptor: rustls::server::Acceptor,
205 io: Option<IO>,
206 alert: Option<(rustls::Error, AcceptedAlert)>,
207}
208
209impl<IO> LazyConfigAcceptor<IO>
210where
211 IO: AsyncRead + AsyncWrite + Unpin,
212{
213 #[inline]
214 pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
215 Self {
216 acceptor,
217 io: Some(io),
218 alert: None,
219 }
220 }
221
222 pub fn take_io(&mut self) -> Option<IO> {
264 self.io.take()
265 }
266}
267
268impl<IO> Future for LazyConfigAcceptor<IO>
269where
270 IO: AsyncRead + AsyncWrite + Unpin,
271{
272 type Output = Result<StartHandshake<IO>, io::Error>;
273
274 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
275 let this = self.get_mut();
276 loop {
277 let io = match this.io.as_mut() {
278 Some(io) => io,
279 None => {
280 return Poll::Ready(Err(io::Error::new(
281 io::ErrorKind::Other,
282 "acceptor cannot be polled after acceptance",
283 )))
284 }
285 };
286
287 if let Some((err, mut alert)) = this.alert.take() {
288 match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
289 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
290 this.alert = Some((err, alert));
291 return Poll::Pending;
292 }
293 Ok(0) | Err(_) => {
294 return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
295 }
296 Ok(_) => {
297 this.alert = Some((err, alert));
298 continue;
299 }
300 };
301 }
302
303 let mut reader = common::SyncReadAdapter { io, cx };
304 match this.acceptor.read_tls(&mut reader) {
305 Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
306 Ok(_) => {}
307 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
308 Err(e) => return Err(e).into(),
309 }
310
311 match this.acceptor.accept() {
312 Ok(Some(accepted)) => {
313 let io = this.io.take().unwrap();
314 return Poll::Ready(Ok(StartHandshake { accepted, io }));
315 }
316 Ok(None) => {}
317 Err((err, alert)) => {
318 this.alert = Some((err, alert));
319 }
320 }
321 }
322 }
323}
324
325pub struct StartHandshake<IO> {
326 accepted: rustls::server::Accepted,
327 io: IO,
328}
329
330impl<IO> StartHandshake<IO>
331where
332 IO: AsyncRead + AsyncWrite + Unpin,
333{
334 pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
335 self.accepted.client_hello()
336 }
337
338 pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
339 self.into_stream_with(config, |_| ())
340 }
341
342 pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
343 where
344 F: FnOnce(&mut ServerConnection),
345 {
346 let mut conn = match self.accepted.into_connection(config) {
347 Ok(conn) => conn,
348 Err((error, alert)) => {
349 return Accept(MidHandshake::SendAlert {
350 io: self.io,
351 alert,
352 error: io::Error::new(io::ErrorKind::InvalidData, error),
355 });
356 }
357 };
358 f(&mut conn);
359
360 Accept(MidHandshake::Handshaking(server::TlsStream {
361 session: conn,
362 io: self.io,
363 state: TlsState::Stream,
364 }))
365 }
366}
367
368pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
371
372pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
375
376pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
378
379pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
381
382impl<IO> Connect<IO> {
383 #[inline]
384 pub fn into_fallible(self) -> FallibleConnect<IO> {
385 FallibleConnect(self.0)
386 }
387
388 pub fn get_ref(&self) -> Option<&IO> {
389 match &self.0 {
390 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
391 MidHandshake::SendAlert { io, .. } => Some(io),
392 MidHandshake::Error { io, .. } => Some(io),
393 MidHandshake::End => None,
394 }
395 }
396
397 pub fn get_mut(&mut self) -> Option<&mut IO> {
398 match &mut self.0 {
399 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
400 MidHandshake::SendAlert { io, .. } => Some(io),
401 MidHandshake::Error { io, .. } => Some(io),
402 MidHandshake::End => None,
403 }
404 }
405}
406
407impl<IO> Accept<IO> {
408 #[inline]
409 pub fn into_fallible(self) -> FallibleAccept<IO> {
410 FallibleAccept(self.0)
411 }
412
413 pub fn get_ref(&self) -> Option<&IO> {
414 match &self.0 {
415 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
416 MidHandshake::SendAlert { io, .. } => Some(io),
417 MidHandshake::Error { io, .. } => Some(io),
418 MidHandshake::End => None,
419 }
420 }
421
422 pub fn get_mut(&mut self) -> Option<&mut IO> {
423 match &mut self.0 {
424 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
425 MidHandshake::SendAlert { io, .. } => Some(io),
426 MidHandshake::Error { io, .. } => Some(io),
427 MidHandshake::End => None,
428 }
429 }
430}
431
432impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
433 type Output = io::Result<client::TlsStream<IO>>;
434
435 #[inline]
436 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
437 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
438 }
439}
440
441impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
442 type Output = io::Result<server::TlsStream<IO>>;
443
444 #[inline]
445 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
446 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
447 }
448}
449
450impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
451 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
452
453 #[inline]
454 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
455 Pin::new(&mut self.0).poll(cx)
456 }
457}
458
459impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
460 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
461
462 #[inline]
463 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
464 Pin::new(&mut self.0).poll(cx)
465 }
466}
467
468#[allow(clippy::large_enum_variant)] #[derive(Debug)]
474pub enum TlsStream<T> {
475 Client(client::TlsStream<T>),
476 Server(server::TlsStream<T>),
477}
478
479impl<T> TlsStream<T> {
480 pub fn get_ref(&self) -> (&T, &CommonState) {
481 use TlsStream::*;
482 match self {
483 Client(io) => {
484 let (io, session) = io.get_ref();
485 (io, session)
486 }
487 Server(io) => {
488 let (io, session) = io.get_ref();
489 (io, session)
490 }
491 }
492 }
493
494 pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
495 use TlsStream::*;
496 match self {
497 Client(io) => {
498 let (io, session) = io.get_mut();
499 (io, &mut *session)
500 }
501 Server(io) => {
502 let (io, session) = io.get_mut();
503 (io, &mut *session)
504 }
505 }
506 }
507}
508
509impl<T> From<client::TlsStream<T>> for TlsStream<T> {
510 fn from(s: client::TlsStream<T>) -> Self {
511 Self::Client(s)
512 }
513}
514
515impl<T> From<server::TlsStream<T>> for TlsStream<T> {
516 fn from(s: server::TlsStream<T>) -> Self {
517 Self::Server(s)
518 }
519}
520
521#[cfg(unix)]
522impl<S> AsRawFd for TlsStream<S>
523where
524 S: AsRawFd,
525{
526 fn as_raw_fd(&self) -> RawFd {
527 self.get_ref().0.as_raw_fd()
528 }
529}
530
531#[cfg(windows)]
532impl<S> AsRawSocket for TlsStream<S>
533where
534 S: AsRawSocket,
535{
536 fn as_raw_socket(&self) -> RawSocket {
537 self.get_ref().0.as_raw_socket()
538 }
539}
540
541impl<T> AsyncRead for TlsStream<T>
542where
543 T: AsyncRead + AsyncWrite + Unpin,
544{
545 #[inline]
546 fn poll_read(
547 self: Pin<&mut Self>,
548 cx: &mut Context<'_>,
549 buf: &mut ReadBuf<'_>,
550 ) -> Poll<io::Result<()>> {
551 match self.get_mut() {
552 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
553 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
554 }
555 }
556}
557
558impl<T> AsyncBufRead for TlsStream<T>
559where
560 T: AsyncRead + AsyncWrite + Unpin,
561{
562 #[inline]
563 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
564 match self.get_mut() {
565 TlsStream::Client(x) => Pin::new(x).poll_fill_buf(cx),
566 TlsStream::Server(x) => Pin::new(x).poll_fill_buf(cx),
567 }
568 }
569
570 #[inline]
571 fn consume(self: Pin<&mut Self>, amt: usize) {
572 match self.get_mut() {
573 TlsStream::Client(x) => Pin::new(x).consume(amt),
574 TlsStream::Server(x) => Pin::new(x).consume(amt),
575 }
576 }
577}
578
579impl<T> AsyncWrite for TlsStream<T>
580where
581 T: AsyncRead + AsyncWrite + Unpin,
582{
583 #[inline]
584 fn poll_write(
585 self: Pin<&mut Self>,
586 cx: &mut Context<'_>,
587 buf: &[u8],
588 ) -> Poll<io::Result<usize>> {
589 match self.get_mut() {
590 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
591 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
592 }
593 }
594
595 #[inline]
596 fn poll_write_vectored(
597 self: Pin<&mut Self>,
598 cx: &mut Context<'_>,
599 bufs: &[io::IoSlice<'_>],
600 ) -> Poll<io::Result<usize>> {
601 match self.get_mut() {
602 TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
603 TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
604 }
605 }
606
607 #[inline]
608 fn is_write_vectored(&self) -> bool {
609 match self {
610 TlsStream::Client(x) => x.is_write_vectored(),
611 TlsStream::Server(x) => x.is_write_vectored(),
612 }
613 }
614
615 #[inline]
616 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
617 match self.get_mut() {
618 TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
619 TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
620 }
621 }
622
623 #[inline]
624 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
625 match self.get_mut() {
626 TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
627 TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
628 }
629 }
630}