use arrow_array::cast::AsArray;
use arrow_array::types::ByteArrayType;
use arrow_array::{
downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray,
Datum, FixedSizeBinaryArray, GenericByteArray,
};
use arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
use arrow_schema::ArrowError;
use arrow_select::take::take;
use std::ops::Not;
#[derive(Debug, Copy, Clone)]
enum Op {
Equal,
NotEqual,
Less,
LessEqual,
Greater,
GreaterEqual,
Distinct,
NotDistinct,
}
impl std::fmt::Display for Op {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Op::Equal => write!(f, "=="),
Op::NotEqual => write!(f, "!="),
Op::Less => write!(f, "<"),
Op::LessEqual => write!(f, "<="),
Op::Greater => write!(f, ">"),
Op::GreaterEqual => write!(f, ">="),
Op::Distinct => write!(f, "IS DISTINCT FROM"),
Op::NotDistinct => write!(f, "IS NOT DISTINCT FROM"),
}
}
}
pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
compare_op(Op::Equal, lhs, rhs)
}
pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
compare_op(Op::NotEqual, lhs, rhs)
}
pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
compare_op(Op::Less, lhs, rhs)
}
pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
compare_op(Op::LessEqual, lhs, rhs)
}
pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
compare_op(Op::Greater, lhs, rhs)
}
pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
compare_op(Op::GreaterEqual, lhs, rhs)
}
pub fn distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
compare_op(Op::Distinct, lhs, rhs)
}
pub fn not_distinct(
lhs: &dyn Datum,
rhs: &dyn Datum,
) -> Result<BooleanArray, ArrowError> {
compare_op(Op::NotDistinct, lhs, rhs)
}
#[inline(never)]
fn compare_op(
op: Op,
lhs: &dyn Datum,
rhs: &dyn Datum,
) -> Result<BooleanArray, ArrowError> {
use arrow_schema::DataType::*;
let (l, l_s) = lhs.get();
let (r, r_s) = rhs.get();
let l_len = l.len();
let r_len = r.len();
if l_len != r_len && !l_s && !r_s {
return Err(ArrowError::InvalidArgumentError(format!(
"Cannot compare arrays of different lengths, got {l_len} vs {r_len}"
)));
}
let len = match l_s {
true => r_len,
false => l_len,
};
let l_nulls = l.logical_nulls();
let r_nulls = r.logical_nulls();
let l_v = l.as_any_dictionary_opt();
let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
let l_t = l.data_type();
let r_v = r.as_any_dictionary_opt();
let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
let r_t = r.data_type();
if l_t != r_t || l_t.is_nested() {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}
let values = || -> BooleanBuffer {
let d = downcast_primitive_array! {
(l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
(Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
(Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
(LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
(Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
(Null, Null) => None,
_ => unreachable!(),
};
d.unwrap_or_else(|| BooleanBuffer::new_unset(len))
};
let l_nulls = l_nulls.filter(|n| n.null_count() > 0);
let r_nulls = r_nulls.filter(|n| n.null_count() > 0);
Ok(match (l_nulls, l_s, r_nulls, r_s) {
(Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => {
match op {
Op::Distinct => {
let values = values();
let l = l.inner().bit_chunks().iter_padded();
let r = r.inner().bit_chunks().iter_padded();
let ne = values.bit_chunks().iter_padded();
let c = |((l, r), n)| ((l ^ r) | (l & r & n));
let buffer = l.zip(r).zip(ne).map(c).collect();
BooleanBuffer::new(buffer, 0, len).into()
}
Op::NotDistinct => {
let values = values();
let l = l.inner().bit_chunks().iter_padded();
let r = r.inner().bit_chunks().iter_padded();
let e = values.bit_chunks().iter_padded();
let c = |((l, r), e)| u64::not(l | r) | (l & r & e);
let buffer = l.zip(r).zip(e).map(c).collect();
BooleanBuffer::new(buffer, 0, len).into()
}
_ => BooleanArray::new(values(), NullBuffer::union(Some(&l), Some(&r))),
}
}
(Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => {
match op {
Op::Distinct => a.into_inner().into(),
Op::NotDistinct => a.into_inner().not().into(),
_ => BooleanArray::new_null(len),
}
}
(Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar) => {
match is_scalar {
true => match op {
Op::Distinct => BooleanBuffer::new_set(len).into(),
Op::NotDistinct => BooleanBuffer::new_unset(len).into(),
_ => BooleanArray::new_null(len),
},
false => match op {
Op::Distinct => {
let values = values();
let l = nulls.inner().bit_chunks().iter_padded();
let ne = values.bit_chunks().iter_padded();
let c = |(l, n)| u64::not(l) | n;
let buffer = l.zip(ne).map(c).collect();
BooleanBuffer::new(buffer, 0, len).into()
}
Op::NotDistinct => (nulls.inner() & &values()).into(),
_ => BooleanArray::new(values(), Some(nulls)),
},
}
}
(None, _, None, _) => BooleanArray::new(values(), None),
})
}
fn apply<T: ArrayOrd>(
op: Op,
l: T,
l_s: bool,
l_v: Option<&dyn AnyDictionaryArray>,
r: T,
r_s: bool,
r_v: Option<&dyn AnyDictionaryArray>,
) -> Option<BooleanBuffer> {
if l.len() == 0 || r.len() == 0 {
return None; }
if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) {
let l_v = l_v
.map(|x| x.normalized_keys())
.unwrap_or_else(|| (0..l.len()).collect());
let r_v = r_v
.map(|x| x.normalized_keys())
.unwrap_or_else(|| (0..r.len()).collect());
assert_eq!(l_v.len(), r_v.len()); Some(match op {
Op::Equal | Op::NotDistinct => {
apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq)
}
Op::NotEqual | Op::Distinct => {
apply_op_vectored(l, &l_v, r, &r_v, true, T::is_eq)
}
Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt),
Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true, T::is_lt),
Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false, T::is_lt),
Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_lt),
})
} else {
let l_s = l_s.then(|| l_v.map(|x| x.normalized_keys()[0]).unwrap_or_default());
let r_s = r_s.then(|| r_v.map(|x| x.normalized_keys()[0]).unwrap_or_default());
let buffer = match op {
Op::Equal | Op::NotDistinct => apply_op(l, l_s, r, r_s, false, T::is_eq),
Op::NotEqual | Op::Distinct => apply_op(l, l_s, r, r_s, true, T::is_eq),
Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt),
Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt),
Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt),
Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt),
};
Some(match (l_v, r_v) {
(Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer),
(_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer),
_ => buffer,
})
}
}
fn take_bits(v: &dyn AnyDictionaryArray, buffer: BooleanBuffer) -> BooleanBuffer {
let array = take(&BooleanArray::new(buffer, None), v.keys(), None).unwrap();
array.as_boolean().values().clone()
}
fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) -> BooleanBuffer {
let mut buffer = MutableBuffer::new(ceil(len, 64) * 8);
let chunks = len / 64;
let remainder = len % 64;
for chunk in 0..chunks {
let mut packed = 0;
for bit_idx in 0..64 {
let i = bit_idx + chunk * 64;
packed |= (f(i) as u64) << bit_idx;
}
if neg {
packed = !packed
}
unsafe { buffer.push_unchecked(packed) }
}
if remainder != 0 {
let mut packed = 0;
for bit_idx in 0..remainder {
let i = bit_idx + chunks * 64;
packed |= (f(i) as u64) << bit_idx;
}
if neg {
packed = !packed
}
unsafe { buffer.push_unchecked(packed) }
}
BooleanBuffer::new(buffer.into(), 0, len)
}
fn apply_op<T: ArrayOrd>(
l: T,
l_s: Option<usize>,
r: T,
r_s: Option<usize>,
neg: bool,
op: impl Fn(T::Item, T::Item) -> bool,
) -> BooleanBuffer {
match (l_s, r_s) {
(None, None) => {
assert_eq!(l.len(), r.len());
collect_bool(l.len(), neg, |idx| unsafe {
op(l.value_unchecked(idx), r.value_unchecked(idx))
})
}
(Some(l_s), Some(r_s)) => {
let a = l.value(l_s);
let b = r.value(r_s);
std::iter::once(op(a, b) ^ neg).collect()
}
(Some(l_s), None) => {
let v = l.value(l_s);
collect_bool(r.len(), neg, |idx| op(v, unsafe { r.value_unchecked(idx) }))
}
(None, Some(r_s)) => {
let v = r.value(r_s);
collect_bool(l.len(), neg, |idx| op(unsafe { l.value_unchecked(idx) }, v))
}
}
}
fn apply_op_vectored<T: ArrayOrd>(
l: T,
l_v: &[usize],
r: T,
r_v: &[usize],
neg: bool,
op: impl Fn(T::Item, T::Item) -> bool,
) -> BooleanBuffer {
assert_eq!(l_v.len(), r_v.len());
collect_bool(l_v.len(), neg, |idx| unsafe {
let l_idx = *l_v.get_unchecked(idx);
let r_idx = *r_v.get_unchecked(idx);
op(l.value_unchecked(l_idx), r.value_unchecked(r_idx))
})
}
trait ArrayOrd {
type Item: Copy + Default;
fn len(&self) -> usize;
fn value(&self, idx: usize) -> Self::Item {
assert!(idx < self.len());
unsafe { self.value_unchecked(idx) }
}
unsafe fn value_unchecked(&self, idx: usize) -> Self::Item;
fn is_eq(l: Self::Item, r: Self::Item) -> bool;
fn is_lt(l: Self::Item, r: Self::Item) -> bool;
}
impl<'a> ArrayOrd for &'a BooleanArray {
type Item = bool;
fn len(&self) -> usize {
Array::len(self)
}
unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
BooleanArray::value_unchecked(self, idx)
}
fn is_eq(l: Self::Item, r: Self::Item) -> bool {
l == r
}
fn is_lt(l: Self::Item, r: Self::Item) -> bool {
!l & r
}
}
impl<T: ArrowNativeTypeOp> ArrayOrd for &[T] {
type Item = T;
fn len(&self) -> usize {
(*self).len()
}
unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
*self.get_unchecked(idx)
}
fn is_eq(l: Self::Item, r: Self::Item) -> bool {
l.is_eq(r)
}
fn is_lt(l: Self::Item, r: Self::Item) -> bool {
l.is_lt(r)
}
}
impl<'a, T: ByteArrayType> ArrayOrd for &'a GenericByteArray<T> {
type Item = &'a [u8];
fn len(&self) -> usize {
Array::len(self)
}
unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
GenericByteArray::value_unchecked(self, idx).as_ref()
}
fn is_eq(l: Self::Item, r: Self::Item) -> bool {
l == r
}
fn is_lt(l: Self::Item, r: Self::Item) -> bool {
l < r
}
}
impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
type Item = &'a [u8];
fn len(&self) -> usize {
Array::len(self)
}
unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
FixedSizeBinaryArray::value_unchecked(self, idx)
}
fn is_eq(l: Self::Item, r: Self::Item) -> bool {
l == r
}
fn is_lt(l: Self::Item, r: Self::Item) -> bool {
l < r
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{DictionaryArray, Int32Array, Scalar};
use super::*;
#[test]
fn test_null_dict() {
let a = DictionaryArray::new(
Int32Array::new_null(10),
Arc::new(Int32Array::new_null(0)),
);
let r = eq(&a, &a).unwrap();
assert_eq!(r.null_count(), 10);
let a = DictionaryArray::new(
Int32Array::from(vec![1, 2, 3, 4, 5, 6]),
Arc::new(Int32Array::new_null(10)),
);
let r = eq(&a, &a).unwrap();
assert_eq!(r.null_count(), 6);
let scalar = DictionaryArray::new(
Int32Array::new_null(1),
Arc::new(Int32Array::new_null(0)),
);
let r = eq(&a, &Scalar::new(&scalar)).unwrap();
assert_eq!(r.null_count(), 6);
let scalar = DictionaryArray::new(
Int32Array::new_null(1),
Arc::new(Int32Array::new_null(0)),
);
let r = eq(&Scalar::new(&scalar), &Scalar::new(&scalar)).unwrap();
assert_eq!(r.null_count(), 1);
let a = DictionaryArray::new(
Int32Array::from(vec![0, 1, 2]),
Arc::new(Int32Array::from(vec![3, 2, 1])),
);
let r = eq(&a, &Scalar::new(&scalar)).unwrap();
assert_eq!(r.null_count(), 3);
}
#[test]
fn is_distinct_from_non_nulls() {
let left_int_array = Int32Array::from(vec![0, 1, 2, 3, 4]);
let right_int_array = Int32Array::from(vec![4, 3, 2, 1, 0]);
assert_eq!(
BooleanArray::from(vec![true, true, false, true, true,]),
distinct(&left_int_array, &right_int_array).unwrap()
);
assert_eq!(
BooleanArray::from(vec![false, false, true, false, false,]),
not_distinct(&left_int_array, &right_int_array).unwrap()
);
}
#[test]
fn is_distinct_from_nulls() {
let left_int_array = Int32Array::new(
vec![0, 0, 1, 3, 0, 0].into(),
Some(NullBuffer::from(vec![true, true, false, true, true, true])),
);
let right_int_array = Int32Array::new(
vec![0; 6].into(),
Some(NullBuffer::from(vec![
true, false, false, false, true, false,
])),
);
assert_eq!(
BooleanArray::from(vec![false, true, false, true, false, true,]),
distinct(&left_int_array, &right_int_array).unwrap()
);
assert_eq!(
BooleanArray::from(vec![true, false, true, false, true, false,]),
not_distinct(&left_int_array, &right_int_array).unwrap()
);
}
#[test]
fn test_distinct_scalar() {
let a = Int32Array::new_scalar(12);
let b = Int32Array::new_scalar(12);
assert!(!distinct(&a, &b).unwrap().value(0));
assert!(not_distinct(&a, &b).unwrap().value(0));
let a = Int32Array::new_scalar(12);
let b = Int32Array::new_null(1);
assert!(distinct(&a, &b).unwrap().value(0));
assert!(!not_distinct(&a, &b).unwrap().value(0));
assert!(distinct(&b, &a).unwrap().value(0));
assert!(!not_distinct(&b, &a).unwrap().value(0));
let b = Scalar::new(b);
assert!(distinct(&a, &b).unwrap().value(0));
assert!(!not_distinct(&a, &b).unwrap().value(0));
assert!(!distinct(&b, &b).unwrap().value(0));
assert!(not_distinct(&b, &b).unwrap().value(0));
let a = Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![false, false, true, true].into()),
);
let expected = BooleanArray::from(vec![false, false, true, true]);
assert_eq!(distinct(&a, &b).unwrap(), expected);
assert_eq!(distinct(&b, &a).unwrap(), expected);
let expected = BooleanArray::from(vec![true, true, false, false]);
assert_eq!(not_distinct(&a, &b).unwrap(), expected);
assert_eq!(not_distinct(&b, &a).unwrap(), expected);
let b = Int32Array::new_scalar(1);
let expected = BooleanArray::from(vec![true; 4]);
assert_eq!(distinct(&a, &b).unwrap(), expected);
assert_eq!(distinct(&b, &a).unwrap(), expected);
let expected = BooleanArray::from(vec![false; 4]);
assert_eq!(not_distinct(&a, &b).unwrap(), expected);
assert_eq!(not_distinct(&b, &a).unwrap(), expected);
let b = Int32Array::new_scalar(3);
let expected = BooleanArray::from(vec![true, true, true, false]);
assert_eq!(distinct(&a, &b).unwrap(), expected);
assert_eq!(distinct(&b, &a).unwrap(), expected);
let expected = BooleanArray::from(vec![false, false, false, true]);
assert_eq!(not_distinct(&a, &b).unwrap(), expected);
assert_eq!(not_distinct(&b, &a).unwrap(), expected);
}
#[test]
fn test_scalar_negation() {
let a = Int32Array::new_scalar(54);
let b = Int32Array::new_scalar(54);
let r = eq(&a, &b).unwrap();
assert!(r.value(0));
let r = neq(&a, &b).unwrap();
assert!(!r.value(0))
}
#[test]
fn test_scalar_empty() {
let a = Int32Array::new_null(0);
let b = Int32Array::new_scalar(23);
let r = eq(&a, &b).unwrap();
assert_eq!(r.len(), 0);
let r = eq(&b, &a).unwrap();
assert_eq!(r.len(), 0);
}
}