1use std::borrow::Cow;
19use std::hash::Hash;
20use std::{any::Any, sync::Arc};
21
22use crate::expressions::try_cast;
23use crate::PhysicalExpr;
24
25use arrow::array::*;
26use arrow::compute::kernels::zip::zip;
27use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
28use arrow::datatypes::{DataType, Schema};
29use datafusion_common::cast::as_boolean_array;
30use datafusion_common::{
31 exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
32};
33use datafusion_expr::ColumnarValue;
34
35use super::{Column, Literal};
36use datafusion_physical_expr_common::datum::compare_with_eq;
37use itertools::Itertools;
38
39type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
40
41#[derive(Debug, Hash, PartialEq, Eq)]
42enum EvalMethod {
43 NoExpression,
48 WithExpression,
54 InfallibleExprOrNull,
60 ScalarOrScalar,
65 ExpressionOrExpression,
70}
71
72#[derive(Debug, Hash, PartialEq, Eq)]
90pub struct CaseExpr {
91 expr: Option<Arc<dyn PhysicalExpr>>,
93 when_then_expr: Vec<WhenThen>,
95 else_expr: Option<Arc<dyn PhysicalExpr>>,
97 eval_method: EvalMethod,
99}
100
101impl std::fmt::Display for CaseExpr {
102 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
103 write!(f, "CASE ")?;
104 if let Some(e) = &self.expr {
105 write!(f, "{e} ")?;
106 }
107 for (w, t) in &self.when_then_expr {
108 write!(f, "WHEN {w} THEN {t} ")?;
109 }
110 if let Some(e) = &self.else_expr {
111 write!(f, "ELSE {e} ")?;
112 }
113 write!(f, "END")
114 }
115}
116
117fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
123 expr.as_any().is::<Column>()
124}
125
126impl CaseExpr {
127 pub fn try_new(
129 expr: Option<Arc<dyn PhysicalExpr>>,
130 when_then_expr: Vec<WhenThen>,
131 else_expr: Option<Arc<dyn PhysicalExpr>>,
132 ) -> Result<Self> {
133 let else_expr = match &else_expr {
136 Some(e) => match e.as_any().downcast_ref::<Literal>() {
137 Some(lit) if lit.value().is_null() => None,
138 _ => else_expr,
139 },
140 _ => else_expr,
141 };
142
143 if when_then_expr.is_empty() {
144 exec_err!("There must be at least one WHEN clause")
145 } else {
146 let eval_method = if expr.is_some() {
147 EvalMethod::WithExpression
148 } else if when_then_expr.len() == 1
149 && is_cheap_and_infallible(&(when_then_expr[0].1))
150 && else_expr.is_none()
151 {
152 EvalMethod::InfallibleExprOrNull
153 } else if when_then_expr.len() == 1
154 && when_then_expr[0].1.as_any().is::<Literal>()
155 && else_expr.is_some()
156 && else_expr.as_ref().unwrap().as_any().is::<Literal>()
157 {
158 EvalMethod::ScalarOrScalar
159 } else if when_then_expr.len() == 1 && else_expr.is_some() {
160 EvalMethod::ExpressionOrExpression
161 } else {
162 EvalMethod::NoExpression
163 };
164
165 Ok(Self {
166 expr,
167 when_then_expr,
168 else_expr,
169 eval_method,
170 })
171 }
172 }
173
174 pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
176 self.expr.as_ref()
177 }
178
179 pub fn when_then_expr(&self) -> &[WhenThen] {
181 &self.when_then_expr
182 }
183
184 pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
186 self.else_expr.as_ref()
187 }
188}
189
190impl CaseExpr {
191 fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
199 let return_type = self.data_type(&batch.schema())?;
200 let expr = self.expr.as_ref().unwrap();
201 let base_value = expr.evaluate(batch)?;
202 let base_value = base_value.into_array(batch.num_rows())?;
203 let base_nulls = is_null(base_value.as_ref())?;
204
205 let mut current_value = new_null_array(&return_type, batch.num_rows());
207 let mut remainder = not(&base_nulls)?;
209 for i in 0..self.when_then_expr.len() {
210 let when_value = self.when_then_expr[i]
211 .0
212 .evaluate_selection(batch, &remainder)?;
213 let when_value = when_value.into_array(batch.num_rows())?;
214 let when_match = compare_with_eq(
216 &when_value,
217 &base_value,
218 base_value.data_type().is_nested(),
221 )?;
222 let when_match = match when_match.null_count() {
224 0 => Cow::Borrowed(&when_match),
225 _ => Cow::Owned(prep_null_mask_filter(&when_match)),
226 };
227 let when_match = and(&when_match, &remainder)?;
229
230 if when_match.true_count() == 0 {
232 continue;
233 }
234
235 let then_value = self.when_then_expr[i]
236 .1
237 .evaluate_selection(batch, &when_match)?;
238
239 current_value = match then_value {
240 ColumnarValue::Scalar(ScalarValue::Null) => {
241 nullif(current_value.as_ref(), &when_match)?
242 }
243 ColumnarValue::Scalar(then_value) => {
244 zip(&when_match, &then_value.to_scalar()?, ¤t_value)?
245 }
246 ColumnarValue::Array(then_value) => {
247 zip(&when_match, &then_value, ¤t_value)?
248 }
249 };
250
251 remainder = and_not(&remainder, &when_match)?;
252 }
253
254 if let Some(e) = self.else_expr() {
255 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
257 remainder = or(&base_nulls, &remainder)?;
259 let else_ = expr
260 .evaluate_selection(batch, &remainder)?
261 .into_array(batch.num_rows())?;
262 current_value = zip(&remainder, &else_, ¤t_value)?;
263 }
264
265 Ok(ColumnarValue::Array(current_value))
266 }
267
268 fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
276 let return_type = self.data_type(&batch.schema())?;
277
278 let mut current_value = new_null_array(&return_type, batch.num_rows());
280 let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
281 for i in 0..self.when_then_expr.len() {
282 let when_value = self.when_then_expr[i]
283 .0
284 .evaluate_selection(batch, &remainder)?;
285 let when_value = when_value.into_array(batch.num_rows())?;
286 let when_value = as_boolean_array(&when_value).map_err(|_| {
287 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
288 })?;
289 let when_value = match when_value.null_count() {
291 0 => Cow::Borrowed(when_value),
292 _ => Cow::Owned(prep_null_mask_filter(when_value)),
293 };
294 let when_value = and(&when_value, &remainder)?;
296
297 if when_value.true_count() == 0 {
299 continue;
300 }
301
302 let then_value = self.when_then_expr[i]
303 .1
304 .evaluate_selection(batch, &when_value)?;
305
306 current_value = match then_value {
307 ColumnarValue::Scalar(ScalarValue::Null) => {
308 nullif(current_value.as_ref(), &when_value)?
309 }
310 ColumnarValue::Scalar(then_value) => {
311 zip(&when_value, &then_value.to_scalar()?, ¤t_value)?
312 }
313 ColumnarValue::Array(then_value) => {
314 zip(&when_value, &then_value, ¤t_value)?
315 }
316 };
317
318 remainder = and_not(&remainder, &when_value)?;
321 }
322
323 if let Some(e) = self.else_expr() {
324 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
326 let else_ = expr
327 .evaluate_selection(batch, &remainder)?
328 .into_array(batch.num_rows())?;
329 current_value = zip(&remainder, &else_, ¤t_value)?;
330 }
331
332 Ok(ColumnarValue::Array(current_value))
333 }
334
335 fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
345 let when_expr = &self.when_then_expr[0].0;
346 let then_expr = &self.when_then_expr[0].1;
347
348 match when_expr.evaluate(batch)? {
349 ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
351 then_expr.evaluate(batch)
352 }
353 ColumnarValue::Scalar(_) => {
355 ScalarValue::try_from(self.data_type(&batch.schema())?)
357 .map(ColumnarValue::Scalar)
358 }
359 ColumnarValue::Array(bit_mask) => {
361 let bit_mask = bit_mask
362 .as_any()
363 .downcast_ref::<BooleanArray>()
364 .expect("predicate should evaluate to a boolean array");
365 let bit_mask = match bit_mask.null_count() {
367 0 => not(bit_mask)?,
368 _ => not(&prep_null_mask_filter(bit_mask))?,
369 };
370 match then_expr.evaluate(batch)? {
371 ColumnarValue::Array(array) => {
372 Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
373 }
374 ColumnarValue::Scalar(_) => {
375 internal_err!("expression did not evaluate to an array")
376 }
377 }
378 }
379 }
380 }
381
382 fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
383 let return_type = self.data_type(&batch.schema())?;
384
385 let when_value = self.when_then_expr[0].0.evaluate(batch)?;
387 let when_value = when_value.into_array(batch.num_rows())?;
388 let when_value = as_boolean_array(&when_value).map_err(|_| {
389 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
390 })?;
391
392 let when_value = match when_value.null_count() {
394 0 => Cow::Borrowed(when_value),
395 _ => Cow::Owned(prep_null_mask_filter(when_value)),
396 };
397
398 let then_value = self.when_then_expr[0].1.evaluate(batch)?;
400 let then_value = Scalar::new(then_value.into_array(1)?);
401
402 let Some(e) = self.else_expr() else {
403 return internal_err!("expression did not evaluate to an array");
404 };
405 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
407 let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
408 Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
409 }
410
411 fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
412 let return_type = self.data_type(&batch.schema())?;
413
414 let when_value = self.when_then_expr[0].0.evaluate(batch)?;
416 let when_value = when_value.into_array(batch.num_rows())?;
417 let when_value = as_boolean_array(&when_value).map_err(|e| {
418 DataFusionError::Context(
419 "WHEN expression did not return a BooleanArray".to_string(),
420 Box::new(e),
421 )
422 })?;
423
424 let when_value = match when_value.null_count() {
426 0 => Cow::Borrowed(when_value),
427 _ => Cow::Owned(prep_null_mask_filter(when_value)),
428 };
429
430 let then_value = self.when_then_expr[0]
431 .1
432 .evaluate_selection(batch, &when_value)?
433 .into_array(batch.num_rows())?;
434
435 let remainder = not(&when_value)?;
437 let e = self.else_expr.as_ref().unwrap();
438 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
440 .unwrap_or_else(|_| Arc::clone(e));
441 let else_ = expr
442 .evaluate_selection(batch, &remainder)?
443 .into_array(batch.num_rows())?;
444
445 Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
446 }
447}
448
449impl PhysicalExpr for CaseExpr {
450 fn as_any(&self) -> &dyn Any {
452 self
453 }
454
455 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
456 let mut data_type = DataType::Null;
459 for i in 0..self.when_then_expr.len() {
460 data_type = self.when_then_expr[i].1.data_type(input_schema)?;
461 if !data_type.equals_datatype(&DataType::Null) {
462 break;
463 }
464 }
465 if data_type.equals_datatype(&DataType::Null) {
467 if let Some(e) = &self.else_expr {
468 data_type = e.data_type(input_schema)?;
469 }
470 }
471
472 Ok(data_type)
473 }
474
475 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
476 let then_nullable = self
478 .when_then_expr
479 .iter()
480 .map(|(_, t)| t.nullable(input_schema))
481 .collect::<Result<Vec<_>>>()?;
482 if then_nullable.contains(&true) {
483 Ok(true)
484 } else if let Some(e) = &self.else_expr {
485 e.nullable(input_schema)
486 } else {
487 Ok(true)
490 }
491 }
492
493 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
494 match self.eval_method {
495 EvalMethod::WithExpression => {
496 self.case_when_with_expr(batch)
499 }
500 EvalMethod::NoExpression => {
501 self.case_when_no_expr(batch)
504 }
505 EvalMethod::InfallibleExprOrNull => {
506 self.case_column_or_null(batch)
508 }
509 EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
510 EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
511 }
512 }
513
514 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
515 let mut children = vec![];
516 if let Some(expr) = &self.expr {
517 children.push(expr)
518 }
519 self.when_then_expr.iter().for_each(|(cond, value)| {
520 children.push(cond);
521 children.push(value);
522 });
523
524 if let Some(else_expr) = &self.else_expr {
525 children.push(else_expr)
526 }
527 children
528 }
529
530 fn with_new_children(
532 self: Arc<Self>,
533 children: Vec<Arc<dyn PhysicalExpr>>,
534 ) -> Result<Arc<dyn PhysicalExpr>> {
535 if children.len() != self.children().len() {
536 internal_err!("CaseExpr: Wrong number of children")
537 } else {
538 let (expr, when_then_expr, else_expr) =
539 match (self.expr().is_some(), self.else_expr().is_some()) {
540 (true, true) => (
541 Some(&children[0]),
542 &children[1..children.len() - 1],
543 Some(&children[children.len() - 1]),
544 ),
545 (true, false) => {
546 (Some(&children[0]), &children[1..children.len()], None)
547 }
548 (false, true) => (
549 None,
550 &children[0..children.len() - 1],
551 Some(&children[children.len() - 1]),
552 ),
553 (false, false) => (None, &children[0..children.len()], None),
554 };
555 Ok(Arc::new(CaseExpr::try_new(
556 expr.cloned(),
557 when_then_expr.iter().cloned().tuples().collect(),
558 else_expr.cloned(),
559 )?))
560 }
561 }
562}
563
564pub fn case(
566 expr: Option<Arc<dyn PhysicalExpr>>,
567 when_thens: Vec<WhenThen>,
568 else_expr: Option<Arc<dyn PhysicalExpr>>,
569) -> Result<Arc<dyn PhysicalExpr>> {
570 Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576
577 use crate::expressions::{binary, cast, col, lit, BinaryExpr};
578 use arrow::buffer::Buffer;
579 use arrow::datatypes::DataType::Float64;
580 use arrow::datatypes::*;
581 use datafusion_common::cast::{as_float64_array, as_int32_array};
582 use datafusion_common::plan_err;
583 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
584 use datafusion_expr::type_coercion::binary::comparison_coercion;
585 use datafusion_expr::Operator;
586
587 #[test]
588 fn case_with_expr() -> Result<()> {
589 let batch = case_test_batch()?;
590 let schema = batch.schema();
591
592 let when1 = lit("foo");
594 let then1 = lit(123i32);
595 let when2 = lit("bar");
596 let then2 = lit(456i32);
597
598 let expr = generate_case_when_with_type_coercion(
599 Some(col("a", &schema)?),
600 vec![(when1, then1), (when2, then2)],
601 None,
602 schema.as_ref(),
603 )?;
604 let result = expr
605 .evaluate(&batch)?
606 .into_array(batch.num_rows())
607 .expect("Failed to convert to array");
608 let result = as_int32_array(&result)?;
609
610 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
611
612 assert_eq!(expected, result);
613
614 Ok(())
615 }
616
617 #[test]
618 fn case_with_expr_else() -> Result<()> {
619 let batch = case_test_batch()?;
620 let schema = batch.schema();
621
622 let when1 = lit("foo");
624 let then1 = lit(123i32);
625 let when2 = lit("bar");
626 let then2 = lit(456i32);
627 let else_value = lit(999i32);
628
629 let expr = generate_case_when_with_type_coercion(
630 Some(col("a", &schema)?),
631 vec![(when1, then1), (when2, then2)],
632 Some(else_value),
633 schema.as_ref(),
634 )?;
635 let result = expr
636 .evaluate(&batch)?
637 .into_array(batch.num_rows())
638 .expect("Failed to convert to array");
639 let result = as_int32_array(&result)?;
640
641 let expected =
642 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
643
644 assert_eq!(expected, result);
645
646 Ok(())
647 }
648
649 #[test]
650 fn case_with_expr_divide_by_zero() -> Result<()> {
651 let batch = case_test_batch1()?;
652 let schema = batch.schema();
653
654 let when1 = lit(0i32);
656 let then1 = lit(ScalarValue::Float64(None));
657 let else_value = binary(
658 lit(25.0f64),
659 Operator::Divide,
660 cast(col("a", &schema)?, &batch.schema(), Float64)?,
661 &batch.schema(),
662 )?;
663
664 let expr = generate_case_when_with_type_coercion(
665 Some(col("a", &schema)?),
666 vec![(when1, then1)],
667 Some(else_value),
668 schema.as_ref(),
669 )?;
670 let result = expr
671 .evaluate(&batch)?
672 .into_array(batch.num_rows())
673 .expect("Failed to convert to array");
674 let result =
675 as_float64_array(&result).expect("failed to downcast to Float64Array");
676
677 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
678
679 assert_eq!(expected, result);
680
681 Ok(())
682 }
683
684 #[test]
685 fn case_without_expr() -> Result<()> {
686 let batch = case_test_batch()?;
687 let schema = batch.schema();
688
689 let when1 = binary(
691 col("a", &schema)?,
692 Operator::Eq,
693 lit("foo"),
694 &batch.schema(),
695 )?;
696 let then1 = lit(123i32);
697 let when2 = binary(
698 col("a", &schema)?,
699 Operator::Eq,
700 lit("bar"),
701 &batch.schema(),
702 )?;
703 let then2 = lit(456i32);
704
705 let expr = generate_case_when_with_type_coercion(
706 None,
707 vec![(when1, then1), (when2, then2)],
708 None,
709 schema.as_ref(),
710 )?;
711 let result = expr
712 .evaluate(&batch)?
713 .into_array(batch.num_rows())
714 .expect("Failed to convert to array");
715 let result = as_int32_array(&result)?;
716
717 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
718
719 assert_eq!(expected, result);
720
721 Ok(())
722 }
723
724 #[test]
725 fn case_with_expr_when_null() -> Result<()> {
726 let batch = case_test_batch()?;
727 let schema = batch.schema();
728
729 let when1 = lit(ScalarValue::Utf8(None));
731 let then1 = lit(0i32);
732 let when2 = col("a", &schema)?;
733 let then2 = lit(123i32);
734 let else_value = lit(999i32);
735
736 let expr = generate_case_when_with_type_coercion(
737 Some(col("a", &schema)?),
738 vec![(when1, then1), (when2, then2)],
739 Some(else_value),
740 schema.as_ref(),
741 )?;
742 let result = expr
743 .evaluate(&batch)?
744 .into_array(batch.num_rows())
745 .expect("Failed to convert to array");
746 let result = as_int32_array(&result)?;
747
748 let expected =
749 &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
750
751 assert_eq!(expected, result);
752
753 Ok(())
754 }
755
756 #[test]
757 fn case_without_expr_divide_by_zero() -> Result<()> {
758 let batch = case_test_batch1()?;
759 let schema = batch.schema();
760
761 let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
763 let then1 = binary(
764 lit(25.0f64),
765 Operator::Divide,
766 cast(col("a", &schema)?, &batch.schema(), Float64)?,
767 &batch.schema(),
768 )?;
769 let x = lit(ScalarValue::Float64(None));
770
771 let expr = generate_case_when_with_type_coercion(
772 None,
773 vec![(when1, then1)],
774 Some(x),
775 schema.as_ref(),
776 )?;
777 let result = expr
778 .evaluate(&batch)?
779 .into_array(batch.num_rows())
780 .expect("Failed to convert to array");
781 let result =
782 as_float64_array(&result).expect("failed to downcast to Float64Array");
783
784 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
785
786 assert_eq!(expected, result);
787
788 Ok(())
789 }
790
791 fn case_test_batch1() -> Result<RecordBatch> {
792 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
793 let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
794 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
795 Ok(batch)
796 }
797
798 #[test]
799 fn case_without_expr_else() -> Result<()> {
800 let batch = case_test_batch()?;
801 let schema = batch.schema();
802
803 let when1 = binary(
805 col("a", &schema)?,
806 Operator::Eq,
807 lit("foo"),
808 &batch.schema(),
809 )?;
810 let then1 = lit(123i32);
811 let when2 = binary(
812 col("a", &schema)?,
813 Operator::Eq,
814 lit("bar"),
815 &batch.schema(),
816 )?;
817 let then2 = lit(456i32);
818 let else_value = lit(999i32);
819
820 let expr = generate_case_when_with_type_coercion(
821 None,
822 vec![(when1, then1), (when2, then2)],
823 Some(else_value),
824 schema.as_ref(),
825 )?;
826 let result = expr
827 .evaluate(&batch)?
828 .into_array(batch.num_rows())
829 .expect("Failed to convert to array");
830 let result = as_int32_array(&result)?;
831
832 let expected =
833 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
834
835 assert_eq!(expected, result);
836
837 Ok(())
838 }
839
840 #[test]
841 fn case_with_type_cast() -> Result<()> {
842 let batch = case_test_batch()?;
843 let schema = batch.schema();
844
845 let when = binary(
847 col("a", &schema)?,
848 Operator::Eq,
849 lit("foo"),
850 &batch.schema(),
851 )?;
852 let then = lit(123.3f64);
853 let else_value = lit(999i32);
854
855 let expr = generate_case_when_with_type_coercion(
856 None,
857 vec![(when, then)],
858 Some(else_value),
859 schema.as_ref(),
860 )?;
861 let result = expr
862 .evaluate(&batch)?
863 .into_array(batch.num_rows())
864 .expect("Failed to convert to array");
865 let result =
866 as_float64_array(&result).expect("failed to downcast to Float64Array");
867
868 let expected =
869 &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
870
871 assert_eq!(expected, result);
872
873 Ok(())
874 }
875
876 #[test]
877 fn case_with_matches_and_nulls() -> Result<()> {
878 let batch = case_test_batch_nulls()?;
879 let schema = batch.schema();
880
881 let when = binary(
883 col("load4", &schema)?,
884 Operator::Eq,
885 lit(1.77f64),
886 &batch.schema(),
887 )?;
888 let then = col("load4", &schema)?;
889
890 let expr = generate_case_when_with_type_coercion(
891 None,
892 vec![(when, then)],
893 None,
894 schema.as_ref(),
895 )?;
896 let result = expr
897 .evaluate(&batch)?
898 .into_array(batch.num_rows())
899 .expect("Failed to convert to array");
900 let result =
901 as_float64_array(&result).expect("failed to downcast to Float64Array");
902
903 let expected =
904 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
905
906 assert_eq!(expected, result);
907
908 Ok(())
909 }
910
911 #[test]
912 fn case_with_scalar_predicate() -> Result<()> {
913 let batch = case_test_batch_nulls()?;
914 let schema = batch.schema();
915
916 let when = lit(true);
918 let then = col("load4", &schema)?;
919 let expr = generate_case_when_with_type_coercion(
920 None,
921 vec![(when, then)],
922 None,
923 schema.as_ref(),
924 )?;
925
926 let result = expr
928 .evaluate(&batch)?
929 .into_array(batch.num_rows())
930 .expect("Failed to convert to array");
931 let result =
932 as_float64_array(&result).expect("failed to downcast to Float64Array");
933 let expected = &Float64Array::from(vec![
934 Some(1.77),
935 None,
936 None,
937 Some(1.78),
938 None,
939 Some(1.77),
940 ]);
941 assert_eq!(expected, result);
942
943 let expected = Float64Array::from(vec![Some(1.1)]);
945 let batch =
946 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
947 let result = expr
948 .evaluate(&batch)?
949 .into_array(batch.num_rows())
950 .expect("Failed to convert to array");
951 let result =
952 as_float64_array(&result).expect("failed to downcast to Float64Array");
953 assert_eq!(&expected, result);
954
955 Ok(())
956 }
957
958 #[test]
959 fn case_expr_matches_and_nulls() -> Result<()> {
960 let batch = case_test_batch_nulls()?;
961 let schema = batch.schema();
962
963 let expr = col("load4", &schema)?;
965 let when = lit(1.77f64);
966 let then = col("load4", &schema)?;
967
968 let expr = generate_case_when_with_type_coercion(
969 Some(expr),
970 vec![(when, then)],
971 None,
972 schema.as_ref(),
973 )?;
974 let result = expr
975 .evaluate(&batch)?
976 .into_array(batch.num_rows())
977 .expect("Failed to convert to array");
978 let result =
979 as_float64_array(&result).expect("failed to downcast to Float64Array");
980
981 let expected =
982 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
983
984 assert_eq!(expected, result);
985
986 Ok(())
987 }
988
989 #[test]
990 fn test_when_null_and_some_cond_else_null() -> Result<()> {
991 let batch = case_test_batch()?;
992 let schema = batch.schema();
993
994 let when = binary(
995 Arc::new(Literal::new(ScalarValue::Boolean(None))),
996 Operator::And,
997 binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
998 &schema,
999 )?;
1000 let then = col("a", &schema)?;
1001
1002 let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
1004 let result = expr
1005 .evaluate(&batch)?
1006 .into_array(batch.num_rows())
1007 .expect("Failed to convert to array");
1008 let result = as_string_array(&result);
1009
1010 assert_eq!(result.logical_null_count(), batch.num_rows());
1012 Ok(())
1013 }
1014
1015 fn case_test_batch() -> Result<RecordBatch> {
1016 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1017 let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1018 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
1019 Ok(batch)
1020 }
1021
1022 fn case_test_batch_nulls() -> Result<RecordBatch> {
1025 let load4: Float64Array = vec![
1026 Some(1.77), Some(1.77), Some(1.77), Some(1.78), None, Some(1.77), ]
1033 .into_iter()
1034 .collect();
1035
1036 let null_buffer = Buffer::from([0b00101001u8]);
1038 let load4 = load4
1039 .into_data()
1040 .into_builder()
1041 .null_bit_buffer(Some(null_buffer))
1042 .build()
1043 .unwrap();
1044 let load4: Float64Array = load4.into();
1045
1046 let batch =
1047 RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
1048 Ok(batch)
1049 }
1050
1051 #[test]
1052 fn case_test_incompatible() -> Result<()> {
1053 let batch = case_test_batch()?;
1056 let schema = batch.schema();
1057
1058 let when1 = binary(
1060 col("a", &schema)?,
1061 Operator::Eq,
1062 lit("foo"),
1063 &batch.schema(),
1064 )?;
1065 let then1 = lit(123i32);
1066 let when2 = binary(
1067 col("a", &schema)?,
1068 Operator::Eq,
1069 lit("bar"),
1070 &batch.schema(),
1071 )?;
1072 let then2 = lit(true);
1073
1074 let expr = generate_case_when_with_type_coercion(
1075 None,
1076 vec![(when1, then1), (when2, then2)],
1077 None,
1078 schema.as_ref(),
1079 );
1080 assert!(expr.is_err());
1081
1082 let when1 = binary(
1087 col("a", &schema)?,
1088 Operator::Eq,
1089 lit("foo"),
1090 &batch.schema(),
1091 )?;
1092 let then1 = lit(123i32);
1093 let when2 = binary(
1094 col("a", &schema)?,
1095 Operator::Eq,
1096 lit("bar"),
1097 &batch.schema(),
1098 )?;
1099 let then2 = lit(456i64);
1100 let else_expr = lit(1.23f64);
1101
1102 let expr = generate_case_when_with_type_coercion(
1103 None,
1104 vec![(when1, then1), (when2, then2)],
1105 Some(else_expr),
1106 schema.as_ref(),
1107 );
1108 assert!(expr.is_ok());
1109 let result_type = expr.unwrap().data_type(schema.as_ref())?;
1110 assert_eq!(Float64, result_type);
1111 Ok(())
1112 }
1113
1114 #[test]
1115 fn case_eq() -> Result<()> {
1116 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1117
1118 let when1 = lit("foo");
1119 let then1 = lit(123i32);
1120 let when2 = lit("bar");
1121 let then2 = lit(456i32);
1122 let else_value = lit(999i32);
1123
1124 let expr1 = generate_case_when_with_type_coercion(
1125 Some(col("a", &schema)?),
1126 vec![
1127 (Arc::clone(&when1), Arc::clone(&then1)),
1128 (Arc::clone(&when2), Arc::clone(&then2)),
1129 ],
1130 Some(Arc::clone(&else_value)),
1131 &schema,
1132 )?;
1133
1134 let expr2 = generate_case_when_with_type_coercion(
1135 Some(col("a", &schema)?),
1136 vec![
1137 (Arc::clone(&when1), Arc::clone(&then1)),
1138 (Arc::clone(&when2), Arc::clone(&then2)),
1139 ],
1140 Some(Arc::clone(&else_value)),
1141 &schema,
1142 )?;
1143
1144 let expr3 = generate_case_when_with_type_coercion(
1145 Some(col("a", &schema)?),
1146 vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
1147 None,
1148 &schema,
1149 )?;
1150
1151 let expr4 = generate_case_when_with_type_coercion(
1152 Some(col("a", &schema)?),
1153 vec![(when1, then1)],
1154 Some(else_value),
1155 &schema,
1156 )?;
1157
1158 assert!(expr1.eq(&expr2));
1159 assert!(expr2.eq(&expr1));
1160
1161 assert!(expr2.ne(&expr3));
1162 assert!(expr3.ne(&expr2));
1163
1164 assert!(expr1.ne(&expr4));
1165 assert!(expr4.ne(&expr1));
1166
1167 Ok(())
1168 }
1169
1170 #[test]
1171 fn case_transform() -> Result<()> {
1172 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1173
1174 let when1 = lit("foo");
1175 let then1 = lit(123i32);
1176 let when2 = lit("bar");
1177 let then2 = lit(456i32);
1178 let else_value = lit(999i32);
1179
1180 let expr = generate_case_when_with_type_coercion(
1181 Some(col("a", &schema)?),
1182 vec![
1183 (Arc::clone(&when1), Arc::clone(&then1)),
1184 (Arc::clone(&when2), Arc::clone(&then2)),
1185 ],
1186 Some(Arc::clone(&else_value)),
1187 &schema,
1188 )?;
1189
1190 let expr2 = Arc::clone(&expr)
1191 .transform(|e| {
1192 let transformed = match e.as_any().downcast_ref::<Literal>() {
1193 Some(lit_value) => match lit_value.value() {
1194 ScalarValue::Utf8(Some(str_value)) => {
1195 Some(lit(str_value.to_uppercase()))
1196 }
1197 _ => None,
1198 },
1199 _ => None,
1200 };
1201 Ok(if let Some(transformed) = transformed {
1202 Transformed::yes(transformed)
1203 } else {
1204 Transformed::no(e)
1205 })
1206 })
1207 .data()
1208 .unwrap();
1209
1210 let expr3 = Arc::clone(&expr)
1211 .transform_down(|e| {
1212 let transformed = match e.as_any().downcast_ref::<Literal>() {
1213 Some(lit_value) => match lit_value.value() {
1214 ScalarValue::Utf8(Some(str_value)) => {
1215 Some(lit(str_value.to_uppercase()))
1216 }
1217 _ => None,
1218 },
1219 _ => None,
1220 };
1221 Ok(if let Some(transformed) = transformed {
1222 Transformed::yes(transformed)
1223 } else {
1224 Transformed::no(e)
1225 })
1226 })
1227 .data()
1228 .unwrap();
1229
1230 assert!(expr.ne(&expr2));
1231 assert!(expr2.eq(&expr3));
1232
1233 Ok(())
1234 }
1235
1236 #[test]
1237 fn test_column_or_null_specialization() -> Result<()> {
1238 let mut c1 = Int32Builder::new();
1240 let mut c2 = StringBuilder::new();
1241 for i in 0..1000 {
1242 c1.append_value(i);
1243 if i % 7 == 0 {
1244 c2.append_null();
1245 } else {
1246 c2.append_value(format!("string {i}"));
1247 }
1248 }
1249 let c1 = Arc::new(c1.finish());
1250 let c2 = Arc::new(c2.finish());
1251 let schema = Schema::new(vec![
1252 Field::new("c1", DataType::Int32, true),
1253 Field::new("c2", DataType::Utf8, true),
1254 ]);
1255 let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
1256
1257 let predicate = Arc::new(BinaryExpr::new(
1259 make_col("c1", 0),
1260 Operator::LtEq,
1261 make_lit_i32(250),
1262 ));
1263 let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
1264 assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
1265 match expr.evaluate(&batch)? {
1266 ColumnarValue::Array(array) => {
1267 assert_eq!(1000, array.len());
1268 assert_eq!(785, array.null_count());
1269 }
1270 _ => unreachable!(),
1271 }
1272 Ok(())
1273 }
1274
1275 #[test]
1276 fn test_expr_or_expr_specialization() -> Result<()> {
1277 let batch = case_test_batch1()?;
1278 let schema = batch.schema();
1279 let when = binary(
1280 col("a", &schema)?,
1281 Operator::LtEq,
1282 lit(2i32),
1283 &batch.schema(),
1284 )?;
1285 let then = binary(
1286 col("a", &schema)?,
1287 Operator::Plus,
1288 lit(1i32),
1289 &batch.schema(),
1290 )?;
1291 let else_expr = binary(
1292 col("a", &schema)?,
1293 Operator::Minus,
1294 lit(1i32),
1295 &batch.schema(),
1296 )?;
1297 let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
1298 assert!(matches!(
1299 expr.eval_method,
1300 EvalMethod::ExpressionOrExpression
1301 ));
1302 let result = expr
1303 .evaluate(&batch)?
1304 .into_array(batch.num_rows())
1305 .expect("Failed to convert to array");
1306 let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
1307
1308 let expected = &Int32Array::from(vec![Some(2), Some(1), None, Some(4)]);
1309
1310 assert_eq!(expected, result);
1311 Ok(())
1312 }
1313
1314 fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
1315 Arc::new(Column::new(name, index))
1316 }
1317
1318 fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
1319 Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
1320 }
1321
1322 fn generate_case_when_with_type_coercion(
1323 expr: Option<Arc<dyn PhysicalExpr>>,
1324 when_thens: Vec<WhenThen>,
1325 else_expr: Option<Arc<dyn PhysicalExpr>>,
1326 input_schema: &Schema,
1327 ) -> Result<Arc<dyn PhysicalExpr>> {
1328 let coerce_type =
1329 get_case_common_type(&when_thens, else_expr.clone(), input_schema);
1330 let (when_thens, else_expr) = match coerce_type {
1331 None => plan_err!(
1332 "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
1333 ),
1334 Some(data_type) => {
1335 let left = when_thens
1337 .into_iter()
1338 .map(|(when, then)| {
1339 let then = try_cast(then, input_schema, data_type.clone())?;
1340 Ok((when, then))
1341 })
1342 .collect::<Result<Vec<_>>>()?;
1343 let right = match else_expr {
1344 None => None,
1345 Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
1346 };
1347
1348 Ok((left, right))
1349 }
1350 }?;
1351 case(expr, when_thens, else_expr)
1352 }
1353
1354 fn get_case_common_type(
1355 when_thens: &[WhenThen],
1356 else_expr: Option<Arc<dyn PhysicalExpr>>,
1357 input_schema: &Schema,
1358 ) -> Option<DataType> {
1359 let thens_type = when_thens
1360 .iter()
1361 .map(|when_then| {
1362 let data_type = &when_then.1.data_type(input_schema).unwrap();
1363 data_type.clone()
1364 })
1365 .collect::<Vec<_>>();
1366 let else_type = match else_expr {
1367 None => {
1368 thens_type[0].clone()
1370 }
1371 Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
1372 };
1373 thens_type
1374 .iter()
1375 .try_fold(else_type, |left_type, right_type| {
1376 comparison_coercion(&left_type, right_type)
1379 })
1380 }
1381}