use crate::expr::Case;
use crate::{expr_schema::ExprSchemable, Expr};
use arrow::datatypes::DataType;
use datafusion_common::{plan_err, DFSchema, Result};
use std::collections::HashSet;
pub struct CaseBuilder {
expr: Option<Box<Expr>>,
when_expr: Vec<Expr>,
then_expr: Vec<Expr>,
else_expr: Option<Box<Expr>>,
}
impl CaseBuilder {
pub fn new(
expr: Option<Box<Expr>>,
when_expr: Vec<Expr>,
then_expr: Vec<Expr>,
else_expr: Option<Box<Expr>>,
) -> Self {
Self {
expr,
when_expr,
then_expr,
else_expr,
}
}
pub fn when(&mut self, when: Expr, then: Expr) -> CaseBuilder {
self.when_expr.push(when);
self.then_expr.push(then);
CaseBuilder {
expr: self.expr.clone(),
when_expr: self.when_expr.clone(),
then_expr: self.then_expr.clone(),
else_expr: self.else_expr.clone(),
}
}
pub fn otherwise(&mut self, else_expr: Expr) -> Result<Expr> {
self.else_expr = Some(Box::new(else_expr));
self.build()
}
pub fn end(&self) -> Result<Expr> {
self.build()
}
fn build(&self) -> Result<Expr> {
let mut then_expr = self.then_expr.clone();
if let Some(e) = &self.else_expr {
then_expr.push(e.as_ref().to_owned());
}
let then_types: Vec<DataType> = then_expr
.iter()
.map(|e| match e {
Expr::Literal(_) => e.get_type(&DFSchema::empty()),
_ => Ok(DataType::Null),
})
.collect::<Result<Vec<_>>>()?;
if then_types.contains(&DataType::Null) {
} else {
let unique_types: HashSet<&DataType> = then_types.iter().collect();
if unique_types.len() != 1 {
return plan_err!(
"CASE expression 'then' values had multiple data types: {unique_types:?}"
);
}
}
Ok(Expr::Case(Case::new(
self.expr.clone(),
self.when_expr
.iter()
.zip(self.then_expr.iter())
.map(|(w, t)| (Box::new(w.clone()), Box::new(t.clone())))
.collect(),
self.else_expr.clone(),
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{col, lit, when};
#[test]
fn case_when_same_literal_then_types() -> Result<()> {
let _ = when(col("state").eq(lit("CO")), lit(303))
.when(col("state").eq(lit("NY")), lit(212))
.end()?;
Ok(())
}
#[test]
fn case_when_different_literal_then_types() {
let maybe_expr = when(col("state").eq(lit("CO")), lit(303))
.when(col("state").eq(lit("NY")), lit("212"))
.end();
assert!(maybe_expr.is_err());
}
}