datafusion_physical_expr/expressions/
case.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::borrow::Cow;
19use std::hash::Hash;
20use std::{any::Any, sync::Arc};
21
22use crate::expressions::try_cast;
23use crate::PhysicalExpr;
24
25use arrow::array::*;
26use arrow::compute::kernels::zip::zip;
27use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
28use arrow::datatypes::{DataType, Schema};
29use datafusion_common::cast::as_boolean_array;
30use datafusion_common::{
31    exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
32};
33use datafusion_expr::ColumnarValue;
34
35use super::{Column, Literal};
36use datafusion_physical_expr_common::datum::compare_with_eq;
37use itertools::Itertools;
38
39type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
40
41#[derive(Debug, Hash, PartialEq, Eq)]
42enum EvalMethod {
43    /// CASE WHEN condition THEN result
44    ///      [WHEN ...]
45    ///      [ELSE result]
46    /// END
47    NoExpression,
48    /// CASE expression
49    ///     WHEN value THEN result
50    ///     [WHEN ...]
51    ///     [ELSE result]
52    /// END
53    WithExpression,
54    /// This is a specialization for a specific use case where we can take a fast path
55    /// for expressions that are infallible and can be cheaply computed for the entire
56    /// record batch rather than just for the rows where the predicate is true.
57    ///
58    /// CASE WHEN condition THEN column [ELSE NULL] END
59    InfallibleExprOrNull,
60    /// This is a specialization for a specific use case where we can take a fast path
61    /// if there is just one when/then pair and both the `then` and `else` expressions
62    /// are literal values
63    /// CASE WHEN condition THEN literal ELSE literal END
64    ScalarOrScalar,
65    /// This is a specialization for a specific use case where we can take a fast path
66    /// if there is just one when/then pair and both the `then` and `else` are expressions
67    ///
68    /// CASE WHEN condition THEN expression ELSE expression END
69    ExpressionOrExpression,
70}
71
72/// The CASE expression is similar to a series of nested if/else and there are two forms that
73/// can be used. The first form consists of a series of boolean "when" expressions with
74/// corresponding "then" expressions, and an optional "else" expression.
75///
76/// CASE WHEN condition THEN result
77///      [WHEN ...]
78///      [ELSE result]
79/// END
80///
81/// The second form uses a base expression and then a series of "when" clauses that match on a
82/// literal value.
83///
84/// CASE expression
85///     WHEN value THEN result
86///     [WHEN ...]
87///     [ELSE result]
88/// END
89#[derive(Debug, Hash, PartialEq, Eq)]
90pub struct CaseExpr {
91    /// Optional base expression that can be compared to literal values in the "when" expressions
92    expr: Option<Arc<dyn PhysicalExpr>>,
93    /// One or more when/then expressions
94    when_then_expr: Vec<WhenThen>,
95    /// Optional "else" expression
96    else_expr: Option<Arc<dyn PhysicalExpr>>,
97    /// Evaluation method to use
98    eval_method: EvalMethod,
99}
100
101impl std::fmt::Display for CaseExpr {
102    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
103        write!(f, "CASE ")?;
104        if let Some(e) = &self.expr {
105            write!(f, "{e} ")?;
106        }
107        for (w, t) in &self.when_then_expr {
108            write!(f, "WHEN {w} THEN {t} ")?;
109        }
110        if let Some(e) = &self.else_expr {
111            write!(f, "ELSE {e} ")?;
112        }
113        write!(f, "END")
114    }
115}
116
117/// This is a specialization for a specific use case where we can take a fast path
118/// for expressions that are infallible and can be cheaply computed for the entire
119/// record batch rather than just for the rows where the predicate is true. For now,
120/// this is limited to use with Column expressions but could potentially be used for other
121/// expressions in the future
122fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
123    expr.as_any().is::<Column>()
124}
125
126impl CaseExpr {
127    /// Create a new CASE WHEN expression
128    pub fn try_new(
129        expr: Option<Arc<dyn PhysicalExpr>>,
130        when_then_expr: Vec<WhenThen>,
131        else_expr: Option<Arc<dyn PhysicalExpr>>,
132    ) -> Result<Self> {
133        // normalize null literals to None in the else_expr (this already happens
134        // during SQL planning, but not necessarily for other use cases)
135        let else_expr = match &else_expr {
136            Some(e) => match e.as_any().downcast_ref::<Literal>() {
137                Some(lit) if lit.value().is_null() => None,
138                _ => else_expr,
139            },
140            _ => else_expr,
141        };
142
143        if when_then_expr.is_empty() {
144            exec_err!("There must be at least one WHEN clause")
145        } else {
146            let eval_method = if expr.is_some() {
147                EvalMethod::WithExpression
148            } else if when_then_expr.len() == 1
149                && is_cheap_and_infallible(&(when_then_expr[0].1))
150                && else_expr.is_none()
151            {
152                EvalMethod::InfallibleExprOrNull
153            } else if when_then_expr.len() == 1
154                && when_then_expr[0].1.as_any().is::<Literal>()
155                && else_expr.is_some()
156                && else_expr.as_ref().unwrap().as_any().is::<Literal>()
157            {
158                EvalMethod::ScalarOrScalar
159            } else if when_then_expr.len() == 1 && else_expr.is_some() {
160                EvalMethod::ExpressionOrExpression
161            } else {
162                EvalMethod::NoExpression
163            };
164
165            Ok(Self {
166                expr,
167                when_then_expr,
168                else_expr,
169                eval_method,
170            })
171        }
172    }
173
174    /// Optional base expression that can be compared to literal values in the "when" expressions
175    pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
176        self.expr.as_ref()
177    }
178
179    /// One or more when/then expressions
180    pub fn when_then_expr(&self) -> &[WhenThen] {
181        &self.when_then_expr
182    }
183
184    /// Optional "else" expression
185    pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
186        self.else_expr.as_ref()
187    }
188}
189
190impl CaseExpr {
191    /// This function evaluates the form of CASE that matches an expression to fixed values.
192    ///
193    /// CASE expression
194    ///     WHEN value THEN result
195    ///     [WHEN ...]
196    ///     [ELSE result]
197    /// END
198    fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
199        let return_type = self.data_type(&batch.schema())?;
200        let expr = self.expr.as_ref().unwrap();
201        let base_value = expr.evaluate(batch)?;
202        let base_value = base_value.into_array(batch.num_rows())?;
203        let base_nulls = is_null(base_value.as_ref())?;
204
205        // start with nulls as default output
206        let mut current_value = new_null_array(&return_type, batch.num_rows());
207        // We only consider non-null values while comparing with whens
208        let mut remainder = not(&base_nulls)?;
209        for i in 0..self.when_then_expr.len() {
210            let when_value = self.when_then_expr[i]
211                .0
212                .evaluate_selection(batch, &remainder)?;
213            let when_value = when_value.into_array(batch.num_rows())?;
214            // build boolean array representing which rows match the "when" value
215            let when_match = compare_with_eq(
216                &when_value,
217                &base_value,
218                // The types of case and when expressions will be coerced to match.
219                // We only need to check if the base_value is nested.
220                base_value.data_type().is_nested(),
221            )?;
222            // Treat nulls as false
223            let when_match = match when_match.null_count() {
224                0 => Cow::Borrowed(&when_match),
225                _ => Cow::Owned(prep_null_mask_filter(&when_match)),
226            };
227            // Make sure we only consider rows that have not been matched yet
228            let when_match = and(&when_match, &remainder)?;
229
230            // When no rows available for when clause, skip then clause
231            if when_match.true_count() == 0 {
232                continue;
233            }
234
235            let then_value = self.when_then_expr[i]
236                .1
237                .evaluate_selection(batch, &when_match)?;
238
239            current_value = match then_value {
240                ColumnarValue::Scalar(ScalarValue::Null) => {
241                    nullif(current_value.as_ref(), &when_match)?
242                }
243                ColumnarValue::Scalar(then_value) => {
244                    zip(&when_match, &then_value.to_scalar()?, &current_value)?
245                }
246                ColumnarValue::Array(then_value) => {
247                    zip(&when_match, &then_value, &current_value)?
248                }
249            };
250
251            remainder = and_not(&remainder, &when_match)?;
252        }
253
254        if let Some(e) = self.else_expr() {
255            // keep `else_expr`'s data type and return type consistent
256            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
257            // null and unmatched tuples should be assigned else value
258            remainder = or(&base_nulls, &remainder)?;
259            let else_ = expr
260                .evaluate_selection(batch, &remainder)?
261                .into_array(batch.num_rows())?;
262            current_value = zip(&remainder, &else_, &current_value)?;
263        }
264
265        Ok(ColumnarValue::Array(current_value))
266    }
267
268    /// This function evaluates the form of CASE where each WHEN expression is a boolean
269    /// expression.
270    ///
271    /// CASE WHEN condition THEN result
272    ///      [WHEN ...]
273    ///      [ELSE result]
274    /// END
275    fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
276        let return_type = self.data_type(&batch.schema())?;
277
278        // start with nulls as default output
279        let mut current_value = new_null_array(&return_type, batch.num_rows());
280        let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
281        for i in 0..self.when_then_expr.len() {
282            let when_value = self.when_then_expr[i]
283                .0
284                .evaluate_selection(batch, &remainder)?;
285            let when_value = when_value.into_array(batch.num_rows())?;
286            let when_value = as_boolean_array(&when_value).map_err(|_| {
287                internal_datafusion_err!("WHEN expression did not return a BooleanArray")
288            })?;
289            // Treat 'NULL' as false value
290            let when_value = match when_value.null_count() {
291                0 => Cow::Borrowed(when_value),
292                _ => Cow::Owned(prep_null_mask_filter(when_value)),
293            };
294            // Make sure we only consider rows that have not been matched yet
295            let when_value = and(&when_value, &remainder)?;
296
297            // When no rows available for when clause, skip then clause
298            if when_value.true_count() == 0 {
299                continue;
300            }
301
302            let then_value = self.when_then_expr[i]
303                .1
304                .evaluate_selection(batch, &when_value)?;
305
306            current_value = match then_value {
307                ColumnarValue::Scalar(ScalarValue::Null) => {
308                    nullif(current_value.as_ref(), &when_value)?
309                }
310                ColumnarValue::Scalar(then_value) => {
311                    zip(&when_value, &then_value.to_scalar()?, &current_value)?
312                }
313                ColumnarValue::Array(then_value) => {
314                    zip(&when_value, &then_value, &current_value)?
315                }
316            };
317
318            // Succeed tuples should be filtered out for short-circuit evaluation,
319            // null values for the current when expr should be kept
320            remainder = and_not(&remainder, &when_value)?;
321        }
322
323        if let Some(e) = self.else_expr() {
324            // keep `else_expr`'s data type and return type consistent
325            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
326            let else_ = expr
327                .evaluate_selection(batch, &remainder)?
328                .into_array(batch.num_rows())?;
329            current_value = zip(&remainder, &else_, &current_value)?;
330        }
331
332        Ok(ColumnarValue::Array(current_value))
333    }
334
335    /// This function evaluates the specialized case of:
336    ///
337    /// CASE WHEN condition THEN column
338    ///      [ELSE NULL]
339    /// END
340    ///
341    /// Note that this function is only safe to use for "then" expressions
342    /// that are infallible because the expression will be evaluated for all
343    /// rows in the input batch.
344    fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
345        let when_expr = &self.when_then_expr[0].0;
346        let then_expr = &self.when_then_expr[0].1;
347
348        match when_expr.evaluate(batch)? {
349            // WHEN true --> column
350            ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
351                then_expr.evaluate(batch)
352            }
353            // WHEN [false | null] --> NULL
354            ColumnarValue::Scalar(_) => {
355                // return scalar NULL value
356                ScalarValue::try_from(self.data_type(&batch.schema())?)
357                    .map(ColumnarValue::Scalar)
358            }
359            // WHEN column --> column
360            ColumnarValue::Array(bit_mask) => {
361                let bit_mask = bit_mask
362                    .as_any()
363                    .downcast_ref::<BooleanArray>()
364                    .expect("predicate should evaluate to a boolean array");
365                // invert the bitmask
366                let bit_mask = match bit_mask.null_count() {
367                    0 => not(bit_mask)?,
368                    _ => not(&prep_null_mask_filter(bit_mask))?,
369                };
370                match then_expr.evaluate(batch)? {
371                    ColumnarValue::Array(array) => {
372                        Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
373                    }
374                    ColumnarValue::Scalar(_) => {
375                        internal_err!("expression did not evaluate to an array")
376                    }
377                }
378            }
379        }
380    }
381
382    fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
383        let return_type = self.data_type(&batch.schema())?;
384
385        // evaluate when expression
386        let when_value = self.when_then_expr[0].0.evaluate(batch)?;
387        let when_value = when_value.into_array(batch.num_rows())?;
388        let when_value = as_boolean_array(&when_value).map_err(|_| {
389            internal_datafusion_err!("WHEN expression did not return a BooleanArray")
390        })?;
391
392        // Treat 'NULL' as false value
393        let when_value = match when_value.null_count() {
394            0 => Cow::Borrowed(when_value),
395            _ => Cow::Owned(prep_null_mask_filter(when_value)),
396        };
397
398        // evaluate then_value
399        let then_value = self.when_then_expr[0].1.evaluate(batch)?;
400        let then_value = Scalar::new(then_value.into_array(1)?);
401
402        let Some(e) = self.else_expr() else {
403            return internal_err!("expression did not evaluate to an array");
404        };
405        // keep `else_expr`'s data type and return type consistent
406        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
407        let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
408        Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
409    }
410
411    fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
412        let return_type = self.data_type(&batch.schema())?;
413
414        // evalute when condition on batch
415        let when_value = self.when_then_expr[0].0.evaluate(batch)?;
416        let when_value = when_value.into_array(batch.num_rows())?;
417        let when_value = as_boolean_array(&when_value).map_err(|e| {
418            DataFusionError::Context(
419                "WHEN expression did not return a BooleanArray".to_string(),
420                Box::new(e),
421            )
422        })?;
423
424        // Treat 'NULL' as false value
425        let when_value = match when_value.null_count() {
426            0 => Cow::Borrowed(when_value),
427            _ => Cow::Owned(prep_null_mask_filter(when_value)),
428        };
429
430        let then_value = self.when_then_expr[0]
431            .1
432            .evaluate_selection(batch, &when_value)?
433            .into_array(batch.num_rows())?;
434
435        // evaluate else expression on the values not covered by when_value
436        let remainder = not(&when_value)?;
437        let e = self.else_expr.as_ref().unwrap();
438        // keep `else_expr`'s data type and return type consistent
439        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
440            .unwrap_or_else(|_| Arc::clone(e));
441        let else_ = expr
442            .evaluate_selection(batch, &remainder)?
443            .into_array(batch.num_rows())?;
444
445        Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
446    }
447}
448
449impl PhysicalExpr for CaseExpr {
450    /// Return a reference to Any that can be used for down-casting
451    fn as_any(&self) -> &dyn Any {
452        self
453    }
454
455    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
456        // since all then results have the same data type, we can choose any one as the
457        // return data type except for the null.
458        let mut data_type = DataType::Null;
459        for i in 0..self.when_then_expr.len() {
460            data_type = self.when_then_expr[i].1.data_type(input_schema)?;
461            if !data_type.equals_datatype(&DataType::Null) {
462                break;
463            }
464        }
465        // if all then results are null, we use data type of else expr instead if possible.
466        if data_type.equals_datatype(&DataType::Null) {
467            if let Some(e) = &self.else_expr {
468                data_type = e.data_type(input_schema)?;
469            }
470        }
471
472        Ok(data_type)
473    }
474
475    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
476        // this expression is nullable if any of the input expressions are nullable
477        let then_nullable = self
478            .when_then_expr
479            .iter()
480            .map(|(_, t)| t.nullable(input_schema))
481            .collect::<Result<Vec<_>>>()?;
482        if then_nullable.contains(&true) {
483            Ok(true)
484        } else if let Some(e) = &self.else_expr {
485            e.nullable(input_schema)
486        } else {
487            // CASE produces NULL if there is no `else` expr
488            // (aka when none of the `when_then_exprs` match)
489            Ok(true)
490        }
491    }
492
493    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
494        match self.eval_method {
495            EvalMethod::WithExpression => {
496                // this use case evaluates "expr" and then compares the values with the "when"
497                // values
498                self.case_when_with_expr(batch)
499            }
500            EvalMethod::NoExpression => {
501                // The "when" conditions all evaluate to boolean in this use case and can be
502                // arbitrary expressions
503                self.case_when_no_expr(batch)
504            }
505            EvalMethod::InfallibleExprOrNull => {
506                // Specialization for CASE WHEN expr THEN column [ELSE NULL] END
507                self.case_column_or_null(batch)
508            }
509            EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
510            EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
511        }
512    }
513
514    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
515        let mut children = vec![];
516        if let Some(expr) = &self.expr {
517            children.push(expr)
518        }
519        self.when_then_expr.iter().for_each(|(cond, value)| {
520            children.push(cond);
521            children.push(value);
522        });
523
524        if let Some(else_expr) = &self.else_expr {
525            children.push(else_expr)
526        }
527        children
528    }
529
530    // For physical CaseExpr, we do not allow modifying children size
531    fn with_new_children(
532        self: Arc<Self>,
533        children: Vec<Arc<dyn PhysicalExpr>>,
534    ) -> Result<Arc<dyn PhysicalExpr>> {
535        if children.len() != self.children().len() {
536            internal_err!("CaseExpr: Wrong number of children")
537        } else {
538            let (expr, when_then_expr, else_expr) =
539                match (self.expr().is_some(), self.else_expr().is_some()) {
540                    (true, true) => (
541                        Some(&children[0]),
542                        &children[1..children.len() - 1],
543                        Some(&children[children.len() - 1]),
544                    ),
545                    (true, false) => {
546                        (Some(&children[0]), &children[1..children.len()], None)
547                    }
548                    (false, true) => (
549                        None,
550                        &children[0..children.len() - 1],
551                        Some(&children[children.len() - 1]),
552                    ),
553                    (false, false) => (None, &children[0..children.len()], None),
554                };
555            Ok(Arc::new(CaseExpr::try_new(
556                expr.cloned(),
557                when_then_expr.iter().cloned().tuples().collect(),
558                else_expr.cloned(),
559            )?))
560        }
561    }
562}
563
564/// Create a CASE expression
565pub fn case(
566    expr: Option<Arc<dyn PhysicalExpr>>,
567    when_thens: Vec<WhenThen>,
568    else_expr: Option<Arc<dyn PhysicalExpr>>,
569) -> Result<Arc<dyn PhysicalExpr>> {
570    Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576
577    use crate::expressions::{binary, cast, col, lit, BinaryExpr};
578    use arrow::buffer::Buffer;
579    use arrow::datatypes::DataType::Float64;
580    use arrow::datatypes::*;
581    use datafusion_common::cast::{as_float64_array, as_int32_array};
582    use datafusion_common::plan_err;
583    use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
584    use datafusion_expr::type_coercion::binary::comparison_coercion;
585    use datafusion_expr::Operator;
586
587    #[test]
588    fn case_with_expr() -> Result<()> {
589        let batch = case_test_batch()?;
590        let schema = batch.schema();
591
592        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
593        let when1 = lit("foo");
594        let then1 = lit(123i32);
595        let when2 = lit("bar");
596        let then2 = lit(456i32);
597
598        let expr = generate_case_when_with_type_coercion(
599            Some(col("a", &schema)?),
600            vec![(when1, then1), (when2, then2)],
601            None,
602            schema.as_ref(),
603        )?;
604        let result = expr
605            .evaluate(&batch)?
606            .into_array(batch.num_rows())
607            .expect("Failed to convert to array");
608        let result = as_int32_array(&result)?;
609
610        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
611
612        assert_eq!(expected, result);
613
614        Ok(())
615    }
616
617    #[test]
618    fn case_with_expr_else() -> Result<()> {
619        let batch = case_test_batch()?;
620        let schema = batch.schema();
621
622        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
623        let when1 = lit("foo");
624        let then1 = lit(123i32);
625        let when2 = lit("bar");
626        let then2 = lit(456i32);
627        let else_value = lit(999i32);
628
629        let expr = generate_case_when_with_type_coercion(
630            Some(col("a", &schema)?),
631            vec![(when1, then1), (when2, then2)],
632            Some(else_value),
633            schema.as_ref(),
634        )?;
635        let result = expr
636            .evaluate(&batch)?
637            .into_array(batch.num_rows())
638            .expect("Failed to convert to array");
639        let result = as_int32_array(&result)?;
640
641        let expected =
642            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
643
644        assert_eq!(expected, result);
645
646        Ok(())
647    }
648
649    #[test]
650    fn case_with_expr_divide_by_zero() -> Result<()> {
651        let batch = case_test_batch1()?;
652        let schema = batch.schema();
653
654        // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64)  END
655        let when1 = lit(0i32);
656        let then1 = lit(ScalarValue::Float64(None));
657        let else_value = binary(
658            lit(25.0f64),
659            Operator::Divide,
660            cast(col("a", &schema)?, &batch.schema(), Float64)?,
661            &batch.schema(),
662        )?;
663
664        let expr = generate_case_when_with_type_coercion(
665            Some(col("a", &schema)?),
666            vec![(when1, then1)],
667            Some(else_value),
668            schema.as_ref(),
669        )?;
670        let result = expr
671            .evaluate(&batch)?
672            .into_array(batch.num_rows())
673            .expect("Failed to convert to array");
674        let result =
675            as_float64_array(&result).expect("failed to downcast to Float64Array");
676
677        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
678
679        assert_eq!(expected, result);
680
681        Ok(())
682    }
683
684    #[test]
685    fn case_without_expr() -> Result<()> {
686        let batch = case_test_batch()?;
687        let schema = batch.schema();
688
689        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
690        let when1 = binary(
691            col("a", &schema)?,
692            Operator::Eq,
693            lit("foo"),
694            &batch.schema(),
695        )?;
696        let then1 = lit(123i32);
697        let when2 = binary(
698            col("a", &schema)?,
699            Operator::Eq,
700            lit("bar"),
701            &batch.schema(),
702        )?;
703        let then2 = lit(456i32);
704
705        let expr = generate_case_when_with_type_coercion(
706            None,
707            vec![(when1, then1), (when2, then2)],
708            None,
709            schema.as_ref(),
710        )?;
711        let result = expr
712            .evaluate(&batch)?
713            .into_array(batch.num_rows())
714            .expect("Failed to convert to array");
715        let result = as_int32_array(&result)?;
716
717        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
718
719        assert_eq!(expected, result);
720
721        Ok(())
722    }
723
724    #[test]
725    fn case_with_expr_when_null() -> Result<()> {
726        let batch = case_test_batch()?;
727        let schema = batch.schema();
728
729        // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END
730        let when1 = lit(ScalarValue::Utf8(None));
731        let then1 = lit(0i32);
732        let when2 = col("a", &schema)?;
733        let then2 = lit(123i32);
734        let else_value = lit(999i32);
735
736        let expr = generate_case_when_with_type_coercion(
737            Some(col("a", &schema)?),
738            vec![(when1, then1), (when2, then2)],
739            Some(else_value),
740            schema.as_ref(),
741        )?;
742        let result = expr
743            .evaluate(&batch)?
744            .into_array(batch.num_rows())
745            .expect("Failed to convert to array");
746        let result = as_int32_array(&result)?;
747
748        let expected =
749            &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
750
751        assert_eq!(expected, result);
752
753        Ok(())
754    }
755
756    #[test]
757    fn case_without_expr_divide_by_zero() -> Result<()> {
758        let batch = case_test_batch1()?;
759        let schema = batch.schema();
760
761        // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
762        let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
763        let then1 = binary(
764            lit(25.0f64),
765            Operator::Divide,
766            cast(col("a", &schema)?, &batch.schema(), Float64)?,
767            &batch.schema(),
768        )?;
769        let x = lit(ScalarValue::Float64(None));
770
771        let expr = generate_case_when_with_type_coercion(
772            None,
773            vec![(when1, then1)],
774            Some(x),
775            schema.as_ref(),
776        )?;
777        let result = expr
778            .evaluate(&batch)?
779            .into_array(batch.num_rows())
780            .expect("Failed to convert to array");
781        let result =
782            as_float64_array(&result).expect("failed to downcast to Float64Array");
783
784        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
785
786        assert_eq!(expected, result);
787
788        Ok(())
789    }
790
791    fn case_test_batch1() -> Result<RecordBatch> {
792        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
793        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
794        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
795        Ok(batch)
796    }
797
798    #[test]
799    fn case_without_expr_else() -> Result<()> {
800        let batch = case_test_batch()?;
801        let schema = batch.schema();
802
803        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
804        let when1 = binary(
805            col("a", &schema)?,
806            Operator::Eq,
807            lit("foo"),
808            &batch.schema(),
809        )?;
810        let then1 = lit(123i32);
811        let when2 = binary(
812            col("a", &schema)?,
813            Operator::Eq,
814            lit("bar"),
815            &batch.schema(),
816        )?;
817        let then2 = lit(456i32);
818        let else_value = lit(999i32);
819
820        let expr = generate_case_when_with_type_coercion(
821            None,
822            vec![(when1, then1), (when2, then2)],
823            Some(else_value),
824            schema.as_ref(),
825        )?;
826        let result = expr
827            .evaluate(&batch)?
828            .into_array(batch.num_rows())
829            .expect("Failed to convert to array");
830        let result = as_int32_array(&result)?;
831
832        let expected =
833            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
834
835        assert_eq!(expected, result);
836
837        Ok(())
838    }
839
840    #[test]
841    fn case_with_type_cast() -> Result<()> {
842        let batch = case_test_batch()?;
843        let schema = batch.schema();
844
845        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
846        let when = binary(
847            col("a", &schema)?,
848            Operator::Eq,
849            lit("foo"),
850            &batch.schema(),
851        )?;
852        let then = lit(123.3f64);
853        let else_value = lit(999i32);
854
855        let expr = generate_case_when_with_type_coercion(
856            None,
857            vec![(when, then)],
858            Some(else_value),
859            schema.as_ref(),
860        )?;
861        let result = expr
862            .evaluate(&batch)?
863            .into_array(batch.num_rows())
864            .expect("Failed to convert to array");
865        let result =
866            as_float64_array(&result).expect("failed to downcast to Float64Array");
867
868        let expected =
869            &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
870
871        assert_eq!(expected, result);
872
873        Ok(())
874    }
875
876    #[test]
877    fn case_with_matches_and_nulls() -> Result<()> {
878        let batch = case_test_batch_nulls()?;
879        let schema = batch.schema();
880
881        // SELECT CASE WHEN load4 = 1.77 THEN load4 END
882        let when = binary(
883            col("load4", &schema)?,
884            Operator::Eq,
885            lit(1.77f64),
886            &batch.schema(),
887        )?;
888        let then = col("load4", &schema)?;
889
890        let expr = generate_case_when_with_type_coercion(
891            None,
892            vec![(when, then)],
893            None,
894            schema.as_ref(),
895        )?;
896        let result = expr
897            .evaluate(&batch)?
898            .into_array(batch.num_rows())
899            .expect("Failed to convert to array");
900        let result =
901            as_float64_array(&result).expect("failed to downcast to Float64Array");
902
903        let expected =
904            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
905
906        assert_eq!(expected, result);
907
908        Ok(())
909    }
910
911    #[test]
912    fn case_with_scalar_predicate() -> Result<()> {
913        let batch = case_test_batch_nulls()?;
914        let schema = batch.schema();
915
916        // SELECT CASE WHEN TRUE THEN load4 END
917        let when = lit(true);
918        let then = col("load4", &schema)?;
919        let expr = generate_case_when_with_type_coercion(
920            None,
921            vec![(when, then)],
922            None,
923            schema.as_ref(),
924        )?;
925
926        // many rows
927        let result = expr
928            .evaluate(&batch)?
929            .into_array(batch.num_rows())
930            .expect("Failed to convert to array");
931        let result =
932            as_float64_array(&result).expect("failed to downcast to Float64Array");
933        let expected = &Float64Array::from(vec![
934            Some(1.77),
935            None,
936            None,
937            Some(1.78),
938            None,
939            Some(1.77),
940        ]);
941        assert_eq!(expected, result);
942
943        // one row
944        let expected = Float64Array::from(vec![Some(1.1)]);
945        let batch =
946            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
947        let result = expr
948            .evaluate(&batch)?
949            .into_array(batch.num_rows())
950            .expect("Failed to convert to array");
951        let result =
952            as_float64_array(&result).expect("failed to downcast to Float64Array");
953        assert_eq!(&expected, result);
954
955        Ok(())
956    }
957
958    #[test]
959    fn case_expr_matches_and_nulls() -> Result<()> {
960        let batch = case_test_batch_nulls()?;
961        let schema = batch.schema();
962
963        // SELECT CASE load4 WHEN 1.77 THEN load4 END
964        let expr = col("load4", &schema)?;
965        let when = lit(1.77f64);
966        let then = col("load4", &schema)?;
967
968        let expr = generate_case_when_with_type_coercion(
969            Some(expr),
970            vec![(when, then)],
971            None,
972            schema.as_ref(),
973        )?;
974        let result = expr
975            .evaluate(&batch)?
976            .into_array(batch.num_rows())
977            .expect("Failed to convert to array");
978        let result =
979            as_float64_array(&result).expect("failed to downcast to Float64Array");
980
981        let expected =
982            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
983
984        assert_eq!(expected, result);
985
986        Ok(())
987    }
988
989    #[test]
990    fn test_when_null_and_some_cond_else_null() -> Result<()> {
991        let batch = case_test_batch()?;
992        let schema = batch.schema();
993
994        let when = binary(
995            Arc::new(Literal::new(ScalarValue::Boolean(None))),
996            Operator::And,
997            binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
998            &schema,
999        )?;
1000        let then = col("a", &schema)?;
1001
1002        // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END
1003        let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
1004        let result = expr
1005            .evaluate(&batch)?
1006            .into_array(batch.num_rows())
1007            .expect("Failed to convert to array");
1008        let result = as_string_array(&result);
1009
1010        // all result values should be null
1011        assert_eq!(result.logical_null_count(), batch.num_rows());
1012        Ok(())
1013    }
1014
1015    fn case_test_batch() -> Result<RecordBatch> {
1016        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1017        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1018        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
1019        Ok(batch)
1020    }
1021
1022    // Construct an array that has several NULL values whose
1023    // underlying buffer actually matches the where expr predicate
1024    fn case_test_batch_nulls() -> Result<RecordBatch> {
1025        let load4: Float64Array = vec![
1026            Some(1.77), // 1.77
1027            Some(1.77), // null <-- same value, but will be set to null
1028            Some(1.77), // null <-- same value, but will be set to null
1029            Some(1.78), // 1.78
1030            None,       // null
1031            Some(1.77), // 1.77
1032        ]
1033        .into_iter()
1034        .collect();
1035
1036        //let valid_array = vec![true, false, false, true, false, tru
1037        let null_buffer = Buffer::from([0b00101001u8]);
1038        let load4 = load4
1039            .into_data()
1040            .into_builder()
1041            .null_bit_buffer(Some(null_buffer))
1042            .build()
1043            .unwrap();
1044        let load4: Float64Array = load4.into();
1045
1046        let batch =
1047            RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
1048        Ok(batch)
1049    }
1050
1051    #[test]
1052    fn case_test_incompatible() -> Result<()> {
1053        // 1 then is int64
1054        // 2 then is boolean
1055        let batch = case_test_batch()?;
1056        let schema = batch.schema();
1057
1058        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
1059        let when1 = binary(
1060            col("a", &schema)?,
1061            Operator::Eq,
1062            lit("foo"),
1063            &batch.schema(),
1064        )?;
1065        let then1 = lit(123i32);
1066        let when2 = binary(
1067            col("a", &schema)?,
1068            Operator::Eq,
1069            lit("bar"),
1070            &batch.schema(),
1071        )?;
1072        let then2 = lit(true);
1073
1074        let expr = generate_case_when_with_type_coercion(
1075            None,
1076            vec![(when1, then1), (when2, then2)],
1077            None,
1078            schema.as_ref(),
1079        );
1080        assert!(expr.is_err());
1081
1082        // then 1 is int32
1083        // then 2 is int64
1084        // else is float
1085        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
1086        let when1 = binary(
1087            col("a", &schema)?,
1088            Operator::Eq,
1089            lit("foo"),
1090            &batch.schema(),
1091        )?;
1092        let then1 = lit(123i32);
1093        let when2 = binary(
1094            col("a", &schema)?,
1095            Operator::Eq,
1096            lit("bar"),
1097            &batch.schema(),
1098        )?;
1099        let then2 = lit(456i64);
1100        let else_expr = lit(1.23f64);
1101
1102        let expr = generate_case_when_with_type_coercion(
1103            None,
1104            vec![(when1, then1), (when2, then2)],
1105            Some(else_expr),
1106            schema.as_ref(),
1107        );
1108        assert!(expr.is_ok());
1109        let result_type = expr.unwrap().data_type(schema.as_ref())?;
1110        assert_eq!(Float64, result_type);
1111        Ok(())
1112    }
1113
1114    #[test]
1115    fn case_eq() -> Result<()> {
1116        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1117
1118        let when1 = lit("foo");
1119        let then1 = lit(123i32);
1120        let when2 = lit("bar");
1121        let then2 = lit(456i32);
1122        let else_value = lit(999i32);
1123
1124        let expr1 = generate_case_when_with_type_coercion(
1125            Some(col("a", &schema)?),
1126            vec![
1127                (Arc::clone(&when1), Arc::clone(&then1)),
1128                (Arc::clone(&when2), Arc::clone(&then2)),
1129            ],
1130            Some(Arc::clone(&else_value)),
1131            &schema,
1132        )?;
1133
1134        let expr2 = generate_case_when_with_type_coercion(
1135            Some(col("a", &schema)?),
1136            vec![
1137                (Arc::clone(&when1), Arc::clone(&then1)),
1138                (Arc::clone(&when2), Arc::clone(&then2)),
1139            ],
1140            Some(Arc::clone(&else_value)),
1141            &schema,
1142        )?;
1143
1144        let expr3 = generate_case_when_with_type_coercion(
1145            Some(col("a", &schema)?),
1146            vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
1147            None,
1148            &schema,
1149        )?;
1150
1151        let expr4 = generate_case_when_with_type_coercion(
1152            Some(col("a", &schema)?),
1153            vec![(when1, then1)],
1154            Some(else_value),
1155            &schema,
1156        )?;
1157
1158        assert!(expr1.eq(&expr2));
1159        assert!(expr2.eq(&expr1));
1160
1161        assert!(expr2.ne(&expr3));
1162        assert!(expr3.ne(&expr2));
1163
1164        assert!(expr1.ne(&expr4));
1165        assert!(expr4.ne(&expr1));
1166
1167        Ok(())
1168    }
1169
1170    #[test]
1171    fn case_transform() -> Result<()> {
1172        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1173
1174        let when1 = lit("foo");
1175        let then1 = lit(123i32);
1176        let when2 = lit("bar");
1177        let then2 = lit(456i32);
1178        let else_value = lit(999i32);
1179
1180        let expr = generate_case_when_with_type_coercion(
1181            Some(col("a", &schema)?),
1182            vec![
1183                (Arc::clone(&when1), Arc::clone(&then1)),
1184                (Arc::clone(&when2), Arc::clone(&then2)),
1185            ],
1186            Some(Arc::clone(&else_value)),
1187            &schema,
1188        )?;
1189
1190        let expr2 = Arc::clone(&expr)
1191            .transform(|e| {
1192                let transformed = match e.as_any().downcast_ref::<Literal>() {
1193                    Some(lit_value) => match lit_value.value() {
1194                        ScalarValue::Utf8(Some(str_value)) => {
1195                            Some(lit(str_value.to_uppercase()))
1196                        }
1197                        _ => None,
1198                    },
1199                    _ => None,
1200                };
1201                Ok(if let Some(transformed) = transformed {
1202                    Transformed::yes(transformed)
1203                } else {
1204                    Transformed::no(e)
1205                })
1206            })
1207            .data()
1208            .unwrap();
1209
1210        let expr3 = Arc::clone(&expr)
1211            .transform_down(|e| {
1212                let transformed = match e.as_any().downcast_ref::<Literal>() {
1213                    Some(lit_value) => match lit_value.value() {
1214                        ScalarValue::Utf8(Some(str_value)) => {
1215                            Some(lit(str_value.to_uppercase()))
1216                        }
1217                        _ => None,
1218                    },
1219                    _ => None,
1220                };
1221                Ok(if let Some(transformed) = transformed {
1222                    Transformed::yes(transformed)
1223                } else {
1224                    Transformed::no(e)
1225                })
1226            })
1227            .data()
1228            .unwrap();
1229
1230        assert!(expr.ne(&expr2));
1231        assert!(expr2.eq(&expr3));
1232
1233        Ok(())
1234    }
1235
1236    #[test]
1237    fn test_column_or_null_specialization() -> Result<()> {
1238        // create input data
1239        let mut c1 = Int32Builder::new();
1240        let mut c2 = StringBuilder::new();
1241        for i in 0..1000 {
1242            c1.append_value(i);
1243            if i % 7 == 0 {
1244                c2.append_null();
1245            } else {
1246                c2.append_value(format!("string {i}"));
1247            }
1248        }
1249        let c1 = Arc::new(c1.finish());
1250        let c2 = Arc::new(c2.finish());
1251        let schema = Schema::new(vec![
1252            Field::new("c1", DataType::Int32, true),
1253            Field::new("c2", DataType::Utf8, true),
1254        ]);
1255        let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
1256
1257        // CaseWhenExprOrNull should produce same results as CaseExpr
1258        let predicate = Arc::new(BinaryExpr::new(
1259            make_col("c1", 0),
1260            Operator::LtEq,
1261            make_lit_i32(250),
1262        ));
1263        let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
1264        assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
1265        match expr.evaluate(&batch)? {
1266            ColumnarValue::Array(array) => {
1267                assert_eq!(1000, array.len());
1268                assert_eq!(785, array.null_count());
1269            }
1270            _ => unreachable!(),
1271        }
1272        Ok(())
1273    }
1274
1275    #[test]
1276    fn test_expr_or_expr_specialization() -> Result<()> {
1277        let batch = case_test_batch1()?;
1278        let schema = batch.schema();
1279        let when = binary(
1280            col("a", &schema)?,
1281            Operator::LtEq,
1282            lit(2i32),
1283            &batch.schema(),
1284        )?;
1285        let then = binary(
1286            col("a", &schema)?,
1287            Operator::Plus,
1288            lit(1i32),
1289            &batch.schema(),
1290        )?;
1291        let else_expr = binary(
1292            col("a", &schema)?,
1293            Operator::Minus,
1294            lit(1i32),
1295            &batch.schema(),
1296        )?;
1297        let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
1298        assert!(matches!(
1299            expr.eval_method,
1300            EvalMethod::ExpressionOrExpression
1301        ));
1302        let result = expr
1303            .evaluate(&batch)?
1304            .into_array(batch.num_rows())
1305            .expect("Failed to convert to array");
1306        let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
1307
1308        let expected = &Int32Array::from(vec![Some(2), Some(1), None, Some(4)]);
1309
1310        assert_eq!(expected, result);
1311        Ok(())
1312    }
1313
1314    fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
1315        Arc::new(Column::new(name, index))
1316    }
1317
1318    fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
1319        Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
1320    }
1321
1322    fn generate_case_when_with_type_coercion(
1323        expr: Option<Arc<dyn PhysicalExpr>>,
1324        when_thens: Vec<WhenThen>,
1325        else_expr: Option<Arc<dyn PhysicalExpr>>,
1326        input_schema: &Schema,
1327    ) -> Result<Arc<dyn PhysicalExpr>> {
1328        let coerce_type =
1329            get_case_common_type(&when_thens, else_expr.clone(), input_schema);
1330        let (when_thens, else_expr) = match coerce_type {
1331            None => plan_err!(
1332                "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
1333            ),
1334            Some(data_type) => {
1335                // cast then expr
1336                let left = when_thens
1337                    .into_iter()
1338                    .map(|(when, then)| {
1339                        let then = try_cast(then, input_schema, data_type.clone())?;
1340                        Ok((when, then))
1341                    })
1342                    .collect::<Result<Vec<_>>>()?;
1343                let right = match else_expr {
1344                    None => None,
1345                    Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
1346                };
1347
1348                Ok((left, right))
1349            }
1350        }?;
1351        case(expr, when_thens, else_expr)
1352    }
1353
1354    fn get_case_common_type(
1355        when_thens: &[WhenThen],
1356        else_expr: Option<Arc<dyn PhysicalExpr>>,
1357        input_schema: &Schema,
1358    ) -> Option<DataType> {
1359        let thens_type = when_thens
1360            .iter()
1361            .map(|when_then| {
1362                let data_type = &when_then.1.data_type(input_schema).unwrap();
1363                data_type.clone()
1364            })
1365            .collect::<Vec<_>>();
1366        let else_type = match else_expr {
1367            None => {
1368                // case when then exprs must have one then value
1369                thens_type[0].clone()
1370            }
1371            Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
1372        };
1373        thens_type
1374            .iter()
1375            .try_fold(else_type, |left_type, right_type| {
1376                // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
1377                // refactor again.
1378                comparison_coercion(&left_type, right_type)
1379            })
1380    }
1381}