use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess,
GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort,
TryCast, WindowFunction,
};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
use crate::{LogicalPlan, Projection, Subquery};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
internal_err, plan_err, Column, DFField, DFSchema, DataFusionError, ExprSchema,
Result,
};
use std::collections::HashMap;
use std::sync::Arc;
pub trait ExprSchemable {
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, String>>;
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>;
}
impl ExprSchemable for Expr {
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
match self {
Expr::Alias(Alias { expr, name, .. }) => match &**expr {
Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
None => schema.data_type(&Column::from_name(name)).cloned(),
Some(dt) => Ok(dt.clone()),
},
_ => expr.get_type(schema),
},
Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema),
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()),
Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case(case) => case.when_then_expr[0].1.get_type(schema),
Expr::Cast(Cast { data_type, .. })
| Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
fun.return_type(&data_types)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
fun.return_type(&data_types)
}
Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
fun.return_type(&data_types)
}
Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::Not(_)
| Expr::IsNull(_)
| Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::Between { .. }
| Expr::InList { .. }
| Expr::IsNotNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
Expr::ScalarSubquery(subquery) => {
Ok(subquery.subquery.schema().field(0).data_type().clone())
}
Expr::BinaryExpr(BinaryExpr {
ref left,
ref right,
ref op,
}) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?),
Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
Expr::Placeholder(Placeholder { data_type, .. }) => {
data_type.clone().ok_or_else(|| {
DataFusionError::Plan(
"Placeholder type could not be resolved".to_owned(),
)
})
}
Expr::Wildcard => {
Ok(DataType::Null)
}
Expr::QualifiedWildcard { .. } => internal_err!(
"QualifiedWildcard expressions are not valid in a logical query plan"
),
Expr::GroupingSet(_) => {
Ok(DataType::Null)
}
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
field_for_index(expr, field, schema).map(|x| x.data_type().clone())
}
}
}
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
match self {
Expr::Alias(Alias { expr, .. })
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort(Sort { expr, .. }) => expr.nullable(input_schema),
Expr::InList(InList { expr, list, .. }) => {
const MAX_INSPECT_LIMIT: usize = 6;
let has_nullable = std::iter::once(expr.as_ref())
.chain(list)
.take(MAX_INSPECT_LIMIT)
.find_map(|e| {
e.nullable(input_schema)
.map(|nullable| if nullable { Some(()) } else { None })
.transpose()
})
.transpose()?;
Ok(match has_nullable {
Some(_) => true,
None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
_ => false,
})
}
Expr::Between(Between {
expr, low, high, ..
}) => Ok(expr.nullable(input_schema)?
|| low.nullable(input_schema)?
|| high.nullable(input_schema)?),
Expr::Column(c) => input_schema.nullable(c),
Expr::OuterReferenceColumn(_, _) => Ok(true),
Expr::Literal(value) => Ok(value.is_null()),
Expr::Case(case) => {
let then_nullable = case
.when_then_expr
.iter()
.map(|(_, t)| t.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
if then_nullable.contains(&true) {
Ok(true)
} else if let Some(e) = &case.else_expr {
e.nullable(input_schema)
} else {
Ok(true)
}
}
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
| Expr::ScalarFunction(..)
| Expr::ScalarUDF(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::AggregateUDF { .. }
| Expr::Placeholder(_) => Ok(true),
Expr::IsNull(_)
| Expr::IsNotNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Exists { .. } => Ok(false),
Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarSubquery(subquery) => {
Ok(subquery.subquery.schema().field(0).is_nullable())
}
Expr::BinaryExpr(BinaryExpr {
ref left,
ref right,
..
}) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
Expr::Like(Like { expr, pattern, .. })
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
}
Expr::Wildcard => internal_err!(
"Wildcard expressions are not valid in a logical query plan"
),
Expr::QualifiedWildcard { .. } => internal_err!(
"QualifiedWildcard expressions are not valid in a logical query plan"
),
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
field_for_index(expr, field, input_schema).map(|x| x.is_nullable())
}
Expr::GroupingSet(_) => {
Ok(true)
}
}
}
fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, String>> {
match self {
Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
_ => Ok(HashMap::new()),
}
}
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
match self {
Expr::Column(c) => Ok(DFField::new(
c.relation.clone(),
&c.name,
self.get_type(input_schema)?,
self.nullable(input_schema)?,
)
.with_metadata(self.metadata(input_schema)?)),
_ => Ok(DFField::new_unqualified(
&self.display_name()?,
self.get_type(input_schema)?,
self.nullable(input_schema)?,
)
.with_metadata(self.metadata(input_schema)?)),
}
}
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr> {
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
return Ok(self);
}
if can_cast_types(&this_type, cast_to_type) {
match self {
Expr::ScalarSubquery(subquery) => {
Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
}
_ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
}
} else {
plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}")
}
}
}
fn field_for_index<S: ExprSchema>(
expr: &Expr,
field: &GetFieldAccess,
schema: &S,
) -> Result<Field> {
let expr_dt = expr.get_type(schema)?;
match field {
GetFieldAccess::NamedStructField { name } => {
GetFieldAccessSchema::NamedStructField { name: name.clone() }
}
GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex {
key_dt: key.get_type(schema)?,
},
GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange {
start_dt: start.get_type(schema)?,
stop_dt: stop.get_type(schema)?,
},
}
.get_accessed_field(&expr_dt)
}
pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
if subquery.subquery.schema().field(0).data_type() == cast_to_type {
return Ok(subquery);
}
let plan = subquery.subquery.as_ref();
let new_plan = match plan {
LogicalPlan::Projection(projection) => {
let cast_expr = projection.expr[0]
.clone()
.cast_to(cast_to_type, projection.input.schema())?;
LogicalPlan::Projection(Projection::try_new(
vec![cast_expr],
projection.input.clone(),
)?)
}
_ => {
let cast_expr = Expr::Column(plan.schema().field(0).qualified_column())
.cast_to(cast_to_type, subquery.subquery.schema())?;
LogicalPlan::Projection(Projection::try_new(
vec![cast_expr],
subquery.subquery,
)?)
}
};
Ok(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{col, lit};
use arrow::datatypes::DataType;
use datafusion_common::{Column, ScalarValue};
macro_rules! test_is_expr_nullable {
($EXPR_TYPE:ident) => {{
let expr = lit(ScalarValue::Null).$EXPR_TYPE();
assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
}};
}
#[test]
fn expr_schema_nullability() {
let expr = col("foo").eq(lit(1));
assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
assert!(expr
.nullable(&MockExprSchema::new().with_nullable(true))
.unwrap());
test_is_expr_nullable!(is_null);
test_is_expr_nullable!(is_not_null);
test_is_expr_nullable!(is_true);
test_is_expr_nullable!(is_not_true);
test_is_expr_nullable!(is_false);
test_is_expr_nullable!(is_not_false);
test_is_expr_nullable!(is_unknown);
test_is_expr_nullable!(is_not_unknown);
}
#[test]
fn test_between_nullability() {
let get_schema = |nullable| {
MockExprSchema::new()
.with_data_type(DataType::Int32)
.with_nullable(nullable)
};
let expr = col("foo").between(lit(1), lit(2));
assert!(!expr.nullable(&get_schema(false)).unwrap());
assert!(expr.nullable(&get_schema(true)).unwrap());
let null = lit(ScalarValue::Int32(None));
let expr = col("foo").between(null.clone(), lit(2));
assert!(expr.nullable(&get_schema(false)).unwrap());
let expr = col("foo").between(lit(1), null.clone());
assert!(expr.nullable(&get_schema(false)).unwrap());
let expr = col("foo").between(null.clone(), null);
assert!(expr.nullable(&get_schema(false)).unwrap());
}
#[test]
fn test_inlist_nullability() {
let get_schema = |nullable| {
MockExprSchema::new()
.with_data_type(DataType::Int32)
.with_nullable(nullable)
};
let expr = col("foo").in_list(vec![lit(1); 5], false);
assert!(!expr.nullable(&get_schema(false)).unwrap());
assert!(expr.nullable(&get_schema(true)).unwrap());
assert!(expr
.nullable(&get_schema(false).with_error_on_nullable(true))
.is_err());
let null = lit(ScalarValue::Int32(None));
let expr = col("foo").in_list(vec![null, lit(1)], false);
assert!(expr.nullable(&get_schema(false)).unwrap());
let expr = col("foo").in_list(vec![lit(1); 6], false);
assert!(expr.nullable(&get_schema(false)).unwrap());
}
#[test]
fn test_like_nullability() {
let get_schema = |nullable| {
MockExprSchema::new()
.with_data_type(DataType::Utf8)
.with_nullable(nullable)
};
let expr = col("foo").like(lit("bar"));
assert!(!expr.nullable(&get_schema(false)).unwrap());
assert!(expr.nullable(&get_schema(true)).unwrap());
let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
assert!(expr.nullable(&get_schema(false)).unwrap());
}
#[test]
fn expr_schema_data_type() {
let expr = col("foo");
assert_eq!(
DataType::Utf8,
expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
.unwrap()
);
}
#[test]
fn test_expr_metadata() {
let mut meta = HashMap::new();
meta.insert("bar".to_string(), "buzz".to_string());
let expr = col("foo");
let schema = MockExprSchema::new()
.with_data_type(DataType::Int32)
.with_metadata(meta.clone());
assert_eq!(meta, expr.metadata(&schema).unwrap());
assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
assert_eq!(
HashMap::new(),
expr.clone()
.cast_to(&DataType::Int64, &schema)
.unwrap()
.metadata(&schema)
.unwrap()
);
let schema = DFSchema::new_with_metadata(
vec![DFField::new_unqualified("foo", DataType::Int32, true)
.with_metadata(meta.clone())],
HashMap::new(),
)
.unwrap();
assert_eq!(&meta, expr.to_field(&schema).unwrap().metadata());
}
#[derive(Debug)]
struct MockExprSchema {
nullable: bool,
data_type: DataType,
error_on_nullable: bool,
metadata: HashMap<String, String>,
}
impl MockExprSchema {
fn new() -> Self {
Self {
nullable: false,
data_type: DataType::Null,
error_on_nullable: false,
metadata: HashMap::new(),
}
}
fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
fn with_data_type(mut self, data_type: DataType) -> Self {
self.data_type = data_type;
self
}
fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
self.error_on_nullable = error_on_nullable;
self
}
fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata = metadata;
self
}
}
impl ExprSchema for MockExprSchema {
fn nullable(&self, _col: &Column) -> Result<bool> {
if self.error_on_nullable {
internal_err!("nullable error")
} else {
Ok(self.nullable)
}
}
fn data_type(&self, _col: &Column) -> Result<&DataType> {
Ok(&self.data_type)
}
fn metadata(&self, _col: &Column) -> Result<&HashMap<String, String>> {
Ok(&self.metadata)
}
}
}