1use std::cmp;
2use std::collections::BinaryHeap;
3use std::iter::FromIterator;
4
5use crate::raw::Output;
6use crate::stream::{IntoStreamer, Streamer};
7
8type BoxedStream<'f> = Box<dyn for<'a> Streamer<'a, Item = (&'a [u8], Output)> + 'f>;
10
11#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
19pub struct IndexedValue {
20 pub index: usize,
22 pub value: u64,
24}
25
26pub struct OpBuilder<'f> {
44 streams: Vec<BoxedStream<'f>>,
45}
46
47impl<'f> Default for OpBuilder<'f> {
48 fn default() -> Self {
49 OpBuilder { streams: vec![] }
50 }
51}
52
53impl<'f> OpBuilder<'f> {
54 pub fn add<I, S>(mut self, stream: I) -> Self
62 where
63 I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
64 S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
65 {
66 self.push(stream);
67 self
68 }
69
70 pub fn push<I, S>(&mut self, stream: I)
75 where
76 I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
77 S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
78 {
79 self.streams.push(Box::new(stream.into_stream()));
80 }
81
82 #[inline]
86 pub fn chain(self) -> Chain<'f> {
87 Chain::new(self.streams)
88 }
89 #[inline]
99 pub fn union(self) -> Union<'f> {
100 Union {
101 heap: StreamHeap::new(self.streams),
102 outs: vec![],
103 cur_slot: None,
104 }
105 }
106
107 #[inline]
117 pub fn intersection(self) -> Intersection<'f> {
118 Intersection {
119 heap: StreamHeap::new(self.streams),
120 outs: vec![],
121 cur_slot: None,
122 }
123 }
124
125 #[inline]
137 pub fn difference(mut self) -> Difference<'f> {
138 let first = self.streams.swap_remove(0);
139 Difference {
140 set: first,
141 key: vec![],
142 heap: StreamHeap::new(self.streams),
143 outs: vec![],
144 }
145 }
146
147 #[inline]
164 pub fn symmetric_difference(self) -> SymmetricDifference<'f> {
165 SymmetricDifference {
166 heap: StreamHeap::new(self.streams),
167 outs: vec![],
168 cur_slot: None,
169 }
170 }
171}
172
173impl<'f, I, S> Extend<I> for OpBuilder<'f>
174where
175 I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
176 S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
177{
178 fn extend<T>(&mut self, it: T)
179 where
180 T: IntoIterator<Item = I>,
181 {
182 for stream in it {
183 self.push(stream);
184 }
185 }
186}
187
188impl<'f, I, S> FromIterator<I> for OpBuilder<'f>
189where
190 I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
191 S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
192{
193 fn from_iter<T>(it: T) -> Self
194 where
195 T: IntoIterator<Item = I>,
196 {
197 let mut op = OpBuilder::default();
198 op.extend(it);
199 op
200 }
201}
202
203pub struct Chain<'f> {
207 streams: Vec<BoxedStream<'f>>,
208 current_stream: BoxedStream<'f>,
209 key: Vec<u8>,
210}
211
212impl<'f> Chain<'f> {
213 fn new(mut streams: Vec<BoxedStream<'f>>) -> Self {
215 streams.reverse();
216
217 let current_stream = streams.pop().unwrap();
218 Chain {
219 streams,
220 current_stream,
221 key: vec![],
222 }
223 }
224}
225
226impl<'a, 'f> Streamer<'a> for Chain<'f> {
227 type Item = (&'a [u8], Output);
228
229 fn next(&'a mut self) -> Option<Self::Item> {
230 loop {
231 if let Some((key, val)) = self.current_stream.next() {
232 self.key.clear();
233 self.key.extend_from_slice(&key);
234 return Some((&self.key, val));
235 } else {
236 if let Some(next_stream) = self.streams.pop() {
237 self.current_stream = next_stream;
238 } else {
239 return None;
240 }
241 }
242 }
243 }
244}
245
246pub struct Union<'f> {
250 heap: StreamHeap<'f>,
251 outs: Vec<IndexedValue>,
252 cur_slot: Option<Slot>,
253}
254
255impl<'a, 'f> Streamer<'a> for Union<'f> {
256 type Item = (&'a [u8], &'a [IndexedValue]);
257
258 fn next(&'a mut self) -> Option<Self::Item> {
259 if let Some(slot) = self.cur_slot.take() {
260 self.heap.refill(slot);
261 }
262 let slot = match self.heap.pop() {
263 None => return None,
264 Some(slot) => {
265 self.cur_slot = Some(slot);
266 self.cur_slot.as_ref().unwrap()
267 }
268 };
269 self.outs.clear();
270 self.outs.push(slot.indexed_value());
271 while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
272 self.outs.push(slot2.indexed_value());
273 self.heap.refill(slot2);
274 }
275 Some((slot.input(), &self.outs))
276 }
277}
278
279pub struct Intersection<'f> {
284 heap: StreamHeap<'f>,
285 outs: Vec<IndexedValue>,
286 cur_slot: Option<Slot>,
287}
288
289impl<'a, 'f> Streamer<'a> for Intersection<'f> {
290 type Item = (&'a [u8], &'a [IndexedValue]);
291
292 fn next(&'a mut self) -> Option<Self::Item> {
293 if let Some(slot) = self.cur_slot.take() {
294 self.heap.refill(slot);
295 }
296 loop {
297 let slot = match self.heap.pop() {
298 None => return None,
299 Some(slot) => slot,
300 };
301 self.outs.clear();
302 self.outs.push(slot.indexed_value());
303 let mut popped: usize = 1;
304 while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
305 self.outs.push(slot2.indexed_value());
306 self.heap.refill(slot2);
307 popped += 1;
308 }
309 if popped < self.heap.num_slots() {
310 self.heap.refill(slot);
311 } else {
312 self.cur_slot = Some(slot);
313 let key = self.cur_slot.as_ref().unwrap().input();
314 return Some((key, &self.outs));
315 }
316 }
317 }
318}
319
320pub struct Difference<'f> {
329 set: BoxedStream<'f>,
330 key: Vec<u8>,
331 heap: StreamHeap<'f>,
332 outs: Vec<IndexedValue>,
333}
334
335impl<'a, 'f> Streamer<'a> for Difference<'f> {
336 type Item = (&'a [u8], &'a [IndexedValue]);
337
338 fn next(&'a mut self) -> Option<Self::Item> {
339 loop {
340 match self.set.next() {
341 None => return None,
342 Some((key, out)) => {
343 self.key.clear();
344 self.key.extend(key);
345 self.outs.clear();
346 self.outs.push(IndexedValue {
347 index: 0,
348 value: out.value(),
349 });
350 }
351 };
352 let mut unique = true;
353 while let Some(slot) = self.heap.pop_if_le(&self.key) {
354 if slot.input() == &*self.key {
355 unique = false;
356 }
357 self.heap.refill(slot);
358 }
359 if unique {
360 return Some((&self.key, &self.outs));
361 }
362 }
363 }
364}
365
366pub struct SymmetricDifference<'f> {
371 heap: StreamHeap<'f>,
372 outs: Vec<IndexedValue>,
373 cur_slot: Option<Slot>,
374}
375
376impl<'a, 'f> Streamer<'a> for SymmetricDifference<'f> {
377 type Item = (&'a [u8], &'a [IndexedValue]);
378
379 fn next(&'a mut self) -> Option<Self::Item> {
380 if let Some(slot) = self.cur_slot.take() {
381 self.heap.refill(slot);
382 }
383 loop {
384 let slot = match self.heap.pop() {
385 None => return None,
386 Some(slot) => slot,
387 };
388 self.outs.clear();
389 self.outs.push(slot.indexed_value());
390 let mut popped: usize = 1;
391 while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
392 self.outs.push(slot2.indexed_value());
393 self.heap.refill(slot2);
394 popped += 1;
395 }
396 if popped % 2 == 0 {
399 self.heap.refill(slot);
400 } else {
401 self.cur_slot = Some(slot);
402 let key = self.cur_slot.as_ref().unwrap().input();
403 return Some((key, &self.outs));
404 }
405 }
406 }
407}
408
409struct StreamHeap<'f> {
410 rdrs: Vec<BoxedStream<'f>>,
411 heap: BinaryHeap<Slot>,
412}
413
414impl<'f> StreamHeap<'f> {
415 fn new(streams: Vec<BoxedStream<'f>>) -> StreamHeap<'f> {
416 let mut u = StreamHeap {
417 rdrs: streams,
418 heap: BinaryHeap::new(),
419 };
420 for i in 0..u.rdrs.len() {
421 u.refill(Slot::new(i));
422 }
423 u
424 }
425
426 fn pop(&mut self) -> Option<Slot> {
427 self.heap.pop()
428 }
429
430 fn peek_is_duplicate(&self, key: &[u8]) -> bool {
431 self.heap.peek().map(|s| s.input() == key).unwrap_or(false)
432 }
433
434 fn pop_if_equal(&mut self, key: &[u8]) -> Option<Slot> {
435 if self.peek_is_duplicate(key) {
436 self.pop()
437 } else {
438 None
439 }
440 }
441
442 fn pop_if_le(&mut self, key: &[u8]) -> Option<Slot> {
443 if self.heap.peek().map(|s| s.input() <= key).unwrap_or(false) {
444 self.pop()
445 } else {
446 None
447 }
448 }
449
450 fn num_slots(&self) -> usize {
451 self.rdrs.len()
452 }
453
454 fn refill(&mut self, mut slot: Slot) {
455 if let Some((input, output)) = self.rdrs[slot.idx].next() {
456 slot.set_input(input);
457 slot.set_output(output);
458 self.heap.push(slot);
459 }
460 }
461}
462
463#[derive(Debug, Eq, PartialEq)]
464struct Slot {
465 idx: usize,
466 input: Vec<u8>,
467 output: Output,
468}
469
470impl Slot {
471 fn new(rdr_idx: usize) -> Slot {
472 Slot {
473 idx: rdr_idx,
474 input: Vec::with_capacity(64),
475 output: Output::zero(),
476 }
477 }
478
479 fn indexed_value(&self) -> IndexedValue {
480 IndexedValue {
481 index: self.idx,
482 value: self.output.value(),
483 }
484 }
485
486 fn input(&self) -> &[u8] {
487 &self.input
488 }
489
490 fn set_input(&mut self, input: &[u8]) {
491 self.input.clear();
492 self.input.extend(input);
493 }
494
495 fn set_output(&mut self, output: Output) {
496 self.output = output;
497 }
498}
499
500impl PartialOrd for Slot {
501 fn partial_cmp(&self, other: &Slot) -> Option<cmp::Ordering> {
502 (&self.input, self.output)
503 .partial_cmp(&(&other.input, other.output))
504 .map(|ord| ord.reverse())
505 }
506}
507
508impl Ord for Slot {
509 fn cmp(&self, other: &Slot) -> cmp::Ordering {
510 self.partial_cmp(other).unwrap()
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use crate::raw::tests::{fst_map, fst_set};
517 use crate::raw::Fst;
518 use crate::raw::Output;
519 use crate::stream::{IntoStreamer, Streamer};
520
521 use super::OpBuilder;
522
523 fn s(string: &str) -> String {
524 string.to_owned()
525 }
526
527 macro_rules! create_set_op {
528 ($name:ident, $op:ident) => {
529 fn $name(sets: Vec<Vec<&str>>) -> Vec<String> {
530 let fsts: Vec<Fst> = sets.into_iter().map(fst_set).collect();
531 let op: OpBuilder = fsts.iter().collect();
532 let mut stream = op.$op().into_stream();
533 let mut keys = vec![];
534 while let Some((key, _)) = stream.next() {
535 keys.push(String::from_utf8(key.to_vec()).unwrap());
536 }
537 keys
538 }
539 };
540 }
541
542 macro_rules! create_map_op {
543 ($name:ident, $op:ident) => {
544 fn $name(sets: Vec<Vec<(&str, u64)>>) -> Vec<(String, u64)> {
545 let fsts: Vec<Fst> = sets.into_iter().map(fst_map).collect();
546 let op: OpBuilder = fsts.iter().collect();
547 let mut stream = op.$op().into_stream();
548 let mut keys = vec![];
549 while let Some((key, outs)) = stream.next() {
550 let merged = outs.iter().fold(0, |a, b| a + b.value);
551 let s = String::from_utf8(key.to_vec()).unwrap();
552 keys.push((s, merged));
553 }
554 keys
555 }
556 };
557 }
558
559 macro_rules! create_map_op_chain {
560 ($name:ident, $op:ident) => {
561 fn $name(sets: Vec<Vec<(&str, u64)>>) -> Vec<(String, Output)> {
562 let fsts: Vec<Fst> = sets.into_iter().map(fst_map).collect();
563 let op: OpBuilder = fsts.iter().collect();
564 let mut stream = op.$op().into_stream();
565 let mut keys = vec![];
566 while let Some((key, outs)) = stream.next() {
567 let s = String::from_utf8(key.to_vec()).unwrap();
568 keys.push((s, outs));
569 }
570 keys
571 }
572 };
573 }
574 create_set_op!(fst_union, union);
575 create_set_op!(fst_intersection, intersection);
576 create_set_op!(fst_symmetric_difference, symmetric_difference);
577 create_set_op!(fst_difference, difference);
578 create_set_op!(fst_chain, chain);
579 create_map_op!(fst_union_map, union);
580 create_map_op!(fst_intersection_map, intersection);
581 create_map_op!(fst_symmetric_difference_map, symmetric_difference);
582 create_map_op!(fst_difference_map, difference);
583 create_map_op_chain!(fst_chain_map, chain);
584
585 #[test]
586 fn chain_set() {
587 let v = fst_chain(vec![vec!["a", "b", "c"], vec!["x", "y", "z"]]);
588 assert_eq!(v, vec!["a", "b", "c", "x", "y", "z"]);
589 }
590 #[test]
591 fn chain_set_wrong_order() {
592 let v = fst_chain(vec![vec!["a", "b", "c", "z"], vec!["x", "y", "z"]]);
594 assert_eq!(v, vec!["a", "b", "c", "z", "x", "y", "z"]);
595 }
596 #[test]
597 fn chain_map() {
598 let v = fst_chain_map(vec![
599 vec![("a", 1), ("b", 2), ("c", 3)],
600 vec![("x", 1), ("y", 2), ("z", 3)],
601 ]);
602 assert_eq!(
603 v,
604 vec![
605 (s("a"), Output::new(1)),
606 (s("b"), Output::new(2)),
607 (s("c"), Output::new(3)),
608 (s("x"), Output::new(1)),
609 (s("y"), Output::new(2)),
610 (s("z"), Output::new(3)),
611 ]
612 );
613 }
614 #[test]
615 fn union_set() {
616 let v = fst_union(vec![vec!["a", "b", "c"], vec!["x", "y", "z"]]);
617 assert_eq!(v, vec!["a", "b", "c", "x", "y", "z"]);
618 }
619
620 #[test]
621 fn union_set_dupes() {
622 let v = fst_union(vec![vec!["aa", "b", "cc"], vec!["b", "cc", "z"]]);
623 assert_eq!(v, vec!["aa", "b", "cc", "z"]);
624 }
625
626 #[test]
627 fn union_map() {
628 let v = fst_union_map(vec![
629 vec![("a", 1), ("b", 2), ("c", 3)],
630 vec![("x", 1), ("y", 2), ("z", 3)],
631 ]);
632 assert_eq!(
633 v,
634 vec![
635 (s("a"), 1),
636 (s("b"), 2),
637 (s("c"), 3),
638 (s("x"), 1),
639 (s("y"), 2),
640 (s("z"), 3),
641 ]
642 );
643 }
644
645 #[test]
646 fn union_map_dupes() {
647 let v = fst_union_map(vec![
648 vec![("aa", 1), ("b", 2), ("cc", 3)],
649 vec![("b", 1), ("cc", 2), ("z", 3)],
650 vec![("b", 1)],
651 ]);
652 assert_eq!(
653 v,
654 vec![(s("aa"), 1), (s("b"), 4), (s("cc"), 5), (s("z"), 3),]
655 );
656 }
657
658 #[test]
659 fn intersect_set() {
660 let v = fst_intersection(vec![vec!["a", "b", "c"], vec!["x", "y", "z"]]);
661 assert_eq!(v, Vec::<String>::new());
662 }
663
664 #[test]
665 fn intersect_set_dupes() {
666 let v = fst_intersection(vec![vec!["aa", "b", "cc"], vec!["b", "cc", "z"]]);
667 assert_eq!(v, vec!["b", "cc"]);
668 }
669
670 #[test]
671 fn intersect_map() {
672 let v = fst_intersection_map(vec![
673 vec![("a", 1), ("b", 2), ("c", 3)],
674 vec![("x", 1), ("y", 2), ("z", 3)],
675 ]);
676 assert_eq!(v, Vec::<(String, u64)>::new());
677 }
678
679 #[test]
680 fn intersect_map_dupes() {
681 let v = fst_intersection_map(vec![
682 vec![("aa", 1), ("b", 2), ("cc", 3)],
683 vec![("b", 1), ("cc", 2), ("z", 3)],
684 vec![("b", 1)],
685 ]);
686 assert_eq!(v, vec![(s("b"), 4)]);
687 }
688
689 #[test]
690 fn symmetric_difference() {
691 let v = fst_symmetric_difference(vec![vec!["a", "b", "c"], vec!["a", "b"], vec!["a"]]);
692 assert_eq!(v, vec!["a", "c"]);
693 }
694
695 #[test]
696 fn symmetric_difference_map() {
697 let v = fst_symmetric_difference_map(vec![
698 vec![("a", 1), ("b", 2), ("c", 3)],
699 vec![("a", 1), ("b", 2)],
700 vec![("a", 1)],
701 ]);
702 assert_eq!(v, vec![(s("a"), 3), (s("c"), 3)]);
703 }
704
705 #[test]
706 fn difference() {
707 let v = fst_difference(vec![vec!["a", "b", "c"], vec!["a", "b"], vec!["a"]]);
708 assert_eq!(v, vec!["c"]);
709 }
710
711 #[test]
712 fn difference2() {
713 let v = fst_difference(vec![vec!["a", "c"], vec!["b", "c"]]);
715 assert_eq!(v, vec!["a"]);
716 let v = fst_difference(vec![vec!["bar", "foo"], vec!["baz", "foo"]]);
717 assert_eq!(v, vec!["bar"]);
718 }
719
720 #[test]
721 fn difference_map() {
722 let v = fst_difference_map(vec![
723 vec![("a", 1), ("b", 2), ("c", 3)],
724 vec![("a", 1), ("b", 2)],
725 vec![("a", 1)],
726 ]);
727 assert_eq!(v, vec![(s("c"), 3)]);
728 }
729}