use crate::expr::{
AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList,
InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
};
use crate::{Expr, ExprFunctionExt};
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
};
use datafusion_common::Result;
impl TreeNode for Expr {
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
match self {
Expr::Alias(Alias { expr, .. })
| Expr::Unnest(Unnest { expr })
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsTrue(expr)
| Expr::IsFalse(expr)
| Expr::IsUnknown(expr)
| Expr::IsNotTrue(expr)
| Expr::IsNotFalse(expr)
| Expr::IsNotUnknown(expr)
| Expr::IsNull(expr)
| Expr::Negative(expr)
| Expr::Cast(Cast { expr, .. })
| Expr::TryCast(TryCast { expr, .. })
| Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f),
Expr::GroupingSet(GroupingSet::Rollup(exprs))
| Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
Expr::ScalarFunction(ScalarFunction { args, .. }) => {
args.apply_elements(f)
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
lists_of_exprs.apply_elements(f)
}
Expr::Column(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
| Expr::Exists { .. }
| Expr::ScalarSubquery(_)
| Expr::Wildcard { .. }
| Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
(left, right).apply_ref_elements(f)
}
Expr::Like(Like { expr, pattern, .. })
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
(expr, pattern).apply_ref_elements(f)
}
Expr::Between(Between {
expr, low, high, ..
}) => (expr, low, high).apply_ref_elements(f),
Expr::Case(Case { expr, when_then_expr, else_expr }) =>
(expr, when_then_expr, else_expr).apply_ref_elements(f),
Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) =>
(args, filter, order_by).apply_ref_elements(f),
Expr::WindowFunction(WindowFunction {
args,
partition_by,
order_by,
..
}) => {
(args, partition_by, order_by).apply_ref_elements(f)
}
Expr::InList(InList { expr, list, .. }) => {
(expr, list).apply_ref_elements(f)
}
}
}
fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
Ok(match self {
Expr::Column(_)
| Expr::Wildcard { .. }
| Expr::Placeholder(Placeholder { .. })
| Expr::OuterReferenceColumn(_, _)
| Expr::Exists { .. }
| Expr::ScalarSubquery(_)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_) => Transformed::no(self),
Expr::Unnest(Unnest { expr, .. }) => expr
.map_elements(f)?
.update_data(|expr| Expr::Unnest(Unnest { expr })),
Expr::Alias(Alias {
expr,
relation,
name,
}) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated,
}) => expr.map_elements(f)?.update_data(|be| {
Expr::InSubquery(InSubquery::new(be, subquery, negated))
}),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
.map_elements(f)?
.update_data(|(new_left, new_right)| {
Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
}),
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => {
(expr, pattern)
.map_elements(f)?
.update_data(|(new_expr, new_pattern)| {
Expr::Like(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
})
}
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => {
(expr, pattern)
.map_elements(f)?
.update_data(|(new_expr, new_pattern)| {
Expr::SimilarTo(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
})
}
Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
Expr::IsNotUnknown(expr) => {
expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
}
Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
Expr::Between(Between {
expr,
negated,
low,
high,
}) => (expr, low, high).map_elements(f)?.update_data(
|(new_expr, new_low, new_high)| {
Expr::Between(Between::new(new_expr, negated, new_low, new_high))
},
),
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
}) => (expr, when_then_expr, else_expr)
.map_elements(f)?
.update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
}),
Expr::Cast(Cast { expr, data_type }) => expr
.map_elements(f)?
.update_data(|be| Expr::Cast(Cast::new(be, data_type))),
Expr::TryCast(TryCast { expr, data_type }) => expr
.map_elements(f)?
.update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
Expr::ScalarFunction(ScalarFunction { func, args }) => {
args.map_elements(f)?.map_data(|new_args| {
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
func, new_args,
)))
})?
}
Expr::WindowFunction(WindowFunction {
args,
fun,
partition_by,
order_by,
window_frame,
null_treatment,
}) => (args, partition_by, order_by).map_elements(f)?.update_data(
|(new_args, new_partition_by, new_order_by)| {
Expr::WindowFunction(WindowFunction::new(fun, new_args))
.partition_by(new_partition_by)
.order_by(new_order_by)
.window_frame(window_frame)
.null_treatment(null_treatment)
.build()
.unwrap()
},
),
Expr::AggregateFunction(AggregateFunction {
args,
func,
distinct,
filter,
order_by,
null_treatment,
}) => (args, filter, order_by).map_elements(f)?.map_data(
|(new_args, new_filter, new_order_by)| {
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
func,
new_args,
distinct,
new_filter,
new_order_by,
null_treatment,
)))
},
)?,
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => exprs
.map_elements(f)?
.update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
GroupingSet::Cube(exprs) => exprs
.map_elements(f)?
.update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
.map_elements(f)?
.update_data(|new_lists_of_exprs| {
Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
}),
},
Expr::InList(InList {
expr,
list,
negated,
}) => (expr, list)
.map_elements(f)?
.update_data(|(new_expr, new_list)| {
Expr::InList(InList::new(new_expr, new_list, negated))
}),
})
}
}