protobuf_support/
toposort.rs1use std::collections::HashSet;
2use std::hash::Hash;
3
4#[derive(Debug, thiserror::Error)]
5#[error("Cycle detected")]
6pub struct TopoSortCycle;
7
8pub fn toposort<K, I>(
9 input: impl IntoIterator<Item = K>,
10 deps: impl Fn(&K) -> I,
11) -> Result<Vec<K>, TopoSortCycle>
12where
13 K: Eq + Hash + Clone,
14 I: Iterator<Item = K>,
15{
16 struct Ts<K, D, I>
17 where
18 K: Eq + Hash + Clone,
19 I: Iterator<Item = K>,
20 D: Fn(&K) -> I,
21 {
22 result_set: HashSet<K>,
23 result: Vec<K>,
24 deps: D,
25 stack: HashSet<K>,
26 }
27
28 impl<K, D, I> Ts<K, D, I>
29 where
30 K: Eq + Hash + Clone,
31 I: Iterator<Item = K>,
32 D: Fn(&K) -> I,
33 {
34 fn visit(&mut self, i: &K) -> Result<(), TopoSortCycle> {
35 if self.result_set.contains(i) {
36 return Ok(());
37 }
38
39 if !self.stack.insert(i.clone()) {
40 return Err(TopoSortCycle);
41 }
42 for dep in (self.deps)(i) {
43 self.visit(&dep)?;
44 }
45
46 let removed = self.stack.remove(i);
47 assert!(removed);
48
49 self.result.push(i.clone());
50 self.result_set.insert(i.clone());
51
52 Ok(())
53 }
54 }
55
56 let mut ts = Ts {
57 result: Vec::new(),
58 result_set: HashSet::new(),
59 deps,
60 stack: HashSet::new(),
61 };
62
63 for i in input {
64 ts.visit(&i)?;
65 }
66
67 Ok(ts.result)
68}
69
70#[cfg(test)]
71mod tests {
72 use std::collections::HashMap;
73
74 use crate::toposort::toposort;
75 use crate::toposort::TopoSortCycle;
76
77 fn test_toposort(input: &str) -> Result<Vec<&str>, TopoSortCycle> {
78 let mut keys: Vec<&str> = Vec::new();
79 let mut edges: HashMap<&str, Vec<&str>> = HashMap::new();
80 for part in input.split(" ") {
81 match part.split_once("->") {
82 Some((k, vs)) => {
83 keys.push(k);
84 edges.insert(k, vs.split(",").collect());
85 }
86 None => keys.push(part),
87 };
88 }
89
90 toposort(keys, |k| {
91 edges
92 .get(k)
93 .map(|v| v.as_slice())
94 .unwrap_or_default()
95 .into_iter()
96 .copied()
97 })
98 }
99
100 fn test_toposort_check(input: &str, expected: &str) {
101 let sorted = test_toposort(input).unwrap();
102 let expected = expected.split(" ").collect::<Vec<_>>();
103 assert_eq!(expected, sorted);
104 }
105
106 #[test]
107 fn test() {
108 test_toposort_check("1 2 3", "1 2 3");
109 test_toposort_check("1->2 2->3 3", "3 2 1");
110 test_toposort_check("1 2->1 3->2", "1 2 3");
111 test_toposort_check("1->2,3 2->3 3", "3 2 1");
112 }
113
114 #[test]
115 fn cycle() {
116 assert!(test_toposort("1->1").is_err());
117 assert!(test_toposort("1->2 2->1").is_err());
118 }
119}