1use arrow_array::cast::AsArray;
21use arrow_array::types::*;
22use arrow_array::{
23 downcast_primitive_array, Array, ArrowNativeTypeOp, BooleanArray, GenericByteArray,
24};
25use arrow_buffer::NullBuffer;
26use arrow_schema::{ArrowError, DataType, SortOptions};
27use std::cmp::Ordering;
28
29pub(crate) fn can_rank(data_type: &DataType) -> bool {
31 data_type.is_primitive()
32 || matches!(
33 data_type,
34 DataType::Boolean
35 | DataType::Utf8
36 | DataType::LargeUtf8
37 | DataType::Binary
38 | DataType::LargeBinary
39 )
40}
41
42pub fn rank(array: &dyn Array, options: Option<SortOptions>) -> Result<Vec<u32>, ArrowError> {
55 let options = options.unwrap_or_default();
56 let ranks = downcast_primitive_array! {
57 array => primitive_rank(array.values(), array.nulls(), options),
58 DataType::Boolean => boolean_rank(array.as_boolean(), options),
59 DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options),
60 DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(), options),
61 DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(), options),
62 DataType::LargeBinary => bytes_rank(array.as_bytes::<LargeBinaryType>(), options),
63 d => return Err(ArrowError::ComputeError(format!("{d:?} not supported in rank")))
64 };
65 Ok(ranks)
66}
67
68#[inline(never)]
69fn primitive_rank<T: ArrowNativeTypeOp>(
70 values: &[T],
71 nulls: Option<&NullBuffer>,
72 options: SortOptions,
73) -> Vec<u32> {
74 let len: u32 = values.len().try_into().unwrap();
75 let to_sort = match nulls.filter(|n| n.null_count() > 0) {
76 Some(n) => n
77 .valid_indices()
78 .map(|idx| (values[idx], idx as u32))
79 .collect(),
80 None => values.iter().copied().zip(0..len).collect(),
81 };
82 rank_impl(values.len(), to_sort, options, T::compare, T::is_eq)
83}
84
85#[inline(never)]
86fn bytes_rank<T: ByteArrayType>(array: &GenericByteArray<T>, options: SortOptions) -> Vec<u32> {
87 let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n| n.null_count() > 0) {
88 Some(n) => n
89 .valid_indices()
90 .map(|idx| (array.value(idx).as_ref(), idx as u32))
91 .collect(),
92 None => (0..array.len())
93 .map(|idx| (array.value(idx).as_ref(), idx as u32))
94 .collect(),
95 };
96 rank_impl(array.len(), to_sort, options, Ord::cmp, PartialEq::eq)
97}
98
99fn rank_impl<T, C, E>(
100 len: usize,
101 mut valid: Vec<(T, u32)>,
102 options: SortOptions,
103 compare: C,
104 eq: E,
105) -> Vec<u32>
106where
107 T: Copy,
108 C: Fn(T, T) -> Ordering,
109 E: Fn(T, T) -> bool,
110{
111 valid.sort_unstable_by(|a, b| compare(a.0, b.0));
113 if options.descending {
114 valid.reverse();
115 }
116
117 let (mut valid_rank, null_rank) = match options.nulls_first {
118 true => (len as u32, (len - valid.len()) as u32),
119 false => (valid.len() as u32, len as u32),
120 };
121
122 let mut out: Vec<_> = vec![null_rank; len];
123 if let Some(v) = valid.last() {
124 out[v.1 as usize] = valid_rank;
125 }
126
127 let mut count = 1; for w in valid.windows(2).rev() {
129 match eq(w[0].0, w[1].0) {
130 true => {
131 count += 1;
132 out[w[0].1 as usize] = valid_rank;
133 }
134 false => {
135 valid_rank -= count;
136 count = 1;
137 out[w[0].1 as usize] = valid_rank
138 }
139 }
140 }
141
142 out
143}
144
145#[inline]
154fn get_boolean_rank_index(value: bool, is_null: bool) -> usize {
155 let is_null_num = is_null as usize;
156 (is_null_num << 1) | (value as usize & !is_null_num)
157}
158
159#[inline(never)]
160fn boolean_rank(array: &BooleanArray, options: SortOptions) -> Vec<u32> {
161 let null_count = array.null_count() as u32;
162 let true_count = array.true_count() as u32;
163 let false_count = array.len() as u32 - null_count - true_count;
164
165 let ranks_index: [u32; 3] = match (options.descending, options.nulls_first) {
182 (true, true) => [
184 null_count + true_count + false_count,
185 null_count + true_count,
186 null_count,
187 ],
188 (true, false) => [
190 true_count + false_count,
191 true_count,
192 true_count + false_count + null_count,
193 ],
194 (false, true) => [
196 null_count + false_count,
197 null_count + false_count + true_count,
198 null_count,
199 ],
200 (false, false) => [
202 false_count,
203 false_count + true_count,
204 false_count + true_count + null_count,
205 ],
206 };
207
208 match array.nulls().filter(|n| n.null_count() > 0) {
209 Some(n) => array
210 .values()
211 .iter()
212 .zip(n.iter())
213 .map(|(value, is_valid)| ranks_index[get_boolean_rank_index(value, !is_valid)])
214 .collect::<Vec<u32>>(),
215 None => array
216 .values()
217 .iter()
218 .map(|value| ranks_index[value as usize])
219 .collect::<Vec<u32>>(),
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use arrow_array::*;
227
228 #[test]
229 fn test_primitive() {
230 let descending = SortOptions {
231 descending: true,
232 nulls_first: true,
233 };
234
235 let nulls_last = SortOptions {
236 descending: false,
237 nulls_first: false,
238 };
239
240 let nulls_last_descending = SortOptions {
241 descending: true,
242 nulls_first: false,
243 };
244
245 let a = Int32Array::from(vec![Some(1), Some(1), None, Some(3), Some(3), Some(4)]);
246 let res = rank(&a, None).unwrap();
247 assert_eq!(res, &[3, 3, 1, 5, 5, 6]);
248
249 let res = rank(&a, Some(descending)).unwrap();
250 assert_eq!(res, &[6, 6, 1, 4, 4, 2]);
251
252 let res = rank(&a, Some(nulls_last)).unwrap();
253 assert_eq!(res, &[2, 2, 6, 4, 4, 5]);
254
255 let res = rank(&a, Some(nulls_last_descending)).unwrap();
256 assert_eq!(res, &[5, 5, 6, 3, 3, 1]);
257
258 let nulls = NullBuffer::from(vec![true, true, false, true, false, false]);
260 let a = Int32Array::new(vec![1, 4, 3, 4, 5, 5].into(), Some(nulls));
261 let res = rank(&a, None).unwrap();
262 assert_eq!(res, &[4, 6, 3, 6, 3, 3]);
263 }
264
265 #[test]
266 fn test_get_boolean_rank_index() {
267 assert_eq!(get_boolean_rank_index(true, true), 2);
268 assert_eq!(get_boolean_rank_index(false, true), 2);
269 assert_eq!(get_boolean_rank_index(true, false), 1);
270 assert_eq!(get_boolean_rank_index(false, false), 0);
271 }
272
273 #[test]
274 fn test_nullable_booleans() {
275 let descending = SortOptions {
276 descending: true,
277 nulls_first: true,
278 };
279
280 let nulls_last = SortOptions {
281 descending: false,
282 nulls_first: false,
283 };
284
285 let nulls_last_descending = SortOptions {
286 descending: true,
287 nulls_first: false,
288 };
289
290 let a = BooleanArray::from(vec![Some(true), Some(true), None, Some(false), Some(false)]);
291 let res = rank(&a, None).unwrap();
292 assert_eq!(res, &[5, 5, 1, 3, 3]);
293
294 let res = rank(&a, Some(descending)).unwrap();
295 assert_eq!(res, &[3, 3, 1, 5, 5]);
296
297 let res = rank(&a, Some(nulls_last)).unwrap();
298 assert_eq!(res, &[4, 4, 5, 2, 2]);
299
300 let res = rank(&a, Some(nulls_last_descending)).unwrap();
301 assert_eq!(res, &[2, 2, 5, 4, 4]);
302
303 let nulls = NullBuffer::from(vec![true, true, false, true, true]);
305 let a = BooleanArray::new(vec![true, true, true, false, false].into(), Some(nulls));
306 let res = rank(&a, None).unwrap();
307 assert_eq!(res, &[5, 5, 1, 3, 3]);
308 }
309
310 #[test]
311 fn test_booleans() {
312 let descending = SortOptions {
313 descending: true,
314 nulls_first: true,
315 };
316
317 let nulls_last = SortOptions {
318 descending: false,
319 nulls_first: false,
320 };
321
322 let nulls_last_descending = SortOptions {
323 descending: true,
324 nulls_first: false,
325 };
326
327 let a = BooleanArray::from(vec![true, false, false, false, true]);
328 let res = rank(&a, None).unwrap();
329 assert_eq!(res, &[5, 3, 3, 3, 5]);
330
331 let res = rank(&a, Some(descending)).unwrap();
332 assert_eq!(res, &[2, 5, 5, 5, 2]);
333
334 let res = rank(&a, Some(nulls_last)).unwrap();
335 assert_eq!(res, &[5, 3, 3, 3, 5]);
336
337 let res = rank(&a, Some(nulls_last_descending)).unwrap();
338 assert_eq!(res, &[2, 5, 5, 5, 2]);
339 }
340
341 #[test]
342 fn test_bytes() {
343 let v = vec!["foo", "fo", "bar", "bar"];
344 let values = StringArray::from(v.clone());
345 let res = rank(&values, None).unwrap();
346 assert_eq!(res, &[4, 3, 2, 2]);
347
348 let values = LargeStringArray::from(v.clone());
349 let res = rank(&values, None).unwrap();
350 assert_eq!(res, &[4, 3, 2, 2]);
351
352 let v: Vec<&[u8]> = vec![&[1, 2], &[0], &[1, 2, 3], &[1, 2]];
353 let values = LargeBinaryArray::from(v.clone());
354 let res = rank(&values, None).unwrap();
355 assert_eq!(res, &[3, 1, 4, 3]);
356
357 let values = BinaryArray::from(v);
358 let res = rank(&values, None).unwrap();
359 assert_eq!(res, &[3, 1, 4, 3]);
360 }
361}