cranelift_isle/
disjointsets.rs

1//! Implementation of [`DisjointSets`], to store disjoint sets and provide efficient operations to
2//! merge sets
3
4use std::collections::HashMap;
5use std::hash::Hash;
6
7/// Stores disjoint sets and provides efficient operations to merge two sets, and to find a
8/// representative member of a set given any member of that set. In this implementation, sets always
9/// have at least two members, and can only be formed by the `merge` operation.
10#[derive(Clone, Debug, Default)]
11pub struct DisjointSets<T> {
12    parent: HashMap<T, (T, u8)>,
13}
14
15impl<T: Copy + std::fmt::Debug + Eq + Hash> DisjointSets<T> {
16    /// Find a representative member of the set containing `x`. If `x` has not been merged with any
17    /// other items using `merge`, returns `None`. This method updates the data structure to make
18    /// future queries faster, and takes amortized constant time.
19    ///
20    /// ```
21    /// let mut sets = cranelift_isle::disjointsets::DisjointSets::default();
22    /// sets.merge(1, 2);
23    /// sets.merge(1, 3);
24    /// sets.merge(2, 4);
25    /// assert_eq!(sets.find_mut(3).unwrap(), sets.find_mut(4).unwrap());
26    /// assert_eq!(sets.find_mut(10), None);
27    /// ```
28    pub fn find_mut(&mut self, mut x: T) -> Option<T> {
29        while let Some(node) = self.parent.get(&x) {
30            if node.0 == x {
31                return Some(x);
32            }
33            let grandparent = self.parent[&node.0].0;
34            // Re-do the lookup but take a mutable borrow this time
35            self.parent.get_mut(&x).unwrap().0 = grandparent;
36            x = grandparent;
37        }
38        None
39    }
40
41    /// Find a representative member of the set containing `x`. If `x` has not been merged with any
42    /// other items using `merge`, returns `None`. This method does not update the data structure to
43    /// make future queries faster, so `find_mut` should be preferred.
44    ///
45    /// ```
46    /// let mut sets = cranelift_isle::disjointsets::DisjointSets::default();
47    /// sets.merge(1, 2);
48    /// sets.merge(1, 3);
49    /// sets.merge(2, 4);
50    /// assert_eq!(sets.find(3).unwrap(), sets.find(4).unwrap());
51    /// assert_eq!(sets.find(10), None);
52    /// ```
53    pub fn find(&self, mut x: T) -> Option<T> {
54        while let Some(node) = self.parent.get(&x) {
55            if node.0 == x {
56                return Some(x);
57            }
58            x = node.0;
59        }
60        None
61    }
62
63    /// Merge the set containing `x` with the set containing `y`. This method takes amortized
64    /// constant time.
65    pub fn merge(&mut self, x: T, y: T) {
66        assert_ne!(x, y);
67        let mut x = if let Some(x) = self.find_mut(x) {
68            self.parent[&x]
69        } else {
70            self.parent.insert(x, (x, 0));
71            (x, 0)
72        };
73        let mut y = if let Some(y) = self.find_mut(y) {
74            self.parent[&y]
75        } else {
76            self.parent.insert(y, (y, 0));
77            (y, 0)
78        };
79
80        if x == y {
81            return;
82        }
83
84        if x.1 < y.1 {
85            std::mem::swap(&mut x, &mut y);
86        }
87
88        self.parent.get_mut(&y.0).unwrap().0 = x.0;
89        if x.1 == y.1 {
90            let x_rank = &mut self.parent.get_mut(&x.0).unwrap().1;
91            *x_rank = x_rank.saturating_add(1);
92        }
93    }
94
95    /// Returns whether the given items have both been merged into the same set. If either is not
96    /// part of any set, returns `false`.
97    ///
98    /// ```
99    /// let mut sets = cranelift_isle::disjointsets::DisjointSets::default();
100    /// sets.merge(1, 2);
101    /// sets.merge(1, 3);
102    /// sets.merge(2, 4);
103    /// sets.merge(5, 6);
104    /// assert!(sets.in_same_set(2, 3));
105    /// assert!(sets.in_same_set(1, 4));
106    /// assert!(sets.in_same_set(3, 4));
107    /// assert!(!sets.in_same_set(4, 5));
108    /// ```
109    pub fn in_same_set(&self, x: T, y: T) -> bool {
110        let x = self.find(x);
111        let y = self.find(y);
112        x.zip(y).filter(|(x, y)| x == y).is_some()
113    }
114
115    /// Remove the set containing the given item, and return all members of that set. The set is
116    /// returned in sorted order. This method takes time linear in the total size of all sets.
117    ///
118    /// ```
119    /// let mut sets = cranelift_isle::disjointsets::DisjointSets::default();
120    /// sets.merge(1, 2);
121    /// sets.merge(1, 3);
122    /// sets.merge(2, 4);
123    /// assert_eq!(sets.remove_set_of(4), &[1, 2, 3, 4]);
124    /// assert_eq!(sets.remove_set_of(1), &[]);
125    /// assert!(sets.is_empty());
126    /// ```
127    pub fn remove_set_of(&mut self, x: T) -> Vec<T>
128    where
129        T: Ord,
130    {
131        let mut set = Vec::new();
132        if let Some(x) = self.find_mut(x) {
133            set.extend(self.parent.keys().copied());
134            // It's important to use `find_mut` here to avoid quadratic worst-case time.
135            set.retain(|&y| self.find_mut(y).unwrap() == x);
136            for y in set.iter() {
137                self.parent.remove(y);
138            }
139            set.sort_unstable();
140        }
141        set
142    }
143
144    /// Returns true if there are no sets. This method takes constant time.
145    ///
146    /// ```
147    /// let mut sets = cranelift_isle::disjointsets::DisjointSets::default();
148    /// assert!(sets.is_empty());
149    /// sets.merge(1, 2);
150    /// assert!(!sets.is_empty());
151    /// ```
152    pub fn is_empty(&self) -> bool {
153        self.parent.is_empty()
154    }
155
156    /// Returns the total number of elements in all sets. This method takes constant time.
157    ///
158    /// ```
159    /// let mut sets = cranelift_isle::disjointsets::DisjointSets::default();
160    /// sets.merge(1, 2);
161    /// assert_eq!(sets.len(), 2);
162    /// sets.merge(3, 4);
163    /// sets.merge(3, 5);
164    /// assert_eq!(sets.len(), 5);
165    /// ```
166    pub fn len(&self) -> usize {
167        self.parent.len()
168    }
169}