use crate::filter::SlicesIterator;
use arrow_array::*;
use arrow_data::transform::MutableArrayData;
use arrow_schema::ArrowError;
pub fn zip(
mask: &BooleanArray,
truthy: &dyn Datum,
falsy: &dyn Datum,
) -> Result<ArrayRef, ArrowError> {
let (truthy, truthy_is_scalar) = truthy.get();
let (falsy, falsy_is_scalar) = falsy.get();
if truthy.data_type() != falsy.data_type() {
return Err(ArrowError::InvalidArgumentError(
"arguments need to have the same data type".into(),
));
}
if truthy_is_scalar && truthy.len() != 1 {
return Err(ArrowError::InvalidArgumentError(
"scalar arrays must have 1 element".into(),
));
}
if !truthy_is_scalar && truthy.len() != mask.len() {
return Err(ArrowError::InvalidArgumentError(
"all arrays should have the same length".into(),
));
}
if falsy_is_scalar && falsy.len() != 1 {
return Err(ArrowError::InvalidArgumentError(
"scalar arrays must have 1 element".into(),
));
}
if !falsy_is_scalar && falsy.len() != mask.len() {
return Err(ArrowError::InvalidArgumentError(
"all arrays should have the same length".into(),
));
}
let falsy = falsy.to_data();
let truthy = truthy.to_data();
let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len());
let mut filled = 0;
SlicesIterator::new(mask).for_each(|(start, end)| {
if start > filled {
if falsy_is_scalar {
for _ in filled..start {
mutable.extend(1, 0, 1);
}
} else {
mutable.extend(1, filled, start);
}
}
if truthy_is_scalar {
for _ in start..end {
mutable.extend(0, 0, 1);
}
} else {
mutable.extend(0, start, end);
}
filled = end;
});
if filled < mask.len() {
if falsy_is_scalar {
for _ in filled..mask.len() {
mutable.extend(1, 0, 1);
}
} else {
mutable.extend(1, filled, mask.len());
}
}
let data = mutable.freeze();
Ok(make_array(data))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_zip_kernel_one() {
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let out = zip(&mask, &a, &b).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![Some(5), None, Some(6), Some(7), Some(1)]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_two() {
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
let mask = BooleanArray::from(vec![false, false, true, true, false]);
let out = zip(&mask, &a, &b).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![None, Some(3), Some(7), None, Some(3)]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_scalar_falsy_1() {
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
let fallback = Scalar::new(Int32Array::from_value(42, 1));
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let out = zip(&mask, &a, &fallback).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_scalar_falsy_2() {
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
let fallback = Scalar::new(Int32Array::from_value(42, 1));
let mask = BooleanArray::from(vec![false, false, true, true, false]);
let out = zip(&mask, &a, &fallback).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_scalar_truthy_1() {
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
let fallback = Scalar::new(Int32Array::from_value(42, 1));
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let out = zip(&mask, &fallback, &a).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_scalar_truthy_2() {
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
let fallback = Scalar::new(Int32Array::from_value(42, 1));
let mask = BooleanArray::from(vec![false, false, true, true, false]);
let out = zip(&mask, &fallback, &a).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_scalar_both() {
let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![Some(42), Some(42), Some(123), Some(123), Some(42)]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_scalar_none_1() {
let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
let scalar_falsy = Scalar::new(Int32Array::new_null(1));
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![Some(42), Some(42), None, None, Some(42)]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_scalar_none_2() {
let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
let scalar_falsy = Scalar::new(Int32Array::new_null(1));
let mask = BooleanArray::from(vec![false, false, true, true, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]);
assert_eq!(actual, &expected);
}
}