1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
//! Run with `cargo run --all-features --example rustls_session` command.
//!
//! To connect through browser, navigate to "https://localhost:3000" url.

use axum::{middleware::AddExtension, routing::get, Extension, Router};
use axum_server::{
    accept::Accept,
    tls_rustls::{RustlsAcceptor, RustlsConfig},
};
use futures_util::future::BoxFuture;
use std::{io, net::SocketAddr, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::server::TlsStream;
use tower::Layer;

#[tokio::main]
async fn main() {
    let app = Router::new().route("/", get(handler));

    let config = RustlsConfig::from_pem_file(
        "examples/self-signed-certs/cert.pem",
        "examples/self-signed-certs/key.pem",
    )
    .await
    .unwrap();

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

    println!("listening on {}", addr);

    let acceptor = CustomAcceptor::new(RustlsAcceptor::new(config));
    let server = axum_server::bind(addr).acceptor(acceptor);

    server.serve(app.into_make_service()).await.unwrap();
}

async fn handler(tls_data: Extension<TlsData>) -> String {
    format!("{:?}", tls_data)
}

#[derive(Debug, Clone)]
struct TlsData {
    _hostname: Option<Arc<str>>,
}

#[derive(Debug, Clone)]
struct CustomAcceptor {
    inner: RustlsAcceptor,
}

impl CustomAcceptor {
    fn new(inner: RustlsAcceptor) -> Self {
        Self { inner }
    }
}

impl<I, S> Accept<I, S> for CustomAcceptor
where
    I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
    S: Send + 'static,
{
    type Stream = TlsStream<I>;
    type Service = AddExtension<S, TlsData>;
    type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;

    fn accept(&self, stream: I, service: S) -> Self::Future {
        let acceptor = self.inner.clone();

        Box::pin(async move {
            let (stream, service) = acceptor.accept(stream, service).await?;
            let server_conn = stream.get_ref().1;
            let sni_hostname = TlsData {
                _hostname: server_conn.server_name().map(From::from),
            };
            let service = Extension(sni_hostname).layer(service);

            Ok((stream, service))
        })
    }
}