datafusion_physical_expr/expressions/
in_list.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Implementation of `InList` expressions: [`InListExpr`]
19
20use std::any::Any;
21use std::fmt::Debug;
22use std::hash::{Hash, Hasher};
23use std::sync::Arc;
24
25use crate::physical_expr::physical_exprs_bag_equal;
26use crate::PhysicalExpr;
27
28use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
29use arrow::array::*;
30use arrow::buffer::BooleanBuffer;
31use arrow::compute::kernels::boolean::{not, or_kleene};
32use arrow::compute::take;
33use arrow::datatypes::*;
34use arrow::util::bit_iterator::BitIndexIterator;
35use arrow::{downcast_dictionary_array, downcast_primitive_array};
36use datafusion_common::cast::{
37    as_boolean_array, as_generic_binary_array, as_string_array,
38};
39use datafusion_common::hash_utils::HashValue;
40use datafusion_common::{
41    exec_err, internal_err, not_impl_err, DFSchema, Result, ScalarValue,
42};
43use datafusion_expr::ColumnarValue;
44use datafusion_physical_expr_common::datum::compare_with_eq;
45
46use ahash::RandomState;
47use datafusion_common::HashMap;
48use hashbrown::hash_map::RawEntryMut;
49
50/// InList
51pub struct InListExpr {
52    expr: Arc<dyn PhysicalExpr>,
53    list: Vec<Arc<dyn PhysicalExpr>>,
54    negated: bool,
55    static_filter: Option<Arc<dyn Set>>,
56}
57
58impl Debug for InListExpr {
59    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
60        f.debug_struct("InListExpr")
61            .field("expr", &self.expr)
62            .field("list", &self.list)
63            .field("negated", &self.negated)
64            .finish()
65    }
66}
67
68/// A type-erased container of array elements
69pub trait Set: Send + Sync {
70    fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
71    fn has_nulls(&self) -> bool;
72}
73
74struct ArrayHashSet {
75    state: RandomState,
76    /// Used to provide a lookup from value to in list index
77    ///
78    /// Note: usize::hash is not used, instead the raw entry
79    /// API is used to store entries w.r.t their value
80    map: HashMap<usize, (), ()>,
81}
82
83struct ArraySet<T> {
84    array: T,
85    hash_set: ArrayHashSet,
86}
87
88impl<T> ArraySet<T>
89where
90    T: Array + From<ArrayData>,
91{
92    fn new(array: &T, hash_set: ArrayHashSet) -> Self {
93        Self {
94            array: downcast_array(array),
95            hash_set,
96        }
97    }
98}
99
100impl<T> Set for ArraySet<T>
101where
102    T: Array + 'static,
103    for<'a> &'a T: ArrayAccessor,
104    for<'a> <&'a T as ArrayAccessor>::Item: IsEqual,
105{
106    fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
107        downcast_dictionary_array! {
108            v => {
109                let values_contains = self.contains(v.values().as_ref(), negated)?;
110                let result = take(&values_contains, v.keys(), None)?;
111                return Ok(downcast_array(result.as_ref()))
112            }
113            _ => {}
114        }
115
116        let v = v.as_any().downcast_ref::<T>().unwrap();
117        let in_array = &self.array;
118        let has_nulls = in_array.null_count() != 0;
119
120        Ok(ArrayIter::new(v)
121            .map(|v| {
122                v.and_then(|v| {
123                    let hash = v.hash_one(&self.hash_set.state);
124                    let contains = self
125                        .hash_set
126                        .map
127                        .raw_entry()
128                        .from_hash(hash, |idx| in_array.value(*idx).is_equal(&v))
129                        .is_some();
130
131                    match contains {
132                        true => Some(!negated),
133                        false if has_nulls => None,
134                        false => Some(negated),
135                    }
136                })
137            })
138            .collect())
139    }
140
141    fn has_nulls(&self) -> bool {
142        self.array.null_count() != 0
143    }
144}
145
146/// Computes an [`ArrayHashSet`] for the provided [`Array`] if there
147/// are nulls present or there are more than the configured number of
148/// elements.
149///
150/// Note: This is split into a separate function as higher-rank trait bounds currently
151/// cause type inference to misbehave
152fn make_hash_set<T>(array: T) -> ArrayHashSet
153where
154    T: ArrayAccessor,
155    T::Item: IsEqual,
156{
157    let state = RandomState::new();
158    let mut map: HashMap<usize, (), ()> =
159        HashMap::with_capacity_and_hasher(array.len(), ());
160
161    let insert_value = |idx| {
162        let value = array.value(idx);
163        let hash = value.hash_one(&state);
164        if let RawEntryMut::Vacant(v) = map
165            .raw_entry_mut()
166            .from_hash(hash, |x| array.value(*x).is_equal(&value))
167        {
168            v.insert_with_hasher(hash, idx, (), |x| array.value(*x).hash_one(&state));
169        }
170    };
171
172    match array.nulls() {
173        Some(nulls) => {
174            BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len())
175                .for_each(insert_value)
176        }
177        None => (0..array.len()).for_each(insert_value),
178    }
179
180    ArrayHashSet { state, map }
181}
182
183/// Creates a `Box<dyn Set>` for the given list of `IN` expressions and `batch`
184fn make_set(array: &dyn Array) -> Result<Arc<dyn Set>> {
185    Ok(downcast_primitive_array! {
186        array => Arc::new(ArraySet::new(array, make_hash_set(array))),
187        DataType::Boolean => {
188            let array = as_boolean_array(array)?;
189            Arc::new(ArraySet::new(array, make_hash_set(array)))
190        },
191        DataType::Utf8 => {
192            let array = as_string_array(array)?;
193            Arc::new(ArraySet::new(array, make_hash_set(array)))
194        }
195        DataType::LargeUtf8 => {
196            let array = as_largestring_array(array);
197            Arc::new(ArraySet::new(array, make_hash_set(array)))
198        }
199        DataType::Binary => {
200            let array = as_generic_binary_array::<i32>(array)?;
201            Arc::new(ArraySet::new(array, make_hash_set(array)))
202        }
203        DataType::LargeBinary => {
204            let array = as_generic_binary_array::<i64>(array)?;
205            Arc::new(ArraySet::new(array, make_hash_set(array)))
206        }
207        DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"),
208        d => return not_impl_err!("DataType::{d} not supported in InList")
209    })
210}
211
212/// Evaluates the list of expressions into an array, flattening any dictionaries
213fn evaluate_list(
214    list: &[Arc<dyn PhysicalExpr>],
215    batch: &RecordBatch,
216) -> Result<ArrayRef> {
217    let scalars = list
218        .iter()
219        .map(|expr| {
220            expr.evaluate(batch).and_then(|r| match r {
221                ColumnarValue::Array(_) => {
222                    exec_err!("InList expression must evaluate to a scalar")
223                }
224                // Flatten dictionary values
225                ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v),
226                ColumnarValue::Scalar(s) => Ok(s),
227            })
228        })
229        .collect::<Result<Vec<_>>>()?;
230
231    ScalarValue::iter_to_array(scalars)
232}
233
234fn try_cast_static_filter_to_set(
235    list: &[Arc<dyn PhysicalExpr>],
236    schema: &Schema,
237) -> Result<Arc<dyn Set>> {
238    let batch = RecordBatch::new_empty(Arc::new(schema.clone()));
239    make_set(evaluate_list(list, &batch)?.as_ref())
240}
241
242/// Custom equality check function which is used with [`ArrayHashSet`] for existence check.
243trait IsEqual: HashValue {
244    fn is_equal(&self, other: &Self) -> bool;
245}
246
247impl<T: IsEqual + ?Sized> IsEqual for &T {
248    fn is_equal(&self, other: &Self) -> bool {
249        T::is_equal(self, other)
250    }
251}
252
253macro_rules! is_equal {
254    ($($t:ty),+) => {
255        $(impl IsEqual for $t {
256            fn is_equal(&self, other: &Self) -> bool {
257                self == other
258            }
259        })*
260    };
261}
262is_equal!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64);
263is_equal!(bool, str, [u8]);
264is_equal!(IntervalDayTime, IntervalMonthDayNano);
265
266macro_rules! is_equal_float {
267    ($($t:ty),+) => {
268        $(impl IsEqual for $t {
269            fn is_equal(&self, other: &Self) -> bool {
270                self.to_bits() == other.to_bits()
271            }
272        })*
273    };
274}
275is_equal_float!(half::f16, f32, f64);
276
277impl InListExpr {
278    /// Create a new InList expression
279    pub fn new(
280        expr: Arc<dyn PhysicalExpr>,
281        list: Vec<Arc<dyn PhysicalExpr>>,
282        negated: bool,
283        static_filter: Option<Arc<dyn Set>>,
284    ) -> Self {
285        Self {
286            expr,
287            list,
288            negated,
289            static_filter,
290        }
291    }
292
293    /// Input expression
294    pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
295        &self.expr
296    }
297
298    /// List to search in
299    pub fn list(&self) -> &[Arc<dyn PhysicalExpr>] {
300        &self.list
301    }
302
303    /// Is this negated e.g. NOT IN LIST
304    pub fn negated(&self) -> bool {
305        self.negated
306    }
307}
308
309impl std::fmt::Display for InListExpr {
310    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
311        if self.negated {
312            if self.static_filter.is_some() {
313                write!(f, "{} NOT IN (SET) ({:?})", self.expr, self.list)
314            } else {
315                write!(f, "{} NOT IN ({:?})", self.expr, self.list)
316            }
317        } else if self.static_filter.is_some() {
318            write!(f, "Use {} IN (SET) ({:?})", self.expr, self.list)
319        } else {
320            write!(f, "{} IN ({:?})", self.expr, self.list)
321        }
322    }
323}
324
325impl PhysicalExpr for InListExpr {
326    /// Return a reference to Any that can be used for downcasting
327    fn as_any(&self) -> &dyn Any {
328        self
329    }
330
331    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
332        Ok(DataType::Boolean)
333    }
334
335    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
336        if self.expr.nullable(input_schema)? {
337            return Ok(true);
338        }
339
340        if let Some(static_filter) = &self.static_filter {
341            Ok(static_filter.has_nulls())
342        } else {
343            for expr in &self.list {
344                if expr.nullable(input_schema)? {
345                    return Ok(true);
346                }
347            }
348            Ok(false)
349        }
350    }
351
352    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
353        let num_rows = batch.num_rows();
354        let value = self.expr.evaluate(batch)?;
355        let r = match &self.static_filter {
356            Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?,
357            None => {
358                let value = value.into_array(num_rows)?;
359                let is_nested = value.data_type().is_nested();
360                let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold(
361                    BooleanArray::new(BooleanBuffer::new_unset(num_rows), None),
362                    |result, expr| -> Result<BooleanArray> {
363                        let rhs = compare_with_eq(
364                            &value,
365                            &expr?.into_array(num_rows)?,
366                            is_nested,
367                        )?;
368                        Ok(or_kleene(&result, &rhs)?)
369                    },
370                )?;
371
372                if self.negated {
373                    not(&found)?
374                } else {
375                    found
376                }
377            }
378        };
379        Ok(ColumnarValue::Array(Arc::new(r)))
380    }
381
382    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
383        let mut children = vec![];
384        children.push(&self.expr);
385        children.extend(&self.list);
386        children
387    }
388
389    fn with_new_children(
390        self: Arc<Self>,
391        children: Vec<Arc<dyn PhysicalExpr>>,
392    ) -> Result<Arc<dyn PhysicalExpr>> {
393        // assume the static_filter will not change during the rewrite process
394        Ok(Arc::new(InListExpr::new(
395            Arc::clone(&children[0]),
396            children[1..].to_vec(),
397            self.negated,
398            self.static_filter.clone(),
399        )))
400    }
401}
402
403impl PartialEq for InListExpr {
404    fn eq(&self, other: &Self) -> bool {
405        self.expr.eq(&other.expr)
406            && physical_exprs_bag_equal(&self.list, &other.list)
407            && self.negated == other.negated
408    }
409}
410
411impl Eq for InListExpr {}
412
413impl Hash for InListExpr {
414    fn hash<H: Hasher>(&self, state: &mut H) {
415        self.expr.hash(state);
416        self.negated.hash(state);
417        self.list.hash(state);
418        // Add `self.static_filter` when hash is available
419    }
420}
421
422/// Creates a unary expression InList
423pub fn in_list(
424    expr: Arc<dyn PhysicalExpr>,
425    list: Vec<Arc<dyn PhysicalExpr>>,
426    negated: &bool,
427    schema: &Schema,
428) -> Result<Arc<dyn PhysicalExpr>> {
429    // check the data type
430    let expr_data_type = expr.data_type(schema)?;
431    for list_expr in list.iter() {
432        let list_expr_data_type = list_expr.data_type(schema)?;
433        if !DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type) {
434            return internal_err!(
435                "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}"
436            );
437        }
438    }
439    let static_filter = try_cast_static_filter_to_set(&list, schema).ok();
440    Ok(Arc::new(InListExpr::new(
441        expr,
442        list,
443        *negated,
444        static_filter,
445    )))
446}
447
448#[cfg(test)]
449mod tests {
450
451    use super::*;
452    use crate::expressions;
453    use crate::expressions::{col, lit, try_cast};
454    use datafusion_common::plan_err;
455    use datafusion_expr::type_coercion::binary::comparison_coercion;
456
457    type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
458
459    // Try to do the type coercion for list physical expr.
460    // It's just used in the test
461    fn in_list_cast(
462        expr: Arc<dyn PhysicalExpr>,
463        list: Vec<Arc<dyn PhysicalExpr>>,
464        input_schema: &Schema,
465    ) -> Result<InListCastResult> {
466        let expr_type = &expr.data_type(input_schema)?;
467        let list_types: Vec<DataType> = list
468            .iter()
469            .map(|list_expr| list_expr.data_type(input_schema).unwrap())
470            .collect();
471        let result_type = get_coerce_type(expr_type, &list_types);
472        match result_type {
473            None => plan_err!(
474                "Can not find compatible types to compare {expr_type:?} with {list_types:?}"
475            ),
476            Some(data_type) => {
477                // find the coerced type
478                let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
479                let cast_list_expr = list
480                    .into_iter()
481                    .map(|list_expr| {
482                        try_cast(list_expr, input_schema, data_type.clone()).unwrap()
483                    })
484                    .collect();
485                Ok((cast_expr, cast_list_expr))
486            }
487        }
488    }
489
490    // Attempts to coerce the types of `list_type` to be comparable with the
491    // `expr_type`
492    fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option<DataType> {
493        list_type
494            .iter()
495            .try_fold(expr_type.clone(), |left_type, right_type| {
496                comparison_coercion(&left_type, right_type)
497            })
498    }
499
500    // applies the in_list expr to an input batch and list
501    macro_rules! in_list {
502        ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{
503            let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?;
504            in_list_raw!(
505                $BATCH,
506                cast_list_exprs,
507                $NEGATED,
508                $EXPECTED,
509                cast_expr,
510                $SCHEMA
511            );
512        }};
513    }
514
515    // applies the in_list expr to an input batch and list without cast
516    macro_rules! in_list_raw {
517        ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{
518            let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap();
519            let result = expr
520                .evaluate(&$BATCH)?
521                .into_array($BATCH.num_rows())
522                .expect("Failed to convert to array");
523            let result =
524                as_boolean_array(&result).expect("failed to downcast to BooleanArray");
525            let expected = &BooleanArray::from($EXPECTED);
526            assert_eq!(expected, result);
527        }};
528    }
529
530    #[test]
531    fn in_list_utf8() -> Result<()> {
532        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
533        let a = StringArray::from(vec![Some("a"), Some("d"), None]);
534        let col_a = col("a", &schema)?;
535        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
536
537        // expression: "a in ("a", "b")"
538        let list = vec![lit("a"), lit("b")];
539        in_list!(
540            batch,
541            list,
542            &false,
543            vec![Some(true), Some(false), None],
544            Arc::clone(&col_a),
545            &schema
546        );
547
548        // expression: "a not in ("a", "b")"
549        let list = vec![lit("a"), lit("b")];
550        in_list!(
551            batch,
552            list,
553            &true,
554            vec![Some(false), Some(true), None],
555            Arc::clone(&col_a),
556            &schema
557        );
558
559        // expression: "a in ("a", "b", null)"
560        let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
561        in_list!(
562            batch,
563            list,
564            &false,
565            vec![Some(true), None, None],
566            Arc::clone(&col_a),
567            &schema
568        );
569
570        // expression: "a not in ("a", "b", null)"
571        let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
572        in_list!(
573            batch,
574            list,
575            &true,
576            vec![Some(false), None, None],
577            Arc::clone(&col_a),
578            &schema
579        );
580
581        Ok(())
582    }
583
584    #[test]
585    fn in_list_binary() -> Result<()> {
586        let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]);
587        let a = BinaryArray::from(vec![
588            Some([1, 2, 3].as_slice()),
589            Some([1, 2, 2].as_slice()),
590            None,
591        ]);
592        let col_a = col("a", &schema)?;
593        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
594
595        // expression: "a in ([1, 2, 3], [4, 5, 6])"
596        let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())];
597        in_list!(
598            batch,
599            list.clone(),
600            &false,
601            vec![Some(true), Some(false), None],
602            Arc::clone(&col_a),
603            &schema
604        );
605
606        // expression: "a not in ([1, 2, 3], [4, 5, 6])"
607        in_list!(
608            batch,
609            list,
610            &true,
611            vec![Some(false), Some(true), None],
612            Arc::clone(&col_a),
613            &schema
614        );
615
616        // expression: "a in ([1, 2, 3], [4, 5, 6], null)"
617        let list = vec![
618            lit([1, 2, 3].as_slice()),
619            lit([4, 5, 6].as_slice()),
620            lit(ScalarValue::Binary(None)),
621        ];
622        in_list!(
623            batch,
624            list.clone(),
625            &false,
626            vec![Some(true), None, None],
627            Arc::clone(&col_a),
628            &schema
629        );
630
631        // expression: "a in ([1, 2, 3], [4, 5, 6], null)"
632        in_list!(
633            batch,
634            list,
635            &true,
636            vec![Some(false), None, None],
637            Arc::clone(&col_a),
638            &schema
639        );
640
641        Ok(())
642    }
643
644    #[test]
645    fn in_list_int64() -> Result<()> {
646        let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
647        let a = Int64Array::from(vec![Some(0), Some(2), None]);
648        let col_a = col("a", &schema)?;
649        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
650
651        // expression: "a in (0, 1)"
652        let list = vec![lit(0i64), lit(1i64)];
653        in_list!(
654            batch,
655            list,
656            &false,
657            vec![Some(true), Some(false), None],
658            Arc::clone(&col_a),
659            &schema
660        );
661
662        // expression: "a not in (0, 1)"
663        let list = vec![lit(0i64), lit(1i64)];
664        in_list!(
665            batch,
666            list,
667            &true,
668            vec![Some(false), Some(true), None],
669            Arc::clone(&col_a),
670            &schema
671        );
672
673        // expression: "a in (0, 1, NULL)"
674        let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)];
675        in_list!(
676            batch,
677            list,
678            &false,
679            vec![Some(true), None, None],
680            Arc::clone(&col_a),
681            &schema
682        );
683
684        // expression: "a not in (0, 1, NULL)"
685        let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)];
686        in_list!(
687            batch,
688            list,
689            &true,
690            vec![Some(false), None, None],
691            Arc::clone(&col_a),
692            &schema
693        );
694
695        Ok(())
696    }
697
698    #[test]
699    fn in_list_float64() -> Result<()> {
700        let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
701        let a = Float64Array::from(vec![
702            Some(0.0),
703            Some(0.2),
704            None,
705            Some(f64::NAN),
706            Some(-f64::NAN),
707        ]);
708        let col_a = col("a", &schema)?;
709        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
710
711        // expression: "a in (0.0, 0.1)"
712        let list = vec![lit(0.0f64), lit(0.1f64)];
713        in_list!(
714            batch,
715            list,
716            &false,
717            vec![Some(true), Some(false), None, Some(false), Some(false)],
718            Arc::clone(&col_a),
719            &schema
720        );
721
722        // expression: "a not in (0.0, 0.1)"
723        let list = vec![lit(0.0f64), lit(0.1f64)];
724        in_list!(
725            batch,
726            list,
727            &true,
728            vec![Some(false), Some(true), None, Some(true), Some(true)],
729            Arc::clone(&col_a),
730            &schema
731        );
732
733        // expression: "a in (0.0, 0.1, NULL)"
734        let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)];
735        in_list!(
736            batch,
737            list,
738            &false,
739            vec![Some(true), None, None, None, None],
740            Arc::clone(&col_a),
741            &schema
742        );
743
744        // expression: "a not in (0.0, 0.1, NULL)"
745        let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)];
746        in_list!(
747            batch,
748            list,
749            &true,
750            vec![Some(false), None, None, None, None],
751            Arc::clone(&col_a),
752            &schema
753        );
754
755        // expression: "a in (0.0, 0.1, NaN)"
756        let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)];
757        in_list!(
758            batch,
759            list,
760            &false,
761            vec![Some(true), Some(false), None, Some(true), Some(false)],
762            Arc::clone(&col_a),
763            &schema
764        );
765
766        // expression: "a not in (0.0, 0.1, NaN)"
767        let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)];
768        in_list!(
769            batch,
770            list,
771            &true,
772            vec![Some(false), Some(true), None, Some(false), Some(true)],
773            Arc::clone(&col_a),
774            &schema
775        );
776
777        // expression: "a in (0.0, 0.1, -NaN)"
778        let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)];
779        in_list!(
780            batch,
781            list,
782            &false,
783            vec![Some(true), Some(false), None, Some(false), Some(true)],
784            Arc::clone(&col_a),
785            &schema
786        );
787
788        // expression: "a not in (0.0, 0.1, -NaN)"
789        let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)];
790        in_list!(
791            batch,
792            list,
793            &true,
794            vec![Some(false), Some(true), None, Some(true), Some(false)],
795            Arc::clone(&col_a),
796            &schema
797        );
798
799        Ok(())
800    }
801
802    #[test]
803    fn in_list_bool() -> Result<()> {
804        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
805        let a = BooleanArray::from(vec![Some(true), None]);
806        let col_a = col("a", &schema)?;
807        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
808
809        // expression: "a in (true)"
810        let list = vec![lit(true)];
811        in_list!(
812            batch,
813            list,
814            &false,
815            vec![Some(true), None],
816            Arc::clone(&col_a),
817            &schema
818        );
819
820        // expression: "a not in (true)"
821        let list = vec![lit(true)];
822        in_list!(
823            batch,
824            list,
825            &true,
826            vec![Some(false), None],
827            Arc::clone(&col_a),
828            &schema
829        );
830
831        // expression: "a in (true, NULL)"
832        let list = vec![lit(true), lit(ScalarValue::Null)];
833        in_list!(
834            batch,
835            list,
836            &false,
837            vec![Some(true), None],
838            Arc::clone(&col_a),
839            &schema
840        );
841
842        // expression: "a not in (true, NULL)"
843        let list = vec![lit(true), lit(ScalarValue::Null)];
844        in_list!(
845            batch,
846            list,
847            &true,
848            vec![Some(false), None],
849            Arc::clone(&col_a),
850            &schema
851        );
852
853        Ok(())
854    }
855
856    #[test]
857    fn in_list_date64() -> Result<()> {
858        let schema = Schema::new(vec![Field::new("a", DataType::Date64, true)]);
859        let a = Date64Array::from(vec![Some(0), Some(2), None]);
860        let col_a = col("a", &schema)?;
861        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
862
863        // expression: "a in (0, 1)"
864        let list = vec![
865            lit(ScalarValue::Date64(Some(0))),
866            lit(ScalarValue::Date64(Some(1))),
867        ];
868        in_list!(
869            batch,
870            list,
871            &false,
872            vec![Some(true), Some(false), None],
873            Arc::clone(&col_a),
874            &schema
875        );
876
877        // expression: "a not in (0, 1)"
878        let list = vec![
879            lit(ScalarValue::Date64(Some(0))),
880            lit(ScalarValue::Date64(Some(1))),
881        ];
882        in_list!(
883            batch,
884            list,
885            &true,
886            vec![Some(false), Some(true), None],
887            Arc::clone(&col_a),
888            &schema
889        );
890
891        // expression: "a in (0, 1, NULL)"
892        let list = vec![
893            lit(ScalarValue::Date64(Some(0))),
894            lit(ScalarValue::Date64(Some(1))),
895            lit(ScalarValue::Null),
896        ];
897        in_list!(
898            batch,
899            list,
900            &false,
901            vec![Some(true), None, None],
902            Arc::clone(&col_a),
903            &schema
904        );
905
906        // expression: "a not in (0, 1, NULL)"
907        let list = vec![
908            lit(ScalarValue::Date64(Some(0))),
909            lit(ScalarValue::Date64(Some(1))),
910            lit(ScalarValue::Null),
911        ];
912        in_list!(
913            batch,
914            list,
915            &true,
916            vec![Some(false), None, None],
917            Arc::clone(&col_a),
918            &schema
919        );
920
921        Ok(())
922    }
923
924    #[test]
925    fn in_list_date32() -> Result<()> {
926        let schema = Schema::new(vec![Field::new("a", DataType::Date32, true)]);
927        let a = Date32Array::from(vec![Some(0), Some(2), None]);
928        let col_a = col("a", &schema)?;
929        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
930
931        // expression: "a in (0, 1)"
932        let list = vec![
933            lit(ScalarValue::Date32(Some(0))),
934            lit(ScalarValue::Date32(Some(1))),
935        ];
936        in_list!(
937            batch,
938            list,
939            &false,
940            vec![Some(true), Some(false), None],
941            Arc::clone(&col_a),
942            &schema
943        );
944
945        // expression: "a not in (0, 1)"
946        let list = vec![
947            lit(ScalarValue::Date32(Some(0))),
948            lit(ScalarValue::Date32(Some(1))),
949        ];
950        in_list!(
951            batch,
952            list,
953            &true,
954            vec![Some(false), Some(true), None],
955            Arc::clone(&col_a),
956            &schema
957        );
958
959        // expression: "a in (0, 1, NULL)"
960        let list = vec![
961            lit(ScalarValue::Date32(Some(0))),
962            lit(ScalarValue::Date32(Some(1))),
963            lit(ScalarValue::Null),
964        ];
965        in_list!(
966            batch,
967            list,
968            &false,
969            vec![Some(true), None, None],
970            Arc::clone(&col_a),
971            &schema
972        );
973
974        // expression: "a not in (0, 1, NULL)"
975        let list = vec![
976            lit(ScalarValue::Date32(Some(0))),
977            lit(ScalarValue::Date32(Some(1))),
978            lit(ScalarValue::Null),
979        ];
980        in_list!(
981            batch,
982            list,
983            &true,
984            vec![Some(false), None, None],
985            Arc::clone(&col_a),
986            &schema
987        );
988
989        Ok(())
990    }
991
992    #[test]
993    fn in_list_decimal() -> Result<()> {
994        // Now, we can check the NULL type
995        let schema =
996            Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]);
997        let array = vec![Some(100_0000_i128), None, Some(200_5000_i128)]
998            .into_iter()
999            .collect::<Decimal128Array>();
1000        let array = array.with_precision_and_scale(13, 4).unwrap();
1001        let col_a = col("a", &schema)?;
1002        let batch =
1003            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)])?;
1004
1005        // expression: "a in (100,200), the data type of list is INT32
1006        let list = vec![lit(100i32), lit(200i32)];
1007        in_list!(
1008            batch,
1009            list,
1010            &false,
1011            vec![Some(true), None, Some(false)],
1012            Arc::clone(&col_a),
1013            &schema
1014        );
1015        // expression: "a not in (100,200)
1016        let list = vec![lit(100i32), lit(200i32)];
1017        in_list!(
1018            batch,
1019            list,
1020            &true,
1021            vec![Some(false), None, Some(true)],
1022            Arc::clone(&col_a),
1023            &schema
1024        );
1025
1026        // expression: "a in (200,NULL), the data type of list is INT32 AND NULL
1027        let list = vec![lit(ScalarValue::Int32(Some(100))), lit(ScalarValue::Null)];
1028        in_list!(
1029            batch,
1030            list.clone(),
1031            &false,
1032            vec![Some(true), None, None],
1033            Arc::clone(&col_a),
1034            &schema
1035        );
1036        // expression: "a not in (200,NULL), the data type of list is INT32 AND NULL
1037        in_list!(
1038            batch,
1039            list,
1040            &true,
1041            vec![Some(false), None, None],
1042            Arc::clone(&col_a),
1043            &schema
1044        );
1045
1046        // expression: "a in (200.5, 100), the data type of list is FLOAT32 and INT32
1047        let list = vec![lit(200.50f32), lit(100i32)];
1048        in_list!(
1049            batch,
1050            list,
1051            &false,
1052            vec![Some(true), None, Some(true)],
1053            Arc::clone(&col_a),
1054            &schema
1055        );
1056
1057        // expression: "a not in (200.5, 100), the data type of list is FLOAT32 and INT32
1058        let list = vec![lit(200.50f32), lit(101i32)];
1059        in_list!(
1060            batch,
1061            list,
1062            &true,
1063            vec![Some(true), None, Some(false)],
1064            Arc::clone(&col_a),
1065            &schema
1066        );
1067
1068        // test the optimization: set
1069        // expression: "a in (99..300), the data type of list is INT32
1070        let list = (99i32..300).map(lit).collect::<Vec<_>>();
1071
1072        in_list!(
1073            batch,
1074            list.clone(),
1075            &false,
1076            vec![Some(true), None, Some(false)],
1077            Arc::clone(&col_a),
1078            &schema
1079        );
1080
1081        in_list!(
1082            batch,
1083            list,
1084            &true,
1085            vec![Some(false), None, Some(true)],
1086            Arc::clone(&col_a),
1087            &schema
1088        );
1089
1090        Ok(())
1091    }
1092
1093    #[test]
1094    fn test_cast_static_filter_to_set() -> Result<()> {
1095        // random schema
1096        let schema =
1097            Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]);
1098
1099        // list of phy expr
1100        let mut phy_exprs = vec![
1101            lit(1i64),
1102            expressions::cast(lit(2i32), &schema, DataType::Int64)?,
1103            try_cast(lit(3.13f32), &schema, DataType::Int64)?,
1104        ];
1105        let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1106
1107        let array = Int64Array::from(vec![1, 2, 3, 4]);
1108        let r = result.contains(&array, false).unwrap();
1109        assert_eq!(r, BooleanArray::from(vec![true, true, true, false]));
1110
1111        try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1112        // cast(cast(lit())), but the cast to the same data type, one case will be ignored
1113        phy_exprs.push(expressions::cast(
1114            expressions::cast(lit(2i32), &schema, DataType::Int64)?,
1115            &schema,
1116            DataType::Int64,
1117        )?);
1118        try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1119
1120        phy_exprs.clear();
1121
1122        // case(cast(lit())), the cast to the diff data type
1123        phy_exprs.push(expressions::cast(
1124            expressions::cast(lit(2i32), &schema, DataType::Int64)?,
1125            &schema,
1126            DataType::Int32,
1127        )?);
1128        try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1129
1130        // column
1131        phy_exprs.push(col("a", &schema)?);
1132        assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err());
1133
1134        Ok(())
1135    }
1136
1137    #[test]
1138    fn in_list_timestamp() -> Result<()> {
1139        let schema = Schema::new(vec![Field::new(
1140            "a",
1141            DataType::Timestamp(TimeUnit::Microsecond, None),
1142            true,
1143        )]);
1144        let a = TimestampMicrosecondArray::from(vec![
1145            Some(1388588401000000000),
1146            Some(1288588501000000000),
1147            None,
1148        ]);
1149        let col_a = col("a", &schema)?;
1150        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1151
1152        let list = vec![
1153            lit(ScalarValue::TimestampMicrosecond(
1154                Some(1388588401000000000),
1155                None,
1156            )),
1157            lit(ScalarValue::TimestampMicrosecond(
1158                Some(1388588401000000001),
1159                None,
1160            )),
1161            lit(ScalarValue::TimestampMicrosecond(
1162                Some(1388588401000000002),
1163                None,
1164            )),
1165        ];
1166
1167        in_list!(
1168            batch,
1169            list.clone(),
1170            &false,
1171            vec![Some(true), Some(false), None],
1172            Arc::clone(&col_a),
1173            &schema
1174        );
1175
1176        in_list!(
1177            batch,
1178            list.clone(),
1179            &true,
1180            vec![Some(false), Some(true), None],
1181            Arc::clone(&col_a),
1182            &schema
1183        );
1184        Ok(())
1185    }
1186
1187    #[test]
1188    fn in_expr_with_multiple_element_in_list() -> Result<()> {
1189        let schema = Schema::new(vec![
1190            Field::new("a", DataType::Float64, true),
1191            Field::new("b", DataType::Float64, true),
1192            Field::new("c", DataType::Float64, true),
1193        ]);
1194        let a = Float64Array::from(vec![
1195            Some(0.0),
1196            Some(1.0),
1197            Some(2.0),
1198            Some(f64::NAN),
1199            Some(-f64::NAN),
1200        ]);
1201        let b = Float64Array::from(vec![
1202            Some(8.0),
1203            Some(1.0),
1204            Some(5.0),
1205            Some(f64::NAN),
1206            Some(3.0),
1207        ]);
1208        let c = Float64Array::from(vec![
1209            Some(6.0),
1210            Some(7.0),
1211            None,
1212            Some(5.0),
1213            Some(-f64::NAN),
1214        ]);
1215        let col_a = col("a", &schema)?;
1216        let col_b = col("b", &schema)?;
1217        let col_c = col("c", &schema)?;
1218        let batch = RecordBatch::try_new(
1219            Arc::new(schema.clone()),
1220            vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1221        )?;
1222
1223        let list = vec![Arc::clone(&col_b), Arc::clone(&col_c)];
1224        in_list!(
1225            batch,
1226            list.clone(),
1227            &false,
1228            vec![Some(false), Some(true), None, Some(true), Some(true)],
1229            Arc::clone(&col_a),
1230            &schema
1231        );
1232
1233        in_list!(
1234            batch,
1235            list,
1236            &true,
1237            vec![Some(true), Some(false), None, Some(false), Some(false)],
1238            Arc::clone(&col_a),
1239            &schema
1240        );
1241
1242        Ok(())
1243    }
1244
1245    macro_rules! test_nullable {
1246        ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{
1247            let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?;
1248            let expr = in_list(cast_expr, cast_list_exprs, &false, $SCHEMA).unwrap();
1249            let result = expr.nullable($SCHEMA)?;
1250            assert_eq!($EXPECTED, result);
1251        }};
1252    }
1253
1254    #[test]
1255    fn in_list_nullable() -> Result<()> {
1256        let schema = Schema::new(vec![
1257            Field::new("c1_nullable", DataType::Int64, true),
1258            Field::new("c2_non_nullable", DataType::Int64, false),
1259        ]);
1260
1261        let c1_nullable = col("c1_nullable", &schema)?;
1262        let c2_non_nullable = col("c2_non_nullable", &schema)?;
1263
1264        // static_filter has no nulls
1265        let list = vec![lit(1_i64), lit(2_i64)];
1266        test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true);
1267        test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false);
1268
1269        // static_filter has nulls
1270        let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)];
1271        test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true);
1272        test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true);
1273
1274        let list = vec![Arc::clone(&c1_nullable)];
1275        test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true);
1276
1277        let list = vec![Arc::clone(&c2_non_nullable)];
1278        test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true);
1279
1280        let list = vec![Arc::clone(&c2_non_nullable), Arc::clone(&c2_non_nullable)];
1281        test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false);
1282
1283        Ok(())
1284    }
1285
1286    #[test]
1287    fn in_list_no_cols() -> Result<()> {
1288        // test logic when the in_list expression doesn't have any columns
1289        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1290        let a = Int32Array::from(vec![Some(1), Some(2), None]);
1291        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1292
1293        let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))];
1294
1295        // 1 IN (1, 6)
1296        let expr = lit(ScalarValue::Int32(Some(1)));
1297        in_list!(
1298            batch,
1299            list.clone(),
1300            &false,
1301            // should have three outputs, as the input batch has three rows
1302            vec![Some(true), Some(true), Some(true)],
1303            expr,
1304            &schema
1305        );
1306
1307        // 2 IN (1, 6)
1308        let expr = lit(ScalarValue::Int32(Some(2)));
1309        in_list!(
1310            batch,
1311            list.clone(),
1312            &false,
1313            // should have three outputs, as the input batch has three rows
1314            vec![Some(false), Some(false), Some(false)],
1315            expr,
1316            &schema
1317        );
1318
1319        // NULL IN (1, 6)
1320        let expr = lit(ScalarValue::Int32(None));
1321        in_list!(
1322            batch,
1323            list.clone(),
1324            &false,
1325            // should have three outputs, as the input batch has three rows
1326            vec![None, None, None],
1327            expr,
1328            &schema
1329        );
1330
1331        Ok(())
1332    }
1333
1334    #[test]
1335    fn in_list_utf8_with_dict_types() -> Result<()> {
1336        fn dict_lit(key_type: DataType, value: &str) -> Arc<dyn PhysicalExpr> {
1337            lit(ScalarValue::Dictionary(
1338                Box::new(key_type),
1339                Box::new(ScalarValue::new_utf8(value.to_string())),
1340            ))
1341        }
1342
1343        fn null_dict_lit(key_type: DataType) -> Arc<dyn PhysicalExpr> {
1344            lit(ScalarValue::Dictionary(
1345                Box::new(key_type),
1346                Box::new(ScalarValue::Utf8(None)),
1347            ))
1348        }
1349
1350        let schema = Schema::new(vec![Field::new(
1351            "a",
1352            DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
1353            true,
1354        )]);
1355        let a: UInt16DictionaryArray =
1356            vec![Some("a"), Some("d"), None].into_iter().collect();
1357        let col_a = col("a", &schema)?;
1358        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1359
1360        // expression: "a in ("a", "b")"
1361        let lists = [
1362            vec![lit("a"), lit("b")],
1363            vec![
1364                dict_lit(DataType::Int8, "a"),
1365                dict_lit(DataType::UInt16, "b"),
1366            ],
1367        ];
1368        for list in lists.iter() {
1369            in_list_raw!(
1370                batch,
1371                list.clone(),
1372                &false,
1373                vec![Some(true), Some(false), None],
1374                Arc::clone(&col_a),
1375                &schema
1376            );
1377        }
1378
1379        // expression: "a not in ("a", "b")"
1380        for list in lists.iter() {
1381            in_list_raw!(
1382                batch,
1383                list.clone(),
1384                &true,
1385                vec![Some(false), Some(true), None],
1386                Arc::clone(&col_a),
1387                &schema
1388            );
1389        }
1390
1391        // expression: "a in ("a", "b", null)"
1392        let lists = [
1393            vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))],
1394            vec![
1395                dict_lit(DataType::Int8, "a"),
1396                dict_lit(DataType::UInt16, "b"),
1397                null_dict_lit(DataType::UInt16),
1398            ],
1399        ];
1400        for list in lists.iter() {
1401            in_list_raw!(
1402                batch,
1403                list.clone(),
1404                &false,
1405                vec![Some(true), None, None],
1406                Arc::clone(&col_a),
1407                &schema
1408            );
1409        }
1410
1411        // expression: "a not in ("a", "b", null)"
1412        for list in lists.iter() {
1413            in_list_raw!(
1414                batch,
1415                list.clone(),
1416                &true,
1417                vec![Some(false), None, None],
1418                Arc::clone(&col_a),
1419                &schema
1420            );
1421        }
1422
1423        Ok(())
1424    }
1425}