intuicio_nodes/
nodes.rs

1use intuicio_core::{registry::Registry, types::TypeQuery};
2use rstar::{Envelope, Point, PointDistance, RTree, RTreeObject, AABB};
3use serde::{Deserialize, Serialize};
4use serde_intermediate::{
5    de::intermediate::DeserializeMode, error::Result as IntermediateResult, Intermediate,
6};
7use std::{
8    collections::{HashMap, HashSet},
9    error::Error,
10    fmt::Display,
11    hash::{Hash, Hasher},
12};
13use typid::ID;
14
15pub type NodeId<T> = ID<Node<T>>;
16pub type PropertyCastMode = DeserializeMode;
17
18#[derive(Debug, Default, Clone, PartialEq)]
19pub struct PropertyValue {
20    value: Intermediate,
21}
22
23impl PropertyValue {
24    pub fn new<T: Serialize>(value: &T) -> IntermediateResult<Self> {
25        Ok(Self {
26            value: serde_intermediate::to_intermediate(value)?,
27        })
28    }
29
30    pub fn get<'a, T: Deserialize<'a>>(&'a self, mode: PropertyCastMode) -> IntermediateResult<T> {
31        serde_intermediate::from_intermediate_as(&self.value, mode)
32    }
33
34    pub fn get_exact<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
35        self.get(PropertyCastMode::Exact)
36    }
37
38    pub fn get_interpret<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
39        self.get(PropertyCastMode::Interpret)
40    }
41
42    pub fn into_inner(self) -> Intermediate {
43        self.value
44    }
45}
46
47pub trait NodeTypeInfo:
48    Clone + std::fmt::Debug + Display + PartialEq + Serialize + for<'de> Deserialize<'de>
49{
50    fn type_query(&self) -> TypeQuery;
51    fn are_compatible(&self, other: &Self) -> bool;
52}
53
54pub trait NodeDefinition: Sized {
55    type TypeInfo: NodeTypeInfo;
56
57    fn node_label(&self, registry: &Registry) -> String;
58    fn node_pins_in(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
59    fn node_pins_out(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
60    fn node_is_start(&self, registry: &Registry) -> bool;
61    fn node_suggestions(
62        x: i64,
63        y: i64,
64        suggestion: NodeSuggestion<Self>,
65        registry: &Registry,
66    ) -> Vec<ResponseSuggestionNode<Self>>;
67
68    #[allow(unused_variables)]
69    fn validate_connection(
70        &self,
71        source: &Self,
72        registry: &Registry,
73    ) -> Result<(), Box<dyn Error>> {
74        Ok(())
75    }
76
77    #[allow(unused_variables)]
78    fn get_property(&self, name: &str) -> Option<PropertyValue> {
79        None
80    }
81
82    #[allow(unused_variables)]
83    fn set_property(&mut self, name: &str, value: PropertyValue) {}
84}
85
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87#[serde(bound = "TI: NodeTypeInfo")]
88pub enum NodePin<TI: NodeTypeInfo> {
89    Execute { name: String, subscope: bool },
90    Parameter { name: String, type_info: TI },
91    Property { name: String },
92}
93
94impl<TI: NodeTypeInfo> NodePin<TI> {
95    pub fn execute(name: impl ToString, subscope: bool) -> Self {
96        Self::Execute {
97            name: name.to_string(),
98            subscope,
99        }
100    }
101
102    pub fn parameter(name: impl ToString, type_info: TI) -> Self {
103        Self::Parameter {
104            name: name.to_string(),
105            type_info,
106        }
107    }
108
109    pub fn property(name: impl ToString) -> Self {
110        Self::Property {
111            name: name.to_string(),
112        }
113    }
114
115    pub fn is_execute(&self) -> bool {
116        matches!(self, Self::Execute { .. })
117    }
118
119    pub fn is_parameter(&self) -> bool {
120        matches!(self, Self::Parameter { .. })
121    }
122
123    pub fn is_property(&self) -> bool {
124        matches!(self, Self::Property { .. })
125    }
126
127    pub fn name(&self) -> &str {
128        match self {
129            Self::Execute { name, .. }
130            | Self::Parameter { name, .. }
131            | Self::Property { name, .. } => name,
132        }
133    }
134
135    pub fn has_subscope(&self) -> bool {
136        matches!(self, Self::Execute { subscope: true, .. })
137    }
138
139    pub fn type_info(&self) -> Option<&TI> {
140        match self {
141            Self::Parameter { type_info, .. } => Some(type_info),
142            _ => None,
143        }
144    }
145}
146
147pub enum NodeSuggestion<'a, T: NodeDefinition> {
148    All,
149    NodeInputPin(&'a T, &'a NodePin<T::TypeInfo>),
150    NodeOutputPin(&'a T, &'a NodePin<T::TypeInfo>),
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ResponseSuggestionNode<T: NodeDefinition> {
155    pub category: String,
156    pub label: String,
157    pub node: Node<T>,
158}
159
160impl<T: NodeDefinition> ResponseSuggestionNode<T> {
161    pub fn new(category: impl ToString, node: Node<T>, registry: &Registry) -> Self {
162        Self {
163            category: category.to_string(),
164            label: node.data.node_label(registry),
165            node,
166        }
167    }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct Node<T: NodeDefinition> {
172    id: NodeId<T>,
173    pub x: i64,
174    pub y: i64,
175    pub data: T,
176}
177
178impl<T: NodeDefinition> Node<T> {
179    pub fn new(x: i64, y: i64, data: T) -> Self {
180        Self {
181            id: Default::default(),
182            x,
183            y,
184            data,
185        }
186    }
187
188    pub fn id(&self) -> NodeId<T> {
189        self.id
190    }
191}
192
193#[derive(Clone, Serialize, Deserialize)]
194pub struct NodeConnection<T: NodeDefinition> {
195    pub from_node: NodeId<T>,
196    pub to_node: NodeId<T>,
197    pub from_pin: String,
198    pub to_pin: String,
199}
200
201impl<T: NodeDefinition> NodeConnection<T> {
202    pub fn new(from_node: NodeId<T>, to_node: NodeId<T>, from_pin: &str, to_pin: &str) -> Self {
203        Self {
204            from_node,
205            to_node,
206            from_pin: from_pin.to_owned(),
207            to_pin: to_pin.to_owned(),
208        }
209    }
210}
211
212impl<T: NodeDefinition> std::fmt::Debug for NodeConnection<T> {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        f.debug_struct("NodeConnection")
215            .field("from_node", &self.from_node)
216            .field("to_node", &self.to_node)
217            .field("from_pin", &self.from_pin)
218            .field("to_pin", &self.to_pin)
219            .finish()
220    }
221}
222
223impl<T: NodeDefinition> PartialEq for NodeConnection<T> {
224    fn eq(&self, other: &Self) -> bool {
225        self.from_node == other.from_node
226            && self.to_node == other.to_node
227            && self.from_pin == other.from_pin
228            && self.to_pin == other.to_pin
229    }
230}
231
232impl<T: NodeDefinition> Eq for NodeConnection<T> {}
233
234impl<T: NodeDefinition> Hash for NodeConnection<T> {
235    fn hash<H: Hasher>(&self, state: &mut H) {
236        self.from_node.hash(state);
237        self.to_node.hash(state);
238        self.from_pin.hash(state);
239        self.to_pin.hash(state);
240    }
241}
242
243#[derive(Debug)]
244pub enum ConnectionError {
245    InternalConnection(String),
246    SourceNodeNotFound(String),
247    TargetNodeNotFound(String),
248    NodesNotFound {
249        from: String,
250        to: String,
251    },
252    SourcePinNotFound {
253        node: String,
254        pin: String,
255    },
256    TargetPinNotFound {
257        node: String,
258        pin: String,
259    },
260    MismatchTypes {
261        from_node: String,
262        from_pin: String,
263        from_type_info: String,
264        to_node: String,
265        to_pin: String,
266        to_type_info: String,
267    },
268    MismatchPins {
269        from_node: String,
270        from_pin: String,
271        to_node: String,
272        to_pin: String,
273    },
274    CycleNodeFound(String),
275    Custom(Box<dyn Error>),
276}
277
278impl std::fmt::Display for ConnectionError {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        match self {
281            Self::InternalConnection(node) => {
282                write!(f, "Trying to connect node: {} to itself", node)
283            }
284            Self::SourceNodeNotFound(node) => write!(f, "Source node: {} not found", node),
285            Self::TargetNodeNotFound(node) => write!(f, "Target node: {} not found", node),
286            Self::NodesNotFound { from, to } => {
287                write!(f, "Source: {} and target: {} nodes not found", from, to)
288            }
289            Self::SourcePinNotFound { node, pin } => {
290                write!(f, "Source pin: {} for node: {} not found", pin, node)
291            }
292            Self::TargetPinNotFound { node, pin } => {
293                write!(f, "Target pin: {} for node: {} not found", pin, node)
294            }
295            Self::MismatchTypes {
296                from_node,
297                from_pin,
298                from_type_info,
299                to_node,
300                to_pin,
301                to_type_info,
302            } => {
303                write!(
304                    f,
305                    "Source type: {} of pin: {} for node: {} does not match target type: {} of pin: {} for node: {}",
306                    from_type_info, from_pin, from_node, to_type_info, to_pin, to_node
307                )
308            }
309            Self::MismatchPins {
310                from_node,
311                from_pin,
312                to_node,
313                to_pin,
314            } => {
315                write!(
316                    f,
317                    "Source pin: {} kind for node: {} does not match target pin: {} kind for node: {}",
318                    from_pin, from_node, to_pin, to_node
319                )
320            }
321            Self::CycleNodeFound(node) => write!(f, "Found cycle node: {}", node),
322            Self::Custom(error) => error.fmt(f),
323        }
324    }
325}
326
327impl Error for ConnectionError {}
328
329#[derive(Debug)]
330pub enum NodeGraphError {
331    Connection(ConnectionError),
332    DuplicateFunctionInputNames(String),
333    DuplicateFunctionOutputNames(String),
334}
335
336impl std::fmt::Display for NodeGraphError {
337    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338        match self {
339            Self::Connection(connection) => connection.fmt(f),
340            Self::DuplicateFunctionInputNames(name) => {
341                write!(
342                    f,
343                    "Found duplicate `{}` function input with different types",
344                    name
345                )
346            }
347            Self::DuplicateFunctionOutputNames(name) => {
348                write!(
349                    f,
350                    "Found duplicate `{}` function output with different types",
351                    name
352                )
353            }
354        }
355    }
356}
357
358impl Error for NodeGraphError {}
359
360#[derive(Clone)]
361struct SpatialNode<T: NodeDefinition> {
362    id: NodeId<T>,
363    x: i64,
364    y: i64,
365}
366
367impl<T: NodeDefinition> RTreeObject for SpatialNode<T> {
368    type Envelope = AABB<[i64; 2]>;
369
370    fn envelope(&self) -> Self::Envelope {
371        AABB::from_point([self.x, self.y])
372    }
373}
374
375impl<T: NodeDefinition> PointDistance for SpatialNode<T> {
376    fn distance_2(
377        &self,
378        point: &<Self::Envelope as Envelope>::Point,
379    ) -> <<Self::Envelope as Envelope>::Point as Point>::Scalar {
380        let dx = self.x - point[0];
381        let dy = self.y - point[1];
382        dx * dx + dy * dy
383    }
384}
385
386#[derive(Clone, Serialize, Deserialize)]
387pub struct NodeGraph<T: NodeDefinition> {
388    nodes: Vec<Node<T>>,
389    connections: Vec<NodeConnection<T>>,
390    #[serde(skip, default)]
391    rtree: RTree<SpatialNode<T>>,
392}
393
394impl<T: NodeDefinition> Default for NodeGraph<T> {
395    fn default() -> Self {
396        Self {
397            nodes: vec![],
398            connections: vec![],
399            rtree: Default::default(),
400        }
401    }
402}
403
404impl<T: NodeDefinition> NodeGraph<T> {
405    pub fn clear(&mut self) {
406        self.nodes.clear();
407        self.connections.clear();
408    }
409
410    pub fn refresh_spatial_cache(&mut self) {
411        self.rtree = RTree::bulk_load(
412            self.nodes
413                .iter()
414                .map(|node| SpatialNode {
415                    id: node.id,
416                    x: node.x,
417                    y: node.y,
418                })
419                .collect(),
420        );
421    }
422
423    pub fn query_nearest_nodes(&self, x: i64, y: i64) -> impl Iterator<Item = NodeId<T>> + '_ {
424        self.rtree
425            .nearest_neighbor_iter(&[x, y])
426            .map(|node| node.id)
427    }
428
429    pub fn query_region_nodes(
430        &self,
431        fx: i64,
432        fy: i64,
433        tx: i64,
434        ty: i64,
435        extrude: i64,
436    ) -> impl Iterator<Item = NodeId<T>> + '_ {
437        self.rtree
438            .locate_in_envelope(&AABB::from_corners(
439                [fx - extrude, fy - extrude],
440                [tx - extrude, ty - extrude],
441            ))
442            .map(|node| node.id)
443    }
444
445    pub fn suggest_all_nodes(
446        x: i64,
447        y: i64,
448        registry: &Registry,
449    ) -> Vec<ResponseSuggestionNode<T>> {
450        T::node_suggestions(x, y, NodeSuggestion::All, registry)
451    }
452
453    pub fn suggest_node_input_pin(
454        &self,
455        x: i64,
456        y: i64,
457        id: NodeId<T>,
458        name: &str,
459        registry: &Registry,
460    ) -> Vec<ResponseSuggestionNode<T>> {
461        if let Some(node) = self.node(id) {
462            if let Some(pin) = node
463                .data
464                .node_pins_in(registry)
465                .into_iter()
466                .find(|pin| pin.name() == name)
467            {
468                return T::node_suggestions(
469                    x,
470                    y,
471                    NodeSuggestion::NodeInputPin(&node.data, &pin),
472                    registry,
473                );
474            }
475        }
476        vec![]
477    }
478
479    pub fn suggest_node_output_pin(
480        &self,
481        x: i64,
482        y: i64,
483        id: NodeId<T>,
484        name: &str,
485        registry: &Registry,
486    ) -> Vec<ResponseSuggestionNode<T>> {
487        if let Some(node) = self.node(id) {
488            if let Some(pin) = node
489                .data
490                .node_pins_out(registry)
491                .into_iter()
492                .find(|pin| pin.name() == name)
493            {
494                return T::node_suggestions(
495                    x,
496                    y,
497                    NodeSuggestion::NodeOutputPin(&node.data, &pin),
498                    registry,
499                );
500            }
501        }
502        vec![]
503    }
504
505    pub fn node(&self, id: NodeId<T>) -> Option<&Node<T>> {
506        self.nodes.iter().find(|node| node.id == id)
507    }
508
509    pub fn node_mut(&mut self, id: NodeId<T>) -> Option<&mut Node<T>> {
510        self.nodes.iter_mut().find(|node| node.id == id)
511    }
512
513    pub fn nodes(&self) -> impl Iterator<Item = &Node<T>> {
514        self.nodes.iter()
515    }
516
517    pub fn nodes_mut(&mut self) -> impl Iterator<Item = &mut Node<T>> {
518        self.nodes.iter_mut()
519    }
520
521    pub fn add_node(&mut self, node: Node<T>, registry: &Registry) -> Option<NodeId<T>> {
522        if node.data.node_is_start(registry)
523            && self
524                .nodes
525                .iter()
526                .any(|node| node.data.node_is_start(registry))
527        {
528            return None;
529        }
530        let id = node.id;
531        if let Some(index) = self.nodes.iter().position(|node| node.id == id) {
532            self.nodes.swap_remove(index);
533        }
534        self.nodes.push(node);
535        Some(id)
536    }
537
538    pub fn remove_node(&mut self, id: NodeId<T>, registry: &Registry) -> Option<Node<T>> {
539        if let Some(index) = self
540            .nodes
541            .iter()
542            .position(|node| node.id == id && !node.data.node_is_start(registry))
543        {
544            self.disconnect_node(id, None);
545            Some(self.nodes.swap_remove(index))
546        } else {
547            None
548        }
549    }
550
551    pub fn connect_nodes(&mut self, connection: NodeConnection<T>) {
552        if !self.connections.iter().any(|other| &connection == other) {
553            self.disconnect_node(connection.from_node, Some(&connection.from_pin));
554            self.disconnect_node(connection.to_node, Some(&connection.to_pin));
555            self.connections.push(connection);
556        }
557    }
558
559    pub fn disconnect_nodes(
560        &mut self,
561        from_node: NodeId<T>,
562        to_node: NodeId<T>,
563        from_pin: &str,
564        to_pin: &str,
565    ) {
566        if let Some(index) = self.connections.iter().position(|connection| {
567            connection.from_node == from_node
568                && connection.to_node == to_node
569                && connection.from_pin == from_pin
570                && connection.to_pin == to_pin
571        }) {
572            self.connections.swap_remove(index);
573        }
574    }
575
576    pub fn disconnect_node(&mut self, node: NodeId<T>, pin: Option<&str>) {
577        let to_remove = self
578            .connections
579            .iter()
580            .enumerate()
581            .filter_map(|(index, connection)| {
582                if let Some(pin) = pin {
583                    if connection.from_node == node && connection.from_pin == pin {
584                        return Some(index);
585                    }
586                    if connection.to_node == node && connection.to_pin == pin {
587                        return Some(index);
588                    }
589                } else if connection.from_node == node || connection.to_node == node {
590                    return Some(index);
591                }
592                None
593            })
594            .collect::<Vec<_>>();
595        for index in to_remove.into_iter().rev() {
596            self.connections.swap_remove(index);
597        }
598    }
599
600    pub fn connections(&self) -> impl Iterator<Item = &NodeConnection<T>> {
601        self.connections.iter()
602    }
603
604    pub fn node_connections(&self, id: NodeId<T>) -> impl Iterator<Item = &NodeConnection<T>> {
605        self.connections
606            .iter()
607            .filter(move |connection| connection.from_node == id || connection.to_node == id)
608    }
609
610    pub fn node_connections_in<'a>(
611        &'a self,
612        id: NodeId<T>,
613        pin: Option<&'a str>,
614    ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
615        self.connections.iter().filter(move |connection| {
616            connection.to_node == id && pin.map(|pin| connection.to_pin == pin).unwrap_or(true)
617        })
618    }
619
620    pub fn node_connections_out<'a>(
621        &'a self,
622        id: NodeId<T>,
623        pin: Option<&'a str>,
624    ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
625        self.connections.iter().filter(move |connection| {
626            connection.from_node == id && pin.map(|pin| connection.from_pin == pin).unwrap_or(true)
627        })
628    }
629
630    pub fn node_neighbors_in<'a>(
631        &'a self,
632        id: NodeId<T>,
633        pin: Option<&'a str>,
634    ) -> impl Iterator<Item = NodeId<T>> + 'a {
635        self.node_connections_in(id, pin)
636            .map(move |connection| connection.from_node)
637    }
638
639    pub fn node_neighbors_out<'a>(
640        &'a self,
641        id: NodeId<T>,
642        pin: Option<&'a str>,
643    ) -> impl Iterator<Item = NodeId<T>> + 'a {
644        self.node_connections_out(id, pin)
645            .map(move |connection| connection.to_node)
646    }
647
648    pub fn validate(&self, registry: &Registry) -> Result<(), Vec<NodeGraphError>> {
649        let mut errors = self
650            .connections
651            .iter()
652            .filter_map(|connection| self.validate_connection(connection, registry))
653            .map(NodeGraphError::Connection)
654            .collect::<Vec<_>>();
655        if let Some(error) = self.detect_cycles() {
656            errors.push(NodeGraphError::Connection(error));
657        }
658        if errors.is_empty() {
659            Ok(())
660        } else {
661            Err(errors)
662        }
663    }
664
665    fn validate_connection(
666        &self,
667        connection: &NodeConnection<T>,
668        registry: &Registry,
669    ) -> Option<ConnectionError> {
670        if connection.from_node == connection.to_node {
671            return Some(ConnectionError::InternalConnection(
672                connection.from_node.to_string(),
673            ));
674        }
675        let from = self
676            .nodes
677            .iter()
678            .find(|node| node.id == connection.from_node);
679        let to = self.nodes.iter().find(|node| node.id == connection.to_node);
680        let (from_node, to_node) = match (from, to) {
681            (Some(from), Some(to)) => (from, to),
682            (Some(_), None) => {
683                return Some(ConnectionError::TargetNodeNotFound(
684                    connection.to_node.to_string(),
685                ));
686            }
687            (None, Some(_)) => {
688                return Some(ConnectionError::SourceNodeNotFound(
689                    connection.from_node.to_string(),
690                ));
691            }
692            (None, None) => {
693                return Some(ConnectionError::NodesNotFound {
694                    from: connection.from_node.to_string(),
695                    to: connection.to_node.to_string(),
696                });
697            }
698        };
699        let from_pins_out = from_node.data.node_pins_out(registry);
700        let from_pin = match from_pins_out
701            .iter()
702            .find(|pin| pin.name() == connection.from_pin)
703        {
704            Some(pin) => pin,
705            None => {
706                return Some(ConnectionError::SourcePinNotFound {
707                    node: connection.from_node.to_string(),
708                    pin: connection.from_pin.to_owned(),
709                })
710            }
711        };
712        let to_pins_in = to_node.data.node_pins_in(registry);
713        let to_pin = match to_pins_in
714            .iter()
715            .find(|pin| pin.name() == connection.to_pin)
716        {
717            Some(pin) => pin,
718            None => {
719                return Some(ConnectionError::TargetPinNotFound {
720                    node: connection.to_node.to_string(),
721                    pin: connection.to_pin.to_owned(),
722                })
723            }
724        };
725        match (from_pin, to_pin) {
726            (NodePin::Execute { .. }, NodePin::Execute { .. }) => {}
727            (NodePin::Parameter { type_info: a, .. }, NodePin::Parameter { type_info: b, .. }) => {
728                if !a.are_compatible(b) {
729                    return Some(ConnectionError::MismatchTypes {
730                        from_node: connection.from_node.to_string(),
731                        from_pin: connection.from_pin.to_owned(),
732                        to_node: connection.to_node.to_string(),
733                        to_pin: connection.to_pin.to_owned(),
734                        from_type_info: a.to_string(),
735                        to_type_info: b.to_string(),
736                    });
737                }
738            }
739            (NodePin::Property { .. }, NodePin::Property { .. }) => {}
740            _ => {
741                return Some(ConnectionError::MismatchPins {
742                    from_node: connection.from_node.to_string(),
743                    from_pin: connection.from_pin.to_owned(),
744                    to_node: connection.to_node.to_string(),
745                    to_pin: connection.to_pin.to_owned(),
746                });
747            }
748        }
749        if let Err(error) = to_node.data.validate_connection(&from_node.data, registry) {
750            return Some(ConnectionError::Custom(error));
751        }
752        None
753    }
754
755    fn detect_cycles(&self) -> Option<ConnectionError> {
756        let mut visited = HashSet::with_capacity(self.nodes.len());
757        let mut available = self.nodes.iter().map(|node| node.id).collect::<Vec<_>>();
758        while let Some(id) = available.first() {
759            if let Some(error) = self.detect_cycle(*id, &mut available, &mut visited) {
760                return Some(error);
761            }
762            available.swap_remove(0);
763        }
764        None
765    }
766
767    fn detect_cycle(
768        &self,
769        id: NodeId<T>,
770        available: &mut Vec<NodeId<T>>,
771        visited: &mut HashSet<NodeId<T>>,
772    ) -> Option<ConnectionError> {
773        if visited.contains(&id) {
774            return Some(ConnectionError::CycleNodeFound(id.to_string()));
775        }
776        visited.insert(id);
777        for id in self.node_neighbors_out(id, None) {
778            if let Some(index) = available.iter().position(|item| item == &id) {
779                available.swap_remove(index);
780                if let Some(error) = self.detect_cycle(id, available, visited) {
781                    return Some(error);
782                }
783            }
784        }
785        None
786    }
787
788    pub fn visit<V: NodeGraphVisitor<T>>(
789        &self,
790        visitor: &mut V,
791        registry: &Registry,
792    ) -> Vec<V::Output> {
793        let starts = self
794            .nodes
795            .iter()
796            .filter(|node| node.data.node_is_start(registry))
797            .map(|node| node.id)
798            .collect::<HashSet<_>>();
799        let mut result = Vec::with_capacity(self.nodes.len());
800        for id in starts {
801            self.visit_statement(id, &mut result, visitor, registry);
802        }
803        result
804    }
805
806    fn visit_statement<V: NodeGraphVisitor<T>>(
807        &self,
808        id: NodeId<T>,
809        result: &mut Vec<V::Output>,
810        visitor: &mut V,
811        registry: &Registry,
812    ) {
813        if let Some(node) = self.node(id) {
814            let inputs = node
815                .data
816                .node_pins_in(registry)
817                .into_iter()
818                .filter(|pin| pin.is_parameter())
819                .filter_map(|pin| {
820                    self.node_neighbors_in(id, Some(pin.name()))
821                        .next()
822                        .map(|id| (pin.name().to_owned(), id))
823                })
824                .filter_map(|(name, id)| {
825                    self.visit_expression(id, visitor, registry)
826                        .map(|input| (name, input))
827                })
828                .collect();
829            let pins_out = node.data.node_pins_out(registry);
830            let scopes = pins_out
831                .iter()
832                .filter(|pin| pin.has_subscope())
833                .filter_map(|pin| {
834                    let id = self.node_neighbors_out(id, Some(pin.name())).next()?;
835                    Some((id, pin.name().to_owned()))
836                })
837                .map(|(id, name)| {
838                    let mut result = Vec::with_capacity(self.nodes.len());
839                    self.visit_statement(id, &mut result, visitor, registry);
840                    (name, result)
841                })
842                .collect();
843            if visitor.visit_statement(node, inputs, scopes, result) {
844                for pin in pins_out {
845                    if pin.is_execute() && !pin.has_subscope() {
846                        for id in self.node_neighbors_out(id, Some(pin.name())) {
847                            self.visit_statement(id, result, visitor, registry);
848                        }
849                    }
850                }
851            }
852        }
853    }
854
855    fn visit_expression<V: NodeGraphVisitor<T>>(
856        &self,
857        id: NodeId<T>,
858        visitor: &mut V,
859        registry: &Registry,
860    ) -> Option<V::Input> {
861        if let Some(node) = self.node(id) {
862            let inputs = node
863                .data
864                .node_pins_in(registry)
865                .into_iter()
866                .filter(|pin| pin.is_parameter())
867                .filter_map(|pin| {
868                    self.node_neighbors_in(id, Some(pin.name()))
869                        .next()
870                        .map(|id| (pin.name().to_owned(), id))
871                })
872                .filter_map(|(name, id)| {
873                    self.visit_expression(id, visitor, registry)
874                        .map(|input| (name, input))
875                })
876                .collect();
877            return visitor.visit_expression(node, inputs);
878        }
879        None
880    }
881}
882
883impl<T: NodeDefinition + std::fmt::Debug> std::fmt::Debug for NodeGraph<T> {
884    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
885        f.debug_struct("NodeGraph")
886            .field("nodes", &self.nodes)
887            .field("connections", &self.connections)
888            .finish()
889    }
890}
891
892pub trait NodeGraphVisitor<T: NodeDefinition> {
893    type Input;
894    type Output;
895
896    fn visit_statement(
897        &mut self,
898        node: &Node<T>,
899        inputs: HashMap<String, Self::Input>,
900        scopes: HashMap<String, Vec<Self::Output>>,
901        result: &mut Vec<Self::Output>,
902    ) -> bool;
903
904    fn visit_expression(
905        &mut self,
906        node: &Node<T>,
907        inputs: HashMap<String, Self::Input>,
908    ) -> Option<Self::Input>;
909}
910
911#[cfg(test)]
912mod tests {
913    use crate::prelude::*;
914    use intuicio_core::prelude::*;
915    use std::collections::HashMap;
916
917    #[derive(Debug, Clone, PartialEq)]
918    enum Script {
919        Literal(i32),
920        Return,
921        Call(String),
922        Scope(Vec<Script>),
923    }
924
925    impl NodeTypeInfo for String {
926        fn type_query(&self) -> TypeQuery {
927            TypeQuery {
928                name: Some(self.into()),
929                ..Default::default()
930            }
931        }
932
933        fn are_compatible(&self, other: &Self) -> bool {
934            self == other
935        }
936    }
937
938    #[derive(Debug, Clone)]
939    enum Nodes {
940        Start,
941        Expression(i32),
942        Result,
943        Convert(String),
944        Child,
945    }
946
947    impl NodeDefinition for Nodes {
948        type TypeInfo = String;
949
950        fn node_label(&self, _: &Registry) -> String {
951            format!("{:?}", self)
952        }
953
954        fn node_pins_in(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
955            match self {
956                Nodes::Start => vec![],
957                Nodes::Expression(_) => {
958                    vec![NodePin::execute("In", false), NodePin::property("Value")]
959                }
960                Nodes::Result => vec![
961                    NodePin::execute("In", false),
962                    NodePin::parameter("Data", "i32".to_owned()),
963                ],
964                Nodes::Convert(_) => vec![
965                    NodePin::execute("In", false),
966                    NodePin::property("Name"),
967                    NodePin::parameter("Data in", "i32".to_owned()),
968                ],
969                Nodes::Child => vec![NodePin::execute("In", false)],
970            }
971        }
972
973        fn node_pins_out(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
974            match self {
975                Nodes::Start => vec![NodePin::execute("Out", false)],
976                Nodes::Expression(_) => vec![
977                    NodePin::execute("Out", false),
978                    NodePin::parameter("Data", "i32".to_owned()),
979                ],
980                Nodes::Result => vec![],
981                Nodes::Convert(_) => vec![
982                    NodePin::execute("Out", false),
983                    NodePin::parameter("Data out", "i32".to_owned()),
984                ],
985                Nodes::Child => vec![
986                    NodePin::execute("Out", false),
987                    NodePin::execute("Body", true),
988                ],
989            }
990        }
991
992        fn node_is_start(&self, _: &Registry) -> bool {
993            matches!(self, Self::Start)
994        }
995
996        fn node_suggestions(
997            _: i64,
998            _: i64,
999            _: NodeSuggestion<Self>,
1000            _: &Registry,
1001        ) -> Vec<ResponseSuggestionNode<Self>> {
1002            vec![]
1003        }
1004
1005        fn get_property(&self, property_name: &str) -> Option<PropertyValue> {
1006            match self {
1007                Nodes::Expression(value) => match property_name {
1008                    "Value" => PropertyValue::new(value).ok(),
1009                    _ => None,
1010                },
1011                Nodes::Convert(name) => match property_name {
1012                    "Name" => PropertyValue::new(name).ok(),
1013                    _ => None,
1014                },
1015                _ => None,
1016            }
1017        }
1018
1019        fn set_property(&mut self, property_name: &str, property_value: PropertyValue) {
1020            #[allow(clippy::single_match)]
1021            match self {
1022                Nodes::Expression(value) => match property_name {
1023                    "Value" => {
1024                        if let Ok(v) = property_value.get_exact::<i32>() {
1025                            *value = v;
1026                        }
1027                    }
1028                    _ => {}
1029                },
1030                Nodes::Convert(name) => {
1031                    if let Ok(v) = property_value.get_exact::<String>() {
1032                        *name = v;
1033                    }
1034                }
1035                _ => {}
1036            }
1037        }
1038    }
1039
1040    struct CompileNodesToScript;
1041
1042    impl NodeGraphVisitor<Nodes> for CompileNodesToScript {
1043        type Input = ();
1044        type Output = Script;
1045
1046        fn visit_statement(
1047            &mut self,
1048            node: &Node<Nodes>,
1049            _: HashMap<String, Self::Input>,
1050            mut scopes: HashMap<String, Vec<Self::Output>>,
1051            result: &mut Vec<Self::Output>,
1052        ) -> bool {
1053            match &node.data {
1054                Nodes::Result => result.push(Script::Return),
1055                Nodes::Convert(name) => result.push(Script::Call(name.to_owned())),
1056                Nodes::Child => {
1057                    if let Some(body) = scopes.remove("Body") {
1058                        result.push(Script::Scope(body));
1059                    }
1060                }
1061                Nodes::Expression(value) => result.push(Script::Literal(*value)),
1062                _ => {}
1063            }
1064            true
1065        }
1066
1067        fn visit_expression(
1068            &mut self,
1069            _: &Node<Nodes>,
1070            _: HashMap<String, Self::Input>,
1071        ) -> Option<Self::Input> {
1072            None
1073        }
1074    }
1075
1076    #[test]
1077    fn test_nodes() {
1078        let registry = Registry::default().with_basic_types();
1079        let mut graph = NodeGraph::default();
1080        let start = graph
1081            .add_node(Node::new(0, 0, Nodes::Start), &registry)
1082            .unwrap();
1083        let expression_child = graph
1084            .add_node(Node::new(0, 0, Nodes::Expression(42)), &registry)
1085            .unwrap();
1086        let convert_child = graph
1087            .add_node(Node::new(0, 0, Nodes::Convert("foo".to_owned())), &registry)
1088            .unwrap();
1089        let result_child = graph
1090            .add_node(Node::new(0, 0, Nodes::Result), &registry)
1091            .unwrap();
1092        let child = graph
1093            .add_node(Node::new(0, 0, Nodes::Child), &registry)
1094            .unwrap();
1095        let expression = graph
1096            .add_node(Node::new(0, 0, Nodes::Expression(42)), &registry)
1097            .unwrap();
1098        let convert = graph
1099            .add_node(Node::new(0, 0, Nodes::Convert("bar".to_owned())), &registry)
1100            .unwrap();
1101        let result = graph
1102            .add_node(Node::new(0, 0, Nodes::Result), &registry)
1103            .unwrap();
1104        graph.connect_nodes(NodeConnection::new(start, child, "Out", "In"));
1105        graph.connect_nodes(NodeConnection::new(child, expression_child, "Body", "In"));
1106        graph.connect_nodes(NodeConnection::new(
1107            expression_child,
1108            convert_child,
1109            "Out",
1110            "In",
1111        ));
1112        graph.connect_nodes(NodeConnection::new(
1113            expression_child,
1114            convert_child,
1115            "Data",
1116            "Data in",
1117        ));
1118        graph.connect_nodes(NodeConnection::new(
1119            convert_child,
1120            result_child,
1121            "Out",
1122            "In",
1123        ));
1124        graph.connect_nodes(NodeConnection::new(
1125            convert_child,
1126            result_child,
1127            "Data out",
1128            "Data",
1129        ));
1130        graph.connect_nodes(NodeConnection::new(child, expression, "Out", "In"));
1131        graph.connect_nodes(NodeConnection::new(expression, convert, "Out", "In"));
1132        graph.connect_nodes(NodeConnection::new(expression, convert, "Data", "Data in"));
1133        graph.connect_nodes(NodeConnection::new(convert, result, "Out", "In"));
1134        graph.connect_nodes(NodeConnection::new(convert, result, "Data out", "Data"));
1135        graph.validate(&registry).unwrap();
1136        assert_eq!(
1137            graph.visit(&mut CompileNodesToScript, &registry),
1138            vec![
1139                Script::Scope(vec![
1140                    Script::Literal(42),
1141                    Script::Call("foo".to_owned()),
1142                    Script::Return
1143                ]),
1144                Script::Literal(42),
1145                Script::Call("bar".to_owned()),
1146                Script::Return
1147            ]
1148        );
1149        assert_eq!(
1150            graph
1151                .node(expression)
1152                .unwrap()
1153                .data
1154                .get_property("Value")
1155                .unwrap(),
1156            PropertyValue::new(&42i32).unwrap(),
1157        );
1158        graph
1159            .node_mut(expression)
1160            .unwrap()
1161            .data
1162            .set_property("Value", PropertyValue::new(&10i32).unwrap());
1163        assert_eq!(
1164            graph
1165                .node(expression)
1166                .unwrap()
1167                .data
1168                .get_property("Value")
1169                .unwrap(),
1170            PropertyValue::new(&10i32).unwrap(),
1171        );
1172    }
1173}