datafusion_expr/
conditional_expressions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Conditional expressions
19use crate::expr::Case;
20use crate::{expr_schema::ExprSchemable, Expr};
21use arrow::datatypes::DataType;
22use datafusion_common::{plan_err, DFSchema, HashSet, Result};
23
24/// Helper struct for building [Expr::Case]
25pub struct CaseBuilder {
26    expr: Option<Box<Expr>>,
27    when_expr: Vec<Expr>,
28    then_expr: Vec<Expr>,
29    else_expr: Option<Box<Expr>>,
30}
31
32impl CaseBuilder {
33    pub fn new(
34        expr: Option<Box<Expr>>,
35        when_expr: Vec<Expr>,
36        then_expr: Vec<Expr>,
37        else_expr: Option<Box<Expr>>,
38    ) -> Self {
39        Self {
40            expr,
41            when_expr,
42            then_expr,
43            else_expr,
44        }
45    }
46    pub fn when(&mut self, when: Expr, then: Expr) -> CaseBuilder {
47        self.when_expr.push(when);
48        self.then_expr.push(then);
49        CaseBuilder {
50            expr: self.expr.clone(),
51            when_expr: self.when_expr.clone(),
52            then_expr: self.then_expr.clone(),
53            else_expr: self.else_expr.clone(),
54        }
55    }
56    pub fn otherwise(&mut self, else_expr: Expr) -> Result<Expr> {
57        self.else_expr = Some(Box::new(else_expr));
58        self.build()
59    }
60
61    pub fn end(&self) -> Result<Expr> {
62        self.build()
63    }
64
65    fn build(&self) -> Result<Expr> {
66        // Collect all "then" expressions
67        let mut then_expr = self.then_expr.clone();
68        if let Some(e) = &self.else_expr {
69            then_expr.push(e.as_ref().to_owned());
70        }
71
72        let then_types: Vec<DataType> = then_expr
73            .iter()
74            .map(|e| match e {
75                Expr::Literal(_) => e.get_type(&DFSchema::empty()),
76                _ => Ok(DataType::Null),
77            })
78            .collect::<Result<Vec<_>>>()?;
79
80        if then_types.contains(&DataType::Null) {
81            // Cannot verify types until execution type
82        } else {
83            let unique_types: HashSet<&DataType> = then_types.iter().collect();
84            if unique_types.len() != 1 {
85                return plan_err!(
86                    "CASE expression 'then' values had multiple data types: {unique_types:?}"
87                );
88            }
89        }
90
91        Ok(Expr::Case(Case::new(
92            self.expr.clone(),
93            self.when_expr
94                .iter()
95                .zip(self.then_expr.iter())
96                .map(|(w, t)| (Box::new(w.clone()), Box::new(t.clone())))
97                .collect(),
98            self.else_expr.clone(),
99        )))
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::{col, lit, when};
107
108    #[test]
109    fn case_when_same_literal_then_types() -> Result<()> {
110        let _ = when(col("state").eq(lit("CO")), lit(303))
111            .when(col("state").eq(lit("NY")), lit(212))
112            .end()?;
113        Ok(())
114    }
115
116    #[test]
117    fn case_when_different_literal_then_types() {
118        let maybe_expr = when(col("state").eq(lit("CO")), lit(303))
119            .when(col("state").eq(lit("NY")), lit("212"))
120            .end();
121        assert!(maybe_expr.is_err());
122    }
123}