lunatic_distributed/distributed/
client.rs1use anyhow::Result;
2use async_cell::sync::AsyncCell;
3use bytes::Bytes;
4use dashmap::DashMap;
5use lunatic_control::NodeInfo;
6use std::sync::{atomic, atomic::AtomicU64, Arc};
7use tokio::sync::mpsc::{self, unbounded_channel, UnboundedReceiver, UnboundedSender};
8
9use crate::{
10 control,
11 distributed::message::{ClientError, Request, Response},
12 quic::{self, RecvStream},
13};
14
15use super::message::Spawn;
16
17struct SendRequest {
18 msg_id: u64,
19 node_id: u64,
20 request: Request,
21}
22#[derive(Clone)]
23pub struct Client {
24 inner: Arc<InnerClient>,
25}
26
27pub struct InnerClient {
28 next_message_id: AtomicU64,
29 node_message_buffers: DashMap<u64, UnboundedSender<(u64, Request)>>,
30 pending_requests: DashMap<u64, Arc<AsyncCell<Response>>>,
31 control_client: control::Client,
32 quic_client: quic::Client,
33 tx: UnboundedSender<SendRequest>,
34}
35
36impl Client {
37 pub async fn new(
39 _node_id: u64,
40 control_client: control::Client,
41 quic_client: quic::Client,
42 ) -> Result<Client> {
43 let (tx, rx) = mpsc::unbounded_channel();
44 let client = Client {
45 inner: Arc::new(InnerClient {
46 next_message_id: AtomicU64::new(1),
47 node_message_buffers: DashMap::new(),
48 pending_requests: DashMap::new(),
49 control_client,
50 quic_client,
51 tx,
52 }),
53 };
54 tokio::spawn(forward_node_messages(client.clone(), rx));
55 Ok(client)
56 }
57
58 pub fn next_message_id(&self) -> u64 {
59 self.inner
60 .next_message_id
61 .fetch_add(1, atomic::Ordering::Relaxed)
62 }
63
64 async fn request(&self, node_id: u64, request: Request) -> Result<Response, ClientError> {
65 let msg_id = self.next_message_id();
66 self.inner
67 .tx
68 .send(SendRequest {
69 msg_id,
70 node_id,
71 request,
72 })
73 .map_err(|e| ClientError::Unexpected(e.to_string()))?;
74 let cell = AsyncCell::shared();
75 self.inner.pending_requests.insert(msg_id, cell.clone());
76 let response = cell.take().await;
77 self.inner.pending_requests.remove(&msg_id);
78 Ok(response)
79 }
80
81 pub async fn message_process(
82 &self,
83 node_id: u64,
84 environment_id: u64,
85 process_id: u64,
86 tag: Option<i64>,
87 data: Vec<u8>,
88 ) -> Result<(), ClientError> {
89 match self
90 .request(
91 node_id,
92 Request::Message {
93 environment_id,
94 process_id,
95 tag,
96 data,
97 },
98 )
99 .await
100 {
101 Ok(Response::Sent) => Ok(()),
102 Ok(Response::Error(error)) | Err(error) => Err(error),
103 Ok(_) => Err(ClientError::Unexpected(
104 "Invalid response type for send".to_string(),
105 )),
106 }
107 }
108
109 fn process_response(&self, id: u64, resp: Response) {
110 if let Some(e) = self.inner.pending_requests.get(&id) {
111 e.set(resp);
112 };
113 }
114
115 pub async fn spawn(&self, node_id: u64, spawn: Spawn) -> Result<u64, ClientError> {
116 match self.request(node_id, Request::Spawn(spawn)).await {
117 Ok(Response::Spawned(id)) => Ok(id),
118 Ok(Response::Error(error)) | Err(error) => Err(error),
119 Ok(_) => Err(ClientError::Unexpected(
120 "Invalid response type for spawn".to_string(),
121 )),
122 }
123 }
124}
125
126async fn reader_task(client: Client, mut recv: RecvStream) -> Result<()> {
127 loop {
128 match recv.receive().await {
129 Ok(bytes) => {
130 let (msg_id, response) =
131 rmp_serde::from_slice::<(u64, super::message::Response)>(&bytes)?;
132 client.process_response(msg_id, response);
133 Ok(())
134 }
135 Err(e) => {
136 log::debug!("Node connection error: {e}");
137 Err(e)
138 }
139 }?;
140 }
141}
142
143async fn forward_node_messages(client: Client, mut rx: UnboundedReceiver<SendRequest>) {
144 while let Some(SendRequest {
145 msg_id,
146 node_id,
147 request,
148 }) = rx.recv().await
149 {
150 if let Some(node_buf) = client.inner.node_message_buffers.get(&node_id) {
151 node_buf.value().send((msg_id, request)).ok();
152 } else {
153 let (send, recv) = unbounded_channel();
154 send.send((msg_id, request)).ok();
155 client.inner.node_message_buffers.insert(node_id, send);
156 tokio::spawn(manage_node_connection(node_id, client.clone(), recv));
157 }
158 }
159}
160
161async fn try_node_info_forever(node_id: u64, client: &Client) -> NodeInfo {
162 loop {
163 let node_info = client.inner.control_client.node_info(node_id);
164 if node_info.is_none() {
165 client.inner.control_client.refresh_nodes().await.ok();
166 } else {
167 return node_info.unwrap();
168 }
169 }
170}
171
172async fn manage_node_connection(
173 node_id: u64,
174 client: Client,
175 mut rx: UnboundedReceiver<(u64, Request)>,
176) {
177 let quic_client = client.inner.quic_client.clone();
178 let NodeInfo { address, name, .. } = try_node_info_forever(node_id, &client).await;
179 let (mut send, recv) = quic::try_connect_forever(&quic_client, address, &name).await;
180 tokio::spawn(reader_task(client.clone(), recv));
181 while let Some(msg) = rx.recv().await {
182 if let Ok(data) = rmp_serde::to_vec(&msg) {
183 let size = (data.len() as u32).to_le_bytes();
184 let size: Bytes = Bytes::copy_from_slice(&size[..]);
185 let bytes: Bytes = data.into();
186 while let Err(e) = send.send(&mut [size.clone(), bytes.clone()]).await {
187 log::debug!("Cannot send data to node: {e}, reconnecting...");
188 let (new_send, new_recv) =
189 quic::try_connect_forever(&quic_client, address, &name).await;
190 tokio::spawn(reader_task(client.clone(), new_recv));
191 send = new_send;
192 }
193 }
194 }
195}