lunatic_distributed/quic/
quin.rs1use std::{net::SocketAddr, sync::Arc, time::Duration};
2
3use anyhow::{anyhow, Result};
4use bytes::Bytes;
5use lunatic_process::{env::Environment, state::ProcessState};
6use quinn::{ClientConfig, Connecting, ConnectionError, Endpoint, ServerConfig};
7use rustls::server::AllowAnyAuthenticatedClient;
8use rustls_pemfile::Item;
9use wasmtime::ResourceLimiter;
10
11use crate::{distributed, DistributedCtx};
12
13pub struct SendStream {
14 pub stream: quinn::SendStream,
15}
16
17impl SendStream {
18 pub async fn send(&mut self, data: &mut [Bytes]) -> Result<()> {
19 self.stream.write_all_chunks(data).await?;
20 Ok(())
21 }
22}
23
24pub struct RecvStream {
25 pub stream: quinn::RecvStream,
26}
27
28impl RecvStream {
29 pub async fn receive(&mut self) -> Result<Bytes> {
30 let mut size = [0u8; 4];
31 self.stream.read_exact(&mut size).await?;
32 let size = u32::from_le_bytes(size);
33 let mut buffer = vec![0u8; size as usize];
34 self.stream.read_exact(&mut buffer).await?;
35 Ok(buffer.into())
36 }
37
38 pub fn id(&self) -> quinn::StreamId {
39 self.stream.id()
40 }
41}
42
43#[derive(Clone)]
44pub struct Client {
45 inner: Endpoint,
46}
47
48impl Client {
49 pub async fn connect(
50 &self,
51 addr: SocketAddr,
52 name: &str,
53 retry: u32,
54 ) -> Result<(SendStream, RecvStream)> {
55 for try_num in 1..(retry + 1) {
56 match self.connect_once(addr, name).await {
57 Ok(r) => return Ok(r),
58 Err(e) => {
59 log::error!("Error connecting to {name} at {addr}, try {try_num}. Error: {e}")
60 }
61 }
62 tokio::time::sleep(Duration::from_secs(2)).await;
63 }
64 Err(anyhow!("Failed to connect to {name} at {addr}"))
65 }
66
67 async fn connect_once(&self, addr: SocketAddr, name: &str) -> Result<(SendStream, RecvStream)> {
68 let conn = self.inner.connect(addr, name)?.await?;
69 let (send, recv) = conn.open_bi().await?;
70 Ok((SendStream { stream: send }, RecvStream { stream: recv }))
71 }
72}
73
74pub fn new_quic_client(ca_cert: &str, cert: &str, key: &str) -> Result<Client> {
75 let mut ca_cert = ca_cert.as_bytes();
76 let ca_cert = rustls_pemfile::read_one(&mut ca_cert)?.unwrap();
77 let ca_cert = match ca_cert {
78 Item::X509Certificate(ca_cert) => Ok(rustls::Certificate(ca_cert)),
79 _ => Err(anyhow!("Not a valid certificate.")),
80 }?;
81 let mut roots = rustls::RootCertStore::empty();
82 roots.add(&ca_cert)?;
83
84 let mut cert = cert.as_bytes();
85 let mut key = key.as_bytes();
86 let pk = rustls_pemfile::read_one(&mut key)?.unwrap();
87 let pk = match pk {
88 Item::PKCS8Key(key) => Ok(rustls::PrivateKey(key)),
89 _ => Err(anyhow!("Not a valid private key.")),
90 }?;
91 let cert = rustls_pemfile::read_one(&mut cert)?.unwrap();
92 let cert = match cert {
93 Item::X509Certificate(cert) => Ok(rustls::Certificate(cert)),
94 _ => Err(anyhow!("Not a valid certificate")),
95 }?;
96 let cert = vec![cert];
97
98 let client_crypto = rustls::ClientConfig::builder()
99 .with_safe_defaults()
100 .with_root_certificates(roots)
101 .with_single_cert(cert, pk)?;
102
103 let client_config = ClientConfig::new(Arc::new(client_crypto));
104 let mut endpoint = Endpoint::client("[::]:0".parse().unwrap())?;
105 endpoint.set_default_client_config(client_config);
106 Ok(Client { inner: endpoint })
107}
108
109pub fn new_quic_server(
110 addr: SocketAddr,
111 certs: Vec<String>,
112 key: &str,
113 ca_cert: &str,
114) -> Result<Endpoint> {
115 let mut ca_cert = ca_cert.as_bytes();
116 let ca_cert = rustls_pemfile::read_one(&mut ca_cert)?.unwrap();
117 let ca_cert = match ca_cert {
118 Item::X509Certificate(ca_cert) => Ok(rustls::Certificate(ca_cert)),
119 _ => Err(anyhow!("Not a valid certificate.")),
120 }?;
121 let mut roots = rustls::RootCertStore::empty();
122 roots.add(&ca_cert)?;
123
124 let mut key = key.as_bytes();
125 let pk = rustls_pemfile::read_one(&mut key)?.unwrap();
126 let pk = match pk {
127 Item::PKCS8Key(key) => Ok(rustls::PrivateKey(key)),
128 _ => Err(anyhow!("Not a valid private key.")),
129 }?;
130
131 let mut cert_chain = Vec::new();
132 for cert in certs {
133 let mut cert = cert.as_bytes();
134 let cert = rustls_pemfile::read_one(&mut cert)?.unwrap();
135 let cert = match cert {
136 Item::X509Certificate(cert) => Ok(rustls::Certificate(cert)),
137 _ => Err(anyhow!("Not a valid certificate")),
138 }?;
139 cert_chain.push(cert);
140 }
141
142 let server_crypto = rustls::ServerConfig::builder()
143 .with_safe_defaults()
144 .with_client_cert_verifier(AllowAnyAuthenticatedClient::new(roots))
145 .with_single_cert(cert_chain, pk)?;
146 let mut server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
147 Arc::get_mut(&mut server_config.transport)
148 .unwrap()
149 .max_concurrent_uni_streams(0_u8.into());
150
151 Ok(quinn::Endpoint::server(server_config, addr)?)
152}
153
154pub async fn handle_node_server<T, E>(
155 quic_server: &mut Endpoint,
156 ctx: distributed::server::ServerCtx<T, E>,
157) -> Result<()>
158where
159 T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
160 E: Environment + 'static,
161{
162 while let Some(conn) = quic_server.accept().await {
163 tokio::spawn(handle_quic_connection_node(ctx.clone(), conn));
164 }
165 Err(anyhow!("Node server exited"))
166}
167
168async fn handle_quic_connection_node<T, E>(
169 ctx: distributed::server::ServerCtx<T, E>,
170 conn: Connecting,
171) -> Result<()>
172where
173 T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
174 E: Environment + 'static,
175{
176 log::info!("New node connection");
177 let conn = conn.await?;
178 log::info!("Remote {} connected", conn.remote_address());
179 loop {
180 if let Some(reason) = conn.close_reason() {
181 log::info!("Connection {} is closed: {reason}", conn.remote_address());
182 break;
183 }
184 let stream = conn.accept_bi().await;
185 log::info!("Stream from remote {} accepted", conn.remote_address());
186 match stream {
187 Ok((s, r)) => {
188 let send = SendStream { stream: s };
189 let recv = RecvStream { stream: r };
190 tokio::spawn(handle_quic_stream_node(ctx.clone(), send, recv));
191 }
192 Err(ConnectionError::LocallyClosed) => break,
193 Err(_) => {}
194 }
195 }
196 log::info!("Connection from remote {} closed", conn.remote_address());
197 Ok(())
198}
199
200async fn handle_quic_stream_node<T, E>(
201 ctx: distributed::server::ServerCtx<T, E>,
202 mut send: SendStream,
203 mut recv: RecvStream,
204) where
205 T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
206 E: Environment + 'static,
207{
208 while let Ok(bytes) = recv.receive().await {
209 if let Ok((msg_id, request)) =
210 rmp_serde::from_slice::<(u64, distributed::message::Request)>(&bytes)
211 {
212 distributed::server::handle_message(ctx.clone(), &mut send, msg_id, request).await;
213 } else {
214 log::debug!("Error deserializing request");
215 }
216 }
217}