arrow_ord/
rank.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Provides `rank` function to assign a rank to each value in an array
19
20use arrow_array::cast::AsArray;
21use arrow_array::types::*;
22use arrow_array::{
23    downcast_primitive_array, Array, ArrowNativeTypeOp, BooleanArray, GenericByteArray,
24};
25use arrow_buffer::NullBuffer;
26use arrow_schema::{ArrowError, DataType, SortOptions};
27use std::cmp::Ordering;
28
29/// Whether `arrow_ord::rank` can rank an array of given data type.
30pub(crate) fn can_rank(data_type: &DataType) -> bool {
31    data_type.is_primitive()
32        || matches!(
33            data_type,
34            DataType::Boolean
35                | DataType::Utf8
36                | DataType::LargeUtf8
37                | DataType::Binary
38                | DataType::LargeBinary
39        )
40}
41
42/// Assigns a rank to each value in `array` based on its position in the sorted order
43///
44/// Where values are equal, they will be assigned the highest of their ranks,
45/// leaving gaps in the overall rank assignment
46///
47/// ```
48/// # use arrow_array::StringArray;
49/// # use arrow_ord::rank::rank;
50/// let array = StringArray::from(vec![Some("foo"), None, Some("foo"), None, Some("bar")]);
51/// let ranks = rank(&array, None).unwrap();
52/// assert_eq!(ranks, &[5, 2, 5, 2, 3]);
53/// ```
54pub fn rank(array: &dyn Array, options: Option<SortOptions>) -> Result<Vec<u32>, ArrowError> {
55    let options = options.unwrap_or_default();
56    let ranks = downcast_primitive_array! {
57        array => primitive_rank(array.values(), array.nulls(), options),
58        DataType::Boolean => boolean_rank(array.as_boolean(), options),
59        DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options),
60        DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(), options),
61        DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(), options),
62        DataType::LargeBinary => bytes_rank(array.as_bytes::<LargeBinaryType>(), options),
63        d => return Err(ArrowError::ComputeError(format!("{d:?} not supported in rank")))
64    };
65    Ok(ranks)
66}
67
68#[inline(never)]
69fn primitive_rank<T: ArrowNativeTypeOp>(
70    values: &[T],
71    nulls: Option<&NullBuffer>,
72    options: SortOptions,
73) -> Vec<u32> {
74    let len: u32 = values.len().try_into().unwrap();
75    let to_sort = match nulls.filter(|n| n.null_count() > 0) {
76        Some(n) => n
77            .valid_indices()
78            .map(|idx| (values[idx], idx as u32))
79            .collect(),
80        None => values.iter().copied().zip(0..len).collect(),
81    };
82    rank_impl(values.len(), to_sort, options, T::compare, T::is_eq)
83}
84
85#[inline(never)]
86fn bytes_rank<T: ByteArrayType>(array: &GenericByteArray<T>, options: SortOptions) -> Vec<u32> {
87    let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n| n.null_count() > 0) {
88        Some(n) => n
89            .valid_indices()
90            .map(|idx| (array.value(idx).as_ref(), idx as u32))
91            .collect(),
92        None => (0..array.len())
93            .map(|idx| (array.value(idx).as_ref(), idx as u32))
94            .collect(),
95    };
96    rank_impl(array.len(), to_sort, options, Ord::cmp, PartialEq::eq)
97}
98
99fn rank_impl<T, C, E>(
100    len: usize,
101    mut valid: Vec<(T, u32)>,
102    options: SortOptions,
103    compare: C,
104    eq: E,
105) -> Vec<u32>
106where
107    T: Copy,
108    C: Fn(T, T) -> Ordering,
109    E: Fn(T, T) -> bool,
110{
111    // We can use an unstable sort as we combine equal values later
112    valid.sort_unstable_by(|a, b| compare(a.0, b.0));
113    if options.descending {
114        valid.reverse();
115    }
116
117    let (mut valid_rank, null_rank) = match options.nulls_first {
118        true => (len as u32, (len - valid.len()) as u32),
119        false => (valid.len() as u32, len as u32),
120    };
121
122    let mut out: Vec<_> = vec![null_rank; len];
123    if let Some(v) = valid.last() {
124        out[v.1 as usize] = valid_rank;
125    }
126
127    let mut count = 1; // Number of values in rank
128    for w in valid.windows(2).rev() {
129        match eq(w[0].0, w[1].0) {
130            true => {
131                count += 1;
132                out[w[0].1 as usize] = valid_rank;
133            }
134            false => {
135                valid_rank -= count;
136                count = 1;
137                out[w[0].1 as usize] = valid_rank
138            }
139        }
140    }
141
142    out
143}
144
145/// Return the index for the rank when ranking boolean array
146///
147/// The index is calculated as follows:
148/// if is_null is true, the index is 2
149/// if is_null is false and the value is true, the index is 1
150/// otherwise, the index is 0
151///
152/// false is 0 and true is 1 because these are the value when cast to number
153#[inline]
154fn get_boolean_rank_index(value: bool, is_null: bool) -> usize {
155    let is_null_num = is_null as usize;
156    (is_null_num << 1) | (value as usize & !is_null_num)
157}
158
159#[inline(never)]
160fn boolean_rank(array: &BooleanArray, options: SortOptions) -> Vec<u32> {
161    let null_count = array.null_count() as u32;
162    let true_count = array.true_count() as u32;
163    let false_count = array.len() as u32 - null_count - true_count;
164
165    // Rank values for [false, true, null] in that order
166    //
167    // The value for a rank is last value rank + own value count
168    // this means that if we have the following order: `false`, `true` and then `null`
169    // the ranks will be:
170    // - false: false_count
171    // - true: false_count + true_count
172    // - null: false_count + true_count + null_count
173    //
174    // If we have the following order: `null`, `false` and then `true`
175    // the ranks will be:
176    // - false: null_count + false_count
177    // - true: null_count + false_count + true_count
178    // - null: null_count
179    //
180    // You will notice that the last rank is always the total length of the array but we don't use it for readability on how the rank is calculated
181    let ranks_index: [u32; 3] = match (options.descending, options.nulls_first) {
182        // The order is null, true, false
183        (true, true) => [
184            null_count + true_count + false_count,
185            null_count + true_count,
186            null_count,
187        ],
188        // The order is true, false, null
189        (true, false) => [
190            true_count + false_count,
191            true_count,
192            true_count + false_count + null_count,
193        ],
194        // The order is null, false, true
195        (false, true) => [
196            null_count + false_count,
197            null_count + false_count + true_count,
198            null_count,
199        ],
200        // The order is false, true, null
201        (false, false) => [
202            false_count,
203            false_count + true_count,
204            false_count + true_count + null_count,
205        ],
206    };
207
208    match array.nulls().filter(|n| n.null_count() > 0) {
209        Some(n) => array
210            .values()
211            .iter()
212            .zip(n.iter())
213            .map(|(value, is_valid)| ranks_index[get_boolean_rank_index(value, !is_valid)])
214            .collect::<Vec<u32>>(),
215        None => array
216            .values()
217            .iter()
218            .map(|value| ranks_index[value as usize])
219            .collect::<Vec<u32>>(),
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use arrow_array::*;
227
228    #[test]
229    fn test_primitive() {
230        let descending = SortOptions {
231            descending: true,
232            nulls_first: true,
233        };
234
235        let nulls_last = SortOptions {
236            descending: false,
237            nulls_first: false,
238        };
239
240        let nulls_last_descending = SortOptions {
241            descending: true,
242            nulls_first: false,
243        };
244
245        let a = Int32Array::from(vec![Some(1), Some(1), None, Some(3), Some(3), Some(4)]);
246        let res = rank(&a, None).unwrap();
247        assert_eq!(res, &[3, 3, 1, 5, 5, 6]);
248
249        let res = rank(&a, Some(descending)).unwrap();
250        assert_eq!(res, &[6, 6, 1, 4, 4, 2]);
251
252        let res = rank(&a, Some(nulls_last)).unwrap();
253        assert_eq!(res, &[2, 2, 6, 4, 4, 5]);
254
255        let res = rank(&a, Some(nulls_last_descending)).unwrap();
256        assert_eq!(res, &[5, 5, 6, 3, 3, 1]);
257
258        // Test with non-zero null values
259        let nulls = NullBuffer::from(vec![true, true, false, true, false, false]);
260        let a = Int32Array::new(vec![1, 4, 3, 4, 5, 5].into(), Some(nulls));
261        let res = rank(&a, None).unwrap();
262        assert_eq!(res, &[4, 6, 3, 6, 3, 3]);
263    }
264
265    #[test]
266    fn test_get_boolean_rank_index() {
267        assert_eq!(get_boolean_rank_index(true, true), 2);
268        assert_eq!(get_boolean_rank_index(false, true), 2);
269        assert_eq!(get_boolean_rank_index(true, false), 1);
270        assert_eq!(get_boolean_rank_index(false, false), 0);
271    }
272
273    #[test]
274    fn test_nullable_booleans() {
275        let descending = SortOptions {
276            descending: true,
277            nulls_first: true,
278        };
279
280        let nulls_last = SortOptions {
281            descending: false,
282            nulls_first: false,
283        };
284
285        let nulls_last_descending = SortOptions {
286            descending: true,
287            nulls_first: false,
288        };
289
290        let a = BooleanArray::from(vec![Some(true), Some(true), None, Some(false), Some(false)]);
291        let res = rank(&a, None).unwrap();
292        assert_eq!(res, &[5, 5, 1, 3, 3]);
293
294        let res = rank(&a, Some(descending)).unwrap();
295        assert_eq!(res, &[3, 3, 1, 5, 5]);
296
297        let res = rank(&a, Some(nulls_last)).unwrap();
298        assert_eq!(res, &[4, 4, 5, 2, 2]);
299
300        let res = rank(&a, Some(nulls_last_descending)).unwrap();
301        assert_eq!(res, &[2, 2, 5, 4, 4]);
302
303        // Test with non-zero null values
304        let nulls = NullBuffer::from(vec![true, true, false, true, true]);
305        let a = BooleanArray::new(vec![true, true, true, false, false].into(), Some(nulls));
306        let res = rank(&a, None).unwrap();
307        assert_eq!(res, &[5, 5, 1, 3, 3]);
308    }
309
310    #[test]
311    fn test_booleans() {
312        let descending = SortOptions {
313            descending: true,
314            nulls_first: true,
315        };
316
317        let nulls_last = SortOptions {
318            descending: false,
319            nulls_first: false,
320        };
321
322        let nulls_last_descending = SortOptions {
323            descending: true,
324            nulls_first: false,
325        };
326
327        let a = BooleanArray::from(vec![true, false, false, false, true]);
328        let res = rank(&a, None).unwrap();
329        assert_eq!(res, &[5, 3, 3, 3, 5]);
330
331        let res = rank(&a, Some(descending)).unwrap();
332        assert_eq!(res, &[2, 5, 5, 5, 2]);
333
334        let res = rank(&a, Some(nulls_last)).unwrap();
335        assert_eq!(res, &[5, 3, 3, 3, 5]);
336
337        let res = rank(&a, Some(nulls_last_descending)).unwrap();
338        assert_eq!(res, &[2, 5, 5, 5, 2]);
339    }
340
341    #[test]
342    fn test_bytes() {
343        let v = vec!["foo", "fo", "bar", "bar"];
344        let values = StringArray::from(v.clone());
345        let res = rank(&values, None).unwrap();
346        assert_eq!(res, &[4, 3, 2, 2]);
347
348        let values = LargeStringArray::from(v.clone());
349        let res = rank(&values, None).unwrap();
350        assert_eq!(res, &[4, 3, 2, 2]);
351
352        let v: Vec<&[u8]> = vec![&[1, 2], &[0], &[1, 2, 3], &[1, 2]];
353        let values = LargeBinaryArray::from(v.clone());
354        let res = rank(&values, None).unwrap();
355        assert_eq!(res, &[3, 1, 4, 3]);
356
357        let values = BinaryArray::from(v);
358        let res = rank(&values, None).unwrap();
359        assert_eq!(res, &[3, 1, 4, 3]);
360    }
361}