lunatic_distributed/distributed/
client.rs

1use anyhow::Result;
2use async_cell::sync::AsyncCell;
3use bytes::Bytes;
4use dashmap::DashMap;
5use lunatic_control::NodeInfo;
6use std::sync::{atomic, atomic::AtomicU64, Arc};
7use tokio::sync::mpsc::{self, unbounded_channel, UnboundedReceiver, UnboundedSender};
8
9use crate::{
10    control,
11    distributed::message::{ClientError, Request, Response},
12    quic::{self, RecvStream},
13};
14
15use super::message::Spawn;
16
17struct SendRequest {
18    msg_id: u64,
19    node_id: u64,
20    request: Request,
21}
22#[derive(Clone)]
23pub struct Client {
24    inner: Arc<InnerClient>,
25}
26
27pub struct InnerClient {
28    next_message_id: AtomicU64,
29    node_message_buffers: DashMap<u64, UnboundedSender<(u64, Request)>>,
30    pending_requests: DashMap<u64, Arc<AsyncCell<Response>>>,
31    control_client: control::Client,
32    quic_client: quic::Client,
33    tx: UnboundedSender<SendRequest>,
34}
35
36impl Client {
37    // TODO node_id?
38    pub async fn new(
39        _node_id: u64,
40        control_client: control::Client,
41        quic_client: quic::Client,
42    ) -> Result<Client> {
43        let (tx, rx) = mpsc::unbounded_channel();
44        let client = Client {
45            inner: Arc::new(InnerClient {
46                next_message_id: AtomicU64::new(1),
47                node_message_buffers: DashMap::new(),
48                pending_requests: DashMap::new(),
49                control_client,
50                quic_client,
51                tx,
52            }),
53        };
54        tokio::spawn(forward_node_messages(client.clone(), rx));
55        Ok(client)
56    }
57
58    pub fn next_message_id(&self) -> u64 {
59        self.inner
60            .next_message_id
61            .fetch_add(1, atomic::Ordering::Relaxed)
62    }
63
64    async fn request(&self, node_id: u64, request: Request) -> Result<Response, ClientError> {
65        let msg_id = self.next_message_id();
66        self.inner
67            .tx
68            .send(SendRequest {
69                msg_id,
70                node_id,
71                request,
72            })
73            .map_err(|e| ClientError::Unexpected(e.to_string()))?;
74        let cell = AsyncCell::shared();
75        self.inner.pending_requests.insert(msg_id, cell.clone());
76        let response = cell.take().await;
77        self.inner.pending_requests.remove(&msg_id);
78        Ok(response)
79    }
80
81    pub async fn message_process(
82        &self,
83        node_id: u64,
84        environment_id: u64,
85        process_id: u64,
86        tag: Option<i64>,
87        data: Vec<u8>,
88    ) -> Result<(), ClientError> {
89        match self
90            .request(
91                node_id,
92                Request::Message {
93                    environment_id,
94                    process_id,
95                    tag,
96                    data,
97                },
98            )
99            .await
100        {
101            Ok(Response::Sent) => Ok(()),
102            Ok(Response::Error(error)) | Err(error) => Err(error),
103            Ok(_) => Err(ClientError::Unexpected(
104                "Invalid response type for send".to_string(),
105            )),
106        }
107    }
108
109    fn process_response(&self, id: u64, resp: Response) {
110        if let Some(e) = self.inner.pending_requests.get(&id) {
111            e.set(resp);
112        };
113    }
114
115    pub async fn spawn(&self, node_id: u64, spawn: Spawn) -> Result<u64, ClientError> {
116        match self.request(node_id, Request::Spawn(spawn)).await {
117            Ok(Response::Spawned(id)) => Ok(id),
118            Ok(Response::Error(error)) | Err(error) => Err(error),
119            Ok(_) => Err(ClientError::Unexpected(
120                "Invalid response type for spawn".to_string(),
121            )),
122        }
123    }
124}
125
126async fn reader_task(client: Client, mut recv: RecvStream) -> Result<()> {
127    loop {
128        match recv.receive().await {
129            Ok(bytes) => {
130                let (msg_id, response) =
131                    rmp_serde::from_slice::<(u64, super::message::Response)>(&bytes)?;
132                client.process_response(msg_id, response);
133                Ok(())
134            }
135            Err(e) => {
136                log::debug!("Node connection error: {e}");
137                Err(e)
138            }
139        }?;
140    }
141}
142
143async fn forward_node_messages(client: Client, mut rx: UnboundedReceiver<SendRequest>) {
144    while let Some(SendRequest {
145        msg_id,
146        node_id,
147        request,
148    }) = rx.recv().await
149    {
150        if let Some(node_buf) = client.inner.node_message_buffers.get(&node_id) {
151            node_buf.value().send((msg_id, request)).ok();
152        } else {
153            let (send, recv) = unbounded_channel();
154            send.send((msg_id, request)).ok();
155            client.inner.node_message_buffers.insert(node_id, send);
156            tokio::spawn(manage_node_connection(node_id, client.clone(), recv));
157        }
158    }
159}
160
161async fn try_node_info_forever(node_id: u64, client: &Client) -> NodeInfo {
162    loop {
163        let node_info = client.inner.control_client.node_info(node_id);
164        if node_info.is_none() {
165            client.inner.control_client.refresh_nodes().await.ok();
166        } else {
167            return node_info.unwrap();
168        }
169    }
170}
171
172async fn manage_node_connection(
173    node_id: u64,
174    client: Client,
175    mut rx: UnboundedReceiver<(u64, Request)>,
176) {
177    let quic_client = client.inner.quic_client.clone();
178    let NodeInfo { address, name, .. } = try_node_info_forever(node_id, &client).await;
179    let (mut send, recv) = quic::try_connect_forever(&quic_client, address, &name).await;
180    tokio::spawn(reader_task(client.clone(), recv));
181    while let Some(msg) = rx.recv().await {
182        if let Ok(data) = rmp_serde::to_vec(&msg) {
183            let size = (data.len() as u32).to_le_bytes();
184            let size: Bytes = Bytes::copy_from_slice(&size[..]);
185            let bytes: Bytes = data.into();
186            while let Err(e) = send.send(&mut [size.clone(), bytes.clone()]).await {
187                log::debug!("Cannot send data to node: {e}, reconnecting...");
188                let (new_send, new_recv) =
189                    quic::try_connect_forever(&quic_client, address, &name).await;
190                tokio::spawn(reader_task(client.clone(), new_recv));
191                send = new_send;
192            }
193        }
194    }
195}