1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use std::collections::HashSet;
use std::hash::Hash;
#[derive(Debug, thiserror::Error)]
#[error("Cycle detected")]
pub struct TopoSortCycle;
pub fn toposort<K, I>(
input: impl IntoIterator<Item = K>,
deps: impl Fn(&K) -> I,
) -> Result<Vec<K>, TopoSortCycle>
where
K: Eq + Hash + Clone,
I: Iterator<Item = K>,
{
struct Ts<K, D, I>
where
K: Eq + Hash + Clone,
I: Iterator<Item = K>,
D: Fn(&K) -> I,
{
result_set: HashSet<K>,
result: Vec<K>,
deps: D,
stack: HashSet<K>,
}
impl<K, D, I> Ts<K, D, I>
where
K: Eq + Hash + Clone,
I: Iterator<Item = K>,
D: Fn(&K) -> I,
{
fn visit(&mut self, i: &K) -> Result<(), TopoSortCycle> {
if self.result_set.contains(i) {
return Ok(());
}
if !self.stack.insert(i.clone()) {
return Err(TopoSortCycle);
}
for dep in (self.deps)(i) {
self.visit(&dep)?;
}
let removed = self.stack.remove(i);
assert!(removed);
self.result.push(i.clone());
self.result_set.insert(i.clone());
Ok(())
}
}
let mut ts = Ts {
result: Vec::new(),
result_set: HashSet::new(),
deps,
stack: HashSet::new(),
};
for i in input {
ts.visit(&i)?;
}
Ok(ts.result)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::toposort::toposort;
use crate::toposort::TopoSortCycle;
fn test_toposort(input: &str) -> Result<Vec<&str>, TopoSortCycle> {
let mut keys: Vec<&str> = Vec::new();
let mut edges: HashMap<&str, Vec<&str>> = HashMap::new();
for part in input.split(" ") {
match part.split_once("->") {
Some((k, vs)) => {
keys.push(k);
edges.insert(k, vs.split(",").collect());
}
None => keys.push(part),
};
}
toposort(keys, |k| {
edges
.get(k)
.map(|v| v.as_slice())
.unwrap_or_default()
.into_iter()
.copied()
})
}
fn test_toposort_check(input: &str, expected: &str) {
let sorted = test_toposort(input).unwrap();
let expected = expected.split(" ").collect::<Vec<_>>();
assert_eq!(expected, sorted);
}
#[test]
fn test() {
test_toposort_check("1 2 3", "1 2 3");
test_toposort_check("1->2 2->3 3", "3 2 1");
test_toposort_check("1 2->1 3->2", "1 2 3");
test_toposort_check("1->2,3 2->3 3", "3 2 1");
}
#[test]
fn cycle() {
assert!(test_toposort("1->1").is_err());
assert!(test_toposort("1->2 2->1").is_err());
}
}