use arrow_array::builder::BufferBuilder;
use arrow_array::types::ArrowDictionaryKeyType;
use arrow_array::*;
use arrow_buffer::buffer::NullBuffer;
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_data::ArrayData;
use arrow_schema::ArrowError;
use std::sync::Arc;
pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> O::Native,
{
array.unary(op)
}
pub fn unary_mut<I, F>(
array: PrimitiveArray<I>,
op: F,
) -> Result<PrimitiveArray<I>, PrimitiveArray<I>>
where
I: ArrowPrimitiveType,
F: Fn(I::Native) -> I::Native,
{
array.unary_mut(op)
}
pub fn try_unary<I, F, O>(
array: &PrimitiveArray<I>,
op: F,
) -> Result<PrimitiveArray<O>, ArrowError>
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> Result<O::Native, ArrowError>,
{
array.try_unary(op)
}
pub fn try_unary_mut<I, F>(
array: PrimitiveArray<I>,
op: F,
) -> Result<Result<PrimitiveArray<I>, ArrowError>, PrimitiveArray<I>>
where
I: ArrowPrimitiveType,
F: Fn(I::Native) -> Result<I::Native, ArrowError>,
{
array.try_unary_mut(op)
}
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef, ArrowError>
where
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = unary::<T, F, T>(dict_values, op);
Ok(Arc::new(array.with_values(Arc::new(values))))
}
fn try_unary_dict<K, F, T>(
array: &DictionaryArray<K>,
op: F,
) -> Result<ArrayRef, ArrowError>
where
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native, ArrowError>,
{
if !PrimitiveArray::<T>::is_compatible(&array.value_type()) {
return Err(ArrowError::CastError(format!(
"Cannot perform the unary operation of type {} on dictionary array of value type {}",
T::DATA_TYPE,
array.value_type()
)));
}
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = try_unary::<T, F, T>(dict_values, op)?;
Ok(Arc::new(array.with_values(Arc::new(values))))
}
#[deprecated(note = "Use arrow_array::AnyDictionaryArray")]
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef, ArrowError>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
downcast_dictionary_array! {
array => unary_dict::<_, F, T>(array, op),
t => {
if PrimitiveArray::<T>::is_compatible(t) {
Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation of type {} on array of type {}",
T::DATA_TYPE,
t
)))
}
}
}
}
#[deprecated(note = "Use arrow_array::AnyDictionaryArray")]
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef, ArrowError>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native, ArrowError>,
{
downcast_dictionary_array! {
array => if array.values().data_type() == &T::DATA_TYPE {
try_unary_dict::<_, F, T>(array, op)
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on dictionary array of type {}",
array.data_type()
)))
},
t => {
if PrimitiveArray::<T>::is_compatible(t) {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)?))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation of type {} on array of type {}",
T::DATA_TYPE,
t
)))
}
}
}
}
pub fn binary<A, B, F, O>(
a: &PrimitiveArray<A>,
b: &PrimitiveArray<B>,
op: F,
) -> Result<PrimitiveArray<O>, ArrowError>
where
A: ArrowPrimitiveType,
B: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(A::Native, B::Native) -> O::Native,
{
if a.len() != b.len() {
return Err(ArrowError::ComputeError(
"Cannot perform binary operation on arrays of different length".to_string(),
));
}
if a.is_empty() {
return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
}
let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref());
let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r));
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
Ok(PrimitiveArray::new(buffer.into(), nulls))
}
pub fn binary_mut<T, F>(
a: PrimitiveArray<T>,
b: &PrimitiveArray<T>,
op: F,
) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
where
T: ArrowPrimitiveType,
F: Fn(T::Native, T::Native) -> T::Native,
{
if a.len() != b.len() {
return Ok(Err(ArrowError::ComputeError(
"Cannot perform binary operation on arrays of different length".to_string(),
)));
}
if a.is_empty() {
return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
&T::DATA_TYPE,
))));
}
let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref());
let mut builder = a.into_builder()?;
builder
.values_slice_mut()
.iter_mut()
.zip(b.values())
.for_each(|(l, r)| *l = op(*l, *r));
let array_builder = builder.finish().into_data().into_builder().nulls(nulls);
let array_data = unsafe { array_builder.build_unchecked() };
Ok(Ok(PrimitiveArray::<T>::from(array_data)))
}
pub fn try_binary<A: ArrayAccessor, B: ArrayAccessor, F, O>(
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>, ArrowError>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
{
if a.len() != b.len() {
return Err(ArrowError::ComputeError(
"Cannot perform a binary operation on arrays of different length".to_string(),
));
}
if a.is_empty() {
return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
}
let len = a.len();
if a.null_count() == 0 && b.null_count() == 0 {
try_binary_no_nulls(len, a, b, op)
} else {
let nulls =
NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref())
.unwrap();
let mut buffer = BufferBuilder::<O::Native>::new(len);
buffer.append_n_zeroed(len);
let slice = buffer.as_slice_mut();
nulls.try_for_each_valid_idx(|idx| {
unsafe {
*slice.get_unchecked_mut(idx) =
op(a.value_unchecked(idx), b.value_unchecked(idx))?
};
Ok::<_, ArrowError>(())
})?;
let values = buffer.finish().into();
Ok(PrimitiveArray::new(values, Some(nulls)))
}
}
pub fn try_binary_mut<T, F>(
a: PrimitiveArray<T>,
b: &PrimitiveArray<T>,
op: F,
) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
where
T: ArrowPrimitiveType,
F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
{
if a.len() != b.len() {
return Ok(Err(ArrowError::ComputeError(
"Cannot perform binary operation on arrays of different length".to_string(),
)));
}
let len = a.len();
if a.is_empty() {
return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
&T::DATA_TYPE,
))));
}
if a.null_count() == 0 && b.null_count() == 0 {
try_binary_no_nulls_mut(len, a, b, op)
} else {
let nulls =
NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref())
.unwrap();
let mut builder = a.into_builder()?;
let slice = builder.values_slice_mut();
match nulls.try_for_each_valid_idx(|idx| {
unsafe {
*slice.get_unchecked_mut(idx) =
op(*slice.get_unchecked(idx), b.value_unchecked(idx))?
};
Ok::<_, ArrowError>(())
}) {
Ok(_) => {}
Err(err) => return Ok(Err(err)),
};
let array_builder = builder.finish().into_data().into_builder();
let array_data = unsafe { array_builder.nulls(Some(nulls)).build_unchecked() };
Ok(Ok(PrimitiveArray::<T>::from(array_data)))
}
}
#[inline(never)]
fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
len: usize,
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>, ArrowError>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
{
let mut buffer = MutableBuffer::new(len * O::get_byte_width());
for idx in 0..len {
unsafe {
buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);
};
}
Ok(PrimitiveArray::new(buffer.into(), None))
}
#[inline(never)]
fn try_binary_no_nulls_mut<T, F>(
len: usize,
a: PrimitiveArray<T>,
b: &PrimitiveArray<T>,
op: F,
) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
where
T: ArrowPrimitiveType,
F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
{
let mut builder = a.into_builder()?;
let slice = builder.values_slice_mut();
for idx in 0..len {
unsafe {
match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) {
Ok(value) => *slice.get_unchecked_mut(idx) = value,
Err(err) => return Ok(Err(err)),
};
};
}
Ok(Ok(builder.finish()))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::*;
use arrow_array::types::*;
#[test]
#[allow(deprecated)]
fn test_unary_f64_slice() {
let input =
Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
let input_slice = input.slice(1, 4);
let result = unary(&input_slice, |n| n.round());
assert_eq!(
result,
Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
);
let result = unary_dyn::<_, Float64Type>(&input_slice, |n| n + 1.0).unwrap();
assert_eq!(
result.as_any().downcast_ref::<Float64Array>().unwrap(),
&Float64Array::from(vec![None, Some(7.8), None, Some(8.2)])
);
}
#[test]
#[allow(deprecated)]
fn test_unary_dict_and_unary_dyn() {
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
builder.append(5).unwrap();
builder.append(6).unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append_null();
builder.append(9).unwrap();
let dictionary_array = builder.finish();
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
builder.append(6).unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append(9).unwrap();
builder.append_null();
builder.append(10).unwrap();
let expected = builder.finish();
let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
&expected
);
let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
&expected
);
}
#[test]
fn test_binary_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap();
let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
assert_eq!(c, expected);
}
#[test]
fn test_try_binary_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
assert_eq!(c, expected);
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
let expected = Int32Array::from(vec![16, 16, 12, 12, 6]);
assert_eq!(c, expected);
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
let _ = try_binary_mut(a, &b, |l, r| {
if l == 1 {
Err(ArrowError::InvalidArgumentError(
"got error".parse().unwrap(),
))
} else {
Ok(l + r)
}
})
.unwrap()
.expect_err("should got error");
}
#[test]
fn test_unary_dict_mut() {
let values = Int32Array::from(vec![Some(10), Some(20), None]);
let keys = Int8Array::from_iter_values([0, 0, 1, 2]);
let dictionary = DictionaryArray::new(keys, Arc::new(values));
let updated = dictionary.unary_mut::<_, Int32Type>(|x| x + 1).unwrap();
let typed = updated.downcast_dict::<Int32Array>().unwrap();
assert_eq!(typed.value(0), 11);
assert_eq!(typed.value(1), 11);
assert_eq!(typed.value(2), 21);
let values = updated.values();
assert!(values.is_null(2));
}
}