protobuf_support/
toposort.rs

1use std::collections::HashSet;
2use std::hash::Hash;
3
4#[derive(Debug, thiserror::Error)]
5#[error("Cycle detected")]
6pub struct TopoSortCycle;
7
8pub fn toposort<K, I>(
9    input: impl IntoIterator<Item = K>,
10    deps: impl Fn(&K) -> I,
11) -> Result<Vec<K>, TopoSortCycle>
12where
13    K: Eq + Hash + Clone,
14    I: Iterator<Item = K>,
15{
16    struct Ts<K, D, I>
17    where
18        K: Eq + Hash + Clone,
19        I: Iterator<Item = K>,
20        D: Fn(&K) -> I,
21    {
22        result_set: HashSet<K>,
23        result: Vec<K>,
24        deps: D,
25        stack: HashSet<K>,
26    }
27
28    impl<K, D, I> Ts<K, D, I>
29    where
30        K: Eq + Hash + Clone,
31        I: Iterator<Item = K>,
32        D: Fn(&K) -> I,
33    {
34        fn visit(&mut self, i: &K) -> Result<(), TopoSortCycle> {
35            if self.result_set.contains(i) {
36                return Ok(());
37            }
38
39            if !self.stack.insert(i.clone()) {
40                return Err(TopoSortCycle);
41            }
42            for dep in (self.deps)(i) {
43                self.visit(&dep)?;
44            }
45
46            let removed = self.stack.remove(i);
47            assert!(removed);
48
49            self.result.push(i.clone());
50            self.result_set.insert(i.clone());
51
52            Ok(())
53        }
54    }
55
56    let mut ts = Ts {
57        result: Vec::new(),
58        result_set: HashSet::new(),
59        deps,
60        stack: HashSet::new(),
61    };
62
63    for i in input {
64        ts.visit(&i)?;
65    }
66
67    Ok(ts.result)
68}
69
70#[cfg(test)]
71mod tests {
72    use std::collections::HashMap;
73
74    use crate::toposort::toposort;
75    use crate::toposort::TopoSortCycle;
76
77    fn test_toposort(input: &str) -> Result<Vec<&str>, TopoSortCycle> {
78        let mut keys: Vec<&str> = Vec::new();
79        let mut edges: HashMap<&str, Vec<&str>> = HashMap::new();
80        for part in input.split(" ") {
81            match part.split_once("->") {
82                Some((k, vs)) => {
83                    keys.push(k);
84                    edges.insert(k, vs.split(",").collect());
85                }
86                None => keys.push(part),
87            };
88        }
89
90        toposort(keys, |k| {
91            edges
92                .get(k)
93                .map(|v| v.as_slice())
94                .unwrap_or_default()
95                .into_iter()
96                .copied()
97        })
98    }
99
100    fn test_toposort_check(input: &str, expected: &str) {
101        let sorted = test_toposort(input).unwrap();
102        let expected = expected.split(" ").collect::<Vec<_>>();
103        assert_eq!(expected, sorted);
104    }
105
106    #[test]
107    fn test() {
108        test_toposort_check("1 2 3", "1 2 3");
109        test_toposort_check("1->2 2->3 3", "3 2 1");
110        test_toposort_check("1 2->1 3->2", "1 2 3");
111        test_toposort_check("1->2,3 2->3 3", "3 2 1");
112    }
113
114    #[test]
115    fn cycle() {
116        assert!(test_toposort("1->1").is_err());
117        assert!(test_toposort("1->2 2->1").is_err());
118    }
119}