lunatic_distributed/distributed/
server.rs

1use std::{net::SocketAddr, sync::Arc};
2
3use anyhow::{anyhow, Result};
4
5use lunatic_process::{
6    env::{Environment, Environments},
7    message::{DataMessage, Message},
8    runtimes::{wasmtime::WasmtimeRuntime, Modules, RawWasm},
9    state::ProcessState,
10    Signal,
11};
12use rcgen::*;
13use wasmtime::ResourceLimiter;
14
15use crate::{
16    distributed::message::{Request, Response},
17    quic::{self, SendStream},
18    DistributedCtx, DistributedProcessState,
19};
20
21use super::message::{ClientError, Spawn};
22
23pub struct ServerCtx<T, E: Environment> {
24    pub envs: Arc<dyn Environments<Env = E>>,
25    pub modules: Modules<T>,
26    pub distributed: DistributedProcessState,
27    pub runtime: WasmtimeRuntime,
28}
29
30impl<T: 'static, E: Environment> Clone for ServerCtx<T, E> {
31    fn clone(&self) -> Self {
32        Self {
33            envs: self.envs.clone(),
34            modules: self.modules.clone(),
35            distributed: self.distributed.clone(),
36            runtime: self.runtime.clone(),
37        }
38    }
39}
40
41pub fn test_root_cert() -> String {
42    crate::control::cert::TEST_ROOT_CERT.to_string()
43}
44
45pub fn root_cert(ca_cert: &str) -> Result<String> {
46    let cert = std::fs::read(ca_cert)?;
47    Ok(std::str::from_utf8(&cert)?.to_string())
48}
49
50pub fn gen_node_cert(node_name: &str) -> Result<Certificate> {
51    let mut params = CertificateParams::new(vec![node_name.to_string()]);
52    params
53        .distinguished_name
54        .push(DnType::OrganizationName, "Lunatic Inc.");
55    params.distinguished_name.push(DnType::CommonName, "Node");
56    Certificate::from_params(params)
57        .map_err(|_| anyhow!("Error while generating node certificate."))
58}
59
60pub async fn node_server<T, E>(
61    ctx: ServerCtx<T, E>,
62    socket: SocketAddr,
63    ca_cert: String,
64    certs: Vec<String>,
65    key: String,
66) -> Result<()>
67where
68    T: ProcessState + ResourceLimiter + DistributedCtx<E> + Send + Sync + 'static,
69    E: Environment + 'static,
70{
71    let mut quic_server = quic::new_quic_server(socket, certs, &key, &ca_cert)?;
72    if let Err(e) = quic::handle_node_server(&mut quic_server, ctx.clone()).await {
73        log::error!("Node server stopped {e}")
74    };
75    Ok(())
76}
77
78pub async fn handle_message<T, E>(
79    ctx: ServerCtx<T, E>,
80    send: &mut SendStream,
81    msg_id: u64,
82    msg: Request,
83) where
84    T: ProcessState + DistributedCtx<E> + ResourceLimiter + Send + Sync + 'static,
85    E: Environment + 'static,
86{
87    if let Err(e) = handle_message_err(ctx, send, msg_id, msg).await {
88        log::error!("Error handling message: {e}");
89    }
90}
91
92async fn handle_message_err<T, E>(
93    ctx: ServerCtx<T, E>,
94    send: &mut SendStream,
95    msg_id: u64,
96    msg: Request,
97) -> Result<()>
98where
99    T: ProcessState + DistributedCtx<E> + ResourceLimiter + Send + Sync + 'static,
100    E: Environment + 'static,
101{
102    match msg {
103        Request::Spawn(spawn) => {
104            match handle_spawn(ctx, spawn).await {
105                Ok(Ok(id)) => {
106                    let mut data = super::message::pack_response(msg_id, Response::Spawned(id));
107                    send.send(&mut data).await?;
108                }
109                Ok(Err(client_error)) => {
110                    let mut data =
111                        super::message::pack_response(msg_id, Response::Error(client_error));
112                    send.send(&mut data).await?;
113                }
114                Err(error) => {
115                    let mut data = super::message::pack_response(
116                        msg_id,
117                        Response::Error(ClientError::Unexpected(error.to_string())),
118                    );
119                    send.send(&mut data).await?
120                }
121            };
122        }
123        Request::Message {
124            environment_id,
125            process_id,
126            tag,
127            data,
128        } => match handle_process_message(ctx, environment_id, process_id, tag, data).await {
129            Ok(_) => {
130                let mut data = super::message::pack_response(msg_id, Response::Sent);
131                send.send(&mut data).await?;
132            }
133            Err(error) => {
134                let mut data = super::message::pack_response(msg_id, Response::Error(error));
135                send.send(&mut data).await?;
136            }
137        },
138    };
139    Ok(())
140}
141
142async fn handle_spawn<T, E>(ctx: ServerCtx<T, E>, spawn: Spawn) -> Result<Result<u64, ClientError>>
143where
144    T: ProcessState + DistributedCtx<E> + ResourceLimiter + Send + Sync + 'static,
145    E: Environment + 'static,
146{
147    let Spawn {
148        environment_id,
149        module_id,
150        function,
151        params,
152        config,
153    } = spawn;
154
155    let config: T::Config = rmp_serde::from_slice(&config[..])?;
156    let config = Arc::new(config);
157
158    let module = match ctx.modules.get(module_id) {
159        Some(module) => module,
160        None => {
161            if let Ok(bytes) = ctx
162                .distributed
163                .control
164                .get_module(module_id, environment_id)
165                .await
166            {
167                let wasm = RawWasm::new(Some(module_id), bytes);
168                ctx.modules.compile(ctx.runtime.clone(), wasm).await??
169            } else {
170                return Ok(Err(ClientError::ModuleNotFound));
171            }
172        }
173    };
174
175    let env = ctx.envs.get(environment_id).await;
176
177    let env = match env {
178        Some(env) => env,
179        None => ctx.envs.create(environment_id).await,
180    };
181
182    env.can_spawn_next_process().await?;
183
184    let distributed = ctx.distributed.clone();
185    let runtime = ctx.runtime.clone();
186    let state = T::new_dist_state(env.clone(), distributed, runtime, module.clone(), config)?;
187    let params: Vec<wasmtime::Val> = params.into_iter().map(Into::into).collect();
188    let (_handle, proc) = lunatic_process::wasm::spawn_wasm(
189        env,
190        ctx.runtime,
191        &module,
192        state,
193        &function,
194        params,
195        None,
196    )
197    .await?;
198    Ok(Ok(proc.id()))
199}
200
201async fn handle_process_message<T, E>(
202    ctx: ServerCtx<T, E>,
203    environment_id: u64,
204    process_id: u64,
205    tag: Option<i64>,
206    data: Vec<u8>,
207) -> std::result::Result<(), ClientError>
208where
209    T: ProcessState + DistributedCtx<E> + ResourceLimiter + Send + 'static,
210    E: Environment,
211{
212    let env = ctx.envs.get(environment_id).await;
213    if let Some(env) = env {
214        if let Some(proc) = env.get_process(process_id) {
215            proc.send(Signal::Message(Message::Data(DataMessage::new_from_vec(
216                tag, data,
217            ))));
218        } else {
219            return Err(ClientError::ProcessNotFound);
220        }
221    }
222    Ok(())
223}