1use core::{
6 cmp::{max, min},
7 fmt::{Debug, Formatter},
8 iter::Peekable,
9 ops::RangeInclusive,
10};
11use std::collections::BTreeMap;
12
13use types::Fixed;
14
15#[derive(Default, Clone, PartialEq, Eq)]
16pub struct RangeSet<T> {
20 ranges: BTreeMap<T, T>,
22}
23
24pub trait OrdAdjacency {
26 fn are_adjacent(self, rhs: Self) -> bool;
28}
29
30impl<T> RangeSet<T>
31where
32 T: Ord + Copy + OrdAdjacency,
33{
34 pub fn is_empty(&self) -> bool {
36 self.ranges.is_empty()
37 }
38
39 pub fn insert(&mut self, range: RangeInclusive<T>) {
41 if range.end() < range.start() {
42 return;
44 }
45
46 let mut start = *range.start();
47 let mut end = *range.end();
48
49 if let Some((prev_start, prev_end)) = self.prev_range(start) {
51 if range_is_subset(start, end, prev_start, prev_end) {
52 return;
53 }
54 if ranges_overlap_or_adjacent(start, end, prev_start, prev_end) {
55 start = min(start, prev_start);
56 end = max(end, prev_end);
57 self.ranges.remove(&prev_start);
58 }
59 };
60
61 loop {
63 let Some((next_start, next_end)) = self.next_range(start) else {
64 self.ranges.insert(start, end);
66 return;
67 };
68
69 if range_is_subset(start, end, next_start, next_end) {
70 return;
71 }
72 if ranges_overlap_or_adjacent(start, end, next_start, next_end) {
73 start = min(start, next_start);
74 end = max(end, next_end);
75 self.ranges.remove(&next_start);
76 } else {
77 self.ranges.insert(start, end);
78 return;
79 }
80 }
81 }
82
83 pub fn iter(&'_ self) -> impl Iterator<Item = RangeInclusive<T>> + '_ {
85 self.ranges.iter().map(|(a, b)| *a..=*b)
86 }
87
88 pub fn intersection<'a>(
90 &'a self,
91 other: &'a Self,
92 ) -> impl Iterator<Item = RangeInclusive<T>> + 'a {
93 IntersectionIter {
94 it_a: self.iter().peekable(),
95 it_b: other.iter().peekable(),
96 }
97 }
98
99 fn next_range(&self, start: T) -> Option<(T, T)> {
101 let (next_start, next_end) = self.ranges.range(start..).next()?;
102 Some((*next_start, *next_end))
103 }
104
105 fn prev_range(&self, start: T) -> Option<(T, T)> {
107 let (next_start, next_end) = self.ranges.range(..start).next_back()?;
108 Some((*next_start, *next_end))
109 }
110}
111
112impl<T> Extend<RangeInclusive<T>> for RangeSet<T>
113where
114 T: Copy + Ord + OrdAdjacency,
115{
116 fn extend<I: IntoIterator<Item = RangeInclusive<T>>>(&mut self, iter: I) {
117 iter.into_iter().for_each(|r| self.insert(r));
118 }
119}
120
121impl<T> FromIterator<RangeInclusive<T>> for RangeSet<T>
122where
123 T: Default + Copy + Ord + OrdAdjacency,
124{
125 fn from_iter<I: IntoIterator<Item = RangeInclusive<T>>>(iter: I) -> Self {
126 let mut result: Self = Default::default();
127 result.extend(iter);
128 result
129 }
130}
131
132impl<T> Debug for RangeSet<T>
133where
134 T: Debug,
135{
136 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
137 write!(f, "RangeSet {{")?;
138 for (start, end) in self.ranges.iter() {
139 write!(f, "[{:?}, {:?}], ", start, end)?;
140 }
141 write!(f, "}}")
142 }
143}
144
145struct IntersectionIter<A, B, T>
146where
147 A: Iterator<Item = RangeInclusive<T>>,
148 B: Iterator<Item = RangeInclusive<T>>,
149{
150 it_a: Peekable<A>,
151 it_b: Peekable<B>,
152}
153
154impl<A, B, T> Iterator for IntersectionIter<A, B, T>
155where
156 A: Iterator<Item = RangeInclusive<T>>,
157 B: Iterator<Item = RangeInclusive<T>>,
158 T: Ord + Copy,
159{
160 type Item = RangeInclusive<T>;
161
162 fn next(&mut self) -> Option<Self::Item> {
163 loop {
164 let (Some(a), Some(b)) = (self.it_a.peek(), self.it_b.peek()) else {
165 return None;
166 };
167
168 let a = a.clone();
169 let b = b.clone();
170
171 match range_intersection(&a, &b) {
172 Some(intersection) => {
173 self.step_iterators(&a, &b);
174 return Some(intersection);
175 }
176 None => self.step_iterators(&a, &b),
177 }
178 }
179 }
180}
181
182impl<A, B, T> IntersectionIter<A, B, T>
183where
184 A: Iterator<Item = RangeInclusive<T>>,
185 B: Iterator<Item = RangeInclusive<T>>,
186 T: Ord,
187{
188 fn step_iterators(&mut self, a: &RangeInclusive<T>, b: &RangeInclusive<T>) {
189 if a.end() <= b.end() {
190 self.it_a.next();
191 }
192
193 if a.end() >= b.end() {
194 self.it_b.next();
195 }
196 }
197}
198
199impl OrdAdjacency for u32 {
200 fn are_adjacent(self, rhs: u32) -> bool {
201 matches!(self.checked_add(1).map(|r| r == rhs), Some(true))
202 || matches!(rhs.checked_add(1).map(|r| r == self), Some(true))
203 }
204}
205
206impl OrdAdjacency for u16 {
207 fn are_adjacent(self, rhs: u16) -> bool {
208 matches!(self.checked_add(1).map(|r| r == rhs), Some(true))
209 || matches!(rhs.checked_add(1).map(|r| r == self), Some(true))
210 }
211}
212
213impl OrdAdjacency for Fixed {
214 fn are_adjacent(self, rhs: Fixed) -> bool {
215 matches!(
216 self.checked_add(Fixed::EPSILON).map(|r| r == rhs),
217 Some(true)
218 ) || matches!(
219 rhs.checked_add(Fixed::EPSILON).map(|r| r == self),
220 Some(true)
221 )
222 }
223}
224
225fn range_intersection<T: Ord + Copy>(
227 a: &RangeInclusive<T>,
228 b: &RangeInclusive<T>,
229) -> Option<RangeInclusive<T>> {
230 if a.start() <= b.end() && b.start() <= a.end() {
231 Some(*max(a.start(), b.start())..=*min(a.end(), b.end()))
232 } else {
233 None
234 }
235}
236
237fn ranges_overlap_or_adjacent<T>(a_start: T, a_end: T, b_start: T, b_end: T) -> bool
241where
242 T: Ord + OrdAdjacency,
243{
244 (a_start <= b_end && b_start <= a_end)
245 || (a_end.are_adjacent(b_start))
246 || (b_end.are_adjacent(a_start))
247}
248
249fn range_is_subset<T>(a_start: T, a_end: T, b_start: T, b_end: T) -> bool
253where
254 T: Ord,
255{
256 a_start >= b_start && a_end <= b_end
257}
258
259#[cfg(test)]
260mod test {
261
262 use super::*;
263
264 #[test]
265 #[allow(clippy::reversed_empty_ranges)]
266 fn insert_invalid() {
267 let mut map: RangeSet<u32> = Default::default();
268 map.insert(12..=11);
269 assert_eq!(map.iter().collect::<Vec<_>>(), vec![],);
270 }
271
272 #[test]
273 fn insert_non_overlapping() {
274 let mut map: RangeSet<u32> = Default::default();
275
276 map.insert(11..=11);
277 map.insert(2..=3);
278 map.insert(6..=9);
279
280 assert_eq!(map.iter().collect::<Vec<_>>(), vec![2..=3, 6..=9, 11..=11],);
281 }
282
283 #[test]
284 fn insert_subset_before() {
285 let mut map: RangeSet<u32> = Default::default();
286
287 map.insert(2..=8);
288 map.insert(3..=7);
289
290 assert_eq!(map.iter().collect::<Vec<_>>(), vec![2..=8],);
291 }
292
293 #[test]
294 fn insert_subset_after() {
295 let mut map: RangeSet<u32> = Default::default();
296
297 map.insert(2..=8);
298 map.insert(2..=7);
299 map.insert(2..=8);
300
301 assert_eq!(map.iter().collect::<Vec<_>>(), vec![2..=8],);
302 }
303
304 #[test]
305 fn insert_overlapping_before() {
306 let mut map: RangeSet<u32> = Default::default();
307
308 map.insert(2..=8);
309 map.insert(7..=11);
310
311 assert_eq!(map.iter().collect::<Vec<_>>(), vec![2..=11],);
312 }
313
314 #[test]
315 fn insert_overlapping_after() {
316 let mut map: RangeSet<u32> = Default::default();
317 map.insert(10..=14);
318 map.insert(7..=11);
319 assert_eq!(map.iter().collect::<Vec<_>>(), vec![7..=14],);
320
321 let mut map: RangeSet<u32> = Default::default();
322 map.insert(10..=14);
323 map.insert(10..=17);
324 assert_eq!(map.iter().collect::<Vec<_>>(), vec![10..=17],);
325 }
326
327 #[test]
328 fn insert_overlapping_multiple_after() {
329 let mut map: RangeSet<u32> = Default::default();
330 map.insert(10..=14);
331 map.insert(16..=17);
332 map.insert(7..=16);
333 assert_eq!(map.iter().collect::<Vec<_>>(), vec![7..=17],);
334
335 let mut map: RangeSet<u32> = Default::default();
336 map.insert(10..=14);
337 map.insert(16..=17);
338 map.insert(10..=16);
339 assert_eq!(map.iter().collect::<Vec<_>>(), vec![10..=17],);
340
341 let mut map: RangeSet<u32> = Default::default();
342 map.insert(10..=14);
343 map.insert(16..=17);
344 map.insert(10..=17);
345 assert_eq!(map.iter().collect::<Vec<_>>(), vec![10..=17],);
346 }
347
348 #[test]
349 fn insert_overlapping_before_and_after() {
350 let mut map: RangeSet<u32> = Default::default();
351
352 map.insert(6..=8);
353 map.insert(10..=14);
354 map.insert(16..=20);
355
356 map.insert(7..=19);
357
358 assert_eq!(map.iter().collect::<Vec<_>>(), vec![6..=20],);
359 }
360
361 #[test]
362 fn insert_joins_adjacent() {
363 let mut map: RangeSet<u32> = Default::default();
364 map.insert(6..=8);
365 map.insert(9..=10);
366 assert_eq!(map.iter().collect::<Vec<_>>(), vec![6..=10],);
367
368 let mut map: RangeSet<u32> = Default::default();
369 map.insert(9..=10);
370 map.insert(6..=8);
371 assert_eq!(map.iter().collect::<Vec<_>>(), vec![6..=10],);
372
373 let mut map: RangeSet<u32> = Default::default();
374 map.insert(6..=8);
375 map.insert(10..=10);
376 map.insert(9..=9);
377 assert_eq!(map.iter().collect::<Vec<_>>(), vec![6..=10],);
378 }
379
380 #[test]
381 fn from_iter_and_extend() {
382 let mut map: RangeSet<u32> = [2..=5, 13..=64, 7..=9].into_iter().collect();
383 assert_eq!(map.iter().collect::<Vec<_>>(), vec![2..=5, 7..=9, 13..=64],);
384
385 map.extend([6..=17, 100..=101]);
386
387 assert_eq!(map.iter().collect::<Vec<_>>(), vec![2..=64, 100..=101],);
388 }
389
390 #[test]
391 fn intersection() {
392 let a: RangeSet<u32> = [2..=5, 7..=9, 13..=64].into_iter().collect();
393 let b: RangeSet<u32> = [1..=3, 5..=8, 13..=64, 67..=69].into_iter().collect();
394
395 let expected = vec![2..=3, 5..=5, 7..=8, 13..=64];
396
397 assert_eq!(a.intersection(&b).collect::<Vec<_>>(), expected);
398 assert_eq!(b.intersection(&a).collect::<Vec<_>>(), expected);
399 }
400}