arrow_ord/
ord.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains functions and function factories to compare arrays.
19
20use arrow_array::cast::AsArray;
21use arrow_array::types::*;
22use arrow_array::*;
23use arrow_buffer::{ArrowNativeType, NullBuffer};
24use arrow_schema::{ArrowError, SortOptions};
25use std::cmp::Ordering;
26
27/// Compare the values at two arbitrary indices in two arrays.
28pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
29
30/// If parent sort order is descending we need to invert the value of nulls_first so that
31/// when the parent is sorted based on the produced ranks, nulls are still ordered correctly
32fn child_opts(opts: SortOptions) -> SortOptions {
33    SortOptions {
34        descending: false,
35        nulls_first: opts.nulls_first != opts.descending,
36    }
37}
38
39fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
40where
41    A: Array + Clone,
42    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
43{
44    let l = l.logical_nulls().filter(|x| x.null_count() > 0);
45    let r = r.logical_nulls().filter(|x| x.null_count() > 0);
46    match (opts.nulls_first, opts.descending) {
47        (true, true) => compare_impl::<true, true, _>(l, r, cmp),
48        (true, false) => compare_impl::<true, false, _>(l, r, cmp),
49        (false, true) => compare_impl::<false, true, _>(l, r, cmp),
50        (false, false) => compare_impl::<false, false, _>(l, r, cmp),
51    }
52}
53
54fn compare_impl<const NULLS_FIRST: bool, const DESCENDING: bool, F>(
55    l: Option<NullBuffer>,
56    r: Option<NullBuffer>,
57    cmp: F,
58) -> DynComparator
59where
60    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
61{
62    let cmp = move |i, j| match DESCENDING {
63        true => cmp(i, j).reverse(),
64        false => cmp(i, j),
65    };
66
67    let (left_null, right_null) = match NULLS_FIRST {
68        true => (Ordering::Less, Ordering::Greater),
69        false => (Ordering::Greater, Ordering::Less),
70    };
71
72    match (l, r) {
73        (None, None) => Box::new(cmp),
74        (Some(l), None) => Box::new(move |i, j| match l.is_null(i) {
75            true => left_null,
76            false => cmp(i, j),
77        }),
78        (None, Some(r)) => Box::new(move |i, j| match r.is_null(j) {
79            true => right_null,
80            false => cmp(i, j),
81        }),
82        (Some(l), Some(r)) => Box::new(move |i, j| match (l.is_null(i), r.is_null(j)) {
83            (true, true) => Ordering::Equal,
84            (true, false) => left_null,
85            (false, true) => right_null,
86            (false, false) => cmp(i, j),
87        }),
88    }
89}
90
91fn compare_primitive<T: ArrowPrimitiveType>(
92    left: &dyn Array,
93    right: &dyn Array,
94    opts: SortOptions,
95) -> DynComparator
96where
97    T::Native: ArrowNativeTypeOp,
98{
99    let left = left.as_primitive::<T>();
100    let right = right.as_primitive::<T>();
101    let l_values = left.values().clone();
102    let r_values = right.values().clone();
103
104    compare(&left, &right, opts, move |i, j| {
105        l_values[i].compare(r_values[j])
106    })
107}
108
109fn compare_boolean(left: &dyn Array, right: &dyn Array, opts: SortOptions) -> DynComparator {
110    let left = left.as_boolean();
111    let right = right.as_boolean();
112
113    let l_values = left.values().clone();
114    let r_values = right.values().clone();
115
116    compare(left, right, opts, move |i, j| {
117        l_values.value(i).cmp(&r_values.value(j))
118    })
119}
120
121fn compare_bytes<T: ByteArrayType>(
122    left: &dyn Array,
123    right: &dyn Array,
124    opts: SortOptions,
125) -> DynComparator {
126    let left = left.as_bytes::<T>();
127    let right = right.as_bytes::<T>();
128
129    let l = left.clone();
130    let r = right.clone();
131    compare(left, right, opts, move |i, j| {
132        let l: &[u8] = l.value(i).as_ref();
133        let r: &[u8] = r.value(j).as_ref();
134        l.cmp(r)
135    })
136}
137
138fn compare_byte_view<T: ByteViewType>(
139    left: &dyn Array,
140    right: &dyn Array,
141    opts: SortOptions,
142) -> DynComparator {
143    let left = left.as_byte_view::<T>();
144    let right = right.as_byte_view::<T>();
145
146    let l = left.clone();
147    let r = right.clone();
148    compare(left, right, opts, move |i, j| {
149        crate::cmp::compare_byte_view(&l, i, &r, j)
150    })
151}
152
153fn compare_dict<K: ArrowDictionaryKeyType>(
154    left: &dyn Array,
155    right: &dyn Array,
156    opts: SortOptions,
157) -> Result<DynComparator, ArrowError> {
158    let left = left.as_dictionary::<K>();
159    let right = right.as_dictionary::<K>();
160
161    let c_opts = child_opts(opts);
162    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
163    let left_keys = left.keys().values().clone();
164    let right_keys = right.keys().values().clone();
165
166    let f = compare(left, right, opts, move |i, j| {
167        let l = left_keys[i].as_usize();
168        let r = right_keys[j].as_usize();
169        cmp(l, r)
170    });
171    Ok(f)
172}
173
174fn compare_list<O: OffsetSizeTrait>(
175    left: &dyn Array,
176    right: &dyn Array,
177    opts: SortOptions,
178) -> Result<DynComparator, ArrowError> {
179    let left = left.as_list::<O>();
180    let right = right.as_list::<O>();
181
182    let c_opts = child_opts(opts);
183    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
184
185    let l_o = left.offsets().clone();
186    let r_o = right.offsets().clone();
187    let f = compare(left, right, opts, move |i, j| {
188        let l_end = l_o[i + 1].as_usize();
189        let l_start = l_o[i].as_usize();
190
191        let r_end = r_o[j + 1].as_usize();
192        let r_start = r_o[j].as_usize();
193
194        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
195            match cmp(i, j) {
196                Ordering::Equal => continue,
197                r => return r,
198            }
199        }
200        (l_end - l_start).cmp(&(r_end - r_start))
201    });
202    Ok(f)
203}
204
205fn compare_fixed_list(
206    left: &dyn Array,
207    right: &dyn Array,
208    opts: SortOptions,
209) -> Result<DynComparator, ArrowError> {
210    let left = left.as_fixed_size_list();
211    let right = right.as_fixed_size_list();
212
213    let c_opts = child_opts(opts);
214    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
215
216    let l_size = left.value_length().to_usize().unwrap();
217    let r_size = right.value_length().to_usize().unwrap();
218    let size_cmp = l_size.cmp(&r_size);
219
220    let f = compare(left, right, opts, move |i, j| {
221        let l_start = i * l_size;
222        let l_end = l_start + l_size;
223        let r_start = j * r_size;
224        let r_end = r_start + r_size;
225        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
226            match cmp(i, j) {
227                Ordering::Equal => continue,
228                r => return r,
229            }
230        }
231        size_cmp
232    });
233    Ok(f)
234}
235
236fn compare_struct(
237    left: &dyn Array,
238    right: &dyn Array,
239    opts: SortOptions,
240) -> Result<DynComparator, ArrowError> {
241    let left = left.as_struct();
242    let right = right.as_struct();
243
244    if left.columns().len() != right.columns().len() {
245        return Err(ArrowError::InvalidArgumentError(
246            "Cannot compare StructArray with different number of columns".to_string(),
247        ));
248    }
249
250    let c_opts = child_opts(opts);
251    let columns = left.columns().iter().zip(right.columns());
252    let comparators = columns
253        .map(|(l, r)| make_comparator(l, r, c_opts))
254        .collect::<Result<Vec<_>, _>>()?;
255
256    let f = compare(left, right, opts, move |i, j| {
257        for cmp in &comparators {
258            match cmp(i, j) {
259                Ordering::Equal => continue,
260                r => return r,
261            }
262        }
263        Ordering::Equal
264    });
265    Ok(f)
266}
267
268#[deprecated(since = "52.0.0", note = "Use make_comparator")]
269#[doc(hidden)]
270pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparator, ArrowError> {
271    make_comparator(left, right, SortOptions::default())
272}
273
274/// Returns a comparison function that compares two values at two different positions
275/// between the two arrays.
276///
277/// For comparing arrays element-wise, see also the vectorised kernels in [`crate::cmp`].
278///
279/// If `nulls_first` is true `NULL` values will be considered less than any non-null value,
280/// otherwise they will be considered greater.
281///
282/// # Basic Usage
283///
284/// ```
285/// # use std::cmp::Ordering;
286/// # use arrow_array::Int32Array;
287/// # use arrow_ord::ord::make_comparator;
288/// # use arrow_schema::SortOptions;
289/// #
290/// let array1 = Int32Array::from(vec![1, 2]);
291/// let array2 = Int32Array::from(vec![3, 4]);
292///
293/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
294/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
295/// assert_eq!(cmp(0, 1), Ordering::Less);
296///
297/// let array1 = Int32Array::from(vec![Some(1), None]);
298/// let array2 = Int32Array::from(vec![None, Some(2)]);
299/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
300///
301/// assert_eq!(cmp(0, 1), Ordering::Less); // Some(1) vs Some(2)
302/// assert_eq!(cmp(1, 1), Ordering::Less); // None vs Some(2)
303/// assert_eq!(cmp(1, 0), Ordering::Equal); // None vs None
304/// assert_eq!(cmp(0, 0), Ordering::Greater); // Some(1) vs None
305/// ```
306///
307/// # Postgres-compatible Nested Comparison
308///
309/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value against a NULL yields
310/// a NULL, many systems, including postgres, instead apply a total ordering to comparison of
311/// nested nulls. That is nulls within nested types are either greater than any value (postgres),
312/// or less than any value (Spark).
313///
314/// In particular
315///
316/// ```ignore
317/// { a: 1, b: null } == { a: 1, b: null } => true
318/// { a: 1, b: null } == { a: 1, b: 1 } => false
319/// { a: 1, b: null } == null => null
320/// null == null => null
321/// ```
322///
323/// This could be implemented as below
324///
325/// ```
326/// # use arrow_array::{Array, BooleanArray};
327/// # use arrow_buffer::NullBuffer;
328/// # use arrow_ord::cmp;
329/// # use arrow_ord::ord::make_comparator;
330/// # use arrow_schema::{ArrowError, SortOptions};
331/// fn eq(a: &dyn Array, b: &dyn Array) -> Result<BooleanArray, ArrowError> {
332///     if !a.data_type().is_nested() {
333///         return cmp::eq(&a, &b); // Use faster vectorised kernel
334///     }
335///
336///     let cmp = make_comparator(a, b, SortOptions::default())?;
337///     let len = a.len().min(b.len());
338///     let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
339///     let nulls = NullBuffer::union(a.nulls(), b.nulls());
340///     Ok(BooleanArray::new(values, nulls))
341/// }
342/// ````
343pub fn make_comparator(
344    left: &dyn Array,
345    right: &dyn Array,
346    opts: SortOptions,
347) -> Result<DynComparator, ArrowError> {
348    use arrow_schema::DataType::*;
349
350    macro_rules! primitive_helper {
351        ($t:ty, $left:expr, $right:expr, $nulls_first:expr) => {
352            Ok(compare_primitive::<$t>($left, $right, $nulls_first))
353        };
354    }
355    downcast_primitive! {
356        left.data_type(), right.data_type() => (primitive_helper, left, right, opts),
357        (Boolean, Boolean) => Ok(compare_boolean(left, right, opts)),
358        (Utf8, Utf8) => Ok(compare_bytes::<Utf8Type>(left, right, opts)),
359        (LargeUtf8, LargeUtf8) => Ok(compare_bytes::<LargeUtf8Type>(left, right, opts)),
360        (Utf8View, Utf8View) => Ok(compare_byte_view::<StringViewType>(left, right, opts)),
361        (Binary, Binary) => Ok(compare_bytes::<BinaryType>(left, right, opts)),
362        (LargeBinary, LargeBinary) => Ok(compare_bytes::<LargeBinaryType>(left, right, opts)),
363        (BinaryView, BinaryView) => Ok(compare_byte_view::<BinaryViewType>(left, right, opts)),
364        (FixedSizeBinary(_), FixedSizeBinary(_)) => {
365            let left = left.as_fixed_size_binary();
366            let right = right.as_fixed_size_binary();
367
368            let l = left.clone();
369            let r = right.clone();
370            Ok(compare(left, right, opts, move |i, j| {
371                l.value(i).cmp(r.value(j))
372            }))
373        },
374        (List(_), List(_)) => compare_list::<i32>(left, right, opts),
375        (LargeList(_), LargeList(_)) => compare_list::<i64>(left, right, opts),
376        (FixedSizeList(_, _), FixedSizeList(_, _)) => compare_fixed_list(left, right, opts),
377        (Struct(_), Struct(_)) => compare_struct(left, right, opts),
378        (Dictionary(l_key, _), Dictionary(r_key, _)) => {
379             macro_rules! dict_helper {
380                ($t:ty, $left:expr, $right:expr, $opts: expr) => {
381                     compare_dict::<$t>($left, $right, $opts)
382                 };
383             }
384            downcast_integer! {
385                 l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right, opts),
386                 _ => unreachable!()
387             }
388        },
389        (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
390            true => format!("The data type type {lhs:?} has no natural order"),
391            false => "Can't compare arrays of different types".to_string(),
392        }))
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use arrow_array::builder::{Int32Builder, ListBuilder};
400    use arrow_buffer::{i256, IntervalDayTime, OffsetBuffer};
401    use arrow_schema::{DataType, Field, Fields};
402    use half::f16;
403    use std::sync::Arc;
404
405    #[test]
406    fn test_fixed_size_binary() {
407        let items = vec![vec![1u8], vec![2u8]];
408        let array = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
409
410        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
411
412        assert_eq!(Ordering::Less, cmp(0, 1));
413    }
414
415    #[test]
416    fn test_fixed_size_binary_fixed_size_binary() {
417        let items = vec![vec![1u8]];
418        let array1 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
419        let items = vec![vec![2u8]];
420        let array2 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
421
422        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
423
424        assert_eq!(Ordering::Less, cmp(0, 0));
425    }
426
427    #[test]
428    fn test_i32() {
429        let array = Int32Array::from(vec![1, 2]);
430
431        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
432
433        assert_eq!(Ordering::Less, (cmp)(0, 1));
434    }
435
436    #[test]
437    fn test_i32_i32() {
438        let array1 = Int32Array::from(vec![1]);
439        let array2 = Int32Array::from(vec![2]);
440
441        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
442
443        assert_eq!(Ordering::Less, cmp(0, 0));
444    }
445
446    #[test]
447    fn test_f16() {
448        let array = Float16Array::from(vec![f16::from_f32(1.0), f16::from_f32(2.0)]);
449
450        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
451
452        assert_eq!(Ordering::Less, cmp(0, 1));
453    }
454
455    #[test]
456    fn test_f64() {
457        let array = Float64Array::from(vec![1.0, 2.0]);
458
459        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
460
461        assert_eq!(Ordering::Less, cmp(0, 1));
462    }
463
464    #[test]
465    fn test_f64_nan() {
466        let array = Float64Array::from(vec![1.0, f64::NAN]);
467
468        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
469
470        assert_eq!(Ordering::Less, cmp(0, 1));
471        assert_eq!(Ordering::Equal, cmp(1, 1));
472    }
473
474    #[test]
475    fn test_f64_zeros() {
476        let array = Float64Array::from(vec![-0.0, 0.0]);
477
478        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
479
480        assert_eq!(Ordering::Less, cmp(0, 1));
481        assert_eq!(Ordering::Greater, cmp(1, 0));
482    }
483
484    #[test]
485    fn test_interval_day_time() {
486        let array = IntervalDayTimeArray::from(vec![
487            // 0 days, 1 second
488            IntervalDayTimeType::make_value(0, 1000),
489            // 1 day, 2 milliseconds
490            IntervalDayTimeType::make_value(1, 2),
491            // 90M milliseconds (which is more than is in 1 day)
492            IntervalDayTimeType::make_value(0, 90_000_000),
493        ]);
494
495        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
496
497        assert_eq!(Ordering::Less, cmp(0, 1));
498        assert_eq!(Ordering::Greater, cmp(1, 0));
499
500        // somewhat confusingly, while 90M milliseconds is more than 1 day,
501        // it will compare less as the comparison is done on the underlying
502        // values not field by field
503        assert_eq!(Ordering::Greater, cmp(1, 2));
504        assert_eq!(Ordering::Less, cmp(2, 1));
505    }
506
507    #[test]
508    fn test_interval_year_month() {
509        let array = IntervalYearMonthArray::from(vec![
510            // 1 year, 0 months
511            IntervalYearMonthType::make_value(1, 0),
512            // 0 years, 13 months
513            IntervalYearMonthType::make_value(0, 13),
514            // 1 year, 1 month
515            IntervalYearMonthType::make_value(1, 1),
516        ]);
517
518        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
519
520        assert_eq!(Ordering::Less, cmp(0, 1));
521        assert_eq!(Ordering::Greater, cmp(1, 0));
522
523        // the underlying representation is months, so both quantities are the same
524        assert_eq!(Ordering::Equal, cmp(1, 2));
525        assert_eq!(Ordering::Equal, cmp(2, 1));
526    }
527
528    #[test]
529    fn test_interval_month_day_nano() {
530        let array = IntervalMonthDayNanoArray::from(vec![
531            // 100 days
532            IntervalMonthDayNanoType::make_value(0, 100, 0),
533            // 1 month
534            IntervalMonthDayNanoType::make_value(1, 0, 0),
535            // 100 day, 1 nanoseconds
536            IntervalMonthDayNanoType::make_value(0, 100, 2),
537        ]);
538
539        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
540
541        assert_eq!(Ordering::Less, cmp(0, 1));
542        assert_eq!(Ordering::Greater, cmp(1, 0));
543
544        // somewhat confusingly, while 100 days is more than 1 month in all cases
545        // it will compare less as the comparison is done on the underlying
546        // values not field by field
547        assert_eq!(Ordering::Greater, cmp(1, 2));
548        assert_eq!(Ordering::Less, cmp(2, 1));
549    }
550
551    #[test]
552    fn test_decimal() {
553        let array = vec![Some(5_i128), Some(2_i128), Some(3_i128)]
554            .into_iter()
555            .collect::<Decimal128Array>()
556            .with_precision_and_scale(23, 6)
557            .unwrap();
558
559        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
560        assert_eq!(Ordering::Less, cmp(1, 0));
561        assert_eq!(Ordering::Greater, cmp(0, 2));
562    }
563
564    #[test]
565    fn test_decimali256() {
566        let array = vec![
567            Some(i256::from_i128(5_i128)),
568            Some(i256::from_i128(2_i128)),
569            Some(i256::from_i128(3_i128)),
570        ]
571        .into_iter()
572        .collect::<Decimal256Array>()
573        .with_precision_and_scale(53, 6)
574        .unwrap();
575
576        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
577        assert_eq!(Ordering::Less, cmp(1, 0));
578        assert_eq!(Ordering::Greater, cmp(0, 2));
579    }
580
581    #[test]
582    fn test_dict() {
583        let data = vec!["a", "b", "c", "a", "a", "c", "c"];
584        let array = data.into_iter().collect::<DictionaryArray<Int16Type>>();
585
586        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
587
588        assert_eq!(Ordering::Less, cmp(0, 1));
589        assert_eq!(Ordering::Equal, cmp(3, 4));
590        assert_eq!(Ordering::Greater, cmp(2, 3));
591    }
592
593    #[test]
594    fn test_multiple_dict() {
595        let d1 = vec!["a", "b", "c", "d"];
596        let a1 = d1.into_iter().collect::<DictionaryArray<Int16Type>>();
597        let d2 = vec!["e", "f", "g", "a"];
598        let a2 = d2.into_iter().collect::<DictionaryArray<Int16Type>>();
599
600        let cmp = make_comparator(&a1, &a2, SortOptions::default()).unwrap();
601
602        assert_eq!(Ordering::Less, cmp(0, 0));
603        assert_eq!(Ordering::Equal, cmp(0, 3));
604        assert_eq!(Ordering::Greater, cmp(1, 3));
605    }
606
607    #[test]
608    fn test_primitive_dict() {
609        let values = Int32Array::from(vec![1_i32, 0, 2, 5]);
610        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
611        let array1 = DictionaryArray::new(keys, Arc::new(values));
612
613        let values = Int32Array::from(vec![2_i32, 3, 4, 5]);
614        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
615        let array2 = DictionaryArray::new(keys, Arc::new(values));
616
617        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
618
619        assert_eq!(Ordering::Less, cmp(0, 0));
620        assert_eq!(Ordering::Less, cmp(0, 3));
621        assert_eq!(Ordering::Equal, cmp(3, 3));
622        assert_eq!(Ordering::Greater, cmp(3, 1));
623        assert_eq!(Ordering::Greater, cmp(3, 2));
624    }
625
626    #[test]
627    fn test_float_dict() {
628        let values = Float32Array::from(vec![1.0, 0.5, 2.1, 5.5]);
629        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
630        let array1 = DictionaryArray::try_new(keys, Arc::new(values)).unwrap();
631
632        let values = Float32Array::from(vec![1.2, 3.2, 4.0, 5.5]);
633        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
634        let array2 = DictionaryArray::new(keys, Arc::new(values));
635
636        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
637
638        assert_eq!(Ordering::Less, cmp(0, 0));
639        assert_eq!(Ordering::Less, cmp(0, 3));
640        assert_eq!(Ordering::Equal, cmp(3, 3));
641        assert_eq!(Ordering::Greater, cmp(3, 1));
642        assert_eq!(Ordering::Greater, cmp(3, 2));
643    }
644
645    #[test]
646    fn test_timestamp_dict() {
647        let values = TimestampSecondArray::from(vec![1, 0, 2, 5]);
648        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
649        let array1 = DictionaryArray::new(keys, Arc::new(values));
650
651        let values = TimestampSecondArray::from(vec![2, 3, 4, 5]);
652        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
653        let array2 = DictionaryArray::new(keys, Arc::new(values));
654
655        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
656
657        assert_eq!(Ordering::Less, cmp(0, 0));
658        assert_eq!(Ordering::Less, cmp(0, 3));
659        assert_eq!(Ordering::Equal, cmp(3, 3));
660        assert_eq!(Ordering::Greater, cmp(3, 1));
661        assert_eq!(Ordering::Greater, cmp(3, 2));
662    }
663
664    #[test]
665    fn test_interval_dict() {
666        let v1 = IntervalDayTime::new(0, 1);
667        let v2 = IntervalDayTime::new(0, 2);
668        let v3 = IntervalDayTime::new(12, 2);
669
670        let values = IntervalDayTimeArray::from(vec![Some(v1), Some(v2), None, Some(v3)]);
671        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
672        let array1 = DictionaryArray::new(keys, Arc::new(values));
673
674        let values = IntervalDayTimeArray::from(vec![Some(v3), Some(v2), None, Some(v1)]);
675        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
676        let array2 = DictionaryArray::new(keys, Arc::new(values));
677
678        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
679
680        assert_eq!(Ordering::Less, cmp(0, 0)); // v1 vs v3
681        assert_eq!(Ordering::Equal, cmp(0, 3)); // v1 vs v1
682        assert_eq!(Ordering::Greater, cmp(3, 3)); // v3 vs v1
683        assert_eq!(Ordering::Greater, cmp(3, 1)); // v3 vs v2
684        assert_eq!(Ordering::Greater, cmp(3, 2)); // v3 vs v2
685    }
686
687    #[test]
688    fn test_duration_dict() {
689        let values = DurationSecondArray::from(vec![1, 0, 2, 5]);
690        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
691        let array1 = DictionaryArray::new(keys, Arc::new(values));
692
693        let values = DurationSecondArray::from(vec![2, 3, 4, 5]);
694        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
695        let array2 = DictionaryArray::new(keys, Arc::new(values));
696
697        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
698
699        assert_eq!(Ordering::Less, cmp(0, 0));
700        assert_eq!(Ordering::Less, cmp(0, 3));
701        assert_eq!(Ordering::Equal, cmp(3, 3));
702        assert_eq!(Ordering::Greater, cmp(3, 1));
703        assert_eq!(Ordering::Greater, cmp(3, 2));
704    }
705
706    #[test]
707    fn test_decimal_dict() {
708        let values = Decimal128Array::from(vec![1, 0, 2, 5]);
709        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
710        let array1 = DictionaryArray::new(keys, Arc::new(values));
711
712        let values = Decimal128Array::from(vec![2, 3, 4, 5]);
713        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
714        let array2 = DictionaryArray::new(keys, Arc::new(values));
715
716        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
717
718        assert_eq!(Ordering::Less, cmp(0, 0));
719        assert_eq!(Ordering::Less, cmp(0, 3));
720        assert_eq!(Ordering::Equal, cmp(3, 3));
721        assert_eq!(Ordering::Greater, cmp(3, 1));
722        assert_eq!(Ordering::Greater, cmp(3, 2));
723    }
724
725    #[test]
726    fn test_decimal256_dict() {
727        let values = Decimal256Array::from(vec![
728            i256::from_i128(1),
729            i256::from_i128(0),
730            i256::from_i128(2),
731            i256::from_i128(5),
732        ]);
733        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
734        let array1 = DictionaryArray::new(keys, Arc::new(values));
735
736        let values = Decimal256Array::from(vec![
737            i256::from_i128(2),
738            i256::from_i128(3),
739            i256::from_i128(4),
740            i256::from_i128(5),
741        ]);
742        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
743        let array2 = DictionaryArray::new(keys, Arc::new(values));
744
745        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
746
747        assert_eq!(Ordering::Less, cmp(0, 0));
748        assert_eq!(Ordering::Less, cmp(0, 3));
749        assert_eq!(Ordering::Equal, cmp(3, 3));
750        assert_eq!(Ordering::Greater, cmp(3, 1));
751        assert_eq!(Ordering::Greater, cmp(3, 2));
752    }
753
754    fn test_bytes_impl<T: ByteArrayType>() {
755        let offsets = OffsetBuffer::from_lengths([3, 3, 1]);
756        let a = GenericByteArray::<T>::new(offsets, b"abcdefa".into(), None);
757        let cmp = make_comparator(&a, &a, SortOptions::default()).unwrap();
758
759        assert_eq!(Ordering::Less, cmp(0, 1));
760        assert_eq!(Ordering::Greater, cmp(0, 2));
761        assert_eq!(Ordering::Equal, cmp(1, 1));
762    }
763
764    #[test]
765    fn test_bytes() {
766        test_bytes_impl::<Utf8Type>();
767        test_bytes_impl::<LargeUtf8Type>();
768        test_bytes_impl::<BinaryType>();
769        test_bytes_impl::<LargeBinaryType>();
770    }
771
772    #[test]
773    fn test_lists() {
774        let mut a = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
775        a.extend([
776            Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
777            Some(vec![
778                Some(vec![Some(1), Some(2), Some(3)]),
779                Some(vec![Some(1)]),
780            ]),
781            Some(vec![]),
782        ]);
783        let a = a.finish();
784        let mut b = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
785        b.extend([
786            Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
787            Some(vec![
788                Some(vec![Some(1), Some(2), None]),
789                Some(vec![Some(1)]),
790            ]),
791            Some(vec![
792                Some(vec![Some(1), Some(2), Some(3), Some(4)]),
793                Some(vec![Some(1)]),
794            ]),
795            None,
796        ]);
797        let b = b.finish();
798
799        let opts = SortOptions {
800            descending: false,
801            nulls_first: true,
802        };
803        let cmp = make_comparator(&a, &b, opts).unwrap();
804        assert_eq!(cmp(0, 0), Ordering::Equal);
805        assert_eq!(cmp(0, 1), Ordering::Less);
806        assert_eq!(cmp(0, 2), Ordering::Less);
807        assert_eq!(cmp(1, 2), Ordering::Less);
808        assert_eq!(cmp(1, 3), Ordering::Greater);
809        assert_eq!(cmp(2, 0), Ordering::Less);
810
811        let opts = SortOptions {
812            descending: true,
813            nulls_first: true,
814        };
815        let cmp = make_comparator(&a, &b, opts).unwrap();
816        assert_eq!(cmp(0, 0), Ordering::Equal);
817        assert_eq!(cmp(0, 1), Ordering::Less);
818        assert_eq!(cmp(0, 2), Ordering::Less);
819        assert_eq!(cmp(1, 2), Ordering::Greater);
820        assert_eq!(cmp(1, 3), Ordering::Greater);
821        assert_eq!(cmp(2, 0), Ordering::Greater);
822
823        let opts = SortOptions {
824            descending: true,
825            nulls_first: false,
826        };
827        let cmp = make_comparator(&a, &b, opts).unwrap();
828        assert_eq!(cmp(0, 0), Ordering::Equal);
829        assert_eq!(cmp(0, 1), Ordering::Greater);
830        assert_eq!(cmp(0, 2), Ordering::Greater);
831        assert_eq!(cmp(1, 2), Ordering::Greater);
832        assert_eq!(cmp(1, 3), Ordering::Less);
833        assert_eq!(cmp(2, 0), Ordering::Greater);
834
835        let opts = SortOptions {
836            descending: false,
837            nulls_first: false,
838        };
839        let cmp = make_comparator(&a, &b, opts).unwrap();
840        assert_eq!(cmp(0, 0), Ordering::Equal);
841        assert_eq!(cmp(0, 1), Ordering::Greater);
842        assert_eq!(cmp(0, 2), Ordering::Greater);
843        assert_eq!(cmp(1, 2), Ordering::Less);
844        assert_eq!(cmp(1, 3), Ordering::Less);
845        assert_eq!(cmp(2, 0), Ordering::Less);
846    }
847
848    #[test]
849    fn test_struct() {
850        let fields = Fields::from(vec![
851            Field::new("a", DataType::Int32, true),
852            Field::new_list("b", Field::new_list_field(DataType::Int32, true), true),
853        ]);
854
855        let a = Int32Array::from(vec![Some(1), Some(2), None, None]);
856        let mut b = ListBuilder::new(Int32Builder::new());
857        b.extend([Some(vec![Some(1), Some(2)]), Some(vec![None]), None, None]);
858        let b = b.finish();
859
860        let nulls = Some(NullBuffer::from_iter([true, true, true, false]));
861        let values = vec![Arc::new(a) as _, Arc::new(b) as _];
862        let s1 = StructArray::new(fields.clone(), values, nulls);
863
864        let a = Int32Array::from(vec![None, Some(2), None]);
865        let mut b = ListBuilder::new(Int32Builder::new());
866        b.extend([None, None, Some(vec![])]);
867        let b = b.finish();
868
869        let values = vec![Arc::new(a) as _, Arc::new(b) as _];
870        let s2 = StructArray::new(fields.clone(), values, None);
871
872        let opts = SortOptions {
873            descending: false,
874            nulls_first: true,
875        };
876        let cmp = make_comparator(&s1, &s2, opts).unwrap();
877        assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
878        assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None)
879        assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
880        assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
881        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
882        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
883        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
884
885        let opts = SortOptions {
886            descending: true,
887            nulls_first: true,
888        };
889        let cmp = make_comparator(&s1, &s2, opts).unwrap();
890        assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
891        assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None)
892        assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
893        assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
894        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
895        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
896        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
897
898        let opts = SortOptions {
899            descending: true,
900            nulls_first: false,
901        };
902        let cmp = make_comparator(&s1, &s2, opts).unwrap();
903        assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
904        assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
905        assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
906        assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, [])
907        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
908        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
909        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
910
911        let opts = SortOptions {
912            descending: false,
913            nulls_first: false,
914        };
915        let cmp = make_comparator(&s1, &s2, opts).unwrap();
916        assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
917        assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
918        assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
919        assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, [])
920        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
921        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
922        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
923    }
924}