1use arrow_array::builder::BufferBuilder;
21use arrow_array::*;
22use arrow_buffer::buffer::NullBuffer;
23use arrow_buffer::ArrowNativeType;
24use arrow_buffer::{Buffer, MutableBuffer};
25use arrow_data::ArrayData;
26use arrow_schema::ArrowError;
27
28pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
30where
31 I: ArrowPrimitiveType,
32 O: ArrowPrimitiveType,
33 F: Fn(I::Native) -> O::Native,
34{
35 array.unary(op)
36}
37
38pub fn unary_mut<I, F>(
40 array: PrimitiveArray<I>,
41 op: F,
42) -> Result<PrimitiveArray<I>, PrimitiveArray<I>>
43where
44 I: ArrowPrimitiveType,
45 F: Fn(I::Native) -> I::Native,
46{
47 array.unary_mut(op)
48}
49
50pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>, ArrowError>
52where
53 I: ArrowPrimitiveType,
54 O: ArrowPrimitiveType,
55 F: Fn(I::Native) -> Result<O::Native, ArrowError>,
56{
57 array.try_unary(op)
58}
59
60pub fn try_unary_mut<I, F>(
62 array: PrimitiveArray<I>,
63 op: F,
64) -> Result<Result<PrimitiveArray<I>, ArrowError>, PrimitiveArray<I>>
65where
66 I: ArrowPrimitiveType,
67 F: Fn(I::Native) -> Result<I::Native, ArrowError>,
68{
69 array.try_unary_mut(op)
70}
71
72pub fn binary<A, B, F, O>(
105 a: &PrimitiveArray<A>,
106 b: &PrimitiveArray<B>,
107 op: F,
108) -> Result<PrimitiveArray<O>, ArrowError>
109where
110 A: ArrowPrimitiveType,
111 B: ArrowPrimitiveType,
112 O: ArrowPrimitiveType,
113 F: Fn(A::Native, B::Native) -> O::Native,
114{
115 if a.len() != b.len() {
116 return Err(ArrowError::ComputeError(
117 "Cannot perform binary operation on arrays of different length".to_string(),
118 ));
119 }
120
121 if a.is_empty() {
122 return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
123 }
124
125 let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref());
126
127 let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r));
128 let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
134 Ok(PrimitiveArray::new(buffer.into(), nulls))
135}
136
137pub fn binary_mut<T, U, F>(
202 a: PrimitiveArray<T>,
203 b: &PrimitiveArray<U>,
204 op: F,
205) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
206where
207 T: ArrowPrimitiveType,
208 U: ArrowPrimitiveType,
209 F: Fn(T::Native, U::Native) -> T::Native,
210{
211 if a.len() != b.len() {
212 return Ok(Err(ArrowError::ComputeError(
213 "Cannot perform binary operation on arrays of different length".to_string(),
214 )));
215 }
216
217 if a.is_empty() {
218 return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
219 &T::DATA_TYPE,
220 ))));
221 }
222
223 let mut builder = a.into_builder()?;
224
225 builder
226 .values_slice_mut()
227 .iter_mut()
228 .zip(b.values())
229 .for_each(|(l, r)| *l = op(*l, *r));
230
231 let array = builder.finish();
232
233 let nulls = NullBuffer::union(array.logical_nulls().as_ref(), b.logical_nulls().as_ref());
235
236 let array_builder = array.into_data().into_builder().nulls(nulls);
237
238 let array_data = unsafe { array_builder.build_unchecked() };
239 Ok(Ok(PrimitiveArray::<T>::from(array_data)))
240}
241
242pub fn try_binary<A: ArrayAccessor, B: ArrayAccessor, F, O>(
255 a: A,
256 b: B,
257 op: F,
258) -> Result<PrimitiveArray<O>, ArrowError>
259where
260 O: ArrowPrimitiveType,
261 F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
262{
263 if a.len() != b.len() {
264 return Err(ArrowError::ComputeError(
265 "Cannot perform a binary operation on arrays of different length".to_string(),
266 ));
267 }
268 if a.is_empty() {
269 return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
270 }
271 let len = a.len();
272
273 if a.null_count() == 0 && b.null_count() == 0 {
274 try_binary_no_nulls(len, a, b, op)
275 } else {
276 let nulls =
277 NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap();
278
279 let mut buffer = BufferBuilder::<O::Native>::new(len);
280 buffer.append_n_zeroed(len);
281 let slice = buffer.as_slice_mut();
282
283 nulls.try_for_each_valid_idx(|idx| {
284 unsafe {
285 *slice.get_unchecked_mut(idx) = op(a.value_unchecked(idx), b.value_unchecked(idx))?
286 };
287 Ok::<_, ArrowError>(())
288 })?;
289
290 let values = buffer.finish().into();
291 Ok(PrimitiveArray::new(values, Some(nulls)))
292 }
293}
294
295pub fn try_binary_mut<T, F>(
306 a: PrimitiveArray<T>,
307 b: &PrimitiveArray<T>,
308 op: F,
309) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
310where
311 T: ArrowPrimitiveType,
312 F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
313{
314 if a.len() != b.len() {
315 return Ok(Err(ArrowError::ComputeError(
316 "Cannot perform binary operation on arrays of different length".to_string(),
317 )));
318 }
319 let len = a.len();
320
321 if a.is_empty() {
322 return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
323 &T::DATA_TYPE,
324 ))));
325 }
326
327 if a.null_count() == 0 && b.null_count() == 0 {
328 try_binary_no_nulls_mut(len, a, b, op)
329 } else {
330 let nulls =
331 create_union_null_buffer(a.logical_nulls().as_ref(), b.logical_nulls().as_ref())
332 .unwrap();
333
334 let mut builder = a.into_builder()?;
335
336 let slice = builder.values_slice_mut();
337
338 let r = nulls.try_for_each_valid_idx(|idx| {
339 unsafe {
340 *slice.get_unchecked_mut(idx) =
341 op(*slice.get_unchecked(idx), b.value_unchecked(idx))?
342 };
343 Ok::<_, ArrowError>(())
344 });
345 if let Err(err) = r {
346 return Ok(Err(err));
347 }
348 let array_builder = builder.finish().into_data().into_builder();
349 let array_data = unsafe { array_builder.nulls(Some(nulls)).build_unchecked() };
350 Ok(Ok(PrimitiveArray::<T>::from(array_data)))
351 }
352}
353
354fn create_union_null_buffer(
360 lhs: Option<&NullBuffer>,
361 rhs: Option<&NullBuffer>,
362) -> Option<NullBuffer> {
363 match (lhs, rhs) {
364 (Some(lhs), Some(rhs)) => Some(NullBuffer::new(lhs.inner() & rhs.inner())),
365 (Some(n), None) | (None, Some(n)) => Some(NullBuffer::new(n.inner() & n.inner())),
366 (None, None) => None,
367 }
368}
369
370#[inline(never)]
372fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
373 len: usize,
374 a: A,
375 b: B,
376 op: F,
377) -> Result<PrimitiveArray<O>, ArrowError>
378where
379 O: ArrowPrimitiveType,
380 F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
381{
382 let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width());
383 for idx in 0..len {
384 unsafe {
385 buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);
386 };
387 }
388 Ok(PrimitiveArray::new(buffer.into(), None))
389}
390
391#[inline(never)]
393fn try_binary_no_nulls_mut<T, F>(
394 len: usize,
395 a: PrimitiveArray<T>,
396 b: &PrimitiveArray<T>,
397 op: F,
398) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
399where
400 T: ArrowPrimitiveType,
401 F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
402{
403 let mut builder = a.into_builder()?;
404 let slice = builder.values_slice_mut();
405
406 for idx in 0..len {
407 unsafe {
408 match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) {
409 Ok(value) => *slice.get_unchecked_mut(idx) = value,
410 Err(err) => return Ok(Err(err)),
411 };
412 };
413 }
414 Ok(Ok(builder.finish()))
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use arrow_array::types::*;
421 use std::sync::Arc;
422
423 #[test]
424 #[allow(deprecated)]
425 fn test_unary_f64_slice() {
426 let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
427 let input_slice = input.slice(1, 4);
428 let result = unary(&input_slice, |n| n.round());
429 assert_eq!(
430 result,
431 Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
432 );
433 }
434
435 #[test]
436 fn test_binary_mut() {
437 let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
438 let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
439 let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap();
440
441 let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
442 assert_eq!(c, expected);
443 }
444
445 #[test]
446 fn test_binary_mut_null_buffer() {
447 let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
448
449 let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);
450
451 let r1 = binary_mut(a, &b, |a, b| a + b).unwrap();
452
453 let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
454 let b = Int32Array::new(
455 vec![10, 11, 12, 13, 14].into(),
456 Some(vec![true, true, true, true, true].into()),
457 );
458
459 let r2 = binary_mut(a, &b, |a, b| a + b).unwrap();
461 assert_eq!(r1.unwrap(), r2.unwrap());
462 }
463
464 #[test]
465 fn test_try_binary_mut() {
466 let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
467 let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
468 let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
469
470 let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
471 assert_eq!(c, expected);
472
473 let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
474 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
475 let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
476 let expected = Int32Array::from(vec![16, 16, 12, 12, 6]);
477 assert_eq!(c, expected);
478
479 let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
480 let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
481 let _ = try_binary_mut(a, &b, |l, r| {
482 if l == 1 {
483 Err(ArrowError::InvalidArgumentError(
484 "got error".parse().unwrap(),
485 ))
486 } else {
487 Ok(l + r)
488 }
489 })
490 .unwrap()
491 .expect_err("should got error");
492 }
493
494 #[test]
495 fn test_try_binary_mut_null_buffer() {
496 let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
497
498 let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);
499
500 let r1 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
501
502 let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
503 let b = Int32Array::new(
504 vec![10, 11, 12, 13, 14].into(),
505 Some(vec![true, true, true, true, true].into()),
506 );
507
508 let r2 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
510 assert_eq!(r1.unwrap(), r2.unwrap());
511 }
512
513 #[test]
514 fn test_unary_dict_mut() {
515 let values = Int32Array::from(vec![Some(10), Some(20), None]);
516 let keys = Int8Array::from_iter_values([0, 0, 1, 2]);
517 let dictionary = DictionaryArray::new(keys, Arc::new(values));
518
519 let updated = dictionary.unary_mut::<_, Int32Type>(|x| x + 1).unwrap();
520 let typed = updated.downcast_dict::<Int32Array>().unwrap();
521 assert_eq!(typed.value(0), 11);
522 assert_eq!(typed.value(1), 11);
523 assert_eq!(typed.value(2), 21);
524
525 let values = updated.values();
526 assert!(values.is_null(2));
527 }
528}