lunatic_distributed/quic/
quin.rs

1use std::{net::SocketAddr, sync::Arc, time::Duration};
2
3use anyhow::{anyhow, Result};
4use bytes::Bytes;
5use lunatic_process::{env::Environment, state::ProcessState};
6use quinn::{ClientConfig, Connecting, ConnectionError, Endpoint, ServerConfig};
7use rustls::server::AllowAnyAuthenticatedClient;
8use rustls_pemfile::Item;
9use wasmtime::ResourceLimiter;
10
11use crate::{distributed, DistributedCtx};
12
13pub struct SendStream {
14    pub stream: quinn::SendStream,
15}
16
17impl SendStream {
18    pub async fn send(&mut self, data: &mut [Bytes]) -> Result<()> {
19        self.stream.write_all_chunks(data).await?;
20        Ok(())
21    }
22}
23
24pub struct RecvStream {
25    pub stream: quinn::RecvStream,
26}
27
28impl RecvStream {
29    pub async fn receive(&mut self) -> Result<Bytes> {
30        let mut size = [0u8; 4];
31        self.stream.read_exact(&mut size).await?;
32        let size = u32::from_le_bytes(size);
33        let mut buffer = vec![0u8; size as usize];
34        self.stream.read_exact(&mut buffer).await?;
35        Ok(buffer.into())
36    }
37
38    pub fn id(&self) -> quinn::StreamId {
39        self.stream.id()
40    }
41}
42
43#[derive(Clone)]
44pub struct Client {
45    inner: Endpoint,
46}
47
48impl Client {
49    pub async fn connect(
50        &self,
51        addr: SocketAddr,
52        name: &str,
53        retry: u32,
54    ) -> Result<(SendStream, RecvStream)> {
55        for try_num in 1..(retry + 1) {
56            match self.connect_once(addr, name).await {
57                Ok(r) => return Ok(r),
58                Err(e) => {
59                    log::error!("Error connecting to {name} at {addr}, try {try_num}. Error: {e}")
60                }
61            }
62            tokio::time::sleep(Duration::from_secs(2)).await;
63        }
64        Err(anyhow!("Failed to connect to {name} at {addr}"))
65    }
66
67    async fn connect_once(&self, addr: SocketAddr, name: &str) -> Result<(SendStream, RecvStream)> {
68        let conn = self.inner.connect(addr, name)?.await?;
69        let (send, recv) = conn.open_bi().await?;
70        Ok((SendStream { stream: send }, RecvStream { stream: recv }))
71    }
72}
73
74pub fn new_quic_client(ca_cert: &str, cert: &str, key: &str) -> Result<Client> {
75    let mut ca_cert = ca_cert.as_bytes();
76    let ca_cert = rustls_pemfile::read_one(&mut ca_cert)?.unwrap();
77    let ca_cert = match ca_cert {
78        Item::X509Certificate(ca_cert) => Ok(rustls::Certificate(ca_cert)),
79        _ => Err(anyhow!("Not a valid certificate.")),
80    }?;
81    let mut roots = rustls::RootCertStore::empty();
82    roots.add(&ca_cert)?;
83
84    let mut cert = cert.as_bytes();
85    let mut key = key.as_bytes();
86    let pk = rustls_pemfile::read_one(&mut key)?.unwrap();
87    let pk = match pk {
88        Item::PKCS8Key(key) => Ok(rustls::PrivateKey(key)),
89        _ => Err(anyhow!("Not a valid private key.")),
90    }?;
91    let cert = rustls_pemfile::read_one(&mut cert)?.unwrap();
92    let cert = match cert {
93        Item::X509Certificate(cert) => Ok(rustls::Certificate(cert)),
94        _ => Err(anyhow!("Not a valid certificate")),
95    }?;
96    let cert = vec![cert];
97
98    let client_crypto = rustls::ClientConfig::builder()
99        .with_safe_defaults()
100        .with_root_certificates(roots)
101        .with_single_cert(cert, pk)?;
102
103    let client_config = ClientConfig::new(Arc::new(client_crypto));
104    let mut endpoint = Endpoint::client("[::]:0".parse().unwrap())?;
105    endpoint.set_default_client_config(client_config);
106    Ok(Client { inner: endpoint })
107}
108
109pub fn new_quic_server(
110    addr: SocketAddr,
111    certs: Vec<String>,
112    key: &str,
113    ca_cert: &str,
114) -> Result<Endpoint> {
115    let mut ca_cert = ca_cert.as_bytes();
116    let ca_cert = rustls_pemfile::read_one(&mut ca_cert)?.unwrap();
117    let ca_cert = match ca_cert {
118        Item::X509Certificate(ca_cert) => Ok(rustls::Certificate(ca_cert)),
119        _ => Err(anyhow!("Not a valid certificate.")),
120    }?;
121    let mut roots = rustls::RootCertStore::empty();
122    roots.add(&ca_cert)?;
123
124    let mut key = key.as_bytes();
125    let pk = rustls_pemfile::read_one(&mut key)?.unwrap();
126    let pk = match pk {
127        Item::PKCS8Key(key) => Ok(rustls::PrivateKey(key)),
128        _ => Err(anyhow!("Not a valid private key.")),
129    }?;
130
131    let mut cert_chain = Vec::new();
132    for cert in certs {
133        let mut cert = cert.as_bytes();
134        let cert = rustls_pemfile::read_one(&mut cert)?.unwrap();
135        let cert = match cert {
136            Item::X509Certificate(cert) => Ok(rustls::Certificate(cert)),
137            _ => Err(anyhow!("Not a valid certificate")),
138        }?;
139        cert_chain.push(cert);
140    }
141
142    let server_crypto = rustls::ServerConfig::builder()
143        .with_safe_defaults()
144        .with_client_cert_verifier(AllowAnyAuthenticatedClient::new(roots))
145        .with_single_cert(cert_chain, pk)?;
146    let mut server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
147    Arc::get_mut(&mut server_config.transport)
148        .unwrap()
149        .max_concurrent_uni_streams(0_u8.into());
150
151    Ok(quinn::Endpoint::server(server_config, addr)?)
152}
153
154pub async fn handle_node_server<T, E>(
155    quic_server: &mut Endpoint,
156    ctx: distributed::server::ServerCtx<T, E>,
157) -> Result<()>
158where
159    T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
160    E: Environment + 'static,
161{
162    while let Some(conn) = quic_server.accept().await {
163        tokio::spawn(handle_quic_connection_node(ctx.clone(), conn));
164    }
165    Err(anyhow!("Node server exited"))
166}
167
168async fn handle_quic_connection_node<T, E>(
169    ctx: distributed::server::ServerCtx<T, E>,
170    conn: Connecting,
171) -> Result<()>
172where
173    T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
174    E: Environment + 'static,
175{
176    log::info!("New node connection");
177    let conn = conn.await?;
178    log::info!("Remote {} connected", conn.remote_address());
179    loop {
180        if let Some(reason) = conn.close_reason() {
181            log::info!("Connection {} is closed: {reason}", conn.remote_address());
182            break;
183        }
184        let stream = conn.accept_bi().await;
185        log::info!("Stream from remote {} accepted", conn.remote_address());
186        match stream {
187            Ok((s, r)) => {
188                let send = SendStream { stream: s };
189                let recv = RecvStream { stream: r };
190                tokio::spawn(handle_quic_stream_node(ctx.clone(), send, recv));
191            }
192            Err(ConnectionError::LocallyClosed) => break,
193            Err(_) => {}
194        }
195    }
196    log::info!("Connection from remote {} closed", conn.remote_address());
197    Ok(())
198}
199
200async fn handle_quic_stream_node<T, E>(
201    ctx: distributed::server::ServerCtx<T, E>,
202    mut send: SendStream,
203    mut recv: RecvStream,
204) where
205    T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
206    E: Environment + 'static,
207{
208    while let Ok(bytes) = recv.receive().await {
209        if let Ok((msg_id, request)) =
210            rmp_serde::from_slice::<(u64, distributed::message::Request)>(&bytes)
211        {
212            distributed::server::handle_message(ctx.clone(), &mut send, msg_id, request).await;
213        } else {
214            log::debug!("Error deserializing request");
215        }
216    }
217}