use std::cmp::Ordering;
use std::collections::HashSet;
use std::hash::{BuildHasher, Hasher};
use std::{cmp, io};
use siphasher::sip::SipHasher24;
pub const P: u8 = 19;
pub const M: u64 = 784_931;
pub struct SipHasher24Builder {
k0: u64,
k1: u64,
}
impl SipHasher24Builder {
pub fn new(k0: u64, k1: u64) -> SipHasher24Builder {
SipHasher24Builder { k0, k1 }
}
}
impl BuildHasher for SipHasher24Builder {
type Hasher = SipHasher24;
fn build_hasher(&self) -> Self::Hasher {
SipHasher24::new_with_keys(self.k0, self.k1)
}
}
pub struct GCSFilterReader<H> {
filter: GCSFilter<H>,
m: u64,
}
impl<H: BuildHasher> GCSFilterReader<H> {
pub fn new(hasher_builder: H, m: u64, p: u8) -> GCSFilterReader<H> {
GCSFilterReader {
filter: GCSFilter::new(hasher_builder, p),
m,
}
}
pub fn match_any(
&self,
reader: &mut dyn io::Read,
query: &mut dyn Iterator<Item = &[u8]>,
) -> Result<bool, io::Error> {
let mut decoder = reader;
let mut length_data = [0u8; 8];
let n_elements = decoder
.read_exact(&mut length_data)
.map(|()| u64::from_le_bytes(length_data))
.unwrap_or(0);
let reader = &mut decoder;
let nm = n_elements * self.m;
let mut mapped = query
.map(|e| map_to_range(self.filter.hash(e), nm))
.collect::<Vec<_>>();
mapped.sort_unstable();
if mapped.is_empty() {
return Ok(false);
}
if n_elements == 0 {
return Ok(false);
}
let mut reader = BitStreamReader::new(reader);
let mut data = self.filter.golomb_rice_decode(&mut reader)?;
let mut remaining = n_elements - 1;
for p in mapped {
loop {
match data.cmp(&p) {
Ordering::Equal => {
return Ok(true);
}
Ordering::Less => {
if remaining > 0 {
data += self.filter.golomb_rice_decode(&mut reader)?;
remaining -= 1;
} else {
return Ok(false);
}
}
Ordering::Greater => {
break;
}
}
}
}
Ok(false)
}
pub fn match_all(
&self,
reader: &mut dyn io::Read,
query: &mut dyn Iterator<Item = &[u8]>,
) -> Result<bool, io::Error> {
let mut decoder = reader;
let mut length_data = [0u8; 8];
let n_elements = decoder
.read_exact(&mut length_data)
.map(|()| u64::from_le_bytes(length_data))
.unwrap_or(0);
let reader = &mut decoder;
let nm = n_elements * self.m;
let mut mapped = query
.map(|e| map_to_range(self.filter.hash(e), nm))
.collect::<Vec<_>>();
mapped.sort_unstable();
mapped.dedup();
if mapped.is_empty() {
return Ok(false);
}
if n_elements == 0 {
return Ok(false);
}
let mut reader = BitStreamReader::new(reader);
let mut data = self.filter.golomb_rice_decode(&mut reader)?;
let mut remaining = n_elements - 1;
for p in mapped {
loop {
match data.cmp(&p) {
Ordering::Equal => {
break;
}
Ordering::Less => {
if remaining > 0 {
data += self.filter.golomb_rice_decode(&mut reader)?;
remaining -= 1;
} else {
return Ok(false);
}
}
Ordering::Greater => {
return Ok(false);
}
}
}
}
Ok(true)
}
}
fn map_to_range(hash: u64, nm: u64) -> u64 {
((hash as u128 * nm as u128) >> 64) as u64
}
pub struct GCSFilterWriter<'a, H> {
filter: GCSFilter<H>,
writer: &'a mut dyn io::Write,
elements: HashSet<Vec<u8>>,
m: u64,
}
impl<'a, H: BuildHasher> GCSFilterWriter<'a, H> {
pub fn new(
writer: &'a mut dyn io::Write,
hasher_builder: H,
m: u64,
p: u8,
) -> GCSFilterWriter<'a, H> {
GCSFilterWriter {
filter: GCSFilter::new(hasher_builder, p),
writer,
elements: HashSet::new(),
m,
}
}
pub fn add_element(&mut self, element: &[u8]) {
if !element.is_empty() {
self.elements.insert(element.to_vec());
}
}
pub fn finish(&mut self) -> Result<usize, io::Error> {
let nm = self.elements.len() as u64 * self.m;
let mut mapped: Vec<_> = self
.elements
.iter()
.map(|e| map_to_range(self.filter.hash(e.as_slice()), nm))
.collect();
mapped.sort_unstable();
let mut wrote = self.writer.write(&(mapped.len() as u64).to_le_bytes())?;
let mut writer = BitStreamWriter::new(self.writer);
let mut last = 0;
for data in mapped {
wrote += self.filter.golomb_rice_encode(&mut writer, data - last)?;
last = data;
}
wrote += writer.flush()?;
Ok(wrote)
}
}
struct GCSFilter<H> {
hasher_builder: H,
p: u8,
}
impl<H: BuildHasher> GCSFilter<H> {
fn new(hasher_builder: H, p: u8) -> GCSFilter<H> {
GCSFilter { hasher_builder, p }
}
fn golomb_rice_encode(&self, writer: &mut BitStreamWriter, n: u64) -> Result<usize, io::Error> {
let mut wrote = 0;
let mut q = n >> self.p;
while q > 0 {
let nbits = cmp::min(q, 64);
wrote += writer.write(!0u64, nbits as u8)?;
q -= nbits;
}
wrote += writer.write(0, 1)?;
wrote += writer.write(n, self.p)?;
Ok(wrote)
}
fn golomb_rice_decode(&self, reader: &mut BitStreamReader) -> Result<u64, io::Error> {
let mut q = 0u64;
while reader.read(1)? == 1 {
q += 1;
}
let r = reader.read(self.p)?;
Ok((q << self.p) + r)
}
fn hash(&self, element: &[u8]) -> u64 {
let mut hasher = self.hasher_builder.build_hasher();
hasher.write(element);
hasher.finish()
}
}
pub struct BitStreamReader<'a> {
buffer: [u8; 1],
offset: u8,
reader: &'a mut dyn io::Read,
}
impl<'a> BitStreamReader<'a> {
pub fn new(reader: &'a mut dyn io::Read) -> BitStreamReader {
BitStreamReader {
buffer: [0u8],
reader,
offset: 8,
}
}
pub fn read(&mut self, mut nbits: u8) -> Result<u64, io::Error> {
if nbits > 64 {
return Err(io::Error::new(
io::ErrorKind::Other,
"can not read more than 64 bits at once",
));
}
let mut data = 0u64;
while nbits > 0 {
if self.offset == 8 {
self.reader.read_exact(&mut self.buffer)?;
self.offset = 0;
}
let bits = cmp::min(8 - self.offset, nbits);
data <<= bits;
data |= ((self.buffer[0] << self.offset) >> (8 - bits)) as u64;
self.offset += bits;
nbits -= bits;
}
Ok(data)
}
}
pub struct BitStreamWriter<'a> {
buffer: [u8; 1],
offset: u8,
writer: &'a mut dyn io::Write,
}
impl<'a> BitStreamWriter<'a> {
pub fn new(writer: &'a mut dyn io::Write) -> BitStreamWriter {
BitStreamWriter {
buffer: [0u8],
writer,
offset: 0,
}
}
pub fn write(&mut self, data: u64, mut nbits: u8) -> Result<usize, io::Error> {
if nbits > 64 {
return Err(io::Error::new(
io::ErrorKind::Other,
"can not write more than 64 bits at once",
));
}
let mut wrote = 0;
while nbits > 0 {
let bits = cmp::min(8 - self.offset, nbits);
self.buffer[0] |= ((data << (64 - nbits)) >> (64 - 8 + self.offset)) as u8;
self.offset += bits;
nbits -= bits;
if self.offset == 8 {
wrote += self.flush()?;
}
}
Ok(wrote)
}
pub fn flush(&mut self) -> Result<usize, io::Error> {
if self.offset > 0 {
self.writer.write_all(&self.buffer)?;
self.buffer[0] = 0u8;
self.offset = 0;
Ok(1)
} else {
Ok(0)
}
}
}
#[cfg(test)]
mod test {
use super::*;
use std::collections::HashSet;
use std::io::Cursor;
#[test]
fn test_filter() {
let mut patterns = HashSet::new();
patterns.insert(hex::decode("000000").unwrap());
patterns.insert(hex::decode("111111").unwrap());
patterns.insert(hex::decode("222222").unwrap());
patterns.insert(hex::decode("333333").unwrap());
patterns.insert(hex::decode("444444").unwrap());
patterns.insert(hex::decode("555555").unwrap());
patterns.insert(hex::decode("666666").unwrap());
patterns.insert(hex::decode("777777").unwrap());
patterns.insert(hex::decode("888888").unwrap());
patterns.insert(hex::decode("999999").unwrap());
patterns.insert(hex::decode("aaaaaa").unwrap());
patterns.insert(hex::decode("bbbbbb").unwrap());
patterns.insert(hex::decode("cccccc").unwrap());
patterns.insert(hex::decode("dddddd").unwrap());
patterns.insert(hex::decode("eeeeee").unwrap());
patterns.insert(hex::decode("ffffff").unwrap());
let mut out = Cursor::new(Vec::new());
{
let mut writer = GCSFilterWriter::new(&mut out, SipHasher24Builder::new(0, 0), M, P);
for p in &patterns {
writer.add_element(p.as_slice());
}
writer.finish().unwrap();
}
let bytes = out.into_inner();
{
let query = vec![
hex::decode("abcdef").unwrap(),
hex::decode("eeeeee").unwrap(),
];
let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
let mut input = Cursor::new(bytes.clone());
assert!(reader
.match_any(&mut input, &mut query.iter().map(|v| v.as_slice()))
.unwrap());
}
{
let query = vec![
hex::decode("abcdef").unwrap(),
hex::decode("123456").unwrap(),
];
let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
let mut input = Cursor::new(bytes.clone());
assert!(!reader
.match_any(&mut input, &mut query.iter().map(|v| v.as_slice()))
.unwrap());
}
{
let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
let mut query = Vec::new();
for p in &patterns {
query.push(p.clone());
}
let mut input = Cursor::new(bytes.clone());
assert!(reader
.match_all(&mut input, &mut query.iter().map(|v| v.as_slice()))
.unwrap());
}
{
let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
let mut query = Vec::new();
for p in &patterns {
query.push(p.clone());
}
query.push(hex::decode("abcdef").unwrap());
let mut input = Cursor::new(bytes.clone());
assert!(!reader
.match_all(&mut input, &mut query.iter().map(|v| v.as_slice()))
.unwrap());
}
{
let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
let mut input = Cursor::new(bytes.clone());
let query: Vec<Vec<u8>> = Vec::new();
assert!(!reader
.match_any(&mut input, &mut query.iter().map(|v| v.as_slice()))
.unwrap());
}
{
let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
let mut input = Cursor::new(bytes);
let query: Vec<Vec<u8>> = Vec::new();
assert!(!reader
.match_all(&mut input, &mut query.iter().map(|v| v.as_slice()))
.unwrap());
}
}
#[test]
fn test_bit_stream() {
let mut out = Cursor::new(Vec::new());
{
let mut writer = BitStreamWriter::new(&mut out);
writer.write(0, 1).unwrap(); writer.write(2, 2).unwrap(); writer.write(6, 3).unwrap(); writer.write(11, 4).unwrap(); writer.write(1, 5).unwrap(); writer.write(32, 6).unwrap(); writer.write(7, 7).unwrap(); writer.flush().unwrap();
}
let bytes = out.into_inner();
assert_eq!(
"01011010110000110000000001110000",
format!(
"{:08b}{:08b}{:08b}{:08b}",
bytes[0], bytes[1], bytes[2], bytes[3]
)
);
{
let mut input = Cursor::new(bytes);
let mut reader = BitStreamReader::new(&mut input);
assert_eq!(reader.read(1).unwrap(), 0);
assert_eq!(reader.read(2).unwrap(), 2);
assert_eq!(reader.read(3).unwrap(), 6);
assert_eq!(reader.read(4).unwrap(), 11);
assert_eq!(reader.read(5).unwrap(), 1);
assert_eq!(reader.read(6).unwrap(), 32);
assert_eq!(reader.read(7).unwrap(), 7);
assert!(reader.read(5).is_err());
}
}
}