datafusion_physical_expr/expressions/
binary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod kernels;
19
20use std::hash::Hash;
21use std::{any::Any, sync::Arc};
22
23use crate::expressions::binary::kernels::concat_elements_utf8view;
24use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
25use crate::PhysicalExpr;
26
27use arrow::array::*;
28use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
29use arrow::compute::kernels::cmp::*;
30use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar};
31use arrow::compute::kernels::concat_elements::concat_elements_utf8;
32use arrow::compute::{cast, ilike, like, nilike, nlike};
33use arrow::datatypes::*;
34use arrow::error::ArrowError;
35use datafusion_common::cast::as_boolean_array;
36use datafusion_common::{internal_err, Result, ScalarValue};
37use datafusion_expr::binary::BinaryTypeCoercer;
38use datafusion_expr::interval_arithmetic::{apply_operator, Interval};
39use datafusion_expr::sort_properties::ExprProperties;
40use datafusion_expr::statistics::Distribution::{Bernoulli, Gaussian};
41use datafusion_expr::statistics::{
42    combine_bernoullis, combine_gaussians, create_bernoulli_from_comparison,
43    new_generic_from_binary_op, Distribution,
44};
45use datafusion_expr::{ColumnarValue, Operator};
46use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
47
48use kernels::{
49    bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
50    bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn,
51    bitwise_shift_right_dyn_scalar, bitwise_xor_dyn, bitwise_xor_dyn_scalar,
52};
53
54/// Binary expression
55#[derive(Debug, Clone, Eq)]
56pub struct BinaryExpr {
57    left: Arc<dyn PhysicalExpr>,
58    op: Operator,
59    right: Arc<dyn PhysicalExpr>,
60    /// Specifies whether an error is returned on overflow or not
61    fail_on_overflow: bool,
62}
63
64// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
65impl PartialEq for BinaryExpr {
66    fn eq(&self, other: &Self) -> bool {
67        self.left.eq(&other.left)
68            && self.op.eq(&other.op)
69            && self.right.eq(&other.right)
70            && self.fail_on_overflow.eq(&other.fail_on_overflow)
71    }
72}
73impl Hash for BinaryExpr {
74    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
75        self.left.hash(state);
76        self.op.hash(state);
77        self.right.hash(state);
78        self.fail_on_overflow.hash(state);
79    }
80}
81
82impl BinaryExpr {
83    /// Create new binary expression
84    pub fn new(
85        left: Arc<dyn PhysicalExpr>,
86        op: Operator,
87        right: Arc<dyn PhysicalExpr>,
88    ) -> Self {
89        Self {
90            left,
91            op,
92            right,
93            fail_on_overflow: false,
94        }
95    }
96
97    /// Create new binary expression with explicit fail_on_overflow value
98    pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self {
99        Self {
100            left: self.left,
101            op: self.op,
102            right: self.right,
103            fail_on_overflow,
104        }
105    }
106
107    /// Get the left side of the binary expression
108    pub fn left(&self) -> &Arc<dyn PhysicalExpr> {
109        &self.left
110    }
111
112    /// Get the right side of the binary expression
113    pub fn right(&self) -> &Arc<dyn PhysicalExpr> {
114        &self.right
115    }
116
117    /// Get the operator for this binary expression
118    pub fn op(&self) -> &Operator {
119        &self.op
120    }
121}
122
123impl std::fmt::Display for BinaryExpr {
124    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
125        // Put parentheses around child binary expressions so that we can see the difference
126        // between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
127        // based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
128        // equivalent and the parentheses are not necessary.
129
130        fn write_child(
131            f: &mut std::fmt::Formatter,
132            expr: &dyn PhysicalExpr,
133            precedence: u8,
134        ) -> std::fmt::Result {
135            if let Some(child) = expr.as_any().downcast_ref::<BinaryExpr>() {
136                let p = child.op.precedence();
137                if p == 0 || p < precedence {
138                    write!(f, "({child})")?;
139                } else {
140                    write!(f, "{child}")?;
141                }
142            } else {
143                write!(f, "{expr}")?;
144            }
145
146            Ok(())
147        }
148
149        let precedence = self.op.precedence();
150        write_child(f, self.left.as_ref(), precedence)?;
151        write!(f, " {} ", self.op)?;
152        write_child(f, self.right.as_ref(), precedence)
153    }
154}
155
156/// Invoke a boolean kernel on a pair of arrays
157#[inline]
158fn boolean_op(
159    left: &dyn Array,
160    right: &dyn Array,
161    op: impl FnOnce(&BooleanArray, &BooleanArray) -> Result<BooleanArray, ArrowError>,
162) -> Result<Arc<(dyn Array + 'static)>, ArrowError> {
163    let ll = as_boolean_array(left).expect("boolean_op failed to downcast left array");
164    let rr = as_boolean_array(right).expect("boolean_op failed to downcast right array");
165    op(ll, rr).map(|t| Arc::new(t) as _)
166}
167
168macro_rules! binary_string_array_flag_op {
169    ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
170        match $LEFT.data_type() {
171            DataType::Utf8View | DataType::Utf8 => {
172                compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
173            },
174            DataType::LargeUtf8 => {
175                compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
176            },
177            other => internal_err!(
178                "Data type {:?} not supported for binary_string_array_flag_op operation '{}' on string array",
179                other, stringify!($OP)
180            ),
181        }
182    }};
183}
184
185/// Invoke a compute kernel on a pair of binary data arrays with flags
186macro_rules! compute_utf8_flag_op {
187    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
188        let ll = $LEFT
189            .as_any()
190            .downcast_ref::<$ARRAYTYPE>()
191            .expect("compute_utf8_flag_op failed to downcast array");
192        let rr = $RIGHT
193            .as_any()
194            .downcast_ref::<$ARRAYTYPE>()
195            .expect("compute_utf8_flag_op failed to downcast array");
196
197        let flag = if $FLAG {
198            Some($ARRAYTYPE::from(vec!["i"; ll.len()]))
199        } else {
200            None
201        };
202        let mut array = $OP(ll, rr, flag.as_ref())?;
203        if $NOT {
204            array = not(&array).unwrap();
205        }
206        Ok(Arc::new(array))
207    }};
208}
209
210macro_rules! binary_string_array_flag_op_scalar {
211    ($LEFT:ident, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
212        // This macro is slightly different from binary_string_array_flag_op because, when comparing with a scalar value,
213        // the query can be optimized in such a way that operands will be dicts, so we need to support it here
214        let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
215            DataType::Utf8View | DataType::Utf8 => {
216                compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
217            },
218            DataType::LargeUtf8 => {
219                compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
220            },
221            DataType::Dictionary(_, _) => {
222                let values = $LEFT.as_any_dictionary().values();
223
224                match values.data_type() {
225                    DataType::Utf8View | DataType::Utf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, StringArray, $NOT, $FLAG),
226                    DataType::LargeUtf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG),
227                    other => internal_err!(
228                        "Data type {:?} not supported as a dictionary value type for binary_string_array_flag_op_scalar operation '{}' on string array",
229                        other, stringify!($OP)
230                    ),
231                }.map(
232                    // downcast_dictionary_array duplicates code per possible key type, so we aim to do all prep work before
233                    |evaluated_values| downcast_dictionary_array! {
234                        $LEFT => {
235                            let unpacked_dict = evaluated_values.take_iter($LEFT.keys().iter().map(|opt| opt.map(|v| v as _))).collect::<BooleanArray>();
236                            Arc::new(unpacked_dict) as _
237                        },
238                        _ => unreachable!(),
239                    }
240                )
241            },
242            other => internal_err!(
243                "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array",
244                other, stringify!($OP)
245            ),
246        };
247        Some(result)
248    }};
249}
250
251/// Invoke a compute kernel on a data array and a scalar value with flag
252macro_rules! compute_utf8_flag_op_scalar {
253    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
254        let ll = $LEFT
255            .as_any()
256            .downcast_ref::<$ARRAYTYPE>()
257            .expect("compute_utf8_flag_op_scalar failed to downcast array");
258
259        let string_value = match $RIGHT.try_as_str() {
260            Some(Some(string_value)) => string_value,
261            // null literal or non string
262            _ => return internal_err!(
263                        "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
264                        $RIGHT, stringify!($OP)
265                    )
266        };
267
268        let flag = $FLAG.then_some("i");
269        let mut array =
270            paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?;
271        if $NOT {
272            array = not(&array).unwrap();
273        }
274
275        Ok(Arc::new(array))
276    }};
277}
278
279impl PhysicalExpr for BinaryExpr {
280    /// Return a reference to Any that can be used for downcasting
281    fn as_any(&self) -> &dyn Any {
282        self
283    }
284
285    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
286        BinaryTypeCoercer::new(
287            &self.left.data_type(input_schema)?,
288            &self.op,
289            &self.right.data_type(input_schema)?,
290        )
291        .get_result_type()
292    }
293
294    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
295        Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
296    }
297
298    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
299        use arrow::compute::kernels::numeric::*;
300
301        let lhs = self.left.evaluate(batch)?;
302        let rhs = self.right.evaluate(batch)?;
303        let left_data_type = lhs.data_type();
304        let right_data_type = rhs.data_type();
305
306        let schema = batch.schema();
307        let input_schema = schema.as_ref();
308
309        if left_data_type.is_nested() {
310            if right_data_type != left_data_type {
311                return internal_err!("type mismatch");
312            }
313            return apply_cmp_for_nested(self.op, &lhs, &rhs);
314        }
315
316        match self.op {
317            Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
318            Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
319            Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub),
320            Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
321            Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul),
322            Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
323            Operator::Divide => return apply(&lhs, &rhs, div),
324            Operator::Modulo => return apply(&lhs, &rhs, rem),
325            Operator::Eq => return apply_cmp(&lhs, &rhs, eq),
326            Operator::NotEq => return apply_cmp(&lhs, &rhs, neq),
327            Operator::Lt => return apply_cmp(&lhs, &rhs, lt),
328            Operator::Gt => return apply_cmp(&lhs, &rhs, gt),
329            Operator::LtEq => return apply_cmp(&lhs, &rhs, lt_eq),
330            Operator::GtEq => return apply_cmp(&lhs, &rhs, gt_eq),
331            Operator::IsDistinctFrom => return apply_cmp(&lhs, &rhs, distinct),
332            Operator::IsNotDistinctFrom => return apply_cmp(&lhs, &rhs, not_distinct),
333            Operator::LikeMatch => return apply_cmp(&lhs, &rhs, like),
334            Operator::ILikeMatch => return apply_cmp(&lhs, &rhs, ilike),
335            Operator::NotLikeMatch => return apply_cmp(&lhs, &rhs, nlike),
336            Operator::NotILikeMatch => return apply_cmp(&lhs, &rhs, nilike),
337            _ => {}
338        }
339
340        let result_type = self.data_type(input_schema)?;
341
342        // Attempt to use special kernels if one input is scalar and the other is an array
343        let scalar_result = match (&lhs, &rhs) {
344            (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => {
345                // if left is array and right is literal(not NULL) - use scalar operations
346                if scalar.is_null() {
347                    None
348                } else {
349                    self.evaluate_array_scalar(array, scalar.clone())?.map(|r| {
350                        r.and_then(|a| to_result_type_array(&self.op, a, &result_type))
351                    })
352                }
353            }
354            (_, _) => None, // default to array implementation
355        };
356
357        if let Some(result) = scalar_result {
358            return result.map(ColumnarValue::Array);
359        }
360
361        // if both arrays or both literals - extract arrays and continue execution
362        let (left, right) = (
363            lhs.into_array(batch.num_rows())?,
364            rhs.into_array(batch.num_rows())?,
365        );
366        self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type)
367            .map(ColumnarValue::Array)
368    }
369
370    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
371        vec![&self.left, &self.right]
372    }
373
374    fn with_new_children(
375        self: Arc<Self>,
376        children: Vec<Arc<dyn PhysicalExpr>>,
377    ) -> Result<Arc<dyn PhysicalExpr>> {
378        Ok(Arc::new(
379            BinaryExpr::new(Arc::clone(&children[0]), self.op, Arc::clone(&children[1]))
380                .with_fail_on_overflow(self.fail_on_overflow),
381        ))
382    }
383
384    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
385        // Get children intervals:
386        let left_interval = children[0];
387        let right_interval = children[1];
388        // Calculate current node's interval:
389        apply_operator(&self.op, left_interval, right_interval)
390    }
391
392    fn propagate_constraints(
393        &self,
394        interval: &Interval,
395        children: &[&Interval],
396    ) -> Result<Option<Vec<Interval>>> {
397        // Get children intervals.
398        let left_interval = children[0];
399        let right_interval = children[1];
400
401        if self.op.eq(&Operator::And) {
402            if interval.eq(&Interval::CERTAINLY_TRUE) {
403                // A certainly true logical conjunction can only derive from possibly
404                // true operands. Otherwise, we prove infeasibility.
405                Ok((!left_interval.eq(&Interval::CERTAINLY_FALSE)
406                    && !right_interval.eq(&Interval::CERTAINLY_FALSE))
407                .then(|| vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_TRUE]))
408            } else if interval.eq(&Interval::CERTAINLY_FALSE) {
409                // If the logical conjunction is certainly false, one of the
410                // operands must be false. However, it's not always possible to
411                // determine which operand is false, leading to different scenarios.
412
413                // If one operand is certainly true and the other one is uncertain,
414                // then the latter must be certainly false.
415                if left_interval.eq(&Interval::CERTAINLY_TRUE)
416                    && right_interval.eq(&Interval::UNCERTAIN)
417                {
418                    Ok(Some(vec![
419                        Interval::CERTAINLY_TRUE,
420                        Interval::CERTAINLY_FALSE,
421                    ]))
422                } else if right_interval.eq(&Interval::CERTAINLY_TRUE)
423                    && left_interval.eq(&Interval::UNCERTAIN)
424                {
425                    Ok(Some(vec![
426                        Interval::CERTAINLY_FALSE,
427                        Interval::CERTAINLY_TRUE,
428                    ]))
429                }
430                // If both children are uncertain, or if one is certainly false,
431                // we cannot conclusively refine their intervals. In this case,
432                // propagation does not result in any interval changes.
433                else {
434                    Ok(Some(vec![]))
435                }
436            } else {
437                // An uncertain logical conjunction result can not shrink the
438                // end-points of its children.
439                Ok(Some(vec![]))
440            }
441        } else if self.op.eq(&Operator::Or) {
442            if interval.eq(&Interval::CERTAINLY_FALSE) {
443                // A certainly false logical conjunction can only derive from certainly
444                // false operands. Otherwise, we prove infeasibility.
445                Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE)
446                    && !right_interval.eq(&Interval::CERTAINLY_TRUE))
447                .then(|| vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE]))
448            } else if interval.eq(&Interval::CERTAINLY_TRUE) {
449                // If the logical disjunction is certainly true, one of the
450                // operands must be true. However, it's not always possible to
451                // determine which operand is true, leading to different scenarios.
452
453                // If one operand is certainly false and the other one is uncertain,
454                // then the latter must be certainly true.
455                if left_interval.eq(&Interval::CERTAINLY_FALSE)
456                    && right_interval.eq(&Interval::UNCERTAIN)
457                {
458                    Ok(Some(vec![
459                        Interval::CERTAINLY_FALSE,
460                        Interval::CERTAINLY_TRUE,
461                    ]))
462                } else if right_interval.eq(&Interval::CERTAINLY_FALSE)
463                    && left_interval.eq(&Interval::UNCERTAIN)
464                {
465                    Ok(Some(vec![
466                        Interval::CERTAINLY_TRUE,
467                        Interval::CERTAINLY_FALSE,
468                    ]))
469                }
470                // If both children are uncertain, or if one is certainly true,
471                // we cannot conclusively refine their intervals. In this case,
472                // propagation does not result in any interval changes.
473                else {
474                    Ok(Some(vec![]))
475                }
476            } else {
477                // An uncertain logical disjunction result can not shrink the
478                // end-points of its children.
479                Ok(Some(vec![]))
480            }
481        } else if self.op.supports_propagation() {
482            Ok(
483                propagate_comparison(&self.op, interval, left_interval, right_interval)?
484                    .map(|(left, right)| vec![left, right]),
485            )
486        } else {
487            Ok(
488                propagate_arithmetic(&self.op, interval, left_interval, right_interval)?
489                    .map(|(left, right)| vec![left, right]),
490            )
491        }
492    }
493
494    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
495        let (left, right) = (children[0], children[1]);
496
497        if self.op.is_numerical_operators() {
498            // We might be able to construct the output statistics more accurately,
499            // without falling back to an unknown distribution, if we are dealing
500            // with Gaussian distributions and numerical operations.
501            if let (Gaussian(left), Gaussian(right)) = (left, right) {
502                if let Some(result) = combine_gaussians(&self.op, left, right)? {
503                    return Ok(Gaussian(result));
504                }
505            }
506        } else if self.op.is_logic_operator() {
507            // If we are dealing with logical operators, we expect (and can only
508            // operate on) Bernoulli distributions.
509            return if let (Bernoulli(left), Bernoulli(right)) = (left, right) {
510                combine_bernoullis(&self.op, left, right).map(Bernoulli)
511            } else {
512                internal_err!(
513                    "Logical operators are only compatible with `Bernoulli` distributions"
514                )
515            };
516        } else if self.op.supports_propagation() {
517            // If we are handling comparison operators, we expect (and can only
518            // operate on) numeric distributions.
519            return create_bernoulli_from_comparison(&self.op, left, right);
520        }
521        // Fall back to an unknown distribution with only summary statistics:
522        new_generic_from_binary_op(&self.op, left, right)
523    }
524
525    /// For each operator, [`BinaryExpr`] has distinct rules.
526    /// TODO: There may be rules specific to some data types and expression ranges.
527    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
528        let (l_order, l_range) = (children[0].sort_properties, &children[0].range);
529        let (r_order, r_range) = (children[1].sort_properties, &children[1].range);
530        match self.op() {
531            Operator::Plus => Ok(ExprProperties {
532                sort_properties: l_order.add(&r_order),
533                range: l_range.add(r_range)?,
534                preserves_lex_ordering: false,
535            }),
536            Operator::Minus => Ok(ExprProperties {
537                sort_properties: l_order.sub(&r_order),
538                range: l_range.sub(r_range)?,
539                preserves_lex_ordering: false,
540            }),
541            Operator::Gt => Ok(ExprProperties {
542                sort_properties: l_order.gt_or_gteq(&r_order),
543                range: l_range.gt(r_range)?,
544                preserves_lex_ordering: false,
545            }),
546            Operator::GtEq => Ok(ExprProperties {
547                sort_properties: l_order.gt_or_gteq(&r_order),
548                range: l_range.gt_eq(r_range)?,
549                preserves_lex_ordering: false,
550            }),
551            Operator::Lt => Ok(ExprProperties {
552                sort_properties: r_order.gt_or_gteq(&l_order),
553                range: l_range.lt(r_range)?,
554                preserves_lex_ordering: false,
555            }),
556            Operator::LtEq => Ok(ExprProperties {
557                sort_properties: r_order.gt_or_gteq(&l_order),
558                range: l_range.lt_eq(r_range)?,
559                preserves_lex_ordering: false,
560            }),
561            Operator::And => Ok(ExprProperties {
562                sort_properties: r_order.and_or(&l_order),
563                range: l_range.and(r_range)?,
564                preserves_lex_ordering: false,
565            }),
566            Operator::Or => Ok(ExprProperties {
567                sort_properties: r_order.and_or(&l_order),
568                range: l_range.or(r_range)?,
569                preserves_lex_ordering: false,
570            }),
571            _ => Ok(ExprProperties::new_unknown()),
572        }
573    }
574}
575
576/// Casts dictionary array to result type for binary numerical operators. Such operators
577/// between array and scalar produce a dictionary array other than primitive array of the
578/// same operators between array and array. This leads to inconsistent result types causing
579/// errors in the following query execution. For such operators between array and scalar,
580/// we cast the dictionary array to primitive array.
581fn to_result_type_array(
582    op: &Operator,
583    array: ArrayRef,
584    result_type: &DataType,
585) -> Result<ArrayRef> {
586    if array.data_type() == result_type {
587        Ok(array)
588    } else if op.is_numerical_operators() {
589        match array.data_type() {
590            DataType::Dictionary(_, value_type) => {
591                if value_type.as_ref() == result_type {
592                    Ok(cast(&array, result_type)?)
593                } else {
594                    internal_err!(
595                            "Incompatible Dictionary value type {value_type:?} with result type {result_type:?} of Binary operator {op:?}"
596                        )
597                }
598            }
599            _ => Ok(array),
600        }
601    } else {
602        Ok(array)
603    }
604}
605
606impl BinaryExpr {
607    /// Evaluate the expression of the left input is an array and
608    /// right is literal - use scalar operations
609    fn evaluate_array_scalar(
610        &self,
611        array: &dyn Array,
612        scalar: ScalarValue,
613    ) -> Result<Option<Result<ArrayRef>>> {
614        use Operator::*;
615        let scalar_result = match &self.op {
616            RegexMatch => binary_string_array_flag_op_scalar!(
617                array,
618                scalar,
619                regexp_is_match,
620                false,
621                false
622            ),
623            RegexIMatch => binary_string_array_flag_op_scalar!(
624                array,
625                scalar,
626                regexp_is_match,
627                false,
628                true
629            ),
630            RegexNotMatch => binary_string_array_flag_op_scalar!(
631                array,
632                scalar,
633                regexp_is_match,
634                true,
635                false
636            ),
637            RegexNotIMatch => binary_string_array_flag_op_scalar!(
638                array,
639                scalar,
640                regexp_is_match,
641                true,
642                true
643            ),
644            BitwiseAnd => bitwise_and_dyn_scalar(array, scalar),
645            BitwiseOr => bitwise_or_dyn_scalar(array, scalar),
646            BitwiseXor => bitwise_xor_dyn_scalar(array, scalar),
647            BitwiseShiftRight => bitwise_shift_right_dyn_scalar(array, scalar),
648            BitwiseShiftLeft => bitwise_shift_left_dyn_scalar(array, scalar),
649            // if scalar operation is not supported - fallback to array implementation
650            _ => None,
651        };
652
653        Ok(scalar_result)
654    }
655
656    fn evaluate_with_resolved_args(
657        &self,
658        left: Arc<dyn Array>,
659        left_data_type: &DataType,
660        right: Arc<dyn Array>,
661        right_data_type: &DataType,
662    ) -> Result<ArrayRef> {
663        use Operator::*;
664        match &self.op {
665            IsDistinctFrom | IsNotDistinctFrom | Lt | LtEq | Gt | GtEq | Eq | NotEq
666            | Plus | Minus | Multiply | Divide | Modulo | LikeMatch | ILikeMatch
667            | NotLikeMatch | NotILikeMatch => unreachable!(),
668            And => {
669                if left_data_type == &DataType::Boolean {
670                    Ok(boolean_op(&left, &right, and_kleene)?)
671                } else {
672                    internal_err!(
673                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
674                        self.op,
675                        left.data_type(),
676                        right.data_type()
677                    )
678                }
679            }
680            Or => {
681                if left_data_type == &DataType::Boolean {
682                    Ok(boolean_op(&left, &right, or_kleene)?)
683                } else {
684                    internal_err!(
685                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
686                        self.op,
687                        left_data_type,
688                        right_data_type
689                    )
690                }
691            }
692            RegexMatch => {
693                binary_string_array_flag_op!(left, right, regexp_is_match, false, false)
694            }
695            RegexIMatch => {
696                binary_string_array_flag_op!(left, right, regexp_is_match, false, true)
697            }
698            RegexNotMatch => {
699                binary_string_array_flag_op!(left, right, regexp_is_match, true, false)
700            }
701            RegexNotIMatch => {
702                binary_string_array_flag_op!(left, right, regexp_is_match, true, true)
703            }
704            BitwiseAnd => bitwise_and_dyn(left, right),
705            BitwiseOr => bitwise_or_dyn(left, right),
706            BitwiseXor => bitwise_xor_dyn(left, right),
707            BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
708            BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
709            StringConcat => concat_elements(left, right),
710            AtArrow | ArrowAt => {
711                unreachable!("ArrowAt and AtArrow should be rewritten to function")
712            }
713        }
714    }
715}
716
717fn concat_elements(left: Arc<dyn Array>, right: Arc<dyn Array>) -> Result<ArrayRef> {
718    Ok(match left.data_type() {
719        DataType::Utf8 => Arc::new(concat_elements_utf8(
720            left.as_string::<i32>(),
721            right.as_string::<i32>(),
722        )?),
723        DataType::LargeUtf8 => Arc::new(concat_elements_utf8(
724            left.as_string::<i64>(),
725            right.as_string::<i64>(),
726        )?),
727        DataType::Utf8View => Arc::new(concat_elements_utf8view(
728            left.as_string_view(),
729            right.as_string_view(),
730        )?),
731        other => {
732            return internal_err!(
733                "Data type {other:?} not supported for binary operation 'concat_elements' on string arrays"
734            );
735        }
736    })
737}
738
739/// Create a binary expression whose arguments are correctly coerced.
740/// This function errors if it is not possible to coerce the arguments
741/// to computational types supported by the operator.
742pub fn binary(
743    lhs: Arc<dyn PhysicalExpr>,
744    op: Operator,
745    rhs: Arc<dyn PhysicalExpr>,
746    _input_schema: &Schema,
747) -> Result<Arc<dyn PhysicalExpr>> {
748    Ok(Arc::new(BinaryExpr::new(lhs, op, rhs)))
749}
750
751/// Create a similar to expression
752pub fn similar_to(
753    negated: bool,
754    case_insensitive: bool,
755    expr: Arc<dyn PhysicalExpr>,
756    pattern: Arc<dyn PhysicalExpr>,
757) -> Result<Arc<dyn PhysicalExpr>> {
758    let binary_op = match (negated, case_insensitive) {
759        (false, false) => Operator::RegexMatch,
760        (false, true) => Operator::RegexIMatch,
761        (true, false) => Operator::RegexNotMatch,
762        (true, true) => Operator::RegexNotIMatch,
763    };
764    Ok(Arc::new(BinaryExpr::new(expr, binary_op, pattern)))
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770    use crate::expressions::{col, lit, try_cast, Column, Literal};
771
772    use datafusion_common::plan_datafusion_err;
773
774    /// Performs a binary operation, applying any type coercion necessary
775    fn binary_op(
776        left: Arc<dyn PhysicalExpr>,
777        op: Operator,
778        right: Arc<dyn PhysicalExpr>,
779        schema: &Schema,
780    ) -> Result<Arc<dyn PhysicalExpr>> {
781        let left_type = left.data_type(schema)?;
782        let right_type = right.data_type(schema)?;
783        let (lhs, rhs) =
784            BinaryTypeCoercer::new(&left_type, &op, &right_type).get_input_types()?;
785
786        let left_expr = try_cast(left, schema, lhs)?;
787        let right_expr = try_cast(right, schema, rhs)?;
788        binary(left_expr, op, right_expr, schema)
789    }
790
791    #[test]
792    fn binary_comparison() -> Result<()> {
793        let schema = Schema::new(vec![
794            Field::new("a", DataType::Int32, false),
795            Field::new("b", DataType::Int32, false),
796        ]);
797        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
798        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
799
800        // expression: "a < b"
801        let lt = binary(
802            col("a", &schema)?,
803            Operator::Lt,
804            col("b", &schema)?,
805            &schema,
806        )?;
807        let batch =
808            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
809
810        let result = lt
811            .evaluate(&batch)?
812            .into_array(batch.num_rows())
813            .expect("Failed to convert to array");
814        assert_eq!(result.len(), 5);
815
816        let expected = [false, false, true, true, true];
817        let result =
818            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
819        for (i, &expected_item) in expected.iter().enumerate().take(5) {
820            assert_eq!(result.value(i), expected_item);
821        }
822
823        Ok(())
824    }
825
826    #[test]
827    fn binary_nested() -> Result<()> {
828        let schema = Schema::new(vec![
829            Field::new("a", DataType::Int32, false),
830            Field::new("b", DataType::Int32, false),
831        ]);
832        let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
833        let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
834
835        // expression: "a < b OR a == b"
836        let expr = binary(
837            binary(
838                col("a", &schema)?,
839                Operator::Lt,
840                col("b", &schema)?,
841                &schema,
842            )?,
843            Operator::Or,
844            binary(
845                col("a", &schema)?,
846                Operator::Eq,
847                col("b", &schema)?,
848                &schema,
849            )?,
850            &schema,
851        )?;
852        let batch =
853            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
854
855        assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{expr}"));
856
857        let result = expr
858            .evaluate(&batch)?
859            .into_array(batch.num_rows())
860            .expect("Failed to convert to array");
861        assert_eq!(result.len(), 5);
862
863        let expected = [true, true, false, true, false];
864        let result =
865            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
866        for (i, &expected_item) in expected.iter().enumerate().take(5) {
867            assert_eq!(result.value(i), expected_item);
868        }
869
870        Ok(())
871    }
872
873    // runs an end-to-end test of physical type coercion:
874    // 1. construct a record batch with two columns of type A and B
875    //  (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements)
876    // 2. construct a physical expression of A OP B
877    // 3. evaluate the expression
878    // 4. verify that the resulting expression is of type C
879    // 5. verify that the results of evaluation are $VEC
880    macro_rules! test_coercion {
881        ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr,) => {{
882            let schema = Schema::new(vec![
883                Field::new("a", $A_TYPE, false),
884                Field::new("b", $B_TYPE, false),
885            ]);
886            let a = $A_ARRAY::from($A_VEC);
887            let b = $B_ARRAY::from($B_VEC);
888            let (lhs, rhs) = BinaryTypeCoercer::new(&$A_TYPE, &$OP, &$B_TYPE).get_input_types()?;
889
890            let left = try_cast(col("a", &schema)?, &schema, lhs)?;
891            let right = try_cast(col("b", &schema)?, &schema, rhs)?;
892
893            // verify that we can construct the expression
894            let expression = binary(left, $OP, right, &schema)?;
895            let batch = RecordBatch::try_new(
896                Arc::new(schema.clone()),
897                vec![Arc::new(a), Arc::new(b)],
898            )?;
899
900            // verify that the expression's type is correct
901            assert_eq!(expression.data_type(&schema)?, $C_TYPE);
902
903            // compute
904            let result = expression.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
905
906            // verify that the array's data_type is correct
907            assert_eq!(*result.data_type(), $C_TYPE);
908
909            // verify that the data itself is downcastable
910            let result = result
911                .as_any()
912                .downcast_ref::<$C_ARRAY>()
913                .expect("failed to downcast");
914            // verify that the result itself is correct
915            for (i, x) in $VEC.iter().enumerate() {
916                let v = result.value(i);
917                assert_eq!(
918                    v,
919                    *x,
920                    "Unexpected output at position {i}:\n\nActual:\n{v}\n\nExpected:\n{x}"
921                );
922            }
923        }};
924    }
925
926    #[test]
927    fn test_type_coercion() -> Result<()> {
928        test_coercion!(
929            Int32Array,
930            DataType::Int32,
931            vec![1i32, 2i32],
932            UInt32Array,
933            DataType::UInt32,
934            vec![1u32, 2u32],
935            Operator::Plus,
936            Int32Array,
937            DataType::Int32,
938            [2i32, 4i32],
939        );
940        test_coercion!(
941            Int32Array,
942            DataType::Int32,
943            vec![1i32],
944            UInt16Array,
945            DataType::UInt16,
946            vec![1u16],
947            Operator::Plus,
948            Int32Array,
949            DataType::Int32,
950            [2i32],
951        );
952        test_coercion!(
953            Float32Array,
954            DataType::Float32,
955            vec![1f32],
956            UInt16Array,
957            DataType::UInt16,
958            vec![1u16],
959            Operator::Plus,
960            Float32Array,
961            DataType::Float32,
962            [2f32],
963        );
964        test_coercion!(
965            Float32Array,
966            DataType::Float32,
967            vec![2f32],
968            UInt16Array,
969            DataType::UInt16,
970            vec![1u16],
971            Operator::Multiply,
972            Float32Array,
973            DataType::Float32,
974            [2f32],
975        );
976        test_coercion!(
977            StringArray,
978            DataType::Utf8,
979            vec!["1994-12-13", "1995-01-26"],
980            Date32Array,
981            DataType::Date32,
982            vec![9112, 9156],
983            Operator::Eq,
984            BooleanArray,
985            DataType::Boolean,
986            [true, true],
987        );
988        test_coercion!(
989            StringArray,
990            DataType::Utf8,
991            vec!["1994-12-13", "1995-01-26"],
992            Date32Array,
993            DataType::Date32,
994            vec![9113, 9154],
995            Operator::Lt,
996            BooleanArray,
997            DataType::Boolean,
998            [true, false],
999        );
1000        test_coercion!(
1001            StringArray,
1002            DataType::Utf8,
1003            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
1004            Date64Array,
1005            DataType::Date64,
1006            vec![787322096000, 791083425000],
1007            Operator::Eq,
1008            BooleanArray,
1009            DataType::Boolean,
1010            [true, true],
1011        );
1012        test_coercion!(
1013            StringArray,
1014            DataType::Utf8,
1015            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
1016            Date64Array,
1017            DataType::Date64,
1018            vec![787322096001, 791083424999],
1019            Operator::Lt,
1020            BooleanArray,
1021            DataType::Boolean,
1022            [true, false],
1023        );
1024        test_coercion!(
1025            StringViewArray,
1026            DataType::Utf8View,
1027            vec!["abc"; 5],
1028            StringArray,
1029            DataType::Utf8,
1030            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1031            Operator::RegexMatch,
1032            BooleanArray,
1033            DataType::Boolean,
1034            [true, false, true, false, false],
1035        );
1036        test_coercion!(
1037            StringViewArray,
1038            DataType::Utf8View,
1039            vec!["abc"; 5],
1040            StringArray,
1041            DataType::Utf8,
1042            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1043            Operator::RegexIMatch,
1044            BooleanArray,
1045            DataType::Boolean,
1046            [true, true, true, true, false],
1047        );
1048        test_coercion!(
1049            StringArray,
1050            DataType::Utf8,
1051            vec!["abc"; 5],
1052            StringViewArray,
1053            DataType::Utf8View,
1054            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1055            Operator::RegexNotMatch,
1056            BooleanArray,
1057            DataType::Boolean,
1058            [false, true, false, true, true],
1059        );
1060        test_coercion!(
1061            StringArray,
1062            DataType::Utf8,
1063            vec!["abc"; 5],
1064            StringViewArray,
1065            DataType::Utf8View,
1066            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1067            Operator::RegexNotIMatch,
1068            BooleanArray,
1069            DataType::Boolean,
1070            [false, false, false, false, true],
1071        );
1072        test_coercion!(
1073            StringArray,
1074            DataType::Utf8,
1075            vec!["abc"; 5],
1076            StringArray,
1077            DataType::Utf8,
1078            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1079            Operator::RegexMatch,
1080            BooleanArray,
1081            DataType::Boolean,
1082            [true, false, true, false, false],
1083        );
1084        test_coercion!(
1085            StringArray,
1086            DataType::Utf8,
1087            vec!["abc"; 5],
1088            StringArray,
1089            DataType::Utf8,
1090            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1091            Operator::RegexIMatch,
1092            BooleanArray,
1093            DataType::Boolean,
1094            [true, true, true, true, false],
1095        );
1096        test_coercion!(
1097            StringArray,
1098            DataType::Utf8,
1099            vec!["abc"; 5],
1100            StringArray,
1101            DataType::Utf8,
1102            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1103            Operator::RegexNotMatch,
1104            BooleanArray,
1105            DataType::Boolean,
1106            [false, true, false, true, true],
1107        );
1108        test_coercion!(
1109            StringArray,
1110            DataType::Utf8,
1111            vec!["abc"; 5],
1112            StringArray,
1113            DataType::Utf8,
1114            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1115            Operator::RegexNotIMatch,
1116            BooleanArray,
1117            DataType::Boolean,
1118            [false, false, false, false, true],
1119        );
1120        test_coercion!(
1121            LargeStringArray,
1122            DataType::LargeUtf8,
1123            vec!["abc"; 5],
1124            LargeStringArray,
1125            DataType::LargeUtf8,
1126            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1127            Operator::RegexMatch,
1128            BooleanArray,
1129            DataType::Boolean,
1130            [true, false, true, false, false],
1131        );
1132        test_coercion!(
1133            LargeStringArray,
1134            DataType::LargeUtf8,
1135            vec!["abc"; 5],
1136            LargeStringArray,
1137            DataType::LargeUtf8,
1138            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1139            Operator::RegexIMatch,
1140            BooleanArray,
1141            DataType::Boolean,
1142            [true, true, true, true, false],
1143        );
1144        test_coercion!(
1145            LargeStringArray,
1146            DataType::LargeUtf8,
1147            vec!["abc"; 5],
1148            LargeStringArray,
1149            DataType::LargeUtf8,
1150            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1151            Operator::RegexNotMatch,
1152            BooleanArray,
1153            DataType::Boolean,
1154            [false, true, false, true, true],
1155        );
1156        test_coercion!(
1157            LargeStringArray,
1158            DataType::LargeUtf8,
1159            vec!["abc"; 5],
1160            LargeStringArray,
1161            DataType::LargeUtf8,
1162            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1163            Operator::RegexNotIMatch,
1164            BooleanArray,
1165            DataType::Boolean,
1166            [false, false, false, false, true],
1167        );
1168        test_coercion!(
1169            StringArray,
1170            DataType::Utf8,
1171            vec!["abc"; 5],
1172            StringArray,
1173            DataType::Utf8,
1174            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1175            Operator::LikeMatch,
1176            BooleanArray,
1177            DataType::Boolean,
1178            [true, false, false, true, false],
1179        );
1180        test_coercion!(
1181            StringArray,
1182            DataType::Utf8,
1183            vec!["abc"; 5],
1184            StringArray,
1185            DataType::Utf8,
1186            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1187            Operator::ILikeMatch,
1188            BooleanArray,
1189            DataType::Boolean,
1190            [true, true, false, true, true],
1191        );
1192        test_coercion!(
1193            StringArray,
1194            DataType::Utf8,
1195            vec!["abc"; 5],
1196            StringArray,
1197            DataType::Utf8,
1198            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1199            Operator::NotLikeMatch,
1200            BooleanArray,
1201            DataType::Boolean,
1202            [false, true, true, false, true],
1203        );
1204        test_coercion!(
1205            StringArray,
1206            DataType::Utf8,
1207            vec!["abc"; 5],
1208            StringArray,
1209            DataType::Utf8,
1210            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1211            Operator::NotILikeMatch,
1212            BooleanArray,
1213            DataType::Boolean,
1214            [false, false, true, false, false],
1215        );
1216        test_coercion!(
1217            LargeStringArray,
1218            DataType::LargeUtf8,
1219            vec!["abc"; 5],
1220            LargeStringArray,
1221            DataType::LargeUtf8,
1222            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1223            Operator::LikeMatch,
1224            BooleanArray,
1225            DataType::Boolean,
1226            [true, false, false, true, false],
1227        );
1228        test_coercion!(
1229            LargeStringArray,
1230            DataType::LargeUtf8,
1231            vec!["abc"; 5],
1232            LargeStringArray,
1233            DataType::LargeUtf8,
1234            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1235            Operator::ILikeMatch,
1236            BooleanArray,
1237            DataType::Boolean,
1238            [true, true, false, true, true],
1239        );
1240        test_coercion!(
1241            LargeStringArray,
1242            DataType::LargeUtf8,
1243            vec!["abc"; 5],
1244            LargeStringArray,
1245            DataType::LargeUtf8,
1246            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1247            Operator::NotLikeMatch,
1248            BooleanArray,
1249            DataType::Boolean,
1250            [false, true, true, false, true],
1251        );
1252        test_coercion!(
1253            LargeStringArray,
1254            DataType::LargeUtf8,
1255            vec!["abc"; 5],
1256            LargeStringArray,
1257            DataType::LargeUtf8,
1258            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1259            Operator::NotILikeMatch,
1260            BooleanArray,
1261            DataType::Boolean,
1262            [false, false, true, false, false],
1263        );
1264        test_coercion!(
1265            Int16Array,
1266            DataType::Int16,
1267            vec![1i16, 2i16, 3i16],
1268            Int64Array,
1269            DataType::Int64,
1270            vec![10i64, 4i64, 5i64],
1271            Operator::BitwiseAnd,
1272            Int64Array,
1273            DataType::Int64,
1274            [0i64, 0i64, 1i64],
1275        );
1276        test_coercion!(
1277            UInt16Array,
1278            DataType::UInt16,
1279            vec![1u16, 2u16, 3u16],
1280            UInt64Array,
1281            DataType::UInt64,
1282            vec![10u64, 4u64, 5u64],
1283            Operator::BitwiseAnd,
1284            UInt64Array,
1285            DataType::UInt64,
1286            [0u64, 0u64, 1u64],
1287        );
1288        test_coercion!(
1289            Int16Array,
1290            DataType::Int16,
1291            vec![3i16, 2i16, 3i16],
1292            Int64Array,
1293            DataType::Int64,
1294            vec![10i64, 6i64, 5i64],
1295            Operator::BitwiseOr,
1296            Int64Array,
1297            DataType::Int64,
1298            [11i64, 6i64, 7i64],
1299        );
1300        test_coercion!(
1301            UInt16Array,
1302            DataType::UInt16,
1303            vec![1u16, 2u16, 3u16],
1304            UInt64Array,
1305            DataType::UInt64,
1306            vec![10u64, 4u64, 5u64],
1307            Operator::BitwiseOr,
1308            UInt64Array,
1309            DataType::UInt64,
1310            [11u64, 6u64, 7u64],
1311        );
1312        test_coercion!(
1313            Int16Array,
1314            DataType::Int16,
1315            vec![3i16, 2i16, 3i16],
1316            Int64Array,
1317            DataType::Int64,
1318            vec![10i64, 6i64, 5i64],
1319            Operator::BitwiseXor,
1320            Int64Array,
1321            DataType::Int64,
1322            [9i64, 4i64, 6i64],
1323        );
1324        test_coercion!(
1325            UInt16Array,
1326            DataType::UInt16,
1327            vec![3u16, 2u16, 3u16],
1328            UInt64Array,
1329            DataType::UInt64,
1330            vec![10u64, 6u64, 5u64],
1331            Operator::BitwiseXor,
1332            UInt64Array,
1333            DataType::UInt64,
1334            [9u64, 4u64, 6u64],
1335        );
1336        test_coercion!(
1337            Int16Array,
1338            DataType::Int16,
1339            vec![4i16, 27i16, 35i16],
1340            Int64Array,
1341            DataType::Int64,
1342            vec![2i64, 3i64, 4i64],
1343            Operator::BitwiseShiftRight,
1344            Int64Array,
1345            DataType::Int64,
1346            [1i64, 3i64, 2i64],
1347        );
1348        test_coercion!(
1349            UInt16Array,
1350            DataType::UInt16,
1351            vec![4u16, 27u16, 35u16],
1352            UInt64Array,
1353            DataType::UInt64,
1354            vec![2u64, 3u64, 4u64],
1355            Operator::BitwiseShiftRight,
1356            UInt64Array,
1357            DataType::UInt64,
1358            [1u64, 3u64, 2u64],
1359        );
1360        test_coercion!(
1361            Int16Array,
1362            DataType::Int16,
1363            vec![2i16, 3i16, 4i16],
1364            Int64Array,
1365            DataType::Int64,
1366            vec![4i64, 12i64, 7i64],
1367            Operator::BitwiseShiftLeft,
1368            Int64Array,
1369            DataType::Int64,
1370            [32i64, 12288i64, 512i64],
1371        );
1372        test_coercion!(
1373            UInt16Array,
1374            DataType::UInt16,
1375            vec![2u16, 3u16, 4u16],
1376            UInt64Array,
1377            DataType::UInt64,
1378            vec![4u64, 12u64, 7u64],
1379            Operator::BitwiseShiftLeft,
1380            UInt64Array,
1381            DataType::UInt64,
1382            [32u64, 12288u64, 512u64],
1383        );
1384        Ok(())
1385    }
1386
1387    // Note it would be nice to use the same test_coercion macro as
1388    // above, but sadly the type of the values of the dictionary are
1389    // not encoded in the rust type of the DictionaryArray. Thus there
1390    // is no way at the time of this writing to create a dictionary
1391    // array using the `From` trait
1392    #[test]
1393    fn test_dictionary_type_to_array_coercion() -> Result<()> {
1394        // Test string  a string dictionary
1395        let dict_type =
1396            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1397        let string_type = DataType::Utf8;
1398
1399        // build dictionary
1400        let mut dict_builder = StringDictionaryBuilder::<Int32Type>::new();
1401
1402        dict_builder.append("one")?;
1403        dict_builder.append_null();
1404        dict_builder.append("three")?;
1405        dict_builder.append("four")?;
1406        let dict_array = Arc::new(dict_builder.finish()) as ArrayRef;
1407
1408        let str_array = Arc::new(StringArray::from(vec![
1409            Some("not one"),
1410            Some("two"),
1411            None,
1412            Some("four"),
1413        ])) as ArrayRef;
1414
1415        let schema = Arc::new(Schema::new(vec![
1416            Field::new("a", dict_type.clone(), true),
1417            Field::new("b", string_type.clone(), true),
1418        ]));
1419
1420        // Test 1: a = b
1421        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1422        apply_logic_op(&schema, &dict_array, &str_array, Operator::Eq, result)?;
1423
1424        // Test 2: now test the other direction
1425        // b = a
1426        let schema = Arc::new(Schema::new(vec![
1427            Field::new("a", string_type, true),
1428            Field::new("b", dict_type, true),
1429        ]));
1430        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1431        apply_logic_op(&schema, &str_array, &dict_array, Operator::Eq, result)?;
1432
1433        Ok(())
1434    }
1435
1436    #[test]
1437    fn plus_op() -> Result<()> {
1438        let schema = Schema::new(vec![
1439            Field::new("a", DataType::Int32, false),
1440            Field::new("b", DataType::Int32, false),
1441        ]);
1442        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1443        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1444
1445        apply_arithmetic::<Int32Type>(
1446            Arc::new(schema),
1447            vec![Arc::new(a), Arc::new(b)],
1448            Operator::Plus,
1449            Int32Array::from(vec![2, 4, 7, 12, 21]),
1450        )?;
1451
1452        Ok(())
1453    }
1454
1455    #[test]
1456    fn plus_op_dict() -> Result<()> {
1457        let schema = Schema::new(vec![
1458            Field::new(
1459                "a",
1460                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1461                true,
1462            ),
1463            Field::new(
1464                "b",
1465                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1466                true,
1467            ),
1468        ]);
1469
1470        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1471        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1472        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1473
1474        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1475        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1476        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1477
1478        apply_arithmetic::<Int32Type>(
1479            Arc::new(schema),
1480            vec![Arc::new(a), Arc::new(b)],
1481            Operator::Plus,
1482            Int32Array::from(vec![Some(2), None, Some(4), Some(8), None]),
1483        )?;
1484
1485        Ok(())
1486    }
1487
1488    #[test]
1489    fn plus_op_dict_decimal() -> Result<()> {
1490        let schema = Schema::new(vec![
1491            Field::new(
1492                "a",
1493                DataType::Dictionary(
1494                    Box::new(DataType::Int8),
1495                    Box::new(DataType::Decimal128(10, 0)),
1496                ),
1497                true,
1498            ),
1499            Field::new(
1500                "b",
1501                DataType::Dictionary(
1502                    Box::new(DataType::Int8),
1503                    Box::new(DataType::Decimal128(10, 0)),
1504                ),
1505                true,
1506            ),
1507        ]);
1508
1509        let value = 123;
1510        let decimal_array = Arc::new(create_decimal_array(
1511            &[
1512                Some(value),
1513                Some(value + 2),
1514                Some(value - 1),
1515                Some(value + 1),
1516            ],
1517            10,
1518            0,
1519        ));
1520
1521        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1522        let a = DictionaryArray::try_new(keys, decimal_array)?;
1523
1524        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1525        let decimal_array = Arc::new(create_decimal_array(
1526            &[
1527                Some(value + 1),
1528                Some(value + 3),
1529                Some(value),
1530                Some(value + 2),
1531            ],
1532            10,
1533            0,
1534        ));
1535        let b = DictionaryArray::try_new(keys, decimal_array)?;
1536
1537        apply_arithmetic(
1538            Arc::new(schema),
1539            vec![Arc::new(a), Arc::new(b)],
1540            Operator::Plus,
1541            create_decimal_array(&[Some(247), None, None, Some(247), Some(246)], 11, 0),
1542        )?;
1543
1544        Ok(())
1545    }
1546
1547    #[test]
1548    fn plus_op_scalar() -> Result<()> {
1549        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1550        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1551
1552        apply_arithmetic_scalar(
1553            Arc::new(schema),
1554            vec![Arc::new(a)],
1555            Operator::Plus,
1556            ScalarValue::Int32(Some(1)),
1557            Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
1558        )?;
1559
1560        Ok(())
1561    }
1562
1563    #[test]
1564    fn plus_op_dict_scalar() -> Result<()> {
1565        let schema = Schema::new(vec![Field::new(
1566            "a",
1567            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1568            true,
1569        )]);
1570
1571        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
1572
1573        dict_builder.append(1)?;
1574        dict_builder.append_null();
1575        dict_builder.append(2)?;
1576        dict_builder.append(5)?;
1577
1578        let a = dict_builder.finish();
1579
1580        let expected: PrimitiveArray<Int32Type> =
1581            PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]);
1582
1583        apply_arithmetic_scalar(
1584            Arc::new(schema),
1585            vec![Arc::new(a)],
1586            Operator::Plus,
1587            ScalarValue::Dictionary(
1588                Box::new(DataType::Int8),
1589                Box::new(ScalarValue::Int32(Some(1))),
1590            ),
1591            Arc::new(expected),
1592        )?;
1593
1594        Ok(())
1595    }
1596
1597    #[test]
1598    fn plus_op_dict_scalar_decimal() -> Result<()> {
1599        let schema = Schema::new(vec![Field::new(
1600            "a",
1601            DataType::Dictionary(
1602                Box::new(DataType::Int8),
1603                Box::new(DataType::Decimal128(10, 0)),
1604            ),
1605            true,
1606        )]);
1607
1608        let value = 123;
1609        let decimal_array = Arc::new(create_decimal_array(
1610            &[Some(value), None, Some(value - 1), Some(value + 1)],
1611            10,
1612            0,
1613        ));
1614
1615        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
1616        let a = DictionaryArray::try_new(keys, decimal_array)?;
1617
1618        let decimal_array = Arc::new(create_decimal_array(
1619            &[
1620                Some(value + 1),
1621                Some(value),
1622                None,
1623                Some(value + 2),
1624                Some(value + 1),
1625            ],
1626            11,
1627            0,
1628        ));
1629
1630        apply_arithmetic_scalar(
1631            Arc::new(schema),
1632            vec![Arc::new(a)],
1633            Operator::Plus,
1634            ScalarValue::Dictionary(
1635                Box::new(DataType::Int8),
1636                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
1637            ),
1638            decimal_array,
1639        )?;
1640
1641        Ok(())
1642    }
1643
1644    #[test]
1645    fn minus_op() -> Result<()> {
1646        let schema = Arc::new(Schema::new(vec![
1647            Field::new("a", DataType::Int32, false),
1648            Field::new("b", DataType::Int32, false),
1649        ]));
1650        let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
1651        let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1652
1653        apply_arithmetic::<Int32Type>(
1654            Arc::clone(&schema),
1655            vec![
1656                Arc::clone(&a) as Arc<dyn Array>,
1657                Arc::clone(&b) as Arc<dyn Array>,
1658            ],
1659            Operator::Minus,
1660            Int32Array::from(vec![0, 0, 1, 4, 11]),
1661        )?;
1662
1663        // should handle have negative values in result (for signed)
1664        apply_arithmetic::<Int32Type>(
1665            schema,
1666            vec![b, a],
1667            Operator::Minus,
1668            Int32Array::from(vec![0, 0, -1, -4, -11]),
1669        )?;
1670
1671        Ok(())
1672    }
1673
1674    #[test]
1675    fn minus_op_dict() -> Result<()> {
1676        let schema = Schema::new(vec![
1677            Field::new(
1678                "a",
1679                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1680                true,
1681            ),
1682            Field::new(
1683                "b",
1684                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1685                true,
1686            ),
1687        ]);
1688
1689        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1690        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1691        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1692
1693        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1694        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1695        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1696
1697        apply_arithmetic::<Int32Type>(
1698            Arc::new(schema),
1699            vec![Arc::new(a), Arc::new(b)],
1700            Operator::Minus,
1701            Int32Array::from(vec![Some(0), None, Some(0), Some(0), None]),
1702        )?;
1703
1704        Ok(())
1705    }
1706
1707    #[test]
1708    fn minus_op_dict_decimal() -> Result<()> {
1709        let schema = Schema::new(vec![
1710            Field::new(
1711                "a",
1712                DataType::Dictionary(
1713                    Box::new(DataType::Int8),
1714                    Box::new(DataType::Decimal128(10, 0)),
1715                ),
1716                true,
1717            ),
1718            Field::new(
1719                "b",
1720                DataType::Dictionary(
1721                    Box::new(DataType::Int8),
1722                    Box::new(DataType::Decimal128(10, 0)),
1723                ),
1724                true,
1725            ),
1726        ]);
1727
1728        let value = 123;
1729        let decimal_array = Arc::new(create_decimal_array(
1730            &[
1731                Some(value),
1732                Some(value + 2),
1733                Some(value - 1),
1734                Some(value + 1),
1735            ],
1736            10,
1737            0,
1738        ));
1739
1740        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1741        let a = DictionaryArray::try_new(keys, decimal_array)?;
1742
1743        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1744        let decimal_array = Arc::new(create_decimal_array(
1745            &[
1746                Some(value + 1),
1747                Some(value + 3),
1748                Some(value),
1749                Some(value + 2),
1750            ],
1751            10,
1752            0,
1753        ));
1754        let b = DictionaryArray::try_new(keys, decimal_array)?;
1755
1756        apply_arithmetic(
1757            Arc::new(schema),
1758            vec![Arc::new(a), Arc::new(b)],
1759            Operator::Minus,
1760            create_decimal_array(&[Some(-1), None, None, Some(1), Some(0)], 11, 0),
1761        )?;
1762
1763        Ok(())
1764    }
1765
1766    #[test]
1767    fn minus_op_scalar() -> Result<()> {
1768        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1769        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1770
1771        apply_arithmetic_scalar(
1772            Arc::new(schema),
1773            vec![Arc::new(a)],
1774            Operator::Minus,
1775            ScalarValue::Int32(Some(1)),
1776            Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
1777        )?;
1778
1779        Ok(())
1780    }
1781
1782    #[test]
1783    fn minus_op_dict_scalar() -> Result<()> {
1784        let schema = Schema::new(vec![Field::new(
1785            "a",
1786            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1787            true,
1788        )]);
1789
1790        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
1791
1792        dict_builder.append(1)?;
1793        dict_builder.append_null();
1794        dict_builder.append(2)?;
1795        dict_builder.append(5)?;
1796
1797        let a = dict_builder.finish();
1798
1799        let expected: PrimitiveArray<Int32Type> =
1800            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]);
1801
1802        apply_arithmetic_scalar(
1803            Arc::new(schema),
1804            vec![Arc::new(a)],
1805            Operator::Minus,
1806            ScalarValue::Dictionary(
1807                Box::new(DataType::Int8),
1808                Box::new(ScalarValue::Int32(Some(1))),
1809            ),
1810            Arc::new(expected),
1811        )?;
1812
1813        Ok(())
1814    }
1815
1816    #[test]
1817    fn minus_op_dict_scalar_decimal() -> Result<()> {
1818        let schema = Schema::new(vec![Field::new(
1819            "a",
1820            DataType::Dictionary(
1821                Box::new(DataType::Int8),
1822                Box::new(DataType::Decimal128(10, 0)),
1823            ),
1824            true,
1825        )]);
1826
1827        let value = 123;
1828        let decimal_array = Arc::new(create_decimal_array(
1829            &[Some(value), None, Some(value - 1), Some(value + 1)],
1830            10,
1831            0,
1832        ));
1833
1834        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
1835        let a = DictionaryArray::try_new(keys, decimal_array)?;
1836
1837        let decimal_array = Arc::new(create_decimal_array(
1838            &[
1839                Some(value - 1),
1840                Some(value - 2),
1841                None,
1842                Some(value),
1843                Some(value - 1),
1844            ],
1845            11,
1846            0,
1847        ));
1848
1849        apply_arithmetic_scalar(
1850            Arc::new(schema),
1851            vec![Arc::new(a)],
1852            Operator::Minus,
1853            ScalarValue::Dictionary(
1854                Box::new(DataType::Int8),
1855                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
1856            ),
1857            decimal_array,
1858        )?;
1859
1860        Ok(())
1861    }
1862
1863    #[test]
1864    fn multiply_op() -> Result<()> {
1865        let schema = Arc::new(Schema::new(vec![
1866            Field::new("a", DataType::Int32, false),
1867            Field::new("b", DataType::Int32, false),
1868        ]));
1869        let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
1870        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
1871
1872        apply_arithmetic::<Int32Type>(
1873            schema,
1874            vec![a, b],
1875            Operator::Multiply,
1876            Int32Array::from(vec![8, 32, 128, 512, 2048]),
1877        )?;
1878
1879        Ok(())
1880    }
1881
1882    #[test]
1883    fn multiply_op_dict() -> Result<()> {
1884        let schema = Schema::new(vec![
1885            Field::new(
1886                "a",
1887                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1888                true,
1889            ),
1890            Field::new(
1891                "b",
1892                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1893                true,
1894            ),
1895        ]);
1896
1897        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1898        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1899        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1900
1901        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1902        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1903        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1904
1905        apply_arithmetic::<Int32Type>(
1906            Arc::new(schema),
1907            vec![Arc::new(a), Arc::new(b)],
1908            Operator::Multiply,
1909            Int32Array::from(vec![Some(1), None, Some(4), Some(16), None]),
1910        )?;
1911
1912        Ok(())
1913    }
1914
1915    #[test]
1916    fn multiply_op_dict_decimal() -> Result<()> {
1917        let schema = Schema::new(vec![
1918            Field::new(
1919                "a",
1920                DataType::Dictionary(
1921                    Box::new(DataType::Int8),
1922                    Box::new(DataType::Decimal128(10, 0)),
1923                ),
1924                true,
1925            ),
1926            Field::new(
1927                "b",
1928                DataType::Dictionary(
1929                    Box::new(DataType::Int8),
1930                    Box::new(DataType::Decimal128(10, 0)),
1931                ),
1932                true,
1933            ),
1934        ]);
1935
1936        let value = 123;
1937        let decimal_array = Arc::new(create_decimal_array(
1938            &[
1939                Some(value),
1940                Some(value + 2),
1941                Some(value - 1),
1942                Some(value + 1),
1943            ],
1944            10,
1945            0,
1946        )) as ArrayRef;
1947
1948        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1949        let a = DictionaryArray::try_new(keys, decimal_array)?;
1950
1951        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1952        let decimal_array = Arc::new(create_decimal_array(
1953            &[
1954                Some(value + 1),
1955                Some(value + 3),
1956                Some(value),
1957                Some(value + 2),
1958            ],
1959            10,
1960            0,
1961        ));
1962        let b = DictionaryArray::try_new(keys, decimal_array)?;
1963
1964        apply_arithmetic(
1965            Arc::new(schema),
1966            vec![Arc::new(a), Arc::new(b)],
1967            Operator::Multiply,
1968            create_decimal_array(
1969                &[Some(15252), None, None, Some(15252), Some(15129)],
1970                21,
1971                0,
1972            ),
1973        )?;
1974
1975        Ok(())
1976    }
1977
1978    #[test]
1979    fn multiply_op_scalar() -> Result<()> {
1980        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1981        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1982
1983        apply_arithmetic_scalar(
1984            Arc::new(schema),
1985            vec![Arc::new(a)],
1986            Operator::Multiply,
1987            ScalarValue::Int32(Some(2)),
1988            Arc::new(Int32Array::from(vec![2, 4, 6, 8, 10])),
1989        )?;
1990
1991        Ok(())
1992    }
1993
1994    #[test]
1995    fn multiply_op_dict_scalar() -> Result<()> {
1996        let schema = Schema::new(vec![Field::new(
1997            "a",
1998            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1999            true,
2000        )]);
2001
2002        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2003
2004        dict_builder.append(1)?;
2005        dict_builder.append_null();
2006        dict_builder.append(2)?;
2007        dict_builder.append(5)?;
2008
2009        let a = dict_builder.finish();
2010
2011        let expected: PrimitiveArray<Int32Type> =
2012            PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]);
2013
2014        apply_arithmetic_scalar(
2015            Arc::new(schema),
2016            vec![Arc::new(a)],
2017            Operator::Multiply,
2018            ScalarValue::Dictionary(
2019                Box::new(DataType::Int8),
2020                Box::new(ScalarValue::Int32(Some(2))),
2021            ),
2022            Arc::new(expected),
2023        )?;
2024
2025        Ok(())
2026    }
2027
2028    #[test]
2029    fn multiply_op_dict_scalar_decimal() -> Result<()> {
2030        let schema = Schema::new(vec![Field::new(
2031            "a",
2032            DataType::Dictionary(
2033                Box::new(DataType::Int8),
2034                Box::new(DataType::Decimal128(10, 0)),
2035            ),
2036            true,
2037        )]);
2038
2039        let value = 123;
2040        let decimal_array = Arc::new(create_decimal_array(
2041            &[Some(value), None, Some(value - 1), Some(value + 1)],
2042            10,
2043            0,
2044        ));
2045
2046        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2047        let a = DictionaryArray::try_new(keys, decimal_array)?;
2048
2049        let decimal_array = Arc::new(create_decimal_array(
2050            &[Some(246), Some(244), None, Some(248), Some(246)],
2051            21,
2052            0,
2053        ));
2054
2055        apply_arithmetic_scalar(
2056            Arc::new(schema),
2057            vec![Arc::new(a)],
2058            Operator::Multiply,
2059            ScalarValue::Dictionary(
2060                Box::new(DataType::Int8),
2061                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2062            ),
2063            decimal_array,
2064        )?;
2065
2066        Ok(())
2067    }
2068
2069    #[test]
2070    fn divide_op() -> Result<()> {
2071        let schema = Arc::new(Schema::new(vec![
2072            Field::new("a", DataType::Int32, false),
2073            Field::new("b", DataType::Int32, false),
2074        ]));
2075        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2076        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
2077
2078        apply_arithmetic::<Int32Type>(
2079            schema,
2080            vec![a, b],
2081            Operator::Divide,
2082            Int32Array::from(vec![4, 8, 16, 32, 64]),
2083        )?;
2084
2085        Ok(())
2086    }
2087
2088    #[test]
2089    fn divide_op_dict() -> Result<()> {
2090        let schema = Schema::new(vec![
2091            Field::new(
2092                "a",
2093                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2094                true,
2095            ),
2096            Field::new(
2097                "b",
2098                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2099                true,
2100            ),
2101        ]);
2102
2103        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2104
2105        dict_builder.append(1)?;
2106        dict_builder.append_null();
2107        dict_builder.append(2)?;
2108        dict_builder.append(5)?;
2109        dict_builder.append(0)?;
2110
2111        let a = dict_builder.finish();
2112
2113        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2114        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2115        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2116
2117        apply_arithmetic::<Int32Type>(
2118            Arc::new(schema),
2119            vec![Arc::new(a), Arc::new(b)],
2120            Operator::Divide,
2121            Int32Array::from(vec![Some(1), None, Some(1), Some(1), Some(0)]),
2122        )?;
2123
2124        Ok(())
2125    }
2126
2127    #[test]
2128    fn divide_op_dict_decimal() -> Result<()> {
2129        let schema = Schema::new(vec![
2130            Field::new(
2131                "a",
2132                DataType::Dictionary(
2133                    Box::new(DataType::Int8),
2134                    Box::new(DataType::Decimal128(10, 0)),
2135                ),
2136                true,
2137            ),
2138            Field::new(
2139                "b",
2140                DataType::Dictionary(
2141                    Box::new(DataType::Int8),
2142                    Box::new(DataType::Decimal128(10, 0)),
2143                ),
2144                true,
2145            ),
2146        ]);
2147
2148        let value = 123;
2149        let decimal_array = Arc::new(create_decimal_array(
2150            &[
2151                Some(value),
2152                Some(value + 2),
2153                Some(value - 1),
2154                Some(value + 1),
2155            ],
2156            10,
2157            0,
2158        ));
2159
2160        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2161        let a = DictionaryArray::try_new(keys, decimal_array)?;
2162
2163        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2164        let decimal_array = Arc::new(create_decimal_array(
2165            &[
2166                Some(value + 1),
2167                Some(value + 3),
2168                Some(value),
2169                Some(value + 2),
2170            ],
2171            10,
2172            0,
2173        ));
2174        let b = DictionaryArray::try_new(keys, decimal_array)?;
2175
2176        apply_arithmetic(
2177            Arc::new(schema),
2178            vec![Arc::new(a), Arc::new(b)],
2179            Operator::Divide,
2180            create_decimal_array(
2181                &[
2182                    Some(9919), // 0.9919
2183                    None,
2184                    None,
2185                    Some(10081), // 1.0081
2186                    Some(10000), // 1.0
2187                ],
2188                14,
2189                4,
2190            ),
2191        )?;
2192
2193        Ok(())
2194    }
2195
2196    #[test]
2197    fn divide_op_scalar() -> Result<()> {
2198        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2199        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2200
2201        apply_arithmetic_scalar(
2202            Arc::new(schema),
2203            vec![Arc::new(a)],
2204            Operator::Divide,
2205            ScalarValue::Int32(Some(2)),
2206            Arc::new(Int32Array::from(vec![0, 1, 1, 2, 2])),
2207        )?;
2208
2209        Ok(())
2210    }
2211
2212    #[test]
2213    fn divide_op_dict_scalar() -> Result<()> {
2214        let schema = Schema::new(vec![Field::new(
2215            "a",
2216            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2217            true,
2218        )]);
2219
2220        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2221
2222        dict_builder.append(1)?;
2223        dict_builder.append_null();
2224        dict_builder.append(2)?;
2225        dict_builder.append(5)?;
2226
2227        let a = dict_builder.finish();
2228
2229        let expected: PrimitiveArray<Int32Type> =
2230            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]);
2231
2232        apply_arithmetic_scalar(
2233            Arc::new(schema),
2234            vec![Arc::new(a)],
2235            Operator::Divide,
2236            ScalarValue::Dictionary(
2237                Box::new(DataType::Int8),
2238                Box::new(ScalarValue::Int32(Some(2))),
2239            ),
2240            Arc::new(expected),
2241        )?;
2242
2243        Ok(())
2244    }
2245
2246    #[test]
2247    fn divide_op_dict_scalar_decimal() -> Result<()> {
2248        let schema = Schema::new(vec![Field::new(
2249            "a",
2250            DataType::Dictionary(
2251                Box::new(DataType::Int8),
2252                Box::new(DataType::Decimal128(10, 0)),
2253            ),
2254            true,
2255        )]);
2256
2257        let value = 123;
2258        let decimal_array = Arc::new(create_decimal_array(
2259            &[Some(value), None, Some(value - 1), Some(value + 1)],
2260            10,
2261            0,
2262        ));
2263
2264        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2265        let a = DictionaryArray::try_new(keys, decimal_array)?;
2266
2267        let decimal_array = Arc::new(create_decimal_array(
2268            &[Some(615000), Some(610000), None, Some(620000), Some(615000)],
2269            14,
2270            4,
2271        ));
2272
2273        apply_arithmetic_scalar(
2274            Arc::new(schema),
2275            vec![Arc::new(a)],
2276            Operator::Divide,
2277            ScalarValue::Dictionary(
2278                Box::new(DataType::Int8),
2279                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2280            ),
2281            decimal_array,
2282        )?;
2283
2284        Ok(())
2285    }
2286
2287    #[test]
2288    fn modulus_op() -> Result<()> {
2289        let schema = Arc::new(Schema::new(vec![
2290            Field::new("a", DataType::Int32, false),
2291            Field::new("b", DataType::Int32, false),
2292        ]));
2293        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2294        let b = Arc::new(Int32Array::from(vec![2, 4, 7, 14, 32]));
2295
2296        apply_arithmetic::<Int32Type>(
2297            schema,
2298            vec![a, b],
2299            Operator::Modulo,
2300            Int32Array::from(vec![0, 0, 2, 8, 0]),
2301        )?;
2302
2303        Ok(())
2304    }
2305
2306    #[test]
2307    fn modulus_op_dict() -> Result<()> {
2308        let schema = Schema::new(vec![
2309            Field::new(
2310                "a",
2311                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2312                true,
2313            ),
2314            Field::new(
2315                "b",
2316                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2317                true,
2318            ),
2319        ]);
2320
2321        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2322
2323        dict_builder.append(1)?;
2324        dict_builder.append_null();
2325        dict_builder.append(2)?;
2326        dict_builder.append(5)?;
2327        dict_builder.append(0)?;
2328
2329        let a = dict_builder.finish();
2330
2331        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2332        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2333        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2334
2335        apply_arithmetic::<Int32Type>(
2336            Arc::new(schema),
2337            vec![Arc::new(a), Arc::new(b)],
2338            Operator::Modulo,
2339            Int32Array::from(vec![Some(0), None, Some(0), Some(1), Some(0)]),
2340        )?;
2341
2342        Ok(())
2343    }
2344
2345    #[test]
2346    fn modulus_op_dict_decimal() -> Result<()> {
2347        let schema = Schema::new(vec![
2348            Field::new(
2349                "a",
2350                DataType::Dictionary(
2351                    Box::new(DataType::Int8),
2352                    Box::new(DataType::Decimal128(10, 0)),
2353                ),
2354                true,
2355            ),
2356            Field::new(
2357                "b",
2358                DataType::Dictionary(
2359                    Box::new(DataType::Int8),
2360                    Box::new(DataType::Decimal128(10, 0)),
2361                ),
2362                true,
2363            ),
2364        ]);
2365
2366        let value = 123;
2367        let decimal_array = Arc::new(create_decimal_array(
2368            &[
2369                Some(value),
2370                Some(value + 2),
2371                Some(value - 1),
2372                Some(value + 1),
2373            ],
2374            10,
2375            0,
2376        ));
2377
2378        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2379        let a = DictionaryArray::try_new(keys, decimal_array)?;
2380
2381        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2382        let decimal_array = Arc::new(create_decimal_array(
2383            &[
2384                Some(value + 1),
2385                Some(value + 3),
2386                Some(value),
2387                Some(value + 2),
2388            ],
2389            10,
2390            0,
2391        ));
2392        let b = DictionaryArray::try_new(keys, decimal_array)?;
2393
2394        apply_arithmetic(
2395            Arc::new(schema),
2396            vec![Arc::new(a), Arc::new(b)],
2397            Operator::Modulo,
2398            create_decimal_array(&[Some(123), None, None, Some(1), Some(0)], 10, 0),
2399        )?;
2400
2401        Ok(())
2402    }
2403
2404    #[test]
2405    fn modulus_op_scalar() -> Result<()> {
2406        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2407        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2408
2409        apply_arithmetic_scalar(
2410            Arc::new(schema),
2411            vec![Arc::new(a)],
2412            Operator::Modulo,
2413            ScalarValue::Int32(Some(2)),
2414            Arc::new(Int32Array::from(vec![1, 0, 1, 0, 1])),
2415        )?;
2416
2417        Ok(())
2418    }
2419
2420    #[test]
2421    fn modules_op_dict_scalar() -> Result<()> {
2422        let schema = Schema::new(vec![Field::new(
2423            "a",
2424            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2425            true,
2426        )]);
2427
2428        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2429
2430        dict_builder.append(1)?;
2431        dict_builder.append_null();
2432        dict_builder.append(2)?;
2433        dict_builder.append(5)?;
2434
2435        let a = dict_builder.finish();
2436
2437        let expected: PrimitiveArray<Int32Type> =
2438            PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]);
2439
2440        apply_arithmetic_scalar(
2441            Arc::new(schema),
2442            vec![Arc::new(a)],
2443            Operator::Modulo,
2444            ScalarValue::Dictionary(
2445                Box::new(DataType::Int8),
2446                Box::new(ScalarValue::Int32(Some(2))),
2447            ),
2448            Arc::new(expected),
2449        )?;
2450
2451        Ok(())
2452    }
2453
2454    #[test]
2455    fn modulus_op_dict_scalar_decimal() -> Result<()> {
2456        let schema = Schema::new(vec![Field::new(
2457            "a",
2458            DataType::Dictionary(
2459                Box::new(DataType::Int8),
2460                Box::new(DataType::Decimal128(10, 0)),
2461            ),
2462            true,
2463        )]);
2464
2465        let value = 123;
2466        let decimal_array = Arc::new(create_decimal_array(
2467            &[Some(value), None, Some(value - 1), Some(value + 1)],
2468            10,
2469            0,
2470        ));
2471
2472        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2473        let a = DictionaryArray::try_new(keys, decimal_array)?;
2474
2475        let decimal_array = Arc::new(create_decimal_array(
2476            &[Some(1), Some(0), None, Some(0), Some(1)],
2477            10,
2478            0,
2479        ));
2480
2481        apply_arithmetic_scalar(
2482            Arc::new(schema),
2483            vec![Arc::new(a)],
2484            Operator::Modulo,
2485            ScalarValue::Dictionary(
2486                Box::new(DataType::Int8),
2487                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2488            ),
2489            decimal_array,
2490        )?;
2491
2492        Ok(())
2493    }
2494
2495    fn apply_arithmetic<T: ArrowNumericType>(
2496        schema: SchemaRef,
2497        data: Vec<ArrayRef>,
2498        op: Operator,
2499        expected: PrimitiveArray<T>,
2500    ) -> Result<()> {
2501        let arithmetic_op =
2502            binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?;
2503        let batch = RecordBatch::try_new(schema, data)?;
2504        let result = arithmetic_op
2505            .evaluate(&batch)?
2506            .into_array(batch.num_rows())
2507            .expect("Failed to convert to array");
2508
2509        assert_eq!(result.as_ref(), &expected);
2510        Ok(())
2511    }
2512
2513    fn apply_arithmetic_scalar(
2514        schema: SchemaRef,
2515        data: Vec<ArrayRef>,
2516        op: Operator,
2517        literal: ScalarValue,
2518        expected: ArrayRef,
2519    ) -> Result<()> {
2520        let lit = Arc::new(Literal::new(literal));
2521        let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?;
2522        let batch = RecordBatch::try_new(schema, data)?;
2523        let result = arithmetic_op
2524            .evaluate(&batch)?
2525            .into_array(batch.num_rows())
2526            .expect("Failed to convert to array");
2527
2528        assert_eq!(&result, &expected);
2529        Ok(())
2530    }
2531
2532    fn apply_logic_op(
2533        schema: &SchemaRef,
2534        left: &ArrayRef,
2535        right: &ArrayRef,
2536        op: Operator,
2537        expected: BooleanArray,
2538    ) -> Result<()> {
2539        let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
2540        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
2541        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
2542        let result = op
2543            .evaluate(&batch)?
2544            .into_array(batch.num_rows())
2545            .expect("Failed to convert to array");
2546
2547        assert_eq!(result.as_ref(), &expected);
2548        Ok(())
2549    }
2550
2551    // Test `scalar <op> arr` produces expected
2552    fn apply_logic_op_scalar_arr(
2553        schema: &SchemaRef,
2554        scalar: &ScalarValue,
2555        arr: &ArrayRef,
2556        op: Operator,
2557        expected: &BooleanArray,
2558    ) -> Result<()> {
2559        let scalar = lit(scalar.clone());
2560        let op = binary_op(scalar, op, col("a", schema)?, schema)?;
2561        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2562        let result = op
2563            .evaluate(&batch)?
2564            .into_array(batch.num_rows())
2565            .expect("Failed to convert to array");
2566        assert_eq!(result.as_ref(), expected);
2567
2568        Ok(())
2569    }
2570
2571    // Test `arr <op> scalar` produces expected
2572    fn apply_logic_op_arr_scalar(
2573        schema: &SchemaRef,
2574        arr: &ArrayRef,
2575        scalar: &ScalarValue,
2576        op: Operator,
2577        expected: &BooleanArray,
2578    ) -> Result<()> {
2579        let scalar = lit(scalar.clone());
2580        let op = binary_op(col("a", schema)?, op, scalar, schema)?;
2581        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2582        let result = op
2583            .evaluate(&batch)?
2584            .into_array(batch.num_rows())
2585            .expect("Failed to convert to array");
2586        assert_eq!(result.as_ref(), expected);
2587
2588        Ok(())
2589    }
2590
2591    #[test]
2592    fn and_with_nulls_op() -> Result<()> {
2593        let schema = Schema::new(vec![
2594            Field::new("a", DataType::Boolean, true),
2595            Field::new("b", DataType::Boolean, true),
2596        ]);
2597        let a = Arc::new(BooleanArray::from(vec![
2598            Some(true),
2599            Some(false),
2600            None,
2601            Some(true),
2602            Some(false),
2603            None,
2604            Some(true),
2605            Some(false),
2606            None,
2607        ])) as ArrayRef;
2608        let b = Arc::new(BooleanArray::from(vec![
2609            Some(true),
2610            Some(true),
2611            Some(true),
2612            Some(false),
2613            Some(false),
2614            Some(false),
2615            None,
2616            None,
2617            None,
2618        ])) as ArrayRef;
2619
2620        let expected = BooleanArray::from(vec![
2621            Some(true),
2622            Some(false),
2623            None,
2624            Some(false),
2625            Some(false),
2626            Some(false),
2627            None,
2628            Some(false),
2629            None,
2630        ]);
2631        apply_logic_op(&Arc::new(schema), &a, &b, Operator::And, expected)?;
2632
2633        Ok(())
2634    }
2635
2636    #[test]
2637    fn regex_with_nulls() -> Result<()> {
2638        let schema = Schema::new(vec![
2639            Field::new("a", DataType::Utf8, true),
2640            Field::new("b", DataType::Utf8, true),
2641        ]);
2642        let a = Arc::new(StringArray::from(vec![
2643            Some("abc"),
2644            None,
2645            Some("abc"),
2646            None,
2647            Some("abc"),
2648        ])) as ArrayRef;
2649        let b = Arc::new(StringArray::from(vec![
2650            Some("^a"),
2651            Some("^A"),
2652            None,
2653            None,
2654            Some("^(b|c)"),
2655        ])) as ArrayRef;
2656
2657        let regex_expected =
2658            BooleanArray::from(vec![Some(true), None, None, None, Some(false)]);
2659        let regex_not_expected =
2660            BooleanArray::from(vec![Some(false), None, None, None, Some(true)]);
2661        apply_logic_op(
2662            &Arc::new(schema.clone()),
2663            &a,
2664            &b,
2665            Operator::RegexMatch,
2666            regex_expected.clone(),
2667        )?;
2668        apply_logic_op(
2669            &Arc::new(schema.clone()),
2670            &a,
2671            &b,
2672            Operator::RegexIMatch,
2673            regex_expected.clone(),
2674        )?;
2675        apply_logic_op(
2676            &Arc::new(schema.clone()),
2677            &a,
2678            &b,
2679            Operator::RegexNotMatch,
2680            regex_not_expected.clone(),
2681        )?;
2682        apply_logic_op(
2683            &Arc::new(schema),
2684            &a,
2685            &b,
2686            Operator::RegexNotIMatch,
2687            regex_not_expected.clone(),
2688        )?;
2689
2690        let schema = Schema::new(vec![
2691            Field::new("a", DataType::LargeUtf8, true),
2692            Field::new("b", DataType::LargeUtf8, true),
2693        ]);
2694        let a = Arc::new(LargeStringArray::from(vec![
2695            Some("abc"),
2696            None,
2697            Some("abc"),
2698            None,
2699            Some("abc"),
2700        ])) as ArrayRef;
2701        let b = Arc::new(LargeStringArray::from(vec![
2702            Some("^a"),
2703            Some("^A"),
2704            None,
2705            None,
2706            Some("^(b|c)"),
2707        ])) as ArrayRef;
2708
2709        apply_logic_op(
2710            &Arc::new(schema.clone()),
2711            &a,
2712            &b,
2713            Operator::RegexMatch,
2714            regex_expected.clone(),
2715        )?;
2716        apply_logic_op(
2717            &Arc::new(schema.clone()),
2718            &a,
2719            &b,
2720            Operator::RegexIMatch,
2721            regex_expected,
2722        )?;
2723        apply_logic_op(
2724            &Arc::new(schema.clone()),
2725            &a,
2726            &b,
2727            Operator::RegexNotMatch,
2728            regex_not_expected.clone(),
2729        )?;
2730        apply_logic_op(
2731            &Arc::new(schema),
2732            &a,
2733            &b,
2734            Operator::RegexNotIMatch,
2735            regex_not_expected,
2736        )?;
2737
2738        Ok(())
2739    }
2740
2741    #[test]
2742    fn or_with_nulls_op() -> Result<()> {
2743        let schema = Schema::new(vec![
2744            Field::new("a", DataType::Boolean, true),
2745            Field::new("b", DataType::Boolean, true),
2746        ]);
2747        let a = Arc::new(BooleanArray::from(vec![
2748            Some(true),
2749            Some(false),
2750            None,
2751            Some(true),
2752            Some(false),
2753            None,
2754            Some(true),
2755            Some(false),
2756            None,
2757        ])) as ArrayRef;
2758        let b = Arc::new(BooleanArray::from(vec![
2759            Some(true),
2760            Some(true),
2761            Some(true),
2762            Some(false),
2763            Some(false),
2764            Some(false),
2765            None,
2766            None,
2767            None,
2768        ])) as ArrayRef;
2769
2770        let expected = BooleanArray::from(vec![
2771            Some(true),
2772            Some(true),
2773            Some(true),
2774            Some(true),
2775            Some(false),
2776            None,
2777            Some(true),
2778            None,
2779            None,
2780        ]);
2781        apply_logic_op(&Arc::new(schema), &a, &b, Operator::Or, expected)?;
2782
2783        Ok(())
2784    }
2785
2786    /// Returns (schema, a: BooleanArray, b: BooleanArray) with all possible inputs
2787    ///
2788    /// a: [true, true, true,  NULL, NULL, NULL,  false, false, false]
2789    /// b: [true, NULL, false, true, NULL, false, true,  NULL,  false]
2790    fn bool_test_arrays() -> (SchemaRef, ArrayRef, ArrayRef) {
2791        let schema = Schema::new(vec![
2792            Field::new("a", DataType::Boolean, true),
2793            Field::new("b", DataType::Boolean, true),
2794        ]);
2795        let a: BooleanArray = [
2796            Some(true),
2797            Some(true),
2798            Some(true),
2799            None,
2800            None,
2801            None,
2802            Some(false),
2803            Some(false),
2804            Some(false),
2805        ]
2806        .iter()
2807        .collect();
2808        let b: BooleanArray = [
2809            Some(true),
2810            None,
2811            Some(false),
2812            Some(true),
2813            None,
2814            Some(false),
2815            Some(true),
2816            None,
2817            Some(false),
2818        ]
2819        .iter()
2820        .collect();
2821        (Arc::new(schema), Arc::new(a), Arc::new(b))
2822    }
2823
2824    /// Returns (schema, BooleanArray) with [true, NULL, false]
2825    fn scalar_bool_test_array() -> (SchemaRef, ArrayRef) {
2826        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
2827        let a: BooleanArray = [Some(true), None, Some(false)].iter().collect();
2828        (Arc::new(schema), Arc::new(a))
2829    }
2830
2831    #[test]
2832    fn eq_op_bool() {
2833        let (schema, a, b) = bool_test_arrays();
2834        let expected = [
2835            Some(true),
2836            None,
2837            Some(false),
2838            None,
2839            None,
2840            None,
2841            Some(false),
2842            None,
2843            Some(true),
2844        ]
2845        .iter()
2846        .collect();
2847        apply_logic_op(&schema, &a, &b, Operator::Eq, expected).unwrap();
2848    }
2849
2850    #[test]
2851    fn eq_op_bool_scalar() {
2852        let (schema, a) = scalar_bool_test_array();
2853        let expected = [Some(true), None, Some(false)].iter().collect();
2854        apply_logic_op_scalar_arr(
2855            &schema,
2856            &ScalarValue::from(true),
2857            &a,
2858            Operator::Eq,
2859            &expected,
2860        )
2861        .unwrap();
2862        apply_logic_op_arr_scalar(
2863            &schema,
2864            &a,
2865            &ScalarValue::from(true),
2866            Operator::Eq,
2867            &expected,
2868        )
2869        .unwrap();
2870
2871        let expected = [Some(false), None, Some(true)].iter().collect();
2872        apply_logic_op_scalar_arr(
2873            &schema,
2874            &ScalarValue::from(false),
2875            &a,
2876            Operator::Eq,
2877            &expected,
2878        )
2879        .unwrap();
2880        apply_logic_op_arr_scalar(
2881            &schema,
2882            &a,
2883            &ScalarValue::from(false),
2884            Operator::Eq,
2885            &expected,
2886        )
2887        .unwrap();
2888    }
2889
2890    #[test]
2891    fn neq_op_bool() {
2892        let (schema, a, b) = bool_test_arrays();
2893        let expected = [
2894            Some(false),
2895            None,
2896            Some(true),
2897            None,
2898            None,
2899            None,
2900            Some(true),
2901            None,
2902            Some(false),
2903        ]
2904        .iter()
2905        .collect();
2906        apply_logic_op(&schema, &a, &b, Operator::NotEq, expected).unwrap();
2907    }
2908
2909    #[test]
2910    fn neq_op_bool_scalar() {
2911        let (schema, a) = scalar_bool_test_array();
2912        let expected = [Some(false), None, Some(true)].iter().collect();
2913        apply_logic_op_scalar_arr(
2914            &schema,
2915            &ScalarValue::from(true),
2916            &a,
2917            Operator::NotEq,
2918            &expected,
2919        )
2920        .unwrap();
2921        apply_logic_op_arr_scalar(
2922            &schema,
2923            &a,
2924            &ScalarValue::from(true),
2925            Operator::NotEq,
2926            &expected,
2927        )
2928        .unwrap();
2929
2930        let expected = [Some(true), None, Some(false)].iter().collect();
2931        apply_logic_op_scalar_arr(
2932            &schema,
2933            &ScalarValue::from(false),
2934            &a,
2935            Operator::NotEq,
2936            &expected,
2937        )
2938        .unwrap();
2939        apply_logic_op_arr_scalar(
2940            &schema,
2941            &a,
2942            &ScalarValue::from(false),
2943            Operator::NotEq,
2944            &expected,
2945        )
2946        .unwrap();
2947    }
2948
2949    #[test]
2950    fn lt_op_bool() {
2951        let (schema, a, b) = bool_test_arrays();
2952        let expected = [
2953            Some(false),
2954            None,
2955            Some(false),
2956            None,
2957            None,
2958            None,
2959            Some(true),
2960            None,
2961            Some(false),
2962        ]
2963        .iter()
2964        .collect();
2965        apply_logic_op(&schema, &a, &b, Operator::Lt, expected).unwrap();
2966    }
2967
2968    #[test]
2969    fn lt_op_bool_scalar() {
2970        let (schema, a) = scalar_bool_test_array();
2971        let expected = [Some(false), None, Some(false)].iter().collect();
2972        apply_logic_op_scalar_arr(
2973            &schema,
2974            &ScalarValue::from(true),
2975            &a,
2976            Operator::Lt,
2977            &expected,
2978        )
2979        .unwrap();
2980
2981        let expected = [Some(false), None, Some(true)].iter().collect();
2982        apply_logic_op_arr_scalar(
2983            &schema,
2984            &a,
2985            &ScalarValue::from(true),
2986            Operator::Lt,
2987            &expected,
2988        )
2989        .unwrap();
2990
2991        let expected = [Some(true), None, Some(false)].iter().collect();
2992        apply_logic_op_scalar_arr(
2993            &schema,
2994            &ScalarValue::from(false),
2995            &a,
2996            Operator::Lt,
2997            &expected,
2998        )
2999        .unwrap();
3000
3001        let expected = [Some(false), None, Some(false)].iter().collect();
3002        apply_logic_op_arr_scalar(
3003            &schema,
3004            &a,
3005            &ScalarValue::from(false),
3006            Operator::Lt,
3007            &expected,
3008        )
3009        .unwrap();
3010    }
3011
3012    #[test]
3013    fn lt_eq_op_bool() {
3014        let (schema, a, b) = bool_test_arrays();
3015        let expected = [
3016            Some(true),
3017            None,
3018            Some(false),
3019            None,
3020            None,
3021            None,
3022            Some(true),
3023            None,
3024            Some(true),
3025        ]
3026        .iter()
3027        .collect();
3028        apply_logic_op(&schema, &a, &b, Operator::LtEq, expected).unwrap();
3029    }
3030
3031    #[test]
3032    fn lt_eq_op_bool_scalar() {
3033        let (schema, a) = scalar_bool_test_array();
3034        let expected = [Some(true), None, Some(false)].iter().collect();
3035        apply_logic_op_scalar_arr(
3036            &schema,
3037            &ScalarValue::from(true),
3038            &a,
3039            Operator::LtEq,
3040            &expected,
3041        )
3042        .unwrap();
3043
3044        let expected = [Some(true), None, Some(true)].iter().collect();
3045        apply_logic_op_arr_scalar(
3046            &schema,
3047            &a,
3048            &ScalarValue::from(true),
3049            Operator::LtEq,
3050            &expected,
3051        )
3052        .unwrap();
3053
3054        let expected = [Some(true), None, Some(true)].iter().collect();
3055        apply_logic_op_scalar_arr(
3056            &schema,
3057            &ScalarValue::from(false),
3058            &a,
3059            Operator::LtEq,
3060            &expected,
3061        )
3062        .unwrap();
3063
3064        let expected = [Some(false), None, Some(true)].iter().collect();
3065        apply_logic_op_arr_scalar(
3066            &schema,
3067            &a,
3068            &ScalarValue::from(false),
3069            Operator::LtEq,
3070            &expected,
3071        )
3072        .unwrap();
3073    }
3074
3075    #[test]
3076    fn gt_op_bool() {
3077        let (schema, a, b) = bool_test_arrays();
3078        let expected = [
3079            Some(false),
3080            None,
3081            Some(true),
3082            None,
3083            None,
3084            None,
3085            Some(false),
3086            None,
3087            Some(false),
3088        ]
3089        .iter()
3090        .collect();
3091        apply_logic_op(&schema, &a, &b, Operator::Gt, expected).unwrap();
3092    }
3093
3094    #[test]
3095    fn gt_op_bool_scalar() {
3096        let (schema, a) = scalar_bool_test_array();
3097        let expected = [Some(false), None, Some(true)].iter().collect();
3098        apply_logic_op_scalar_arr(
3099            &schema,
3100            &ScalarValue::from(true),
3101            &a,
3102            Operator::Gt,
3103            &expected,
3104        )
3105        .unwrap();
3106
3107        let expected = [Some(false), None, Some(false)].iter().collect();
3108        apply_logic_op_arr_scalar(
3109            &schema,
3110            &a,
3111            &ScalarValue::from(true),
3112            Operator::Gt,
3113            &expected,
3114        )
3115        .unwrap();
3116
3117        let expected = [Some(false), None, Some(false)].iter().collect();
3118        apply_logic_op_scalar_arr(
3119            &schema,
3120            &ScalarValue::from(false),
3121            &a,
3122            Operator::Gt,
3123            &expected,
3124        )
3125        .unwrap();
3126
3127        let expected = [Some(true), None, Some(false)].iter().collect();
3128        apply_logic_op_arr_scalar(
3129            &schema,
3130            &a,
3131            &ScalarValue::from(false),
3132            Operator::Gt,
3133            &expected,
3134        )
3135        .unwrap();
3136    }
3137
3138    #[test]
3139    fn gt_eq_op_bool() {
3140        let (schema, a, b) = bool_test_arrays();
3141        let expected = [
3142            Some(true),
3143            None,
3144            Some(true),
3145            None,
3146            None,
3147            None,
3148            Some(false),
3149            None,
3150            Some(true),
3151        ]
3152        .iter()
3153        .collect();
3154        apply_logic_op(&schema, &a, &b, Operator::GtEq, expected).unwrap();
3155    }
3156
3157    #[test]
3158    fn gt_eq_op_bool_scalar() {
3159        let (schema, a) = scalar_bool_test_array();
3160        let expected = [Some(true), None, Some(true)].iter().collect();
3161        apply_logic_op_scalar_arr(
3162            &schema,
3163            &ScalarValue::from(true),
3164            &a,
3165            Operator::GtEq,
3166            &expected,
3167        )
3168        .unwrap();
3169
3170        let expected = [Some(true), None, Some(false)].iter().collect();
3171        apply_logic_op_arr_scalar(
3172            &schema,
3173            &a,
3174            &ScalarValue::from(true),
3175            Operator::GtEq,
3176            &expected,
3177        )
3178        .unwrap();
3179
3180        let expected = [Some(false), None, Some(true)].iter().collect();
3181        apply_logic_op_scalar_arr(
3182            &schema,
3183            &ScalarValue::from(false),
3184            &a,
3185            Operator::GtEq,
3186            &expected,
3187        )
3188        .unwrap();
3189
3190        let expected = [Some(true), None, Some(true)].iter().collect();
3191        apply_logic_op_arr_scalar(
3192            &schema,
3193            &a,
3194            &ScalarValue::from(false),
3195            Operator::GtEq,
3196            &expected,
3197        )
3198        .unwrap();
3199    }
3200
3201    #[test]
3202    fn is_distinct_from_op_bool() {
3203        let (schema, a, b) = bool_test_arrays();
3204        let expected = [
3205            Some(false),
3206            Some(true),
3207            Some(true),
3208            Some(true),
3209            Some(false),
3210            Some(true),
3211            Some(true),
3212            Some(true),
3213            Some(false),
3214        ]
3215        .iter()
3216        .collect();
3217        apply_logic_op(&schema, &a, &b, Operator::IsDistinctFrom, expected).unwrap();
3218    }
3219
3220    #[test]
3221    fn is_not_distinct_from_op_bool() {
3222        let (schema, a, b) = bool_test_arrays();
3223        let expected = [
3224            Some(true),
3225            Some(false),
3226            Some(false),
3227            Some(false),
3228            Some(true),
3229            Some(false),
3230            Some(false),
3231            Some(false),
3232            Some(true),
3233        ]
3234        .iter()
3235        .collect();
3236        apply_logic_op(&schema, &a, &b, Operator::IsNotDistinctFrom, expected).unwrap();
3237    }
3238
3239    #[test]
3240    fn relatively_deeply_nested() {
3241        // Reproducer for https://github.com/apache/datafusion/issues/419
3242
3243        // where even relatively shallow binary expressions overflowed
3244        // the stack in debug builds
3245
3246        let input: Vec<_> = vec![1, 2, 3, 4, 5].into_iter().map(Some).collect();
3247        let a: Int32Array = input.iter().collect();
3248
3249        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(a) as _)]).unwrap();
3250        let schema = batch.schema();
3251
3252        // build a left deep tree ((((a + a) + a) + a ....
3253        let tree_depth: i32 = 100;
3254        let expr = (0..tree_depth)
3255            .map(|_| col("a", schema.as_ref()).unwrap())
3256            .reduce(|l, r| binary(l, Operator::Plus, r, &schema).unwrap())
3257            .unwrap();
3258
3259        let result = expr
3260            .evaluate(&batch)
3261            .expect("evaluation")
3262            .into_array(batch.num_rows())
3263            .expect("Failed to convert to array");
3264
3265        let expected: Int32Array = input
3266            .into_iter()
3267            .map(|i| i.map(|i| i * tree_depth))
3268            .collect();
3269        assert_eq!(result.as_ref(), &expected);
3270    }
3271
3272    fn create_decimal_array(
3273        array: &[Option<i128>],
3274        precision: u8,
3275        scale: i8,
3276    ) -> Decimal128Array {
3277        let mut decimal_builder = Decimal128Builder::with_capacity(array.len());
3278        for value in array.iter().copied() {
3279            decimal_builder.append_option(value)
3280        }
3281        decimal_builder
3282            .finish()
3283            .with_precision_and_scale(precision, scale)
3284            .unwrap()
3285    }
3286
3287    #[test]
3288    fn comparison_dict_decimal_scalar_expr_test() -> Result<()> {
3289        // scalar of decimal compare with dictionary decimal array
3290        let value_i128 = 123;
3291        let decimal_scalar = ScalarValue::Dictionary(
3292            Box::new(DataType::Int8),
3293            Box::new(ScalarValue::Decimal128(Some(value_i128), 25, 3)),
3294        );
3295        let schema = Arc::new(Schema::new(vec![Field::new(
3296            "a",
3297            DataType::Dictionary(
3298                Box::new(DataType::Int8),
3299                Box::new(DataType::Decimal128(25, 3)),
3300            ),
3301            true,
3302        )]));
3303        let decimal_array = Arc::new(create_decimal_array(
3304            &[
3305                Some(value_i128),
3306                None,
3307                Some(value_i128 - 1),
3308                Some(value_i128 + 1),
3309            ],
3310            25,
3311            3,
3312        ));
3313
3314        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
3315        let dictionary =
3316            Arc::new(DictionaryArray::try_new(keys, decimal_array)?) as ArrayRef;
3317
3318        // array = scalar
3319        apply_logic_op_arr_scalar(
3320            &schema,
3321            &dictionary,
3322            &decimal_scalar,
3323            Operator::Eq,
3324            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3325        )
3326        .unwrap();
3327        // array != scalar
3328        apply_logic_op_arr_scalar(
3329            &schema,
3330            &dictionary,
3331            &decimal_scalar,
3332            Operator::NotEq,
3333            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3334        )
3335        .unwrap();
3336        //  array < scalar
3337        apply_logic_op_arr_scalar(
3338            &schema,
3339            &dictionary,
3340            &decimal_scalar,
3341            Operator::Lt,
3342            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3343        )
3344        .unwrap();
3345
3346        //  array <= scalar
3347        apply_logic_op_arr_scalar(
3348            &schema,
3349            &dictionary,
3350            &decimal_scalar,
3351            Operator::LtEq,
3352            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3353        )
3354        .unwrap();
3355        // array > scalar
3356        apply_logic_op_arr_scalar(
3357            &schema,
3358            &dictionary,
3359            &decimal_scalar,
3360            Operator::Gt,
3361            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3362        )
3363        .unwrap();
3364
3365        // array >= scalar
3366        apply_logic_op_arr_scalar(
3367            &schema,
3368            &dictionary,
3369            &decimal_scalar,
3370            Operator::GtEq,
3371            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3372        )
3373        .unwrap();
3374
3375        Ok(())
3376    }
3377
3378    #[test]
3379    fn comparison_decimal_expr_test() -> Result<()> {
3380        // scalar of decimal compare with decimal array
3381        let value_i128 = 123;
3382        let decimal_scalar = ScalarValue::Decimal128(Some(value_i128), 25, 3);
3383        let schema = Arc::new(Schema::new(vec![Field::new(
3384            "a",
3385            DataType::Decimal128(25, 3),
3386            true,
3387        )]));
3388        let decimal_array = Arc::new(create_decimal_array(
3389            &[
3390                Some(value_i128),
3391                None,
3392                Some(value_i128 - 1),
3393                Some(value_i128 + 1),
3394            ],
3395            25,
3396            3,
3397        )) as ArrayRef;
3398        // array = scalar
3399        apply_logic_op_arr_scalar(
3400            &schema,
3401            &decimal_array,
3402            &decimal_scalar,
3403            Operator::Eq,
3404            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3405        )
3406        .unwrap();
3407        // array != scalar
3408        apply_logic_op_arr_scalar(
3409            &schema,
3410            &decimal_array,
3411            &decimal_scalar,
3412            Operator::NotEq,
3413            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3414        )
3415        .unwrap();
3416        //  array < scalar
3417        apply_logic_op_arr_scalar(
3418            &schema,
3419            &decimal_array,
3420            &decimal_scalar,
3421            Operator::Lt,
3422            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3423        )
3424        .unwrap();
3425
3426        //  array <= scalar
3427        apply_logic_op_arr_scalar(
3428            &schema,
3429            &decimal_array,
3430            &decimal_scalar,
3431            Operator::LtEq,
3432            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3433        )
3434        .unwrap();
3435        // array > scalar
3436        apply_logic_op_arr_scalar(
3437            &schema,
3438            &decimal_array,
3439            &decimal_scalar,
3440            Operator::Gt,
3441            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3442        )
3443        .unwrap();
3444
3445        // array >= scalar
3446        apply_logic_op_arr_scalar(
3447            &schema,
3448            &decimal_array,
3449            &decimal_scalar,
3450            Operator::GtEq,
3451            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3452        )
3453        .unwrap();
3454
3455        // scalar of different data type with decimal array
3456        let decimal_scalar = ScalarValue::Decimal128(Some(123_456), 10, 3);
3457        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
3458        // scalar == array
3459        apply_logic_op_scalar_arr(
3460            &schema,
3461            &decimal_scalar,
3462            &(Arc::new(Int64Array::from(vec![Some(124), None])) as ArrayRef),
3463            Operator::Eq,
3464            &BooleanArray::from(vec![Some(false), None]),
3465        )
3466        .unwrap();
3467
3468        // array != scalar
3469        apply_logic_op_arr_scalar(
3470            &schema,
3471            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(1)])) as ArrayRef),
3472            &decimal_scalar,
3473            Operator::NotEq,
3474            &BooleanArray::from(vec![Some(true), None, Some(true)]),
3475        )
3476        .unwrap();
3477
3478        // array < scalar
3479        apply_logic_op_arr_scalar(
3480            &schema,
3481            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3482            &decimal_scalar,
3483            Operator::Lt,
3484            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3485        )
3486        .unwrap();
3487
3488        // array > scalar
3489        apply_logic_op_arr_scalar(
3490            &schema,
3491            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3492            &decimal_scalar,
3493            Operator::Gt,
3494            &BooleanArray::from(vec![Some(false), None, Some(true)]),
3495        )
3496        .unwrap();
3497
3498        let schema =
3499            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
3500        // array == scalar
3501        apply_logic_op_arr_scalar(
3502            &schema,
3503            &(Arc::new(Float64Array::from(vec![Some(123.456), None, Some(123.457)]))
3504                as ArrayRef),
3505            &decimal_scalar,
3506            Operator::Eq,
3507            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3508        )
3509        .unwrap();
3510
3511        // array <= scalar
3512        apply_logic_op_arr_scalar(
3513            &schema,
3514            &(Arc::new(Float64Array::from(vec![
3515                Some(123.456),
3516                None,
3517                Some(123.457),
3518                Some(123.45),
3519            ])) as ArrayRef),
3520            &decimal_scalar,
3521            Operator::LtEq,
3522            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3523        )
3524        .unwrap();
3525        // array >= scalar
3526        apply_logic_op_arr_scalar(
3527            &schema,
3528            &(Arc::new(Float64Array::from(vec![
3529                Some(123.456),
3530                None,
3531                Some(123.457),
3532                Some(123.45),
3533            ])) as ArrayRef),
3534            &decimal_scalar,
3535            Operator::GtEq,
3536            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3537        )
3538        .unwrap();
3539
3540        let value: i128 = 123;
3541        let decimal_array = Arc::new(create_decimal_array(
3542            &[Some(value), None, Some(value - 1), Some(value + 1)],
3543            10,
3544            0,
3545        )) as ArrayRef;
3546
3547        // comparison array op for decimal array
3548        let schema = Arc::new(Schema::new(vec![
3549            Field::new("a", DataType::Decimal128(10, 0), true),
3550            Field::new("b", DataType::Decimal128(10, 0), true),
3551        ]));
3552        let right_decimal_array = Arc::new(create_decimal_array(
3553            &[
3554                Some(value - 1),
3555                Some(value),
3556                Some(value + 1),
3557                Some(value + 1),
3558            ],
3559            10,
3560            0,
3561        )) as ArrayRef;
3562
3563        apply_logic_op(
3564            &schema,
3565            &decimal_array,
3566            &right_decimal_array,
3567            Operator::Eq,
3568            BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3569        )
3570        .unwrap();
3571
3572        apply_logic_op(
3573            &schema,
3574            &decimal_array,
3575            &right_decimal_array,
3576            Operator::NotEq,
3577            BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3578        )
3579        .unwrap();
3580
3581        apply_logic_op(
3582            &schema,
3583            &decimal_array,
3584            &right_decimal_array,
3585            Operator::Lt,
3586            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3587        )
3588        .unwrap();
3589
3590        apply_logic_op(
3591            &schema,
3592            &decimal_array,
3593            &right_decimal_array,
3594            Operator::LtEq,
3595            BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3596        )
3597        .unwrap();
3598
3599        apply_logic_op(
3600            &schema,
3601            &decimal_array,
3602            &right_decimal_array,
3603            Operator::Gt,
3604            BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3605        )
3606        .unwrap();
3607
3608        apply_logic_op(
3609            &schema,
3610            &decimal_array,
3611            &right_decimal_array,
3612            Operator::GtEq,
3613            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3614        )
3615        .unwrap();
3616
3617        // compare decimal array with other array type
3618        let value: i64 = 123;
3619        let schema = Arc::new(Schema::new(vec![
3620            Field::new("a", DataType::Int64, true),
3621            Field::new("b", DataType::Decimal128(10, 0), true),
3622        ]));
3623
3624        let int64_array = Arc::new(Int64Array::from(vec![
3625            Some(value),
3626            Some(value - 1),
3627            Some(value),
3628            Some(value + 1),
3629        ])) as ArrayRef;
3630
3631        // eq: int64array == decimal array
3632        apply_logic_op(
3633            &schema,
3634            &int64_array,
3635            &decimal_array,
3636            Operator::Eq,
3637            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3638        )
3639        .unwrap();
3640        // neq: int64array != decimal array
3641        apply_logic_op(
3642            &schema,
3643            &int64_array,
3644            &decimal_array,
3645            Operator::NotEq,
3646            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3647        )
3648        .unwrap();
3649
3650        let schema = Arc::new(Schema::new(vec![
3651            Field::new("a", DataType::Float64, true),
3652            Field::new("b", DataType::Decimal128(10, 2), true),
3653        ]));
3654
3655        let value: i128 = 123;
3656        let decimal_array = Arc::new(create_decimal_array(
3657            &[
3658                Some(value), // 1.23
3659                None,
3660                Some(value - 1), // 1.22
3661                Some(value + 1), // 1.24
3662            ],
3663            10,
3664            2,
3665        )) as ArrayRef;
3666        let float64_array = Arc::new(Float64Array::from(vec![
3667            Some(1.23),
3668            Some(1.22),
3669            Some(1.23),
3670            Some(1.24),
3671        ])) as ArrayRef;
3672        // lt: float64array < decimal array
3673        apply_logic_op(
3674            &schema,
3675            &float64_array,
3676            &decimal_array,
3677            Operator::Lt,
3678            BooleanArray::from(vec![Some(false), None, Some(false), Some(false)]),
3679        )
3680        .unwrap();
3681        // lt_eq: float64array <= decimal array
3682        apply_logic_op(
3683            &schema,
3684            &float64_array,
3685            &decimal_array,
3686            Operator::LtEq,
3687            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3688        )
3689        .unwrap();
3690        // gt: float64array > decimal array
3691        apply_logic_op(
3692            &schema,
3693            &float64_array,
3694            &decimal_array,
3695            Operator::Gt,
3696            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3697        )
3698        .unwrap();
3699        apply_logic_op(
3700            &schema,
3701            &float64_array,
3702            &decimal_array,
3703            Operator::GtEq,
3704            BooleanArray::from(vec![Some(true), None, Some(true), Some(true)]),
3705        )
3706        .unwrap();
3707        // is distinct: float64array is distinct decimal array
3708        // TODO: now we do not refactor the `is distinct or is not distinct` rule of coercion.
3709        // traced by https://github.com/apache/datafusion/issues/1590
3710        // the decimal array will be casted to float64array
3711        apply_logic_op(
3712            &schema,
3713            &float64_array,
3714            &decimal_array,
3715            Operator::IsDistinctFrom,
3716            BooleanArray::from(vec![Some(false), Some(true), Some(true), Some(false)]),
3717        )
3718        .unwrap();
3719        // is not distinct
3720        apply_logic_op(
3721            &schema,
3722            &float64_array,
3723            &decimal_array,
3724            Operator::IsNotDistinctFrom,
3725            BooleanArray::from(vec![Some(true), Some(false), Some(false), Some(true)]),
3726        )
3727        .unwrap();
3728
3729        Ok(())
3730    }
3731
3732    fn apply_decimal_arithmetic_op(
3733        schema: &SchemaRef,
3734        left: &ArrayRef,
3735        right: &ArrayRef,
3736        op: Operator,
3737        expected: ArrayRef,
3738    ) -> Result<()> {
3739        let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
3740        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
3741        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
3742        let result = arithmetic_op
3743            .evaluate(&batch)?
3744            .into_array(batch.num_rows())
3745            .expect("Failed to convert to array");
3746
3747        assert_eq!(result.as_ref(), expected.as_ref());
3748        Ok(())
3749    }
3750
3751    #[test]
3752    fn arithmetic_decimal_expr_test() -> Result<()> {
3753        let schema = Arc::new(Schema::new(vec![
3754            Field::new("a", DataType::Int32, true),
3755            Field::new("b", DataType::Decimal128(10, 2), true),
3756        ]));
3757        let value: i128 = 123;
3758        let decimal_array = Arc::new(create_decimal_array(
3759            &[
3760                Some(value), // 1.23
3761                None,
3762                Some(value - 1), // 1.22
3763                Some(value + 1), // 1.24
3764            ],
3765            10,
3766            2,
3767        )) as ArrayRef;
3768        let int32_array = Arc::new(Int32Array::from(vec![
3769            Some(123),
3770            Some(122),
3771            Some(123),
3772            Some(124),
3773        ])) as ArrayRef;
3774
3775        // add: Int32array add decimal array
3776        let expect = Arc::new(create_decimal_array(
3777            &[Some(12423), None, Some(12422), Some(12524)],
3778            13,
3779            2,
3780        )) as ArrayRef;
3781        apply_decimal_arithmetic_op(
3782            &schema,
3783            &int32_array,
3784            &decimal_array,
3785            Operator::Plus,
3786            expect,
3787        )
3788        .unwrap();
3789
3790        // subtract: decimal array subtract int32 array
3791        let schema = Arc::new(Schema::new(vec![
3792            Field::new("a", DataType::Decimal128(10, 2), true),
3793            Field::new("b", DataType::Int32, true),
3794        ]));
3795        let expect = Arc::new(create_decimal_array(
3796            &[Some(-12177), None, Some(-12178), Some(-12276)],
3797            13,
3798            2,
3799        )) as ArrayRef;
3800        apply_decimal_arithmetic_op(
3801            &schema,
3802            &decimal_array,
3803            &int32_array,
3804            Operator::Minus,
3805            expect,
3806        )
3807        .unwrap();
3808
3809        // multiply: decimal array multiply int32 array
3810        let expect = Arc::new(create_decimal_array(
3811            &[Some(15129), None, Some(15006), Some(15376)],
3812            21,
3813            2,
3814        )) as ArrayRef;
3815        apply_decimal_arithmetic_op(
3816            &schema,
3817            &decimal_array,
3818            &int32_array,
3819            Operator::Multiply,
3820            expect,
3821        )
3822        .unwrap();
3823
3824        // divide: int32 array divide decimal array
3825        let schema = Arc::new(Schema::new(vec![
3826            Field::new("a", DataType::Int32, true),
3827            Field::new("b", DataType::Decimal128(10, 2), true),
3828        ]));
3829        let expect = Arc::new(create_decimal_array(
3830            &[Some(1000000), None, Some(1008196), Some(1000000)],
3831            16,
3832            4,
3833        )) as ArrayRef;
3834        apply_decimal_arithmetic_op(
3835            &schema,
3836            &int32_array,
3837            &decimal_array,
3838            Operator::Divide,
3839            expect,
3840        )
3841        .unwrap();
3842
3843        // modulus: int32 array modulus decimal array
3844        let schema = Arc::new(Schema::new(vec![
3845            Field::new("a", DataType::Int32, true),
3846            Field::new("b", DataType::Decimal128(10, 2), true),
3847        ]));
3848        let expect = Arc::new(create_decimal_array(
3849            &[Some(000), None, Some(100), Some(000)],
3850            10,
3851            2,
3852        )) as ArrayRef;
3853        apply_decimal_arithmetic_op(
3854            &schema,
3855            &int32_array,
3856            &decimal_array,
3857            Operator::Modulo,
3858            expect,
3859        )
3860        .unwrap();
3861
3862        Ok(())
3863    }
3864
3865    #[test]
3866    fn arithmetic_decimal_float_expr_test() -> Result<()> {
3867        let schema = Arc::new(Schema::new(vec![
3868            Field::new("a", DataType::Float64, true),
3869            Field::new("b", DataType::Decimal128(10, 2), true),
3870        ]));
3871        let value: i128 = 123;
3872        let decimal_array = Arc::new(create_decimal_array(
3873            &[
3874                Some(value), // 1.23
3875                None,
3876                Some(value - 1), // 1.22
3877                Some(value + 1), // 1.24
3878            ],
3879            10,
3880            2,
3881        )) as ArrayRef;
3882        let float64_array = Arc::new(Float64Array::from(vec![
3883            Some(123.0),
3884            Some(122.0),
3885            Some(123.0),
3886            Some(124.0),
3887        ])) as ArrayRef;
3888
3889        // add: float64 array add decimal array
3890        let expect = Arc::new(Float64Array::from(vec![
3891            Some(124.23),
3892            None,
3893            Some(124.22),
3894            Some(125.24),
3895        ])) as ArrayRef;
3896        apply_decimal_arithmetic_op(
3897            &schema,
3898            &float64_array,
3899            &decimal_array,
3900            Operator::Plus,
3901            expect,
3902        )
3903        .unwrap();
3904
3905        // subtract: decimal array subtract float64 array
3906        let schema = Arc::new(Schema::new(vec![
3907            Field::new("a", DataType::Float64, true),
3908            Field::new("b", DataType::Decimal128(10, 2), true),
3909        ]));
3910        let expect = Arc::new(Float64Array::from(vec![
3911            Some(121.77),
3912            None,
3913            Some(121.78),
3914            Some(122.76),
3915        ])) as ArrayRef;
3916        apply_decimal_arithmetic_op(
3917            &schema,
3918            &float64_array,
3919            &decimal_array,
3920            Operator::Minus,
3921            expect,
3922        )
3923        .unwrap();
3924
3925        // multiply: decimal array multiply float64 array
3926        let expect = Arc::new(Float64Array::from(vec![
3927            Some(151.29),
3928            None,
3929            Some(150.06),
3930            Some(153.76),
3931        ])) as ArrayRef;
3932        apply_decimal_arithmetic_op(
3933            &schema,
3934            &float64_array,
3935            &decimal_array,
3936            Operator::Multiply,
3937            expect,
3938        )
3939        .unwrap();
3940
3941        // divide: float64 array divide decimal array
3942        let schema = Arc::new(Schema::new(vec![
3943            Field::new("a", DataType::Float64, true),
3944            Field::new("b", DataType::Decimal128(10, 2), true),
3945        ]));
3946        let expect = Arc::new(Float64Array::from(vec![
3947            Some(100.0),
3948            None,
3949            Some(100.81967213114754),
3950            Some(100.0),
3951        ])) as ArrayRef;
3952        apply_decimal_arithmetic_op(
3953            &schema,
3954            &float64_array,
3955            &decimal_array,
3956            Operator::Divide,
3957            expect,
3958        )
3959        .unwrap();
3960
3961        // modulus: float64 array modulus decimal array
3962        let schema = Arc::new(Schema::new(vec![
3963            Field::new("a", DataType::Float64, true),
3964            Field::new("b", DataType::Decimal128(10, 2), true),
3965        ]));
3966        let expect = Arc::new(Float64Array::from(vec![
3967            Some(1.7763568394002505e-15),
3968            None,
3969            Some(1.0000000000000027),
3970            Some(8.881784197001252e-16),
3971        ])) as ArrayRef;
3972        apply_decimal_arithmetic_op(
3973            &schema,
3974            &float64_array,
3975            &decimal_array,
3976            Operator::Modulo,
3977            expect,
3978        )
3979        .unwrap();
3980
3981        Ok(())
3982    }
3983
3984    #[test]
3985    fn arithmetic_divide_zero() -> Result<()> {
3986        // other data type
3987        let schema = Arc::new(Schema::new(vec![
3988            Field::new("a", DataType::Int32, true),
3989            Field::new("b", DataType::Int32, true),
3990        ]));
3991        let a = Arc::new(Int32Array::from(vec![100]));
3992        let b = Arc::new(Int32Array::from(vec![0]));
3993
3994        let err = apply_arithmetic::<Int32Type>(
3995            schema,
3996            vec![a, b],
3997            Operator::Divide,
3998            Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64)]),
3999        )
4000        .unwrap_err();
4001
4002        let _expected = plan_datafusion_err!("Divide by zero");
4003
4004        assert!(matches!(err, ref _expected), "{err}");
4005
4006        // decimal
4007        let schema = Arc::new(Schema::new(vec![
4008            Field::new("a", DataType::Decimal128(25, 3), true),
4009            Field::new("b", DataType::Decimal128(25, 3), true),
4010        ]));
4011        let left_decimal_array = Arc::new(create_decimal_array(&[Some(1234567)], 25, 3));
4012        let right_decimal_array = Arc::new(create_decimal_array(&[Some(0)], 25, 3));
4013
4014        let err = apply_arithmetic::<Decimal128Type>(
4015            schema,
4016            vec![left_decimal_array, right_decimal_array],
4017            Operator::Divide,
4018            create_decimal_array(
4019                &[Some(12345670000000000000000000000000000), None],
4020                38,
4021                29,
4022            ),
4023        )
4024        .unwrap_err();
4025
4026        assert!(matches!(err, ref _expected), "{err}");
4027
4028        Ok(())
4029    }
4030
4031    #[test]
4032    fn bitwise_array_test() -> Result<()> {
4033        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4034        let right =
4035            Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
4036        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
4037        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
4038        assert_eq!(result.as_ref(), &expected);
4039
4040        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
4041        let expected = Int32Array::from(vec![Some(13), None, Some(15)]);
4042        assert_eq!(result.as_ref(), &expected);
4043
4044        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
4045        let expected = Int32Array::from(vec![Some(13), None, Some(12)]);
4046        assert_eq!(result.as_ref(), &expected);
4047
4048        let left =
4049            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4050        let right =
4051            Arc::new(UInt32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
4052        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
4053        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
4054        assert_eq!(result.as_ref(), &expected);
4055
4056        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
4057        let expected = UInt32Array::from(vec![Some(13), None, Some(15)]);
4058        assert_eq!(result.as_ref(), &expected);
4059
4060        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
4061        let expected = UInt32Array::from(vec![Some(13), None, Some(12)]);
4062        assert_eq!(result.as_ref(), &expected);
4063
4064        Ok(())
4065    }
4066
4067    #[test]
4068    fn bitwise_shift_array_test() -> Result<()> {
4069        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4070        let modules =
4071            Arc::new(Int32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4072        let mut result =
4073            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4074
4075        let expected = Int32Array::from(vec![Some(8), None, Some(2560)]);
4076        assert_eq!(result.as_ref(), &expected);
4077
4078        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4079        assert_eq!(result.as_ref(), &input);
4080
4081        let input =
4082            Arc::new(UInt32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4083        let modules =
4084            Arc::new(UInt32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4085        let mut result =
4086            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4087
4088        let expected = UInt32Array::from(vec![Some(8), None, Some(2560)]);
4089        assert_eq!(result.as_ref(), &expected);
4090
4091        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4092        assert_eq!(result.as_ref(), &input);
4093        Ok(())
4094    }
4095
4096    #[test]
4097    fn bitwise_shift_array_overflow_test() -> Result<()> {
4098        let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef;
4099        let modules = Arc::new(Int32Array::from(vec![Some(100)])) as ArrayRef;
4100        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4101
4102        let expected = Int32Array::from(vec![Some(32)]);
4103        assert_eq!(result.as_ref(), &expected);
4104
4105        let input = Arc::new(UInt32Array::from(vec![Some(2)])) as ArrayRef;
4106        let modules = Arc::new(UInt32Array::from(vec![Some(100)])) as ArrayRef;
4107        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4108
4109        let expected = UInt32Array::from(vec![Some(32)]);
4110        assert_eq!(result.as_ref(), &expected);
4111        Ok(())
4112    }
4113
4114    #[test]
4115    fn bitwise_scalar_test() -> Result<()> {
4116        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4117        let right = ScalarValue::from(3i32);
4118        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4119        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
4120        assert_eq!(result.as_ref(), &expected);
4121
4122        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4123        let expected = Int32Array::from(vec![Some(15), None, Some(11)]);
4124        assert_eq!(result.as_ref(), &expected);
4125
4126        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4127        let expected = Int32Array::from(vec![Some(15), None, Some(8)]);
4128        assert_eq!(result.as_ref(), &expected);
4129
4130        let left =
4131            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4132        let right = ScalarValue::from(3u32);
4133        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4134        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
4135        assert_eq!(result.as_ref(), &expected);
4136
4137        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4138        let expected = UInt32Array::from(vec![Some(15), None, Some(11)]);
4139        assert_eq!(result.as_ref(), &expected);
4140
4141        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4142        let expected = UInt32Array::from(vec![Some(15), None, Some(8)]);
4143        assert_eq!(result.as_ref(), &expected);
4144        Ok(())
4145    }
4146
4147    #[test]
4148    fn bitwise_shift_scalar_test() -> Result<()> {
4149        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4150        let module = ScalarValue::from(10i32);
4151        let mut result =
4152            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4153
4154        let expected = Int32Array::from(vec![Some(2048), None, Some(4096)]);
4155        assert_eq!(result.as_ref(), &expected);
4156
4157        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4158        assert_eq!(result.as_ref(), &input);
4159
4160        let input = Arc::new(UInt32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4161        let module = ScalarValue::from(10u32);
4162        let mut result =
4163            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4164
4165        let expected = UInt32Array::from(vec![Some(2048), None, Some(4096)]);
4166        assert_eq!(result.as_ref(), &expected);
4167
4168        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4169        assert_eq!(result.as_ref(), &input);
4170        Ok(())
4171    }
4172
4173    #[test]
4174    fn test_display_and_or_combo() {
4175        let expr = BinaryExpr::new(
4176            Arc::new(BinaryExpr::new(
4177                lit(ScalarValue::from(1)),
4178                Operator::And,
4179                lit(ScalarValue::from(2)),
4180            )),
4181            Operator::And,
4182            Arc::new(BinaryExpr::new(
4183                lit(ScalarValue::from(3)),
4184                Operator::And,
4185                lit(ScalarValue::from(4)),
4186            )),
4187        );
4188        assert_eq!(expr.to_string(), "1 AND 2 AND 3 AND 4");
4189
4190        let expr = BinaryExpr::new(
4191            Arc::new(BinaryExpr::new(
4192                lit(ScalarValue::from(1)),
4193                Operator::Or,
4194                lit(ScalarValue::from(2)),
4195            )),
4196            Operator::Or,
4197            Arc::new(BinaryExpr::new(
4198                lit(ScalarValue::from(3)),
4199                Operator::Or,
4200                lit(ScalarValue::from(4)),
4201            )),
4202        );
4203        assert_eq!(expr.to_string(), "1 OR 2 OR 3 OR 4");
4204
4205        let expr = BinaryExpr::new(
4206            Arc::new(BinaryExpr::new(
4207                lit(ScalarValue::from(1)),
4208                Operator::And,
4209                lit(ScalarValue::from(2)),
4210            )),
4211            Operator::Or,
4212            Arc::new(BinaryExpr::new(
4213                lit(ScalarValue::from(3)),
4214                Operator::And,
4215                lit(ScalarValue::from(4)),
4216            )),
4217        );
4218        assert_eq!(expr.to_string(), "1 AND 2 OR 3 AND 4");
4219
4220        let expr = BinaryExpr::new(
4221            Arc::new(BinaryExpr::new(
4222                lit(ScalarValue::from(1)),
4223                Operator::Or,
4224                lit(ScalarValue::from(2)),
4225            )),
4226            Operator::And,
4227            Arc::new(BinaryExpr::new(
4228                lit(ScalarValue::from(3)),
4229                Operator::Or,
4230                lit(ScalarValue::from(4)),
4231            )),
4232        );
4233        assert_eq!(expr.to_string(), "(1 OR 2) AND (3 OR 4)");
4234    }
4235
4236    #[test]
4237    fn test_to_result_type_array() {
4238        let values = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
4239        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
4240        let dictionary =
4241            Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef;
4242
4243        // Casting Dictionary to Int32
4244        let casted = to_result_type_array(
4245            &Operator::Plus,
4246            Arc::clone(&dictionary),
4247            &DataType::Int32,
4248        )
4249        .unwrap();
4250        assert_eq!(
4251            &casted,
4252            &(Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)]))
4253                as ArrayRef)
4254        );
4255
4256        // Array has same datatype as result type, no casting
4257        let casted = to_result_type_array(
4258            &Operator::Plus,
4259            Arc::clone(&dictionary),
4260            dictionary.data_type(),
4261        )
4262        .unwrap();
4263        assert_eq!(&casted, &dictionary);
4264
4265        // Not numerical operator, no casting
4266        let casted = to_result_type_array(
4267            &Operator::Eq,
4268            Arc::clone(&dictionary),
4269            &DataType::Int32,
4270        )
4271        .unwrap();
4272        assert_eq!(&casted, &dictionary);
4273    }
4274
4275    #[test]
4276    fn test_add_with_overflow() -> Result<()> {
4277        // create test data
4278        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4279        let r = Arc::new(Int32Array::from(vec![2, 1]));
4280        let schema = Arc::new(Schema::new(vec![
4281            Field::new("l", DataType::Int32, false),
4282            Field::new("r", DataType::Int32, false),
4283        ]));
4284        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4285
4286        // create expression
4287        let expr = BinaryExpr::new(
4288            Arc::new(Column::new("l", 0)),
4289            Operator::Plus,
4290            Arc::new(Column::new("r", 1)),
4291        )
4292        .with_fail_on_overflow(true);
4293
4294        // evaluate expression
4295        let result = expr.evaluate(&batch);
4296        assert!(result
4297            .err()
4298            .unwrap()
4299            .to_string()
4300            .contains("Overflow happened on: 2147483647 + 1"));
4301        Ok(())
4302    }
4303
4304    #[test]
4305    fn test_subtract_with_overflow() -> Result<()> {
4306        // create test data
4307        let l = Arc::new(Int32Array::from(vec![1, i32::MIN]));
4308        let r = Arc::new(Int32Array::from(vec![2, 1]));
4309        let schema = Arc::new(Schema::new(vec![
4310            Field::new("l", DataType::Int32, false),
4311            Field::new("r", DataType::Int32, false),
4312        ]));
4313        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4314
4315        // create expression
4316        let expr = BinaryExpr::new(
4317            Arc::new(Column::new("l", 0)),
4318            Operator::Minus,
4319            Arc::new(Column::new("r", 1)),
4320        )
4321        .with_fail_on_overflow(true);
4322
4323        // evaluate expression
4324        let result = expr.evaluate(&batch);
4325        assert!(result
4326            .err()
4327            .unwrap()
4328            .to_string()
4329            .contains("Overflow happened on: -2147483648 - 1"));
4330        Ok(())
4331    }
4332
4333    #[test]
4334    fn test_mul_with_overflow() -> Result<()> {
4335        // create test data
4336        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4337        let r = Arc::new(Int32Array::from(vec![2, 2]));
4338        let schema = Arc::new(Schema::new(vec![
4339            Field::new("l", DataType::Int32, false),
4340            Field::new("r", DataType::Int32, false),
4341        ]));
4342        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4343
4344        // create expression
4345        let expr = BinaryExpr::new(
4346            Arc::new(Column::new("l", 0)),
4347            Operator::Multiply,
4348            Arc::new(Column::new("r", 1)),
4349        )
4350        .with_fail_on_overflow(true);
4351
4352        // evaluate expression
4353        let result = expr.evaluate(&batch);
4354        assert!(result
4355            .err()
4356            .unwrap()
4357            .to_string()
4358            .contains("Overflow happened on: 2147483647 * 2"));
4359        Ok(())
4360    }
4361
4362    /// Test helper for SIMILAR TO binary operation
4363    fn apply_similar_to(
4364        schema: &SchemaRef,
4365        va: Vec<&str>,
4366        vb: Vec<&str>,
4367        negated: bool,
4368        case_insensitive: bool,
4369        expected: &BooleanArray,
4370    ) -> Result<()> {
4371        let a = StringArray::from(va);
4372        let b = StringArray::from(vb);
4373        let op = similar_to(
4374            negated,
4375            case_insensitive,
4376            col("a", schema)?,
4377            col("b", schema)?,
4378        )?;
4379        let batch =
4380            RecordBatch::try_new(Arc::clone(schema), vec![Arc::new(a), Arc::new(b)])?;
4381        let result = op
4382            .evaluate(&batch)?
4383            .into_array(batch.num_rows())
4384            .expect("Failed to convert to array");
4385        assert_eq!(result.as_ref(), expected);
4386
4387        Ok(())
4388    }
4389
4390    #[test]
4391    fn test_similar_to() {
4392        let schema = Arc::new(Schema::new(vec![
4393            Field::new("a", DataType::Utf8, false),
4394            Field::new("b", DataType::Utf8, false),
4395        ]));
4396
4397        let expected = [Some(true), Some(false)].iter().collect();
4398        // case-sensitive
4399        apply_similar_to(
4400            &schema,
4401            vec!["hello world", "Hello World"],
4402            vec!["hello.*", "hello.*"],
4403            false,
4404            false,
4405            &expected,
4406        )
4407        .unwrap();
4408        // case-insensitive
4409        apply_similar_to(
4410            &schema,
4411            vec!["hello world", "bye"],
4412            vec!["hello.*", "hello.*"],
4413            false,
4414            true,
4415            &expected,
4416        )
4417        .unwrap();
4418    }
4419
4420    pub fn binary_expr(
4421        left: Arc<dyn PhysicalExpr>,
4422        op: Operator,
4423        right: Arc<dyn PhysicalExpr>,
4424        schema: &Schema,
4425    ) -> Result<BinaryExpr> {
4426        Ok(binary_op(left, op, right, schema)?
4427            .as_any()
4428            .downcast_ref::<BinaryExpr>()
4429            .unwrap()
4430            .clone())
4431    }
4432
4433    /// Test for Uniform-Uniform, Unknown-Uniform, Uniform-Unknown and Unknown-Unknown evaluation.
4434    #[test]
4435    fn test_evaluate_statistics_combination_of_range_holders() -> Result<()> {
4436        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4437        let a = Arc::new(Column::new("a", 0)) as _;
4438        let b = lit(ScalarValue::from(12.0));
4439
4440        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4441        let right_interval = Interval::make(Some(12.0), Some(36.0))?;
4442        let (left_mean, right_mean) = (ScalarValue::from(6.0), ScalarValue::from(24.0));
4443        let (left_med, right_med) = (ScalarValue::from(6.0), ScalarValue::from(24.0));
4444
4445        for children in [
4446            vec![
4447                &Distribution::new_uniform(left_interval.clone())?,
4448                &Distribution::new_uniform(right_interval.clone())?,
4449            ],
4450            vec![
4451                &Distribution::new_generic(
4452                    left_mean.clone(),
4453                    left_med.clone(),
4454                    ScalarValue::Float64(None),
4455                    left_interval.clone(),
4456                )?,
4457                &Distribution::new_uniform(right_interval.clone())?,
4458            ],
4459            vec![
4460                &Distribution::new_uniform(right_interval.clone())?,
4461                &Distribution::new_generic(
4462                    right_mean.clone(),
4463                    right_med.clone(),
4464                    ScalarValue::Float64(None),
4465                    right_interval.clone(),
4466                )?,
4467            ],
4468            vec![
4469                &Distribution::new_generic(
4470                    left_mean.clone(),
4471                    left_med.clone(),
4472                    ScalarValue::Float64(None),
4473                    left_interval.clone(),
4474                )?,
4475                &Distribution::new_generic(
4476                    right_mean.clone(),
4477                    right_med.clone(),
4478                    ScalarValue::Float64(None),
4479                    right_interval.clone(),
4480                )?,
4481            ],
4482        ] {
4483            let ops = vec![
4484                Operator::Plus,
4485                Operator::Minus,
4486                Operator::Multiply,
4487                Operator::Divide,
4488            ];
4489
4490            for op in ops {
4491                let expr = binary_expr(Arc::clone(&a), op, Arc::clone(&b), schema)?;
4492                assert_eq!(
4493                    expr.evaluate_statistics(&children)?,
4494                    new_generic_from_binary_op(&op, children[0], children[1])?
4495                );
4496            }
4497        }
4498        Ok(())
4499    }
4500
4501    #[test]
4502    fn test_evaluate_statistics_bernoulli() -> Result<()> {
4503        let schema = &Schema::new(vec![
4504            Field::new("a", DataType::Int64, false),
4505            Field::new("b", DataType::Int64, false),
4506        ]);
4507        let a = Arc::new(Column::new("a", 0)) as _;
4508        let b = Arc::new(Column::new("b", 1)) as _;
4509        let eq = Arc::new(binary_expr(
4510            Arc::clone(&a),
4511            Operator::Eq,
4512            Arc::clone(&b),
4513            schema,
4514        )?);
4515        let neq = Arc::new(binary_expr(a, Operator::NotEq, b, schema)?);
4516
4517        let left_stat = &Distribution::new_uniform(Interval::make(Some(0), Some(7))?)?;
4518        let right_stat = &Distribution::new_uniform(Interval::make(Some(4), Some(11))?)?;
4519
4520        // Intervals: [0, 7], [4, 11].
4521        // The intersection is [4, 7], so the probability of equality is 4 / 64 = 1 / 16.
4522        assert_eq!(
4523            eq.evaluate_statistics(&[left_stat, right_stat])?,
4524            Distribution::new_bernoulli(ScalarValue::from(1.0 / 16.0))?
4525        );
4526
4527        // The probability of being distinct is 1 - 1 / 16 = 15 / 16.
4528        assert_eq!(
4529            neq.evaluate_statistics(&[left_stat, right_stat])?,
4530            Distribution::new_bernoulli(ScalarValue::from(15.0 / 16.0))?
4531        );
4532
4533        Ok(())
4534    }
4535
4536    #[test]
4537    fn test_propagate_statistics_combination_of_range_holders_arithmetic() -> Result<()> {
4538        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4539        let a = Arc::new(Column::new("a", 0)) as _;
4540        let b = lit(ScalarValue::from(12.0));
4541
4542        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4543        let right_interval = Interval::make(Some(12.0), Some(36.0))?;
4544
4545        let parent = Distribution::new_uniform(Interval::make(Some(-432.), Some(432.))?)?;
4546        let children = vec![
4547            vec![
4548                Distribution::new_uniform(left_interval.clone())?,
4549                Distribution::new_uniform(right_interval.clone())?,
4550            ],
4551            vec![
4552                Distribution::new_generic(
4553                    ScalarValue::from(6.),
4554                    ScalarValue::from(6.),
4555                    ScalarValue::Float64(None),
4556                    left_interval.clone(),
4557                )?,
4558                Distribution::new_uniform(right_interval.clone())?,
4559            ],
4560            vec![
4561                Distribution::new_uniform(left_interval.clone())?,
4562                Distribution::new_generic(
4563                    ScalarValue::from(12.),
4564                    ScalarValue::from(12.),
4565                    ScalarValue::Float64(None),
4566                    right_interval.clone(),
4567                )?,
4568            ],
4569            vec![
4570                Distribution::new_generic(
4571                    ScalarValue::from(6.),
4572                    ScalarValue::from(6.),
4573                    ScalarValue::Float64(None),
4574                    left_interval.clone(),
4575                )?,
4576                Distribution::new_generic(
4577                    ScalarValue::from(12.),
4578                    ScalarValue::from(12.),
4579                    ScalarValue::Float64(None),
4580                    right_interval.clone(),
4581                )?,
4582            ],
4583        ];
4584
4585        let ops = vec![
4586            Operator::Plus,
4587            Operator::Minus,
4588            Operator::Multiply,
4589            Operator::Divide,
4590        ];
4591
4592        for child_view in children {
4593            let child_refs = child_view.iter().collect::<Vec<_>>();
4594            for op in &ops {
4595                let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
4596                assert_eq!(
4597                    expr.propagate_statistics(&parent, child_refs.as_slice())?,
4598                    Some(child_view.clone())
4599                );
4600            }
4601        }
4602        Ok(())
4603    }
4604
4605    #[test]
4606    fn test_propagate_statistics_combination_of_range_holders_comparison() -> Result<()> {
4607        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4608        let a = Arc::new(Column::new("a", 0)) as _;
4609        let b = lit(ScalarValue::from(12.0));
4610
4611        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4612        let right_interval = Interval::make(Some(6.0), Some(18.0))?;
4613
4614        let one = ScalarValue::from(1.0);
4615        let parent = Distribution::new_bernoulli(one)?;
4616        let children = vec![
4617            vec![
4618                Distribution::new_uniform(left_interval.clone())?,
4619                Distribution::new_uniform(right_interval.clone())?,
4620            ],
4621            vec![
4622                Distribution::new_generic(
4623                    ScalarValue::from(6.),
4624                    ScalarValue::from(6.),
4625                    ScalarValue::Float64(None),
4626                    left_interval.clone(),
4627                )?,
4628                Distribution::new_uniform(right_interval.clone())?,
4629            ],
4630            vec![
4631                Distribution::new_uniform(left_interval.clone())?,
4632                Distribution::new_generic(
4633                    ScalarValue::from(12.),
4634                    ScalarValue::from(12.),
4635                    ScalarValue::Float64(None),
4636                    right_interval.clone(),
4637                )?,
4638            ],
4639            vec![
4640                Distribution::new_generic(
4641                    ScalarValue::from(6.),
4642                    ScalarValue::from(6.),
4643                    ScalarValue::Float64(None),
4644                    left_interval.clone(),
4645                )?,
4646                Distribution::new_generic(
4647                    ScalarValue::from(12.),
4648                    ScalarValue::from(12.),
4649                    ScalarValue::Float64(None),
4650                    right_interval.clone(),
4651                )?,
4652            ],
4653        ];
4654
4655        let ops = vec![
4656            Operator::Eq,
4657            Operator::Gt,
4658            Operator::GtEq,
4659            Operator::Lt,
4660            Operator::LtEq,
4661        ];
4662
4663        for child_view in children {
4664            let child_refs = child_view.iter().collect::<Vec<_>>();
4665            for op in &ops {
4666                let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
4667                assert!(expr
4668                    .propagate_statistics(&parent, child_refs.as_slice())?
4669                    .is_some());
4670            }
4671        }
4672
4673        Ok(())
4674    }
4675}