1use intuicio_core::{registry::Registry, types::TypeQuery};
2use rstar::{Envelope, Point, PointDistance, RTree, RTreeObject, AABB};
3use serde::{Deserialize, Serialize};
4use serde_intermediate::{
5 de::intermediate::DeserializeMode, error::Result as IntermediateResult, Intermediate,
6};
7use std::{
8 collections::{HashMap, HashSet},
9 error::Error,
10 fmt::Display,
11 hash::{Hash, Hasher},
12};
13use typid::ID;
14
15pub type NodeId<T> = ID<Node<T>>;
16pub type PropertyCastMode = DeserializeMode;
17
18#[derive(Debug, Default, Clone, PartialEq)]
19pub struct PropertyValue {
20 value: Intermediate,
21}
22
23impl PropertyValue {
24 pub fn new<T: Serialize>(value: &T) -> IntermediateResult<Self> {
25 Ok(Self {
26 value: serde_intermediate::to_intermediate(value)?,
27 })
28 }
29
30 pub fn get<'a, T: Deserialize<'a>>(&'a self, mode: PropertyCastMode) -> IntermediateResult<T> {
31 serde_intermediate::from_intermediate_as(&self.value, mode)
32 }
33
34 pub fn get_exact<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
35 self.get(PropertyCastMode::Exact)
36 }
37
38 pub fn get_interpret<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
39 self.get(PropertyCastMode::Interpret)
40 }
41
42 pub fn into_inner(self) -> Intermediate {
43 self.value
44 }
45}
46
47pub trait NodeTypeInfo:
48 Clone + std::fmt::Debug + Display + PartialEq + Serialize + for<'de> Deserialize<'de>
49{
50 fn type_query(&self) -> TypeQuery;
51 fn are_compatible(&self, other: &Self) -> bool;
52}
53
54pub trait NodeDefinition: Sized {
55 type TypeInfo: NodeTypeInfo;
56
57 fn node_label(&self, registry: &Registry) -> String;
58 fn node_pins_in(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
59 fn node_pins_out(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
60 fn node_is_start(&self, registry: &Registry) -> bool;
61 fn node_suggestions(
62 x: i64,
63 y: i64,
64 suggestion: NodeSuggestion<Self>,
65 registry: &Registry,
66 ) -> Vec<ResponseSuggestionNode<Self>>;
67
68 #[allow(unused_variables)]
69 fn validate_connection(
70 &self,
71 source: &Self,
72 registry: &Registry,
73 ) -> Result<(), Box<dyn Error>> {
74 Ok(())
75 }
76
77 #[allow(unused_variables)]
78 fn get_property(&self, name: &str) -> Option<PropertyValue> {
79 None
80 }
81
82 #[allow(unused_variables)]
83 fn set_property(&mut self, name: &str, value: PropertyValue) {}
84}
85
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87#[serde(bound = "TI: NodeTypeInfo")]
88pub enum NodePin<TI: NodeTypeInfo> {
89 Execute { name: String, subscope: bool },
90 Parameter { name: String, type_info: TI },
91 Property { name: String },
92}
93
94impl<TI: NodeTypeInfo> NodePin<TI> {
95 pub fn execute(name: impl ToString, subscope: bool) -> Self {
96 Self::Execute {
97 name: name.to_string(),
98 subscope,
99 }
100 }
101
102 pub fn parameter(name: impl ToString, type_info: TI) -> Self {
103 Self::Parameter {
104 name: name.to_string(),
105 type_info,
106 }
107 }
108
109 pub fn property(name: impl ToString) -> Self {
110 Self::Property {
111 name: name.to_string(),
112 }
113 }
114
115 pub fn is_execute(&self) -> bool {
116 matches!(self, Self::Execute { .. })
117 }
118
119 pub fn is_parameter(&self) -> bool {
120 matches!(self, Self::Parameter { .. })
121 }
122
123 pub fn is_property(&self) -> bool {
124 matches!(self, Self::Property { .. })
125 }
126
127 pub fn name(&self) -> &str {
128 match self {
129 Self::Execute { name, .. }
130 | Self::Parameter { name, .. }
131 | Self::Property { name, .. } => name,
132 }
133 }
134
135 pub fn has_subscope(&self) -> bool {
136 matches!(self, Self::Execute { subscope: true, .. })
137 }
138
139 pub fn type_info(&self) -> Option<&TI> {
140 match self {
141 Self::Parameter { type_info, .. } => Some(type_info),
142 _ => None,
143 }
144 }
145}
146
147pub enum NodeSuggestion<'a, T: NodeDefinition> {
148 All,
149 NodeInputPin(&'a T, &'a NodePin<T::TypeInfo>),
150 NodeOutputPin(&'a T, &'a NodePin<T::TypeInfo>),
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ResponseSuggestionNode<T: NodeDefinition> {
155 pub category: String,
156 pub label: String,
157 pub node: Node<T>,
158}
159
160impl<T: NodeDefinition> ResponseSuggestionNode<T> {
161 pub fn new(category: impl ToString, node: Node<T>, registry: &Registry) -> Self {
162 Self {
163 category: category.to_string(),
164 label: node.data.node_label(registry),
165 node,
166 }
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct Node<T: NodeDefinition> {
172 id: NodeId<T>,
173 pub x: i64,
174 pub y: i64,
175 pub data: T,
176}
177
178impl<T: NodeDefinition> Node<T> {
179 pub fn new(x: i64, y: i64, data: T) -> Self {
180 Self {
181 id: Default::default(),
182 x,
183 y,
184 data,
185 }
186 }
187
188 pub fn id(&self) -> NodeId<T> {
189 self.id
190 }
191}
192
193#[derive(Clone, Serialize, Deserialize)]
194pub struct NodeConnection<T: NodeDefinition> {
195 pub from_node: NodeId<T>,
196 pub to_node: NodeId<T>,
197 pub from_pin: String,
198 pub to_pin: String,
199}
200
201impl<T: NodeDefinition> NodeConnection<T> {
202 pub fn new(from_node: NodeId<T>, to_node: NodeId<T>, from_pin: &str, to_pin: &str) -> Self {
203 Self {
204 from_node,
205 to_node,
206 from_pin: from_pin.to_owned(),
207 to_pin: to_pin.to_owned(),
208 }
209 }
210}
211
212impl<T: NodeDefinition> std::fmt::Debug for NodeConnection<T> {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 f.debug_struct("NodeConnection")
215 .field("from_node", &self.from_node)
216 .field("to_node", &self.to_node)
217 .field("from_pin", &self.from_pin)
218 .field("to_pin", &self.to_pin)
219 .finish()
220 }
221}
222
223impl<T: NodeDefinition> PartialEq for NodeConnection<T> {
224 fn eq(&self, other: &Self) -> bool {
225 self.from_node == other.from_node
226 && self.to_node == other.to_node
227 && self.from_pin == other.from_pin
228 && self.to_pin == other.to_pin
229 }
230}
231
232impl<T: NodeDefinition> Eq for NodeConnection<T> {}
233
234impl<T: NodeDefinition> Hash for NodeConnection<T> {
235 fn hash<H: Hasher>(&self, state: &mut H) {
236 self.from_node.hash(state);
237 self.to_node.hash(state);
238 self.from_pin.hash(state);
239 self.to_pin.hash(state);
240 }
241}
242
243#[derive(Debug)]
244pub enum ConnectionError {
245 InternalConnection(String),
246 SourceNodeNotFound(String),
247 TargetNodeNotFound(String),
248 NodesNotFound {
249 from: String,
250 to: String,
251 },
252 SourcePinNotFound {
253 node: String,
254 pin: String,
255 },
256 TargetPinNotFound {
257 node: String,
258 pin: String,
259 },
260 MismatchTypes {
261 from_node: String,
262 from_pin: String,
263 from_type_info: String,
264 to_node: String,
265 to_pin: String,
266 to_type_info: String,
267 },
268 MismatchPins {
269 from_node: String,
270 from_pin: String,
271 to_node: String,
272 to_pin: String,
273 },
274 CycleNodeFound(String),
275 Custom(Box<dyn Error>),
276}
277
278impl std::fmt::Display for ConnectionError {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 match self {
281 Self::InternalConnection(node) => {
282 write!(f, "Trying to connect node: {} to itself", node)
283 }
284 Self::SourceNodeNotFound(node) => write!(f, "Source node: {} not found", node),
285 Self::TargetNodeNotFound(node) => write!(f, "Target node: {} not found", node),
286 Self::NodesNotFound { from, to } => {
287 write!(f, "Source: {} and target: {} nodes not found", from, to)
288 }
289 Self::SourcePinNotFound { node, pin } => {
290 write!(f, "Source pin: {} for node: {} not found", pin, node)
291 }
292 Self::TargetPinNotFound { node, pin } => {
293 write!(f, "Target pin: {} for node: {} not found", pin, node)
294 }
295 Self::MismatchTypes {
296 from_node,
297 from_pin,
298 from_type_info,
299 to_node,
300 to_pin,
301 to_type_info,
302 } => {
303 write!(
304 f,
305 "Source type: {} of pin: {} for node: {} does not match target type: {} of pin: {} for node: {}",
306 from_type_info, from_pin, from_node, to_type_info, to_pin, to_node
307 )
308 }
309 Self::MismatchPins {
310 from_node,
311 from_pin,
312 to_node,
313 to_pin,
314 } => {
315 write!(
316 f,
317 "Source pin: {} kind for node: {} does not match target pin: {} kind for node: {}",
318 from_pin, from_node, to_pin, to_node
319 )
320 }
321 Self::CycleNodeFound(node) => write!(f, "Found cycle node: {}", node),
322 Self::Custom(error) => error.fmt(f),
323 }
324 }
325}
326
327impl Error for ConnectionError {}
328
329#[derive(Debug)]
330pub enum NodeGraphError {
331 Connection(ConnectionError),
332 DuplicateFunctionInputNames(String),
333 DuplicateFunctionOutputNames(String),
334}
335
336impl std::fmt::Display for NodeGraphError {
337 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 match self {
339 Self::Connection(connection) => connection.fmt(f),
340 Self::DuplicateFunctionInputNames(name) => {
341 write!(
342 f,
343 "Found duplicate `{}` function input with different types",
344 name
345 )
346 }
347 Self::DuplicateFunctionOutputNames(name) => {
348 write!(
349 f,
350 "Found duplicate `{}` function output with different types",
351 name
352 )
353 }
354 }
355 }
356}
357
358impl Error for NodeGraphError {}
359
360#[derive(Clone)]
361struct SpatialNode<T: NodeDefinition> {
362 id: NodeId<T>,
363 x: i64,
364 y: i64,
365}
366
367impl<T: NodeDefinition> RTreeObject for SpatialNode<T> {
368 type Envelope = AABB<[i64; 2]>;
369
370 fn envelope(&self) -> Self::Envelope {
371 AABB::from_point([self.x, self.y])
372 }
373}
374
375impl<T: NodeDefinition> PointDistance for SpatialNode<T> {
376 fn distance_2(
377 &self,
378 point: &<Self::Envelope as Envelope>::Point,
379 ) -> <<Self::Envelope as Envelope>::Point as Point>::Scalar {
380 let dx = self.x - point[0];
381 let dy = self.y - point[1];
382 dx * dx + dy * dy
383 }
384}
385
386#[derive(Clone, Serialize, Deserialize)]
387pub struct NodeGraph<T: NodeDefinition> {
388 nodes: Vec<Node<T>>,
389 connections: Vec<NodeConnection<T>>,
390 #[serde(skip, default)]
391 rtree: RTree<SpatialNode<T>>,
392}
393
394impl<T: NodeDefinition> Default for NodeGraph<T> {
395 fn default() -> Self {
396 Self {
397 nodes: vec![],
398 connections: vec![],
399 rtree: Default::default(),
400 }
401 }
402}
403
404impl<T: NodeDefinition> NodeGraph<T> {
405 pub fn clear(&mut self) {
406 self.nodes.clear();
407 self.connections.clear();
408 }
409
410 pub fn refresh_spatial_cache(&mut self) {
411 self.rtree = RTree::bulk_load(
412 self.nodes
413 .iter()
414 .map(|node| SpatialNode {
415 id: node.id,
416 x: node.x,
417 y: node.y,
418 })
419 .collect(),
420 );
421 }
422
423 pub fn query_nearest_nodes(&self, x: i64, y: i64) -> impl Iterator<Item = NodeId<T>> + '_ {
424 self.rtree
425 .nearest_neighbor_iter(&[x, y])
426 .map(|node| node.id)
427 }
428
429 pub fn query_region_nodes(
430 &self,
431 fx: i64,
432 fy: i64,
433 tx: i64,
434 ty: i64,
435 extrude: i64,
436 ) -> impl Iterator<Item = NodeId<T>> + '_ {
437 self.rtree
438 .locate_in_envelope(&AABB::from_corners(
439 [fx - extrude, fy - extrude],
440 [tx - extrude, ty - extrude],
441 ))
442 .map(|node| node.id)
443 }
444
445 pub fn suggest_all_nodes(
446 x: i64,
447 y: i64,
448 registry: &Registry,
449 ) -> Vec<ResponseSuggestionNode<T>> {
450 T::node_suggestions(x, y, NodeSuggestion::All, registry)
451 }
452
453 pub fn suggest_node_input_pin(
454 &self,
455 x: i64,
456 y: i64,
457 id: NodeId<T>,
458 name: &str,
459 registry: &Registry,
460 ) -> Vec<ResponseSuggestionNode<T>> {
461 if let Some(node) = self.node(id) {
462 if let Some(pin) = node
463 .data
464 .node_pins_in(registry)
465 .into_iter()
466 .find(|pin| pin.name() == name)
467 {
468 return T::node_suggestions(
469 x,
470 y,
471 NodeSuggestion::NodeInputPin(&node.data, &pin),
472 registry,
473 );
474 }
475 }
476 vec![]
477 }
478
479 pub fn suggest_node_output_pin(
480 &self,
481 x: i64,
482 y: i64,
483 id: NodeId<T>,
484 name: &str,
485 registry: &Registry,
486 ) -> Vec<ResponseSuggestionNode<T>> {
487 if let Some(node) = self.node(id) {
488 if let Some(pin) = node
489 .data
490 .node_pins_out(registry)
491 .into_iter()
492 .find(|pin| pin.name() == name)
493 {
494 return T::node_suggestions(
495 x,
496 y,
497 NodeSuggestion::NodeOutputPin(&node.data, &pin),
498 registry,
499 );
500 }
501 }
502 vec![]
503 }
504
505 pub fn node(&self, id: NodeId<T>) -> Option<&Node<T>> {
506 self.nodes.iter().find(|node| node.id == id)
507 }
508
509 pub fn node_mut(&mut self, id: NodeId<T>) -> Option<&mut Node<T>> {
510 self.nodes.iter_mut().find(|node| node.id == id)
511 }
512
513 pub fn nodes(&self) -> impl Iterator<Item = &Node<T>> {
514 self.nodes.iter()
515 }
516
517 pub fn nodes_mut(&mut self) -> impl Iterator<Item = &mut Node<T>> {
518 self.nodes.iter_mut()
519 }
520
521 pub fn add_node(&mut self, node: Node<T>, registry: &Registry) -> Option<NodeId<T>> {
522 if node.data.node_is_start(registry)
523 && self
524 .nodes
525 .iter()
526 .any(|node| node.data.node_is_start(registry))
527 {
528 return None;
529 }
530 let id = node.id;
531 if let Some(index) = self.nodes.iter().position(|node| node.id == id) {
532 self.nodes.swap_remove(index);
533 }
534 self.nodes.push(node);
535 Some(id)
536 }
537
538 pub fn remove_node(&mut self, id: NodeId<T>, registry: &Registry) -> Option<Node<T>> {
539 if let Some(index) = self
540 .nodes
541 .iter()
542 .position(|node| node.id == id && !node.data.node_is_start(registry))
543 {
544 self.disconnect_node(id, None);
545 Some(self.nodes.swap_remove(index))
546 } else {
547 None
548 }
549 }
550
551 pub fn connect_nodes(&mut self, connection: NodeConnection<T>) {
552 if !self.connections.iter().any(|other| &connection == other) {
553 self.disconnect_node(connection.from_node, Some(&connection.from_pin));
554 self.disconnect_node(connection.to_node, Some(&connection.to_pin));
555 self.connections.push(connection);
556 }
557 }
558
559 pub fn disconnect_nodes(
560 &mut self,
561 from_node: NodeId<T>,
562 to_node: NodeId<T>,
563 from_pin: &str,
564 to_pin: &str,
565 ) {
566 if let Some(index) = self.connections.iter().position(|connection| {
567 connection.from_node == from_node
568 && connection.to_node == to_node
569 && connection.from_pin == from_pin
570 && connection.to_pin == to_pin
571 }) {
572 self.connections.swap_remove(index);
573 }
574 }
575
576 pub fn disconnect_node(&mut self, node: NodeId<T>, pin: Option<&str>) {
577 let to_remove = self
578 .connections
579 .iter()
580 .enumerate()
581 .filter_map(|(index, connection)| {
582 if let Some(pin) = pin {
583 if connection.from_node == node && connection.from_pin == pin {
584 return Some(index);
585 }
586 if connection.to_node == node && connection.to_pin == pin {
587 return Some(index);
588 }
589 } else if connection.from_node == node || connection.to_node == node {
590 return Some(index);
591 }
592 None
593 })
594 .collect::<Vec<_>>();
595 for index in to_remove.into_iter().rev() {
596 self.connections.swap_remove(index);
597 }
598 }
599
600 pub fn connections(&self) -> impl Iterator<Item = &NodeConnection<T>> {
601 self.connections.iter()
602 }
603
604 pub fn node_connections(&self, id: NodeId<T>) -> impl Iterator<Item = &NodeConnection<T>> {
605 self.connections
606 .iter()
607 .filter(move |connection| connection.from_node == id || connection.to_node == id)
608 }
609
610 pub fn node_connections_in<'a>(
611 &'a self,
612 id: NodeId<T>,
613 pin: Option<&'a str>,
614 ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
615 self.connections.iter().filter(move |connection| {
616 connection.to_node == id && pin.map(|pin| connection.to_pin == pin).unwrap_or(true)
617 })
618 }
619
620 pub fn node_connections_out<'a>(
621 &'a self,
622 id: NodeId<T>,
623 pin: Option<&'a str>,
624 ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
625 self.connections.iter().filter(move |connection| {
626 connection.from_node == id && pin.map(|pin| connection.from_pin == pin).unwrap_or(true)
627 })
628 }
629
630 pub fn node_neighbors_in<'a>(
631 &'a self,
632 id: NodeId<T>,
633 pin: Option<&'a str>,
634 ) -> impl Iterator<Item = NodeId<T>> + 'a {
635 self.node_connections_in(id, pin)
636 .map(move |connection| connection.from_node)
637 }
638
639 pub fn node_neighbors_out<'a>(
640 &'a self,
641 id: NodeId<T>,
642 pin: Option<&'a str>,
643 ) -> impl Iterator<Item = NodeId<T>> + 'a {
644 self.node_connections_out(id, pin)
645 .map(move |connection| connection.to_node)
646 }
647
648 pub fn validate(&self, registry: &Registry) -> Result<(), Vec<NodeGraphError>> {
649 let mut errors = self
650 .connections
651 .iter()
652 .filter_map(|connection| self.validate_connection(connection, registry))
653 .map(NodeGraphError::Connection)
654 .collect::<Vec<_>>();
655 if let Some(error) = self.detect_cycles() {
656 errors.push(NodeGraphError::Connection(error));
657 }
658 if errors.is_empty() {
659 Ok(())
660 } else {
661 Err(errors)
662 }
663 }
664
665 fn validate_connection(
666 &self,
667 connection: &NodeConnection<T>,
668 registry: &Registry,
669 ) -> Option<ConnectionError> {
670 if connection.from_node == connection.to_node {
671 return Some(ConnectionError::InternalConnection(
672 connection.from_node.to_string(),
673 ));
674 }
675 let from = self
676 .nodes
677 .iter()
678 .find(|node| node.id == connection.from_node);
679 let to = self.nodes.iter().find(|node| node.id == connection.to_node);
680 let (from_node, to_node) = match (from, to) {
681 (Some(from), Some(to)) => (from, to),
682 (Some(_), None) => {
683 return Some(ConnectionError::TargetNodeNotFound(
684 connection.to_node.to_string(),
685 ));
686 }
687 (None, Some(_)) => {
688 return Some(ConnectionError::SourceNodeNotFound(
689 connection.from_node.to_string(),
690 ));
691 }
692 (None, None) => {
693 return Some(ConnectionError::NodesNotFound {
694 from: connection.from_node.to_string(),
695 to: connection.to_node.to_string(),
696 });
697 }
698 };
699 let from_pins_out = from_node.data.node_pins_out(registry);
700 let from_pin = match from_pins_out
701 .iter()
702 .find(|pin| pin.name() == connection.from_pin)
703 {
704 Some(pin) => pin,
705 None => {
706 return Some(ConnectionError::SourcePinNotFound {
707 node: connection.from_node.to_string(),
708 pin: connection.from_pin.to_owned(),
709 })
710 }
711 };
712 let to_pins_in = to_node.data.node_pins_in(registry);
713 let to_pin = match to_pins_in
714 .iter()
715 .find(|pin| pin.name() == connection.to_pin)
716 {
717 Some(pin) => pin,
718 None => {
719 return Some(ConnectionError::TargetPinNotFound {
720 node: connection.to_node.to_string(),
721 pin: connection.to_pin.to_owned(),
722 })
723 }
724 };
725 match (from_pin, to_pin) {
726 (NodePin::Execute { .. }, NodePin::Execute { .. }) => {}
727 (NodePin::Parameter { type_info: a, .. }, NodePin::Parameter { type_info: b, .. }) => {
728 if !a.are_compatible(b) {
729 return Some(ConnectionError::MismatchTypes {
730 from_node: connection.from_node.to_string(),
731 from_pin: connection.from_pin.to_owned(),
732 to_node: connection.to_node.to_string(),
733 to_pin: connection.to_pin.to_owned(),
734 from_type_info: a.to_string(),
735 to_type_info: b.to_string(),
736 });
737 }
738 }
739 (NodePin::Property { .. }, NodePin::Property { .. }) => {}
740 _ => {
741 return Some(ConnectionError::MismatchPins {
742 from_node: connection.from_node.to_string(),
743 from_pin: connection.from_pin.to_owned(),
744 to_node: connection.to_node.to_string(),
745 to_pin: connection.to_pin.to_owned(),
746 });
747 }
748 }
749 if let Err(error) = to_node.data.validate_connection(&from_node.data, registry) {
750 return Some(ConnectionError::Custom(error));
751 }
752 None
753 }
754
755 fn detect_cycles(&self) -> Option<ConnectionError> {
756 let mut visited = HashSet::with_capacity(self.nodes.len());
757 let mut available = self.nodes.iter().map(|node| node.id).collect::<Vec<_>>();
758 while let Some(id) = available.first() {
759 if let Some(error) = self.detect_cycle(*id, &mut available, &mut visited) {
760 return Some(error);
761 }
762 available.swap_remove(0);
763 }
764 None
765 }
766
767 fn detect_cycle(
768 &self,
769 id: NodeId<T>,
770 available: &mut Vec<NodeId<T>>,
771 visited: &mut HashSet<NodeId<T>>,
772 ) -> Option<ConnectionError> {
773 if visited.contains(&id) {
774 return Some(ConnectionError::CycleNodeFound(id.to_string()));
775 }
776 visited.insert(id);
777 for id in self.node_neighbors_out(id, None) {
778 if let Some(index) = available.iter().position(|item| item == &id) {
779 available.swap_remove(index);
780 if let Some(error) = self.detect_cycle(id, available, visited) {
781 return Some(error);
782 }
783 }
784 }
785 None
786 }
787
788 pub fn visit<V: NodeGraphVisitor<T>>(
789 &self,
790 visitor: &mut V,
791 registry: &Registry,
792 ) -> Vec<V::Output> {
793 let starts = self
794 .nodes
795 .iter()
796 .filter(|node| node.data.node_is_start(registry))
797 .map(|node| node.id)
798 .collect::<HashSet<_>>();
799 let mut result = Vec::with_capacity(self.nodes.len());
800 for id in starts {
801 self.visit_statement(id, &mut result, visitor, registry);
802 }
803 result
804 }
805
806 fn visit_statement<V: NodeGraphVisitor<T>>(
807 &self,
808 id: NodeId<T>,
809 result: &mut Vec<V::Output>,
810 visitor: &mut V,
811 registry: &Registry,
812 ) {
813 if let Some(node) = self.node(id) {
814 let inputs = node
815 .data
816 .node_pins_in(registry)
817 .into_iter()
818 .filter(|pin| pin.is_parameter())
819 .filter_map(|pin| {
820 self.node_neighbors_in(id, Some(pin.name()))
821 .next()
822 .map(|id| (pin.name().to_owned(), id))
823 })
824 .filter_map(|(name, id)| {
825 self.visit_expression(id, visitor, registry)
826 .map(|input| (name, input))
827 })
828 .collect();
829 let pins_out = node.data.node_pins_out(registry);
830 let scopes = pins_out
831 .iter()
832 .filter(|pin| pin.has_subscope())
833 .filter_map(|pin| {
834 let id = self.node_neighbors_out(id, Some(pin.name())).next()?;
835 Some((id, pin.name().to_owned()))
836 })
837 .map(|(id, name)| {
838 let mut result = Vec::with_capacity(self.nodes.len());
839 self.visit_statement(id, &mut result, visitor, registry);
840 (name, result)
841 })
842 .collect();
843 if visitor.visit_statement(node, inputs, scopes, result) {
844 for pin in pins_out {
845 if pin.is_execute() && !pin.has_subscope() {
846 for id in self.node_neighbors_out(id, Some(pin.name())) {
847 self.visit_statement(id, result, visitor, registry);
848 }
849 }
850 }
851 }
852 }
853 }
854
855 fn visit_expression<V: NodeGraphVisitor<T>>(
856 &self,
857 id: NodeId<T>,
858 visitor: &mut V,
859 registry: &Registry,
860 ) -> Option<V::Input> {
861 if let Some(node) = self.node(id) {
862 let inputs = node
863 .data
864 .node_pins_in(registry)
865 .into_iter()
866 .filter(|pin| pin.is_parameter())
867 .filter_map(|pin| {
868 self.node_neighbors_in(id, Some(pin.name()))
869 .next()
870 .map(|id| (pin.name().to_owned(), id))
871 })
872 .filter_map(|(name, id)| {
873 self.visit_expression(id, visitor, registry)
874 .map(|input| (name, input))
875 })
876 .collect();
877 return visitor.visit_expression(node, inputs);
878 }
879 None
880 }
881}
882
883impl<T: NodeDefinition + std::fmt::Debug> std::fmt::Debug for NodeGraph<T> {
884 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
885 f.debug_struct("NodeGraph")
886 .field("nodes", &self.nodes)
887 .field("connections", &self.connections)
888 .finish()
889 }
890}
891
892pub trait NodeGraphVisitor<T: NodeDefinition> {
893 type Input;
894 type Output;
895
896 fn visit_statement(
897 &mut self,
898 node: &Node<T>,
899 inputs: HashMap<String, Self::Input>,
900 scopes: HashMap<String, Vec<Self::Output>>,
901 result: &mut Vec<Self::Output>,
902 ) -> bool;
903
904 fn visit_expression(
905 &mut self,
906 node: &Node<T>,
907 inputs: HashMap<String, Self::Input>,
908 ) -> Option<Self::Input>;
909}
910
911#[cfg(test)]
912mod tests {
913 use crate::prelude::*;
914 use intuicio_core::prelude::*;
915 use std::collections::HashMap;
916
917 #[derive(Debug, Clone, PartialEq)]
918 enum Script {
919 Literal(i32),
920 Return,
921 Call(String),
922 Scope(Vec<Script>),
923 }
924
925 impl NodeTypeInfo for String {
926 fn type_query(&self) -> TypeQuery {
927 TypeQuery {
928 name: Some(self.into()),
929 ..Default::default()
930 }
931 }
932
933 fn are_compatible(&self, other: &Self) -> bool {
934 self == other
935 }
936 }
937
938 #[derive(Debug, Clone)]
939 enum Nodes {
940 Start,
941 Expression(i32),
942 Result,
943 Convert(String),
944 Child,
945 }
946
947 impl NodeDefinition for Nodes {
948 type TypeInfo = String;
949
950 fn node_label(&self, _: &Registry) -> String {
951 format!("{:?}", self)
952 }
953
954 fn node_pins_in(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
955 match self {
956 Nodes::Start => vec![],
957 Nodes::Expression(_) => {
958 vec![NodePin::execute("In", false), NodePin::property("Value")]
959 }
960 Nodes::Result => vec![
961 NodePin::execute("In", false),
962 NodePin::parameter("Data", "i32".to_owned()),
963 ],
964 Nodes::Convert(_) => vec![
965 NodePin::execute("In", false),
966 NodePin::property("Name"),
967 NodePin::parameter("Data in", "i32".to_owned()),
968 ],
969 Nodes::Child => vec![NodePin::execute("In", false)],
970 }
971 }
972
973 fn node_pins_out(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
974 match self {
975 Nodes::Start => vec![NodePin::execute("Out", false)],
976 Nodes::Expression(_) => vec![
977 NodePin::execute("Out", false),
978 NodePin::parameter("Data", "i32".to_owned()),
979 ],
980 Nodes::Result => vec![],
981 Nodes::Convert(_) => vec![
982 NodePin::execute("Out", false),
983 NodePin::parameter("Data out", "i32".to_owned()),
984 ],
985 Nodes::Child => vec![
986 NodePin::execute("Out", false),
987 NodePin::execute("Body", true),
988 ],
989 }
990 }
991
992 fn node_is_start(&self, _: &Registry) -> bool {
993 matches!(self, Self::Start)
994 }
995
996 fn node_suggestions(
997 _: i64,
998 _: i64,
999 _: NodeSuggestion<Self>,
1000 _: &Registry,
1001 ) -> Vec<ResponseSuggestionNode<Self>> {
1002 vec![]
1003 }
1004
1005 fn get_property(&self, property_name: &str) -> Option<PropertyValue> {
1006 match self {
1007 Nodes::Expression(value) => match property_name {
1008 "Value" => PropertyValue::new(value).ok(),
1009 _ => None,
1010 },
1011 Nodes::Convert(name) => match property_name {
1012 "Name" => PropertyValue::new(name).ok(),
1013 _ => None,
1014 },
1015 _ => None,
1016 }
1017 }
1018
1019 fn set_property(&mut self, property_name: &str, property_value: PropertyValue) {
1020 #[allow(clippy::single_match)]
1021 match self {
1022 Nodes::Expression(value) => match property_name {
1023 "Value" => {
1024 if let Ok(v) = property_value.get_exact::<i32>() {
1025 *value = v;
1026 }
1027 }
1028 _ => {}
1029 },
1030 Nodes::Convert(name) => {
1031 if let Ok(v) = property_value.get_exact::<String>() {
1032 *name = v;
1033 }
1034 }
1035 _ => {}
1036 }
1037 }
1038 }
1039
1040 struct CompileNodesToScript;
1041
1042 impl NodeGraphVisitor<Nodes> for CompileNodesToScript {
1043 type Input = ();
1044 type Output = Script;
1045
1046 fn visit_statement(
1047 &mut self,
1048 node: &Node<Nodes>,
1049 _: HashMap<String, Self::Input>,
1050 mut scopes: HashMap<String, Vec<Self::Output>>,
1051 result: &mut Vec<Self::Output>,
1052 ) -> bool {
1053 match &node.data {
1054 Nodes::Result => result.push(Script::Return),
1055 Nodes::Convert(name) => result.push(Script::Call(name.to_owned())),
1056 Nodes::Child => {
1057 if let Some(body) = scopes.remove("Body") {
1058 result.push(Script::Scope(body));
1059 }
1060 }
1061 Nodes::Expression(value) => result.push(Script::Literal(*value)),
1062 _ => {}
1063 }
1064 true
1065 }
1066
1067 fn visit_expression(
1068 &mut self,
1069 _: &Node<Nodes>,
1070 _: HashMap<String, Self::Input>,
1071 ) -> Option<Self::Input> {
1072 None
1073 }
1074 }
1075
1076 #[test]
1077 fn test_nodes() {
1078 let registry = Registry::default().with_basic_types();
1079 let mut graph = NodeGraph::default();
1080 let start = graph
1081 .add_node(Node::new(0, 0, Nodes::Start), ®istry)
1082 .unwrap();
1083 let expression_child = graph
1084 .add_node(Node::new(0, 0, Nodes::Expression(42)), ®istry)
1085 .unwrap();
1086 let convert_child = graph
1087 .add_node(Node::new(0, 0, Nodes::Convert("foo".to_owned())), ®istry)
1088 .unwrap();
1089 let result_child = graph
1090 .add_node(Node::new(0, 0, Nodes::Result), ®istry)
1091 .unwrap();
1092 let child = graph
1093 .add_node(Node::new(0, 0, Nodes::Child), ®istry)
1094 .unwrap();
1095 let expression = graph
1096 .add_node(Node::new(0, 0, Nodes::Expression(42)), ®istry)
1097 .unwrap();
1098 let convert = graph
1099 .add_node(Node::new(0, 0, Nodes::Convert("bar".to_owned())), ®istry)
1100 .unwrap();
1101 let result = graph
1102 .add_node(Node::new(0, 0, Nodes::Result), ®istry)
1103 .unwrap();
1104 graph.connect_nodes(NodeConnection::new(start, child, "Out", "In"));
1105 graph.connect_nodes(NodeConnection::new(child, expression_child, "Body", "In"));
1106 graph.connect_nodes(NodeConnection::new(
1107 expression_child,
1108 convert_child,
1109 "Out",
1110 "In",
1111 ));
1112 graph.connect_nodes(NodeConnection::new(
1113 expression_child,
1114 convert_child,
1115 "Data",
1116 "Data in",
1117 ));
1118 graph.connect_nodes(NodeConnection::new(
1119 convert_child,
1120 result_child,
1121 "Out",
1122 "In",
1123 ));
1124 graph.connect_nodes(NodeConnection::new(
1125 convert_child,
1126 result_child,
1127 "Data out",
1128 "Data",
1129 ));
1130 graph.connect_nodes(NodeConnection::new(child, expression, "Out", "In"));
1131 graph.connect_nodes(NodeConnection::new(expression, convert, "Out", "In"));
1132 graph.connect_nodes(NodeConnection::new(expression, convert, "Data", "Data in"));
1133 graph.connect_nodes(NodeConnection::new(convert, result, "Out", "In"));
1134 graph.connect_nodes(NodeConnection::new(convert, result, "Data out", "Data"));
1135 graph.validate(®istry).unwrap();
1136 assert_eq!(
1137 graph.visit(&mut CompileNodesToScript, ®istry),
1138 vec![
1139 Script::Scope(vec![
1140 Script::Literal(42),
1141 Script::Call("foo".to_owned()),
1142 Script::Return
1143 ]),
1144 Script::Literal(42),
1145 Script::Call("bar".to_owned()),
1146 Script::Return
1147 ]
1148 );
1149 assert_eq!(
1150 graph
1151 .node(expression)
1152 .unwrap()
1153 .data
1154 .get_property("Value")
1155 .unwrap(),
1156 PropertyValue::new(&42i32).unwrap(),
1157 );
1158 graph
1159 .node_mut(expression)
1160 .unwrap()
1161 .data
1162 .set_property("Value", PropertyValue::new(&10i32).unwrap());
1163 assert_eq!(
1164 graph
1165 .node(expression)
1166 .unwrap()
1167 .data
1168 .get_property("Value")
1169 .unwrap(),
1170 PropertyValue::new(&10i32).unwrap(),
1171 );
1172 }
1173}