1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright 2015-2022 Benjamin Fry <benjaminfry@me.com>
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// https://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

//! HTTP/3 related server items

use std::{io, net::SocketAddr, sync::Arc};

use bytes::Bytes;
use h3::server::{Connection, RequestStream};
use h3_quinn::{BidiStream, Endpoint};
use http::Request;
use quinn::crypto::rustls::QuicServerConfig;
use quinn::{EndpointConfig, ServerConfig};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::ServerConfig as TlsServerConfig;
use rustls::version::TLS13;

use crate::{error::ProtoError, udp::UdpSocket};

use super::ALPN_H3;

/// A DNS-over-HTTP/3 Server, see H3ClientStream for the client counterpart
pub struct H3Server {
    endpoint: Endpoint,
}

impl H3Server {
    /// Construct the new Acceptor with the associated pkcs12 data
    pub async fn new(
        name_server: SocketAddr,
        cert: Vec<CertificateDer<'static>>,
        key: PrivateKeyDer<'static>,
    ) -> Result<Self, ProtoError> {
        // setup a new socket for the server to use
        let socket = <tokio::net::UdpSocket as UdpSocket>::bind(name_server).await?;
        Self::with_socket(socket, cert, key)
    }

    /// Construct the new server with an existing socket
    pub fn with_socket(
        socket: tokio::net::UdpSocket,
        cert: Vec<CertificateDer<'static>>,
        key: PrivateKeyDer<'static>,
    ) -> Result<Self, ProtoError> {
        let mut config = TlsServerConfig::builder_with_provider(Arc::new(
            rustls::crypto::ring::default_provider(),
        ))
        .with_protocol_versions(&[&TLS13])
        .expect("TLS1.3 not supported")
        .with_no_client_auth()
        .with_single_cert(cert, key)?;

        config.alpn_protocols = vec![ALPN_H3.to_vec()];

        let mut server_config =
            ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(config).unwrap()));
        server_config.transport = Arc::new(super::transport());

        let socket = socket.into_std()?;

        let endpoint = Endpoint::new(
            EndpointConfig::default(),
            Some(server_config),
            socket,
            Arc::new(quinn::TokioRuntime),
        )?;

        Ok(Self { endpoint })
    }

    /// Accept the next incoming connection.
    ///
    /// # Returns
    ///
    /// A remote connection that could accept many potential requests and the remote socket address
    pub async fn accept(&mut self) -> Result<Option<(H3Connection, SocketAddr)>, ProtoError> {
        let connecting = match self.endpoint.accept().await {
            Some(conn) => conn,
            None => return Ok(None),
        };

        let remote_addr = connecting.remote_address();
        let connection = connecting.await?;
        Ok(Some((
            H3Connection {
                connection: Connection::new(h3_quinn::Connection::new(connection))
                    .await
                    .map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?,
            },
            remote_addr,
        )))
    }

    /// Returns the address this server is listening on
    ///
    /// This can be useful in tests, where a random port can be associated with the server by binding on `127.0.0.1:0` and then getting the
    ///   associated port address with this function.
    pub fn local_addr(&self) -> Result<SocketAddr, io::Error> {
        self.endpoint.local_addr()
    }
}

/// A HTTP/3 connection.
pub struct H3Connection {
    connection: Connection<h3_quinn::Connection, Bytes>,
}

impl H3Connection {
    /// Accept the next request from the client
    pub async fn accept(
        &mut self,
    ) -> Option<Result<(Request<()>, RequestStream<BidiStream<Bytes>, Bytes>), ProtoError>> {
        match self.connection.accept().await {
            Ok(Some((request, stream))) => Some(Ok((request, stream))),
            Ok(None) => None,
            Err(e) => Some(Err(ProtoError::from(format!("h3 request failed: {e}")))),
        }
    }

    /// Shutdown the connection.
    pub async fn shutdown(&mut self) -> Result<(), ProtoError> {
        self.connection
            .shutdown(0)
            .await
            .map_err(|e| ProtoError::from(format!("h3 connection shutdown failed: {e}")))
    }
}