arrow_ord/
cmp.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Comparison kernels for `Array`s.
19//!
20//! These kernels can leverage SIMD if available on your system.  Currently no runtime
21//! detection is provided, you should enable the specific SIMD intrinsics using
22//! `RUSTFLAGS="-C target-feature=+avx2"` for example.  See the documentation
23//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
24//!
25
26use arrow_array::cast::AsArray;
27use arrow_array::types::{ByteArrayType, ByteViewType};
28use arrow_array::{
29    downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum,
30    FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray,
31};
32use arrow_buffer::bit_util::ceil;
33use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
34use arrow_schema::ArrowError;
35use arrow_select::take::take;
36use std::ops::Not;
37
38#[derive(Debug, Copy, Clone)]
39enum Op {
40    Equal,
41    NotEqual,
42    Less,
43    LessEqual,
44    Greater,
45    GreaterEqual,
46    Distinct,
47    NotDistinct,
48}
49
50impl std::fmt::Display for Op {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        match self {
53            Op::Equal => write!(f, "=="),
54            Op::NotEqual => write!(f, "!="),
55            Op::Less => write!(f, "<"),
56            Op::LessEqual => write!(f, "<="),
57            Op::Greater => write!(f, ">"),
58            Op::GreaterEqual => write!(f, ">="),
59            Op::Distinct => write!(f, "IS DISTINCT FROM"),
60            Op::NotDistinct => write!(f, "IS NOT DISTINCT FROM"),
61        }
62    }
63}
64
65/// Perform `left == right` operation on two [`Datum`].
66///
67/// Comparing null values on either side will yield a null in the corresponding
68/// slot of the resulting [`BooleanArray`].
69///
70/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
71/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
72/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
73/// to treat them as equal, please normalize zeros before calling this kernel. See
74/// [`f32::total_cmp`] and [`f64::total_cmp`].
75///
76/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
77/// For comparisons involving nested types see [`crate::ord::make_comparator`]
78pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
79    compare_op(Op::Equal, lhs, rhs)
80}
81
82/// Perform `left != right` operation on two [`Datum`].
83///
84/// Comparing null values on either side will yield a null in the corresponding
85/// slot of the resulting [`BooleanArray`].
86///
87/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
88/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
89/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
90/// to treat them as equal, please normalize zeros before calling this kernel. See
91/// [`f32::total_cmp`] and [`f64::total_cmp`].
92///
93/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
94/// For comparisons involving nested types see [`crate::ord::make_comparator`]
95pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
96    compare_op(Op::NotEqual, lhs, rhs)
97}
98
99/// Perform `left < right` operation on two [`Datum`].
100///
101/// Comparing null values on either side will yield a null in the corresponding
102/// slot of the resulting [`BooleanArray`].
103///
104/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
105/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
106/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
107/// to treat them as equal, please normalize zeros before calling this kernel. See
108/// [`f32::total_cmp`] and [`f64::total_cmp`].
109///
110/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
111/// For comparisons involving nested types see [`crate::ord::make_comparator`]
112pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
113    compare_op(Op::Less, lhs, rhs)
114}
115
116/// Perform `left <= right` operation on two [`Datum`].
117///
118/// Comparing null values on either side will yield a null in the corresponding
119/// slot of the resulting [`BooleanArray`].
120///
121/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
122/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
123/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
124/// to treat them as equal, please normalize zeros before calling this kernel. See
125/// [`f32::total_cmp`] and [`f64::total_cmp`].
126///
127/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
128/// For comparisons involving nested types see [`crate::ord::make_comparator`]
129pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
130    compare_op(Op::LessEqual, lhs, rhs)
131}
132
133/// Perform `left > right` operation on two [`Datum`].
134///
135/// Comparing null values on either side will yield a null in the corresponding
136/// slot of the resulting [`BooleanArray`].
137///
138/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
139/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
140/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
141/// to treat them as equal, please normalize zeros before calling this kernel. See
142/// [`f32::total_cmp`] and [`f64::total_cmp`].
143///
144/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
145/// For comparisons involving nested types see [`crate::ord::make_comparator`]
146pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
147    compare_op(Op::Greater, lhs, rhs)
148}
149
150/// Perform `left >= right` operation on two [`Datum`].
151///
152/// Comparing null values on either side will yield a null in the corresponding
153/// slot of the resulting [`BooleanArray`].
154///
155/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
156/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
157/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
158/// to treat them as equal, please normalize zeros before calling this kernel. See
159/// [`f32::total_cmp`] and [`f64::total_cmp`].
160///
161/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
162/// For comparisons involving nested types see [`crate::ord::make_comparator`]
163pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
164    compare_op(Op::GreaterEqual, lhs, rhs)
165}
166
167/// Perform `left IS DISTINCT FROM right` operation on two [`Datum`]
168///
169/// [`distinct`] is similar to [`neq`], only differing in null handling. In particular, two
170/// operands are considered DISTINCT if they have a different value or if one of them is NULL
171/// and the other isn't. The result of [`distinct`] is never NULL.
172///
173/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
174/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
175/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
176/// to treat them as equal, please normalize zeros before calling this kernel. See
177/// [`f32::total_cmp`] and [`f64::total_cmp`].
178///
179/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
180/// For comparisons involving nested types see [`crate::ord::make_comparator`]
181pub fn distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
182    compare_op(Op::Distinct, lhs, rhs)
183}
184
185/// Perform `left IS NOT DISTINCT FROM right` operation on two [`Datum`]
186///
187/// [`not_distinct`] is similar to [`eq`], only differing in null handling. In particular, two
188/// operands are considered `NOT DISTINCT` if they have the same value or if both of them
189/// is NULL. The result of [`not_distinct`] is never NULL.
190///
191/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
192/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
193/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
194/// to treat them as equal, please normalize zeros before calling this kernel. See
195/// [`f32::total_cmp`] and [`f64::total_cmp`].
196///
197/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
198/// For comparisons involving nested types see [`crate::ord::make_comparator`]
199pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
200    compare_op(Op::NotDistinct, lhs, rhs)
201}
202
203/// Perform `op` on the provided `Datum`
204#[inline(never)]
205fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
206    use arrow_schema::DataType::*;
207    let (l, l_s) = lhs.get();
208    let (r, r_s) = rhs.get();
209
210    let l_len = l.len();
211    let r_len = r.len();
212
213    if l_len != r_len && !l_s && !r_s {
214        return Err(ArrowError::InvalidArgumentError(format!(
215            "Cannot compare arrays of different lengths, got {l_len} vs {r_len}"
216        )));
217    }
218
219    let len = match l_s {
220        true => r_len,
221        false => l_len,
222    };
223
224    let l_nulls = l.logical_nulls();
225    let r_nulls = r.logical_nulls();
226
227    let l_v = l.as_any_dictionary_opt();
228    let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
229    let l_t = l.data_type();
230
231    let r_v = r.as_any_dictionary_opt();
232    let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
233    let r_t = r.data_type();
234
235    if r_t.is_nested() || l_t.is_nested() {
236        return Err(ArrowError::InvalidArgumentError(format!(
237            "Nested comparison: {l_t} {op} {r_t} (hint: use make_comparator instead)"
238        )));
239    } else if l_t != r_t {
240        return Err(ArrowError::InvalidArgumentError(format!(
241            "Invalid comparison operation: {l_t} {op} {r_t}"
242        )));
243    }
244
245    // Defer computation as may not be necessary
246    let values = || -> BooleanBuffer {
247        let d = downcast_primitive_array! {
248            (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
249            (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
250            (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
251            (Utf8View, Utf8View) => apply(op, l.as_string_view(), l_s, l_v, r.as_string_view(), r_s, r_v),
252            (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
253            (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
254            (BinaryView, BinaryView) => apply(op, l.as_binary_view(), l_s, l_v, r.as_binary_view(), r_s, r_v),
255            (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
256            (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
257            (Null, Null) => None,
258            _ => unreachable!(),
259        };
260        d.unwrap_or_else(|| BooleanBuffer::new_unset(len))
261    };
262
263    let l_nulls = l_nulls.filter(|n| n.null_count() > 0);
264    let r_nulls = r_nulls.filter(|n| n.null_count() > 0);
265    Ok(match (l_nulls, l_s, r_nulls, r_s) {
266        (Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => {
267            // Either both sides are scalar or neither side is scalar
268            match op {
269                Op::Distinct => {
270                    let values = values();
271                    let l = l.inner().bit_chunks().iter_padded();
272                    let r = r.inner().bit_chunks().iter_padded();
273                    let ne = values.bit_chunks().iter_padded();
274
275                    let c = |((l, r), n)| ((l ^ r) | (l & r & n));
276                    let buffer = l.zip(r).zip(ne).map(c).collect();
277                    BooleanBuffer::new(buffer, 0, len).into()
278                }
279                Op::NotDistinct => {
280                    let values = values();
281                    let l = l.inner().bit_chunks().iter_padded();
282                    let r = r.inner().bit_chunks().iter_padded();
283                    let e = values.bit_chunks().iter_padded();
284
285                    let c = |((l, r), e)| u64::not(l | r) | (l & r & e);
286                    let buffer = l.zip(r).zip(e).map(c).collect();
287                    BooleanBuffer::new(buffer, 0, len).into()
288                }
289                _ => BooleanArray::new(values(), NullBuffer::union(Some(&l), Some(&r))),
290            }
291        }
292        (Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => {
293            // Scalar is null, other side is non-scalar and nullable
294            match op {
295                Op::Distinct => a.into_inner().into(),
296                Op::NotDistinct => a.into_inner().not().into(),
297                _ => BooleanArray::new_null(len),
298            }
299        }
300        (Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar) => {
301            // Only one side is nullable
302            match is_scalar {
303                true => match op {
304                    // Scalar is null, other side is not nullable
305                    Op::Distinct => BooleanBuffer::new_set(len).into(),
306                    Op::NotDistinct => BooleanBuffer::new_unset(len).into(),
307                    _ => BooleanArray::new_null(len),
308                },
309                false => match op {
310                    Op::Distinct => {
311                        let values = values();
312                        let l = nulls.inner().bit_chunks().iter_padded();
313                        let ne = values.bit_chunks().iter_padded();
314                        let c = |(l, n)| u64::not(l) | n;
315                        let buffer = l.zip(ne).map(c).collect();
316                        BooleanBuffer::new(buffer, 0, len).into()
317                    }
318                    Op::NotDistinct => (nulls.inner() & &values()).into(),
319                    _ => BooleanArray::new(values(), Some(nulls)),
320                },
321            }
322        }
323        // Neither side is nullable
324        (None, _, None, _) => BooleanArray::new(values(), None),
325    })
326}
327
328/// Perform a potentially vectored `op` on the provided `ArrayOrd`
329fn apply<T: ArrayOrd>(
330    op: Op,
331    l: T,
332    l_s: bool,
333    l_v: Option<&dyn AnyDictionaryArray>,
334    r: T,
335    r_s: bool,
336    r_v: Option<&dyn AnyDictionaryArray>,
337) -> Option<BooleanBuffer> {
338    if l.len() == 0 || r.len() == 0 {
339        return None; // Handle empty dictionaries
340    }
341
342    if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) {
343        // Not scalar and at least one side has a dictionary, need to perform vectored comparison
344        let l_v = l_v
345            .map(|x| x.normalized_keys())
346            .unwrap_or_else(|| (0..l.len()).collect());
347
348        let r_v = r_v
349            .map(|x| x.normalized_keys())
350            .unwrap_or_else(|| (0..r.len()).collect());
351
352        assert_eq!(l_v.len(), r_v.len()); // Sanity check
353
354        Some(match op {
355            Op::Equal | Op::NotDistinct => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq),
356            Op::NotEqual | Op::Distinct => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_eq),
357            Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt),
358            Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true, T::is_lt),
359            Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false, T::is_lt),
360            Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_lt),
361        })
362    } else {
363        let l_s = l_s.then(|| l_v.map(|x| x.normalized_keys()[0]).unwrap_or_default());
364        let r_s = r_s.then(|| r_v.map(|x| x.normalized_keys()[0]).unwrap_or_default());
365
366        let buffer = match op {
367            Op::Equal | Op::NotDistinct => apply_op(l, l_s, r, r_s, false, T::is_eq),
368            Op::NotEqual | Op::Distinct => apply_op(l, l_s, r, r_s, true, T::is_eq),
369            Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt),
370            Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt),
371            Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt),
372            Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt),
373        };
374
375        // If a side had a dictionary, and was not scalar, we need to materialize this
376        Some(match (l_v, r_v) {
377            (Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer),
378            (_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer),
379            _ => buffer,
380        })
381    }
382}
383
384/// Perform a take operation on `buffer` with the given dictionary
385fn take_bits(v: &dyn AnyDictionaryArray, buffer: BooleanBuffer) -> BooleanBuffer {
386    let array = take(&BooleanArray::new(buffer, None), v.keys(), None).unwrap();
387    array.as_boolean().values().clone()
388}
389
390/// Invokes `f` with values `0..len` collecting the boolean results into a new `BooleanBuffer`
391///
392/// This is similar to [`MutableBuffer::collect_bool`] but with
393/// the option to efficiently negate the result
394fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) -> BooleanBuffer {
395    let mut buffer = MutableBuffer::new(ceil(len, 64) * 8);
396
397    let chunks = len / 64;
398    let remainder = len % 64;
399    for chunk in 0..chunks {
400        let mut packed = 0;
401        for bit_idx in 0..64 {
402            let i = bit_idx + chunk * 64;
403            packed |= (f(i) as u64) << bit_idx;
404        }
405        if neg {
406            packed = !packed
407        }
408
409        // SAFETY: Already allocated sufficient capacity
410        unsafe { buffer.push_unchecked(packed) }
411    }
412
413    if remainder != 0 {
414        let mut packed = 0;
415        for bit_idx in 0..remainder {
416            let i = bit_idx + chunks * 64;
417            packed |= (f(i) as u64) << bit_idx;
418        }
419        if neg {
420            packed = !packed
421        }
422
423        // SAFETY: Already allocated sufficient capacity
424        unsafe { buffer.push_unchecked(packed) }
425    }
426    BooleanBuffer::new(buffer.into(), 0, len)
427}
428
429/// Applies `op` to possibly scalar `ArrayOrd`
430///
431/// If l is scalar `l_s` will be `Some(idx)` where `idx` is the index of the scalar value in `l`
432/// If r is scalar `r_s` will be `Some(idx)` where `idx` is the index of the scalar value in `r`
433///
434/// If `neg` is true the result of `op` will be negated
435fn apply_op<T: ArrayOrd>(
436    l: T,
437    l_s: Option<usize>,
438    r: T,
439    r_s: Option<usize>,
440    neg: bool,
441    op: impl Fn(T::Item, T::Item) -> bool,
442) -> BooleanBuffer {
443    match (l_s, r_s) {
444        (None, None) => {
445            assert_eq!(l.len(), r.len());
446            collect_bool(l.len(), neg, |idx| unsafe {
447                op(l.value_unchecked(idx), r.value_unchecked(idx))
448            })
449        }
450        (Some(l_s), Some(r_s)) => {
451            let a = l.value(l_s);
452            let b = r.value(r_s);
453            std::iter::once(op(a, b) ^ neg).collect()
454        }
455        (Some(l_s), None) => {
456            let v = l.value(l_s);
457            collect_bool(r.len(), neg, |idx| op(v, unsafe { r.value_unchecked(idx) }))
458        }
459        (None, Some(r_s)) => {
460            let v = r.value(r_s);
461            collect_bool(l.len(), neg, |idx| op(unsafe { l.value_unchecked(idx) }, v))
462        }
463    }
464}
465
466/// Applies `op` to possibly scalar `ArrayOrd` with the given indices
467fn apply_op_vectored<T: ArrayOrd>(
468    l: T,
469    l_v: &[usize],
470    r: T,
471    r_v: &[usize],
472    neg: bool,
473    op: impl Fn(T::Item, T::Item) -> bool,
474) -> BooleanBuffer {
475    assert_eq!(l_v.len(), r_v.len());
476    collect_bool(l_v.len(), neg, |idx| unsafe {
477        let l_idx = *l_v.get_unchecked(idx);
478        let r_idx = *r_v.get_unchecked(idx);
479        op(l.value_unchecked(l_idx), r.value_unchecked(r_idx))
480    })
481}
482
483trait ArrayOrd {
484    type Item: Copy;
485
486    fn len(&self) -> usize;
487
488    fn value(&self, idx: usize) -> Self::Item {
489        assert!(idx < self.len());
490        unsafe { self.value_unchecked(idx) }
491    }
492
493    /// # Safety
494    ///
495    /// Safe if `idx < self.len()`
496    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item;
497
498    fn is_eq(l: Self::Item, r: Self::Item) -> bool;
499
500    fn is_lt(l: Self::Item, r: Self::Item) -> bool;
501}
502
503impl ArrayOrd for &BooleanArray {
504    type Item = bool;
505
506    fn len(&self) -> usize {
507        Array::len(self)
508    }
509
510    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
511        BooleanArray::value_unchecked(self, idx)
512    }
513
514    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
515        l == r
516    }
517
518    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
519        !l & r
520    }
521}
522
523impl<T: ArrowNativeTypeOp> ArrayOrd for &[T] {
524    type Item = T;
525
526    fn len(&self) -> usize {
527        (*self).len()
528    }
529
530    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
531        *self.get_unchecked(idx)
532    }
533
534    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
535        l.is_eq(r)
536    }
537
538    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
539        l.is_lt(r)
540    }
541}
542
543impl<'a, T: ByteArrayType> ArrayOrd for &'a GenericByteArray<T> {
544    type Item = &'a [u8];
545
546    fn len(&self) -> usize {
547        Array::len(self)
548    }
549
550    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
551        GenericByteArray::value_unchecked(self, idx).as_ref()
552    }
553
554    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
555        l == r
556    }
557
558    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
559        l < r
560    }
561}
562
563impl<'a, T: ByteViewType> ArrayOrd for &'a GenericByteViewArray<T> {
564    /// This is the item type for the GenericByteViewArray::compare
565    /// Item.0 is the array, Item.1 is the index
566    type Item = (&'a GenericByteViewArray<T>, usize);
567
568    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
569        // # Safety
570        // The index is within bounds as it is checked in value()
571        let l_view = unsafe { l.0.views().get_unchecked(l.1) };
572        let l_len = *l_view as u32;
573
574        let r_view = unsafe { r.0.views().get_unchecked(r.1) };
575        let r_len = *r_view as u32;
576        // This is a fast path for equality check.
577        // We don't need to look at the actual bytes to determine if they are equal.
578        if l_len != r_len {
579            return false;
580        }
581
582        unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_eq() }
583    }
584
585    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
586        // # Safety
587        // The index is within bounds as it is checked in value()
588        unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_lt() }
589    }
590
591    fn len(&self) -> usize {
592        Array::len(self)
593    }
594
595    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
596        (self, idx)
597    }
598}
599
600impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
601    type Item = &'a [u8];
602
603    fn len(&self) -> usize {
604        Array::len(self)
605    }
606
607    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
608        FixedSizeBinaryArray::value_unchecked(self, idx)
609    }
610
611    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
612        l == r
613    }
614
615    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
616        l < r
617    }
618}
619
620/// Compares two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
621pub fn compare_byte_view<T: ByteViewType>(
622    left: &GenericByteViewArray<T>,
623    left_idx: usize,
624    right: &GenericByteViewArray<T>,
625    right_idx: usize,
626) -> std::cmp::Ordering {
627    assert!(left_idx < left.len());
628    assert!(right_idx < right.len());
629    unsafe { GenericByteViewArray::compare_unchecked(left, left_idx, right, right_idx) }
630}
631
632/// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
633///
634/// Comparing two ByteView types are non-trivial.
635/// It takes a bit of patience to understand why we don't just compare two &[u8] directly.
636///
637/// ByteView types give us the following two advantages, and we need to be careful not to lose them:
638/// (1) For string/byte smaller than 12 bytes, the entire data is inlined in the view.
639///     Meaning that reading one array element requires only one memory access
640///     (two memory access required for StringArray, one for offset buffer, the other for value buffer).
641///
642/// (2) For string/byte larger than 12 bytes, we can still be faster than (for certain operations) StringArray/ByteArray,
643///     thanks to the inlined 4 bytes.
644///     Consider equality check:
645///     If the first four bytes of the two strings are different, we can return false immediately (with just one memory access).
646///
647/// If we directly compare two &[u8], we materialize the entire string (i.e., make multiple memory accesses), which might be unnecessary.
648/// - Most of the time (eq, ord), we only need to look at the first 4 bytes to know the answer,
649///   e.g., if the inlined 4 bytes are different, we can directly return unequal without looking at the full string.
650///
651/// # Order check flow
652/// (1) if both string are smaller than 12 bytes, we can directly compare the data inlined to the view.
653/// (2) if any of the string is larger than 12 bytes, we need to compare the full string.
654///     (2.1) if the inlined 4 bytes are different, we can return the result immediately.
655///     (2.2) o.w., we need to compare the full string.
656///
657/// # Safety
658/// The left/right_idx must within range of each array
659#[deprecated(
660    since = "52.2.0",
661    note = "Use `GenericByteViewArray::compare_unchecked` instead"
662)]
663pub unsafe fn compare_byte_view_unchecked<T: ByteViewType>(
664    left: &GenericByteViewArray<T>,
665    left_idx: usize,
666    right: &GenericByteViewArray<T>,
667    right_idx: usize,
668) -> std::cmp::Ordering {
669    let l_view = left.views().get_unchecked(left_idx);
670    let l_len = *l_view as u32;
671
672    let r_view = right.views().get_unchecked(right_idx);
673    let r_len = *r_view as u32;
674
675    if l_len <= 12 && r_len <= 12 {
676        let l_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, l_len as usize) };
677        let r_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, r_len as usize) };
678        return l_data.cmp(r_data);
679    }
680
681    // one of the string is larger than 12 bytes,
682    // we then try to compare the inlined data first
683    let l_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, 4) };
684    let r_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, 4) };
685    if r_inlined_data != l_inlined_data {
686        return l_inlined_data.cmp(r_inlined_data);
687    }
688
689    // unfortunately, we need to compare the full data
690    let l_full_data: &[u8] = unsafe { left.value_unchecked(left_idx).as_ref() };
691    let r_full_data: &[u8] = unsafe { right.value_unchecked(right_idx).as_ref() };
692
693    l_full_data.cmp(r_full_data)
694}
695
696#[cfg(test)]
697mod tests {
698    use std::sync::Arc;
699
700    use arrow_array::{DictionaryArray, Int32Array, Scalar, StringArray};
701
702    use super::*;
703
704    #[test]
705    fn test_null_dict() {
706        let a = DictionaryArray::new(Int32Array::new_null(10), Arc::new(Int32Array::new_null(0)));
707        let r = eq(&a, &a).unwrap();
708        assert_eq!(r.null_count(), 10);
709
710        let a = DictionaryArray::new(
711            Int32Array::from(vec![1, 2, 3, 4, 5, 6]),
712            Arc::new(Int32Array::new_null(10)),
713        );
714        let r = eq(&a, &a).unwrap();
715        assert_eq!(r.null_count(), 6);
716
717        let scalar =
718            DictionaryArray::new(Int32Array::new_null(1), Arc::new(Int32Array::new_null(0)));
719        let r = eq(&a, &Scalar::new(&scalar)).unwrap();
720        assert_eq!(r.null_count(), 6);
721
722        let scalar =
723            DictionaryArray::new(Int32Array::new_null(1), Arc::new(Int32Array::new_null(0)));
724        let r = eq(&Scalar::new(&scalar), &Scalar::new(&scalar)).unwrap();
725        assert_eq!(r.null_count(), 1);
726
727        let a = DictionaryArray::new(
728            Int32Array::from(vec![0, 1, 2]),
729            Arc::new(Int32Array::from(vec![3, 2, 1])),
730        );
731        let r = eq(&a, &Scalar::new(&scalar)).unwrap();
732        assert_eq!(r.null_count(), 3);
733    }
734
735    #[test]
736    fn is_distinct_from_non_nulls() {
737        let left_int_array = Int32Array::from(vec![0, 1, 2, 3, 4]);
738        let right_int_array = Int32Array::from(vec![4, 3, 2, 1, 0]);
739
740        assert_eq!(
741            BooleanArray::from(vec![true, true, false, true, true,]),
742            distinct(&left_int_array, &right_int_array).unwrap()
743        );
744        assert_eq!(
745            BooleanArray::from(vec![false, false, true, false, false,]),
746            not_distinct(&left_int_array, &right_int_array).unwrap()
747        );
748    }
749
750    #[test]
751    fn is_distinct_from_nulls() {
752        // [0, 0, NULL, 0, 0, 0]
753        let left_int_array = Int32Array::new(
754            vec![0, 0, 1, 3, 0, 0].into(),
755            Some(NullBuffer::from(vec![true, true, false, true, true, true])),
756        );
757        // [0, NULL, NULL, NULL, 0, NULL]
758        let right_int_array = Int32Array::new(
759            vec![0; 6].into(),
760            Some(NullBuffer::from(vec![
761                true, false, false, false, true, false,
762            ])),
763        );
764
765        assert_eq!(
766            BooleanArray::from(vec![false, true, false, true, false, true,]),
767            distinct(&left_int_array, &right_int_array).unwrap()
768        );
769
770        assert_eq!(
771            BooleanArray::from(vec![true, false, true, false, true, false,]),
772            not_distinct(&left_int_array, &right_int_array).unwrap()
773        );
774    }
775
776    #[test]
777    fn test_distinct_scalar() {
778        let a = Int32Array::new_scalar(12);
779        let b = Int32Array::new_scalar(12);
780        assert!(!distinct(&a, &b).unwrap().value(0));
781        assert!(not_distinct(&a, &b).unwrap().value(0));
782
783        let a = Int32Array::new_scalar(12);
784        let b = Int32Array::new_null(1);
785        assert!(distinct(&a, &b).unwrap().value(0));
786        assert!(!not_distinct(&a, &b).unwrap().value(0));
787        assert!(distinct(&b, &a).unwrap().value(0));
788        assert!(!not_distinct(&b, &a).unwrap().value(0));
789
790        let b = Scalar::new(b);
791        assert!(distinct(&a, &b).unwrap().value(0));
792        assert!(!not_distinct(&a, &b).unwrap().value(0));
793
794        assert!(!distinct(&b, &b).unwrap().value(0));
795        assert!(not_distinct(&b, &b).unwrap().value(0));
796
797        let a = Int32Array::new(
798            vec![0, 1, 2, 3].into(),
799            Some(vec![false, false, true, true].into()),
800        );
801        let expected = BooleanArray::from(vec![false, false, true, true]);
802        assert_eq!(distinct(&a, &b).unwrap(), expected);
803        assert_eq!(distinct(&b, &a).unwrap(), expected);
804
805        let expected = BooleanArray::from(vec![true, true, false, false]);
806        assert_eq!(not_distinct(&a, &b).unwrap(), expected);
807        assert_eq!(not_distinct(&b, &a).unwrap(), expected);
808
809        let b = Int32Array::new_scalar(1);
810        let expected = BooleanArray::from(vec![true; 4]);
811        assert_eq!(distinct(&a, &b).unwrap(), expected);
812        assert_eq!(distinct(&b, &a).unwrap(), expected);
813        let expected = BooleanArray::from(vec![false; 4]);
814        assert_eq!(not_distinct(&a, &b).unwrap(), expected);
815        assert_eq!(not_distinct(&b, &a).unwrap(), expected);
816
817        let b = Int32Array::new_scalar(3);
818        let expected = BooleanArray::from(vec![true, true, true, false]);
819        assert_eq!(distinct(&a, &b).unwrap(), expected);
820        assert_eq!(distinct(&b, &a).unwrap(), expected);
821        let expected = BooleanArray::from(vec![false, false, false, true]);
822        assert_eq!(not_distinct(&a, &b).unwrap(), expected);
823        assert_eq!(not_distinct(&b, &a).unwrap(), expected);
824    }
825
826    #[test]
827    fn test_scalar_negation() {
828        let a = Int32Array::new_scalar(54);
829        let b = Int32Array::new_scalar(54);
830        let r = eq(&a, &b).unwrap();
831        assert!(r.value(0));
832
833        let r = neq(&a, &b).unwrap();
834        assert!(!r.value(0))
835    }
836
837    #[test]
838    fn test_scalar_empty() {
839        let a = Int32Array::new_null(0);
840        let b = Int32Array::new_scalar(23);
841        let r = eq(&a, &b).unwrap();
842        assert_eq!(r.len(), 0);
843        let r = eq(&b, &a).unwrap();
844        assert_eq!(r.len(), 0);
845    }
846
847    #[test]
848    fn test_dictionary_nulls() {
849        let values = StringArray::from(vec![Some("us-west"), Some("us-east")]);
850        let nulls = NullBuffer::from(vec![false, true, true]);
851
852        let key_values = vec![100i32, 1i32, 0i32].into();
853        let keys = Int32Array::new(key_values, Some(nulls));
854        let col = DictionaryArray::try_new(keys, Arc::new(values)).unwrap();
855
856        neq(&col.slice(0, col.len() - 1), &col.slice(1, col.len() - 1)).unwrap();
857    }
858}