use std::io;
use std::ops::Bound;
use tantivy_fst::automaton::AlwaysMatch;
use tantivy_fst::Automaton;
use crate::dictionary::Dictionary;
use crate::{DeltaReader, SSTable, TermOrdinal};
pub struct StreamerBuilder<'a, TSSTable, A = AlwaysMatch>
where
A: Automaton,
A::State: Clone,
TSSTable: SSTable,
{
term_dict: &'a Dictionary<TSSTable>,
automaton: A,
lower: Bound<Vec<u8>>,
upper: Bound<Vec<u8>>,
limit: Option<u64>,
}
fn bound_as_byte_slice(bound: &Bound<Vec<u8>>) -> Bound<&[u8]> {
match bound.as_ref() {
Bound::Included(key) => Bound::Included(key.as_slice()),
Bound::Excluded(key) => Bound::Excluded(key.as_slice()),
Bound::Unbounded => Bound::Unbounded,
}
}
impl<'a, TSSTable, A> StreamerBuilder<'a, TSSTable, A>
where
A: Automaton,
A::State: Clone,
TSSTable: SSTable,
{
pub(crate) fn new(term_dict: &'a Dictionary<TSSTable>, automaton: A) -> Self {
StreamerBuilder {
term_dict,
automaton,
lower: Bound::Unbounded,
upper: Bound::Unbounded,
limit: None,
}
}
pub fn ge<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.lower = Bound::Included(bound.as_ref().to_owned());
self
}
pub fn gt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.lower = Bound::Excluded(bound.as_ref().to_owned());
self
}
pub fn le<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.upper = Bound::Included(bound.as_ref().to_owned());
self
}
pub fn lt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.upper = Bound::Excluded(bound.as_ref().to_owned());
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.limit = Some(limit);
self
}
fn delta_reader(&self) -> io::Result<DeltaReader<TSSTable::ValueReader>> {
let key_range = (
bound_as_byte_slice(&self.lower),
bound_as_byte_slice(&self.upper),
);
self.term_dict
.sstable_delta_reader_for_key_range(key_range, self.limit)
}
async fn delta_reader_async(&self) -> io::Result<DeltaReader<TSSTable::ValueReader>> {
let key_range = (
bound_as_byte_slice(&self.lower),
bound_as_byte_slice(&self.upper),
);
self.term_dict
.sstable_delta_reader_for_key_range_async(key_range, self.limit)
.await
}
fn into_stream_given_delta_reader(
self,
delta_reader: DeltaReader<<TSSTable as SSTable>::ValueReader>,
) -> io::Result<Streamer<'a, TSSTable, A>> {
let start_state = self.automaton.start();
let start_key = bound_as_byte_slice(&self.lower);
let first_term = match start_key {
Bound::Included(key) | Bound::Excluded(key) => self
.term_dict
.sstable_index
.get_block_with_key(key)
.map(|block| block.first_ordinal)
.unwrap_or(0),
Bound::Unbounded => 0,
};
Ok(Streamer {
automaton: self.automaton,
states: vec![start_state],
delta_reader,
key: Vec::new(),
term_ord: first_term.checked_sub(1),
lower_bound: self.lower,
upper_bound: self.upper,
_lifetime: std::marker::PhantomData,
})
}
pub async fn into_stream_async(self) -> io::Result<Streamer<'a, TSSTable, A>> {
let delta_reader = self.delta_reader_async().await?;
self.into_stream_given_delta_reader(delta_reader)
}
pub fn into_stream(self) -> io::Result<Streamer<'a, TSSTable, A>> {
let delta_reader = self.delta_reader()?;
self.into_stream_given_delta_reader(delta_reader)
}
}
pub struct Streamer<'a, TSSTable, A = AlwaysMatch>
where
A: Automaton,
A::State: Clone,
TSSTable: SSTable,
{
automaton: A,
states: Vec<A::State>,
delta_reader: crate::DeltaReader<TSSTable::ValueReader>,
key: Vec<u8>,
term_ord: Option<TermOrdinal>,
lower_bound: Bound<Vec<u8>>,
upper_bound: Bound<Vec<u8>>,
_lifetime: std::marker::PhantomData<&'a ()>,
}
impl<'a, TSSTable> Streamer<'a, TSSTable, AlwaysMatch>
where TSSTable: SSTable
{
pub fn empty() -> Self {
Streamer {
automaton: AlwaysMatch,
states: Vec::new(),
delta_reader: DeltaReader::empty(),
key: Vec::new(),
term_ord: None,
lower_bound: Bound::Unbounded,
upper_bound: Bound::Unbounded,
_lifetime: std::marker::PhantomData,
}
}
}
impl<'a, TSSTable, A> Streamer<'a, TSSTable, A>
where
A: Automaton,
A::State: Clone,
TSSTable: SSTable,
{
pub fn advance(&mut self) -> bool {
while self.delta_reader.advance().unwrap() {
self.term_ord = Some(
self.term_ord
.map(|term_ord| term_ord + 1u64)
.unwrap_or(0u64),
);
let common_prefix_len = self.delta_reader.common_prefix_len();
self.states.truncate(common_prefix_len + 1);
self.key.truncate(common_prefix_len);
let mut state: A::State = self.states.last().unwrap().clone();
for &b in self.delta_reader.suffix() {
state = self.automaton.accept(&state, b);
self.states.push(state.clone());
}
self.key.extend_from_slice(self.delta_reader.suffix());
let match_lower_bound = match &self.lower_bound {
Bound::Unbounded => true,
Bound::Included(lower_bound_key) => lower_bound_key[..] <= self.key[..],
Bound::Excluded(lower_bound_key) => lower_bound_key[..] < self.key[..],
};
if !match_lower_bound {
continue;
}
self.lower_bound = Bound::Unbounded;
let match_upper_bound = match &self.upper_bound {
Bound::Unbounded => true,
Bound::Included(upper_bound_key) => upper_bound_key[..] >= self.key[..],
Bound::Excluded(upper_bound_key) => upper_bound_key[..] > self.key[..],
};
if !match_upper_bound {
return false;
}
if self.automaton.is_match(&state) {
return true;
}
}
false
}
pub fn term_ord(&self) -> TermOrdinal {
self.term_ord.unwrap_or(0u64)
}
pub fn key(&self) -> &[u8] {
&self.key
}
pub fn value(&self) -> &TSSTable::Value {
self.delta_reader.value()
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Option<(&[u8], &TSSTable::Value)> {
if self.advance() {
Some((self.key(), self.value()))
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use std::io;
use common::OwnedBytes;
use crate::{Dictionary, MonotonicU64SSTable};
fn create_test_dictionary() -> io::Result<Dictionary<MonotonicU64SSTable>> {
let mut dict_builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new())?;
dict_builder.insert(b"abaisance", &0)?;
dict_builder.insert(b"abalation", &1)?;
dict_builder.insert(b"abalienate", &2)?;
dict_builder.insert(b"abandon", &3)?;
let buffer = dict_builder.finish()?;
let owned_bytes = OwnedBytes::new(buffer);
Dictionary::from_bytes(owned_bytes)
}
#[test]
fn test_sstable_stream() -> io::Result<()> {
let dict = create_test_dictionary()?;
let mut streamer = dict.stream()?;
assert!(streamer.advance());
assert_eq!(streamer.key(), b"abaisance");
assert_eq!(streamer.value(), &0);
assert!(streamer.advance());
assert_eq!(streamer.key(), b"abalation");
assert_eq!(streamer.value(), &1);
assert!(streamer.advance());
assert_eq!(streamer.key(), b"abalienate");
assert_eq!(streamer.value(), &2);
assert!(streamer.advance());
assert_eq!(streamer.key(), b"abandon");
assert_eq!(streamer.value(), &3);
assert!(!streamer.advance());
Ok(())
}
#[test]
fn test_sstable_search() -> io::Result<()> {
let term_dict = create_test_dictionary()?;
let ptn = tantivy_fst::Regex::new("ab.*t.*").unwrap();
let mut term_streamer = term_dict.search(ptn).into_stream()?;
assert!(term_streamer.advance());
assert_eq!(term_streamer.key(), b"abalation");
assert_eq!(term_streamer.value(), &1u64);
assert!(term_streamer.advance());
assert_eq!(term_streamer.key(), b"abalienate");
assert_eq!(term_streamer.value(), &2u64);
assert!(!term_streamer.advance());
Ok(())
}
}