intuicio_nodes/
server.rs

1use crate::nodes::*;
2use intuicio_core::registry::Registry;
3use serde::{Deserialize, Serialize};
4use std::{
5    collections::{HashMap, HashSet},
6    error::Error,
7};
8use typid::ID;
9
10pub type NodeGraphId<T> = ID<NodeGraph<T>>;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RequestAdd<T: NodeDefinition> {
14    pub nodes: Vec<Node<T>>,
15    pub connections: Vec<NodeConnection<T>>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RequestRemove<T: NodeDefinition> {
20    pub nodes: Vec<NodeId<T>>,
21    pub connections: Vec<NodeConnection<T>>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct RequestUpdate<T: NodeDefinition> {
26    pub nodes: Vec<Node<T>>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct RequestQueryRegion {
31    pub fx: i64,
32    pub fy: i64,
33    pub tx: i64,
34    pub ty: i64,
35    pub extrude: i64,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ResponseQuery<T: NodeDefinition> {
40    pub nodes: Vec<Node<T>>,
41    pub connections: Vec<NodeConnection<T>>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub enum NodeGraphServerError {
46    NodeGraphDoesNotExists(String),
47    NodeNotFound { graph: String, node: String },
48    ValidationErrors { graph: String, errors: Vec<String> },
49}
50
51impl std::fmt::Display for NodeGraphServerError {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            NodeGraphServerError::NodeGraphDoesNotExists(id) => {
55                write!(f, "Node graph does not exists: {}", id)
56            }
57            NodeGraphServerError::NodeNotFound { graph, node } => {
58                write!(f, "Node graph: {} does not have node: {}", graph, node)
59            }
60            NodeGraphServerError::ValidationErrors { graph, errors } => {
61                write!(f, "Node graph: {} validation errors:", graph)?;
62                for error in errors {
63                    write!(f, "{}", error)?;
64                }
65                Ok(())
66            }
67        }
68    }
69}
70
71impl Error for NodeGraphServerError {}
72
73pub struct NodeGraphServer<T: NodeDefinition + Clone> {
74    graphs: HashMap<NodeGraphId<T>, NodeGraph<T>>,
75}
76
77impl<T: NodeDefinition + Clone> Default for NodeGraphServer<T> {
78    fn default() -> Self {
79        Self {
80            graphs: Default::default(),
81        }
82    }
83}
84
85impl<T: NodeDefinition + Clone> NodeGraphServer<T> {
86    pub fn graph(&self, id: NodeGraphId<T>) -> Result<&NodeGraph<T>, NodeGraphServerError> {
87        self.graphs
88            .get(&id)
89            .ok_or_else(|| NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
90    }
91
92    pub fn graph_mut(
93        &mut self,
94        id: NodeGraphId<T>,
95    ) -> Result<&mut NodeGraph<T>, NodeGraphServerError> {
96        self.graphs
97            .get_mut(&id)
98            .ok_or_else(|| NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
99    }
100
101    pub fn create(&mut self) -> NodeGraphId<T> {
102        let id = NodeGraphId::new();
103        self.graphs.insert(id, NodeGraph::default());
104        id
105    }
106
107    pub fn destroy(&mut self, id: NodeGraphId<T>) -> Result<NodeGraph<T>, NodeGraphServerError> {
108        self.graphs
109            .remove(&id)
110            .ok_or_else(|| NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
111    }
112
113    pub fn list(&self) -> impl Iterator<Item = &NodeGraphId<T>> {
114        self.graphs.keys()
115    }
116
117    pub fn add(
118        &mut self,
119        id: NodeGraphId<T>,
120        request: RequestAdd<T>,
121        registry: &Registry,
122    ) -> Result<(), NodeGraphServerError> {
123        if let Some(graph) = self.graphs.get_mut(&id) {
124            for node in request.nodes {
125                graph.add_node(node, registry);
126            }
127            for connection in request.connections {
128                graph.connect_nodes(connection);
129            }
130            graph.refresh_spatial_cache();
131            Ok(())
132        } else {
133            Err(NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
134        }
135    }
136
137    pub fn remove(
138        &mut self,
139        id: NodeGraphId<T>,
140        request: RequestRemove<T>,
141        registry: &Registry,
142    ) -> Result<(), NodeGraphServerError> {
143        if let Some(graph) = self.graphs.get_mut(&id) {
144            for connection in request.connections {
145                graph.disconnect_nodes(
146                    connection.from_node,
147                    connection.to_node,
148                    &connection.from_pin,
149                    &connection.to_pin,
150                );
151            }
152            for id in request.nodes {
153                graph.remove_node(id, registry);
154            }
155            graph.refresh_spatial_cache();
156            Ok(())
157        } else {
158            Err(NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
159        }
160    }
161
162    pub fn update(
163        &mut self,
164        id: NodeGraphId<T>,
165        request: RequestUpdate<T>,
166    ) -> Result<(), NodeGraphServerError> {
167        if let Some(graph) = self.graphs.get_mut(&id) {
168            for source in &request.nodes {
169                if graph.node(source.id()).is_none() {
170                    return Err(NodeGraphServerError::NodeNotFound {
171                        graph: id.to_string(),
172                        node: source.id().to_string(),
173                    });
174                }
175            }
176            for source in request.nodes {
177                let id = source.id();
178                *graph.node_mut(id).unwrap() = source;
179            }
180            graph.refresh_spatial_cache();
181            Ok(())
182        } else {
183            Err(NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
184        }
185    }
186
187    pub fn clear(&mut self, id: NodeGraphId<T>) -> Result<(), NodeGraphServerError> {
188        if let Some(graph) = self.graphs.get_mut(&id) {
189            graph.clear();
190            graph.refresh_spatial_cache();
191            Ok(())
192        } else {
193            Err(NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
194        }
195    }
196
197    pub fn query_all(
198        &self,
199        graph: NodeGraphId<T>,
200    ) -> Result<ResponseQuery<T>, NodeGraphServerError> {
201        if let Some(graph) = self.graphs.get(&graph) {
202            Ok(ResponseQuery {
203                nodes: graph.nodes().cloned().collect(),
204                connections: graph.connections().cloned().collect(),
205            })
206        } else {
207            Err(NodeGraphServerError::NodeGraphDoesNotExists(
208                graph.to_string(),
209            ))
210        }
211    }
212
213    pub fn query_region(
214        &self,
215        graph: NodeGraphId<T>,
216        request: RequestQueryRegion,
217    ) -> Result<ResponseQuery<T>, NodeGraphServerError> {
218        if let Some(graph) = self.graphs.get(&graph) {
219            let RequestQueryRegion {
220                fx,
221                fy,
222                tx,
223                ty,
224                extrude,
225            } = request;
226            let nodes = graph
227                .query_region_nodes(fx, fy, tx, ty, extrude)
228                .filter_map(|id| graph.node(id))
229                .cloned()
230                .collect::<Vec<_>>();
231            let connections = nodes
232                .iter()
233                .flat_map(|node| graph.node_connections(node.id()))
234                .cloned()
235                .collect::<HashSet<_>>()
236                .into_iter()
237                .collect();
238            Ok(ResponseQuery { nodes, connections })
239        } else {
240            Err(NodeGraphServerError::NodeGraphDoesNotExists(
241                graph.to_string(),
242            ))
243        }
244    }
245
246    pub fn suggest_all_nodes(
247        x: i64,
248        y: i64,
249        registry: &Registry,
250    ) -> Vec<ResponseSuggestionNode<T>> {
251        NodeGraph::suggest_all_nodes(x, y, registry)
252    }
253
254    pub fn validate(
255        &self,
256        graph: NodeGraphId<T>,
257        registry: &Registry,
258    ) -> Result<(), NodeGraphServerError> {
259        if let Some(item) = self.graphs.get(&graph) {
260            match item.validate(registry) {
261                Ok(_) => Ok(()),
262                Err(errors) => Err(NodeGraphServerError::ValidationErrors {
263                    graph: graph.to_string(),
264                    errors: errors.into_iter().map(|error| error.to_string()).collect(),
265                }),
266            }
267        } else {
268            Err(NodeGraphServerError::NodeGraphDoesNotExists(
269                graph.to_string(),
270            ))
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use crate::prelude::*;
278    use intuicio_core::prelude::*;
279    use serde::{de::DeserializeOwned, Deserialize, Serialize};
280
281    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
282    struct TypeInfo;
283
284    impl std::fmt::Display for TypeInfo {
285        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286            write!(f, "")
287        }
288    }
289
290    impl NodeTypeInfo for TypeInfo {
291        fn type_query(&self) -> TypeQuery {
292            Default::default()
293        }
294
295        fn are_compatible(&self, _: &Self) -> bool {
296            true
297        }
298    }
299
300    #[derive(Debug, Clone, Serialize, Deserialize)]
301    enum Nodes {
302        Start,
303        Expression(i32),
304        Result,
305        Convert(String),
306    }
307
308    impl NodeDefinition for Nodes {
309        type TypeInfo = TypeInfo;
310
311        fn node_label(&self, _: &Registry) -> String {
312            format!("{:?}", self)
313        }
314
315        fn node_pins_in(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
316            match self {
317                Nodes::Start => vec![],
318                Nodes::Expression(_) => vec![NodePin::property("Value")],
319                Nodes::Result => vec![
320                    NodePin::execute("In", false),
321                    NodePin::parameter("Data", TypeInfo),
322                ],
323                Nodes::Convert(_) => vec![
324                    NodePin::execute("In", false),
325                    NodePin::property("Name"),
326                    NodePin::parameter("Data in", TypeInfo),
327                ],
328            }
329        }
330
331        fn node_pins_out(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
332            match self {
333                Nodes::Start => vec![NodePin::execute("Out", false)],
334                Nodes::Expression(_) => vec![NodePin::parameter("Data", TypeInfo)],
335                Nodes::Result => vec![],
336                Nodes::Convert(_) => vec![
337                    NodePin::execute("Out", false),
338                    NodePin::parameter("Data out", TypeInfo),
339                ],
340            }
341        }
342
343        fn node_is_start(&self, _: &Registry) -> bool {
344            matches!(self, Self::Start)
345        }
346
347        fn node_suggestions(
348            _: i64,
349            _: i64,
350            _: NodeSuggestion<Self>,
351            _: &Registry,
352        ) -> Vec<ResponseSuggestionNode<Self>> {
353            vec![]
354        }
355    }
356
357    fn mock_transfer<T: Serialize + DeserializeOwned>(value: T) -> T {
358        let content = serde_json::to_string(&value).unwrap();
359        serde_json::from_str(&content).unwrap()
360    }
361
362    #[test]
363    fn test_server() {
364        let registry = Registry::default().with_basic_types();
365        let mut server = NodeGraphServer::default();
366        let graph = server.create();
367        let start = Node::new(0, 0, Nodes::Start);
368        let expression = Node::new(0, 0, Nodes::Expression(42));
369        let convert = Node::new(0, 0, Nodes::Convert("foo".to_owned()));
370        let result = Node::new(0, 0, Nodes::Result);
371        server
372            .add(
373                graph,
374                mock_transfer(RequestAdd {
375                    connections: vec![
376                        NodeConnection::new(start.id(), convert.id(), "Out", "In"),
377                        NodeConnection::new(convert.id(), result.id(), "Out", "In"),
378                        NodeConnection::new(expression.id(), convert.id(), "Data", "Data in"),
379                    ],
380                    nodes: vec![
381                        start.clone(),
382                        expression.clone(),
383                        convert.clone(),
384                        result.clone(),
385                    ],
386                }),
387                &registry,
388            )
389            .unwrap();
390        let temp = server.query_all(graph).unwrap();
391        assert_eq!(temp.nodes.len(), 4);
392        assert_eq!(temp.connections.len(), 3);
393        server
394            .remove(
395                graph,
396                RequestRemove {
397                    nodes: vec![result.id(), convert.id()],
398                    connections: vec![],
399                },
400                &registry,
401            )
402            .unwrap();
403        let temp = server.query_all(graph).unwrap();
404        assert_eq!(temp.nodes.len(), 2);
405        assert_eq!(temp.connections.len(), 0);
406        assert!(matches!(
407            server.update(
408                graph,
409                mock_transfer(RequestUpdate {
410                    nodes: vec![expression.clone(), convert.clone()],
411                }),
412            ),
413            Err(NodeGraphServerError::NodeNotFound { .. })
414        ));
415        let temp = server.query_all(graph).unwrap();
416        assert_eq!(temp.nodes.len(), 2);
417        assert_eq!(temp.connections.len(), 0);
418        server
419            .update(
420                graph,
421                mock_transfer(RequestUpdate {
422                    nodes: vec![expression.clone(), start.clone()],
423                }),
424            )
425            .unwrap();
426        let temp = server.query_all(graph).unwrap();
427        assert_eq!(temp.nodes.len(), 2);
428        assert_eq!(temp.connections.len(), 0);
429    }
430}