1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::filter::SlicesIterator;
use arrow_array::*;
use arrow_data::transform::MutableArrayData;
use arrow_schema::ArrowError;

/// Zip two arrays by some boolean mask. Where the mask evaluates `true` values of `truthy`
/// are taken, where the mask evaluates `false` values of `falsy` are taken.
///
/// # Arguments
/// * `mask` - Boolean values used to determine from which array to take the values.
/// * `truthy` - Values of this array are taken if mask evaluates `true`
/// * `falsy` - Values of this array are taken if mask evaluates `false`
pub fn zip(
    mask: &BooleanArray,
    truthy: &dyn Array,
    falsy: &dyn Array,
) -> Result<ArrayRef, ArrowError> {
    if truthy.data_type() != falsy.data_type() {
        return Err(ArrowError::InvalidArgumentError(
            "arguments need to have the same data type".into(),
        ));
    }
    if truthy.len() != falsy.len() || falsy.len() != mask.len() {
        return Err(ArrowError::InvalidArgumentError(
            "all arrays should have the same length".into(),
        ));
    }
    let falsy = falsy.to_data();
    let truthy = truthy.to_data();

    let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len());

    // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
    // fill with falsy values

    // keep track of how much is filled
    let mut filled = 0;

    SlicesIterator::new(mask).for_each(|(start, end)| {
        // the gap needs to be filled with falsy values
        if start > filled {
            mutable.extend(1, filled, start);
        }
        // fill with truthy values
        mutable.extend(0, start, end);
        filled = end;
    });
    // the remaining part is falsy
    if filled < truthy.len() {
        mutable.extend(1, filled, truthy.len());
    }

    let data = mutable.freeze();
    Ok(make_array(data))
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_zip_kernel() {
        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
        let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
        let mask = BooleanArray::from(vec![true, true, false, false, true]);
        let out = zip(&mask, &a, &b).unwrap();
        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
        let expected = Int32Array::from(vec![Some(5), None, Some(6), Some(7), Some(1)]);
        assert_eq!(actual, &expected);
    }
}