use std::collections::{BTreeSet, HashMap};
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
use std::sync::Arc;
use crate::{OptimizerConfig, OptimizerRule};
use crate::optimizer::ApplyOrder;
use crate::utils::NamePreserver;
use datafusion_common::alias::AliasGenerator;
use datafusion_common::hash_utils::combine_hashes;
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
};
use datafusion_common::{
internal_datafusion_err, qualified_name, Column, DFSchema, DFSchemaRef, Result,
};
use datafusion_expr::expr::{Alias, ScalarFunction};
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
};
use datafusion_expr::{col, BinaryExpr, Case, Expr, ExprSchemable, Operator};
use indexmap::IndexMap;
const CSE_PREFIX: &str = "__common_expr";
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct Identifier<'n> {
hash: u64,
expr: &'n Expr,
}
impl<'n> Identifier<'n> {
fn new(expr: &'n Expr, random_state: &RandomState) -> Self {
let mut hasher = random_state.build_hasher();
expr.hash_node(&mut hasher);
let hash = hasher.finish();
Self { hash, expr }
}
fn combine(mut self, other: Option<Self>) -> Self {
other.map_or(self, |other_id| {
self.hash = combine_hashes(self.hash, other_id.hash);
self
})
}
}
impl Hash for Identifier<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.hash);
}
}
type IdArray<'n> = Vec<(usize, Option<Identifier<'n>>)>;
type ExprStats<'n> = HashMap<Identifier<'n>, (usize, usize)>;
type CommonExprs<'n> = IndexMap<Identifier<'n>, (Expr, String)>;
pub struct CommonSubexprEliminate {
random_state: RandomState,
}
impl CommonSubexprEliminate {
pub fn new() -> Self {
Self {
random_state: RandomState::new(),
}
}
fn to_arrays<'n>(
&self,
exprs: &'n [Expr],
expr_stats: &mut ExprStats<'n>,
expr_mask: ExprMask,
) -> Result<(bool, Vec<IdArray<'n>>)> {
let mut found_common = false;
exprs
.iter()
.map(|e| {
let mut id_array = vec![];
self.expr_to_identifier(e, expr_stats, &mut id_array, expr_mask)
.map(|fc| {
found_common |= fc;
id_array
})
})
.collect::<Result<Vec<_>>>()
.map(|id_arrays| (found_common, id_arrays))
}
fn expr_to_identifier<'n>(
&self,
expr: &'n Expr,
expr_stats: &mut ExprStats<'n>,
id_array: &mut IdArray<'n>,
expr_mask: ExprMask,
) -> Result<bool> {
let mut visitor = ExprIdentifierVisitor {
expr_stats,
id_array,
visit_stack: vec![],
down_index: 0,
up_index: 0,
expr_mask,
random_state: &self.random_state,
found_common: false,
conditional: false,
};
expr.visit(&mut visitor)?;
Ok(visitor.found_common)
}
fn rewrite_exprs_list<'n>(
&self,
exprs_list: Vec<Vec<Expr>>,
arrays_list: Vec<Vec<IdArray<'n>>>,
expr_stats: &ExprStats<'n>,
common_exprs: &mut CommonExprs<'n>,
alias_generator: &AliasGenerator,
) -> Result<Transformed<Vec<Vec<Expr>>>> {
let mut transformed = false;
exprs_list
.into_iter()
.zip(arrays_list.iter())
.map(|(exprs, arrays)| {
exprs
.into_iter()
.zip(arrays.iter())
.map(|(expr, id_array)| {
let replaced = replace_common_expr(
expr,
id_array,
expr_stats,
common_exprs,
alias_generator,
)?;
transformed |= replaced.transformed;
Ok(replaced.data)
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()
.map(|rewritten_exprs_list| {
Transformed::new_transformed(rewritten_exprs_list, transformed)
})
}
fn rewrite_expr(
&self,
exprs_list: Vec<Vec<Expr>>,
arrays_list: Vec<Vec<IdArray>>,
input: LogicalPlan,
expr_stats: &ExprStats,
config: &dyn OptimizerConfig,
) -> Result<Transformed<(Vec<Vec<Expr>>, LogicalPlan)>> {
let mut transformed = false;
let mut common_exprs = CommonExprs::new();
let rewrite_exprs = self.rewrite_exprs_list(
exprs_list,
arrays_list,
expr_stats,
&mut common_exprs,
&config.alias_generator(),
)?;
transformed |= rewrite_exprs.transformed;
let new_input = self.rewrite(input, config)?;
transformed |= new_input.transformed;
let mut new_input = new_input.data;
if !common_exprs.is_empty() {
assert!(transformed);
new_input = build_common_expr_project_plan(new_input, common_exprs)?;
}
Ok(Transformed::new_transformed(
(rewrite_exprs.data, new_input),
transformed,
))
}
fn try_optimize_proj(
&self,
projection: Projection,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let Projection {
expr,
input,
schema,
..
} = projection;
let input = unwrap_arc(input);
self.try_unary_plan(expr, input, config)?
.map_data(|(new_expr, new_input)| {
Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
.map(LogicalPlan::Projection)
})
}
fn try_optimize_sort(
&self,
sort: Sort,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let Sort { expr, input, fetch } = sort;
let input = unwrap_arc(input);
let new_sort = self.try_unary_plan(expr, input, config)?.update_data(
|(new_expr, new_input)| {
LogicalPlan::Sort(Sort {
expr: new_expr,
input: Arc::new(new_input),
fetch,
})
},
);
Ok(new_sort)
}
fn try_optimize_filter(
&self,
filter: Filter,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let Filter {
predicate, input, ..
} = filter;
let input = unwrap_arc(input);
let expr = vec![predicate];
self.try_unary_plan(expr, input, config)?
.map_data(|(mut new_expr, new_input)| {
assert_eq!(new_expr.len(), 1); let new_predicate = new_expr.pop().unwrap();
Filter::try_new(new_predicate, Arc::new(new_input))
.map(LogicalPlan::Filter)
})
}
fn try_optimize_window(
&self,
window: Window,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let (mut window_exprs, mut window_schemas, mut plan) =
get_consecutive_window_exprs(window);
let mut found_common = false;
let mut expr_stats = ExprStats::new();
let arrays_per_window = window_exprs
.iter()
.map(|window_expr| {
self.to_arrays(window_expr, &mut expr_stats, ExprMask::Normal)
.map(|(fc, id_arrays)| {
found_common |= fc;
id_arrays
})
})
.collect::<Result<Vec<_>>>()?;
if found_common {
let name_preserver = NamePreserver::new(&plan);
let mut saved_names = window_exprs
.iter()
.map(|exprs| {
exprs
.iter()
.map(|expr| name_preserver.save(expr))
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()?;
assert_eq!(window_exprs.len(), arrays_per_window.len());
let num_window_exprs = window_exprs.len();
let rewritten_window_exprs = self.rewrite_expr(
window_exprs.clone(),
arrays_per_window,
plan,
&expr_stats,
config,
)?;
let transformed = rewritten_window_exprs.transformed;
assert!(transformed);
let (mut new_expr, new_input) = rewritten_window_exprs.data;
let mut plan = new_input;
assert_eq!(num_window_exprs, new_expr.len());
assert_eq!(num_window_exprs, saved_names.len());
while let (Some(new_window_expr), Some(saved_names)) =
(new_expr.pop(), saved_names.pop())
{
assert_eq!(new_window_expr.len(), saved_names.len());
let new_window_expr = new_window_expr
.into_iter()
.zip(saved_names.into_iter())
.map(|(new_window_expr, saved_name)| {
saved_name.restore(new_window_expr)
})
.collect::<Result<Vec<_>>>()?;
plan = LogicalPlan::Window(Window::try_new(
new_window_expr,
Arc::new(plan),
)?);
}
Ok(Transformed::new_transformed(plan, transformed))
} else {
while let (Some(window_expr), Some(schema)) =
(window_exprs.pop(), window_schemas.pop())
{
plan = LogicalPlan::Window(Window {
input: Arc::new(plan),
window_expr,
schema,
});
}
Ok(Transformed::no(plan))
}
}
fn try_optimize_aggregate(
&self,
aggregate: Aggregate,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let Aggregate {
group_expr,
aggr_expr,
input,
schema: orig_schema,
..
} = aggregate;
let mut transformed = false;
let name_perserver = NamePreserver::new_for_projection();
let saved_names = aggr_expr
.iter()
.map(|expr| name_perserver.save(expr))
.collect::<Result<Vec<_>>>()?;
let mut expr_stats = ExprStats::new();
let (group_found_common, group_arrays) =
self.to_arrays(&group_expr, &mut expr_stats, ExprMask::Normal)?;
let (aggr_found_common, aggr_arrays) =
self.to_arrays(&aggr_expr, &mut expr_stats, ExprMask::Normal)?;
let (new_aggr_expr, new_group_expr, new_input) =
if group_found_common || aggr_found_common {
let rewritten = self.rewrite_expr(
vec![group_expr.clone(), aggr_expr.clone()],
vec![group_arrays, aggr_arrays],
unwrap_arc(input),
&expr_stats,
config,
)?;
assert!(rewritten.transformed);
transformed |= rewritten.transformed;
let (mut new_expr, new_input) = rewritten.data;
let new_aggr_expr = pop_expr(&mut new_expr)?;
let new_group_expr = pop_expr(&mut new_expr)?;
(new_aggr_expr, new_group_expr, Arc::new(new_input))
} else {
(aggr_expr, group_expr, input)
};
let mut expr_stats = ExprStats::new();
let (aggr_found_common, aggr_arrays) = self.to_arrays(
&new_aggr_expr,
&mut expr_stats,
ExprMask::NormalAndAggregates,
)?;
if aggr_found_common {
let mut common_exprs = CommonExprs::new();
let mut rewritten_exprs = self.rewrite_exprs_list(
vec![new_aggr_expr.clone()],
vec![aggr_arrays],
&expr_stats,
&mut common_exprs,
&config.alias_generator(),
)?;
assert!(rewritten_exprs.transformed);
let rewritten = pop_expr(&mut rewritten_exprs.data)?;
assert!(!common_exprs.is_empty());
let mut agg_exprs = common_exprs
.into_values()
.map(|(expr, expr_alias)| expr.alias(expr_alias))
.collect::<Vec<_>>();
let new_input_schema = Arc::clone(new_input.schema());
let mut proj_exprs = vec![];
for expr in &new_group_expr {
extract_expressions(expr, &new_input_schema, &mut proj_exprs)?
}
for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) {
if expr_rewritten == expr_orig {
if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten {
agg_exprs.push(expr.alias(&name));
proj_exprs.push(Expr::Column(Column::from_name(name)));
} else {
let expr_alias = config.alias_generator().next(CSE_PREFIX);
let (qualifier, field) =
expr_rewritten.to_field(&new_input_schema)?;
let out_name = qualified_name(qualifier.as_ref(), field.name());
agg_exprs.push(expr_rewritten.alias(&expr_alias));
proj_exprs.push(
Expr::Column(Column::from_name(expr_alias)).alias(out_name),
);
}
} else {
proj_exprs.push(expr_rewritten);
}
}
let agg = LogicalPlan::Aggregate(Aggregate::try_new(
new_input,
new_group_expr,
agg_exprs,
)?);
Projection::try_new(proj_exprs, Arc::new(agg))
.map(LogicalPlan::Projection)
.map(Transformed::yes)
} else {
let new_aggr_expr = new_aggr_expr
.into_iter()
.zip(saved_names.into_iter())
.map(|(new_expr, saved_name)| saved_name.restore(new_expr))
.collect::<Result<Vec<Expr>>>()?;
let new_agg = if transformed {
Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)?
} else {
Aggregate::try_new_with_schema(
new_input,
new_group_expr,
new_aggr_expr,
orig_schema,
)?
};
let new_agg = LogicalPlan::Aggregate(new_agg);
Ok(Transformed::new_transformed(new_agg, transformed))
}
}
fn try_unary_plan(
&self,
expr: Vec<Expr>,
input: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
let mut expr_stats = ExprStats::new();
let (found_common, id_arrays) =
self.to_arrays(&expr, &mut expr_stats, ExprMask::Normal)?;
if found_common {
let rewritten = self.rewrite_expr(
vec![expr.clone()],
vec![id_arrays],
input,
&expr_stats,
config,
)?;
assert!(rewritten.transformed);
rewritten.map_data(|(mut new_expr, new_input)| {
assert_eq!(new_expr.len(), 1);
Ok((new_expr.pop().unwrap(), new_input))
})
} else {
Ok(Transformed::no((expr, input)))
}
}
}
fn get_consecutive_window_exprs(
window: Window,
) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
let mut window_exprs = vec![];
let mut window_schemas = vec![];
let mut plan = LogicalPlan::Window(window);
while let LogicalPlan::Window(Window {
input,
window_expr,
schema,
}) = plan
{
window_exprs.push(window_expr);
window_schemas.push(schema);
plan = unwrap_arc(input);
}
(window_exprs, window_schemas, plan)
}
impl OptimizerRule for CommonSubexprEliminate {
fn supports_rewrite(&self) -> bool {
true
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let original_schema = Arc::clone(plan.schema());
let optimized_plan = match plan {
LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Repartition(_)
| LogicalPlan::Union(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::Values(_)
| LogicalPlan::EmptyRelation(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Statement(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Extension(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Prepare(_) => {
Transformed::no(plan)
}
};
if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
{
optimized_plan.map_data(|optimized_plan| {
build_recover_project_plan(&original_schema, optimized_plan)
})
} else {
Ok(optimized_plan)
}
}
fn name(&self) -> &str {
"common_sub_expression_eliminate"
}
}
impl Default for CommonSubexprEliminate {
fn default() -> Self {
Self::new()
}
}
fn pop_expr(new_expr: &mut Vec<Vec<Expr>>) -> Result<Vec<Expr>> {
new_expr
.pop()
.ok_or_else(|| internal_datafusion_err!("Failed to pop expression"))
}
fn build_common_expr_project_plan(
input: LogicalPlan,
common_exprs: CommonExprs,
) -> Result<LogicalPlan> {
let mut fields_set = BTreeSet::new();
let mut project_exprs = common_exprs
.into_values()
.map(|(expr, expr_alias)| {
fields_set.insert(expr_alias.clone());
Ok(expr.alias(expr_alias))
})
.collect::<Result<Vec<_>>>()?;
for (qualifier, field) in input.schema().iter() {
if fields_set.insert(qualified_name(qualifier, field.name())) {
project_exprs.push(Expr::from((qualifier, field)));
}
}
Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
}
fn build_recover_project_plan(
schema: &DFSchema,
input: LogicalPlan,
) -> Result<LogicalPlan> {
let col_exprs = schema.iter().map(Expr::from).collect();
Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
}
fn extract_expressions(
expr: &Expr,
schema: &DFSchema,
result: &mut Vec<Expr>,
) -> Result<()> {
if let Expr::GroupingSet(groupings) = expr {
for e in groupings.distinct_expr() {
let (qualifier, field) = e.to_field(schema)?;
let col = Column::new(qualifier, field.name());
result.push(Expr::Column(col))
}
} else {
let (qualifier, field) = expr.to_field(schema)?;
let col = Column::new(qualifier, field.name());
result.push(Expr::Column(col));
}
Ok(())
}
#[derive(Debug, Clone, Copy)]
enum ExprMask {
Normal,
NormalAndAggregates,
}
impl ExprMask {
fn ignores(&self, expr: &Expr) -> bool {
let is_normal_minus_aggregates = matches!(
expr,
Expr::Literal(..)
| Expr::Column(..)
| Expr::ScalarVariable(..)
| Expr::Alias(..)
| Expr::Sort { .. }
| Expr::Wildcard { .. }
);
let is_aggr = matches!(expr, Expr::AggregateFunction(..));
match self {
Self::Normal => is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_normal_minus_aggregates,
}
}
}
struct ExprIdentifierVisitor<'a, 'n> {
expr_stats: &'a mut ExprStats<'n>,
id_array: &'a mut IdArray<'n>,
visit_stack: Vec<VisitRecord<'n>>,
down_index: usize,
up_index: usize,
expr_mask: ExprMask,
random_state: &'a RandomState,
found_common: bool,
conditional: bool,
}
enum VisitRecord<'n> {
EnterMark(usize),
ExprItem(Identifier<'n>, bool),
}
impl<'n> ExprIdentifierVisitor<'_, 'n> {
fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n>>, bool) {
let mut expr_id = None;
let mut is_valid = true;
while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(down_index) => {
return (down_index, expr_id, is_valid);
}
VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => {
expr_id = Some(sub_expr_id.combine(expr_id));
is_valid &= sub_expr_is_valid;
}
}
}
unreachable!("Enter mark should paired with node number");
}
fn conditionally<F: FnMut(&mut Self) -> Result<()>>(
&mut self,
mut f: F,
) -> Result<()> {
let conditional = self.conditional;
self.conditional = true;
f(self)?;
self.conditional = conditional;
Ok(())
}
}
impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
type Node = Expr;
fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
self.id_array.push((0, None));
self.visit_stack
.push(VisitRecord::EnterMark(self.down_index));
self.down_index += 1;
Ok(match expr {
_ if self.conditional => TreeNodeRecursion::Continue,
Expr::ScalarFunction(ScalarFunction { func, args })
if func.short_circuits() =>
{
self.conditionally(|visitor| {
args.iter().try_for_each(|e| e.visit(visitor).map(|_| ()))
})?;
TreeNodeRecursion::Jump
}
Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::And | Operator::Or,
right,
}) => {
left.visit(self)?;
self.conditionally(|visitor| right.visit(visitor).map(|_| ()))?;
TreeNodeRecursion::Jump
}
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
}) => {
expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?;
when_then_expr.iter().take(1).try_for_each(|(when, then)| {
when.visit(self)?;
self.conditionally(|visitor| then.visit(visitor).map(|_| ()))
})?;
self.conditionally(|visitor| {
when_then_expr.iter().skip(1).try_for_each(|(when, then)| {
when.visit(visitor)?;
then.visit(visitor).map(|_| ())
})?;
else_expr
.iter()
.try_for_each(|e| e.visit(visitor).map(|_| ()))
})?;
TreeNodeRecursion::Jump
}
_ => TreeNodeRecursion::Continue,
})
}
fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
let (down_index, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark();
let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id);
let is_valid = !expr.is_volatile_node() && sub_expr_is_valid;
self.id_array[down_index].0 = self.up_index;
if is_valid && !self.expr_mask.ignores(expr) {
self.id_array[down_index].1 = Some(expr_id);
let (count, conditional_count) =
self.expr_stats.entry(expr_id).or_insert((0, 0));
if self.conditional {
*conditional_count += 1;
} else {
*count += 1;
}
if *count > 1 || (*count == 1 && *conditional_count > 0) {
self.found_common = true;
}
}
self.visit_stack
.push(VisitRecord::ExprItem(expr_id, is_valid));
self.up_index += 1;
Ok(TreeNodeRecursion::Continue)
}
}
struct CommonSubexprRewriter<'a, 'n> {
expr_stats: &'a ExprStats<'n>,
id_array: &'a IdArray<'n>,
common_exprs: &'a mut CommonExprs<'n>,
down_index: usize,
alias_counter: usize,
alias_generator: &'a AliasGenerator,
}
impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
type Node = Expr;
fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
if matches!(expr, Expr::Alias(_)) {
self.alias_counter += 1;
}
let (up_index, expr_id) = self.id_array[self.down_index];
self.down_index += 1;
if let Some(expr_id) = expr_id {
let (count, conditional_count) = self.expr_stats.get(&expr_id).unwrap();
if *count > 1 || *count == 1 && *conditional_count > 0 {
while self.down_index < self.id_array.len()
&& self.id_array[self.down_index].0 < up_index
{
self.down_index += 1;
}
let expr_name = expr.display_name()?;
let (_, expr_alias) =
self.common_exprs.entry(expr_id).or_insert_with(|| {
let expr_alias = self.alias_generator.next(CSE_PREFIX);
(expr, expr_alias)
});
let rewritten = if self.alias_counter > 0 {
col(expr_alias.clone())
} else {
self.alias_counter += 1;
col(expr_alias.clone()).alias(expr_name)
};
return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
}
}
Ok(Transformed::no(expr))
}
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
if matches!(expr, Expr::Alias(_)) {
self.alias_counter -= 1
}
Ok(Transformed::no(expr))
}
}
fn replace_common_expr<'n>(
expr: Expr,
id_array: &IdArray<'n>,
expr_stats: &ExprStats<'n>,
common_exprs: &mut CommonExprs<'n>,
alias_generator: &AliasGenerator,
) -> Result<Transformed<Expr>> {
if id_array.is_empty() {
Ok(Transformed::no(expr))
} else {
expr.rewrite(&mut CommonSubexprRewriter {
expr_stats,
id_array,
common_exprs,
down_index: 0,
alias_counter: 0,
alias_generator,
})
}
}
#[cfg(test)]
mod test {
use std::any::Any;
use std::collections::HashSet;
use std::iter;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::logical_plan::{table_scan, JoinType};
use datafusion_expr::{
grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr,
ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF,
Volatility,
};
use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
use crate::optimizer::OptimizerContext;
use crate::test::*;
use datafusion_expr::test::function_stub::{avg, sum};
use super::*;
fn assert_non_optimized_plan_eq(
expected: &str,
plan: LogicalPlan,
config: Option<&dyn OptimizerConfig>,
) {
assert_eq!(expected, format!("{plan}"), "Unexpected starting plan");
let optimizer = CommonSubexprEliminate::new();
let default_config = OptimizerContext::new();
let config = config.unwrap_or(&default_config);
let optimized_plan = optimizer.rewrite(plan, config).unwrap();
assert!(!optimized_plan.transformed, "unexpectedly optimize plan");
let optimized_plan = optimized_plan.data;
assert_eq!(
expected,
format!("{optimized_plan}"),
"Unexpected optimized plan"
);
}
fn assert_optimized_plan_eq(
expected: &str,
plan: LogicalPlan,
config: Option<&dyn OptimizerConfig>,
) {
let optimizer = CommonSubexprEliminate::new();
let default_config = OptimizerContext::new();
let config = config.unwrap_or(&default_config);
let optimized_plan = optimizer.rewrite(plan, config).unwrap();
assert!(optimized_plan.transformed, "failed to optimize plan");
let optimized_plan = optimized_plan.data;
let formatted_plan = format!("{optimized_plan}");
assert_eq!(expected, formatted_plan);
}
#[test]
fn id_array_visitor() -> Result<()> {
let optimizer = CommonSubexprEliminate::new();
let a_plus_1 = col("a") + lit(1);
let avg_c = avg(col("c"));
let sum_a_plus_1 = sum(a_plus_1);
let sum_a_plus_1_minus_avg_c = sum_a_plus_1 - avg_c;
let expr = sum_a_plus_1_minus_avg_c * lit(2);
let Expr::BinaryExpr(BinaryExpr {
left: sum_a_plus_1_minus_avg_c,
..
}) = &expr
else {
panic!("Cannot extract subexpression reference")
};
let Expr::BinaryExpr(BinaryExpr {
left: sum_a_plus_1,
right: avg_c,
..
}) = sum_a_plus_1_minus_avg_c.as_ref()
else {
panic!("Cannot extract subexpression reference")
};
let Expr::AggregateFunction(AggregateFunction {
args: a_plus_1_vec, ..
}) = sum_a_plus_1.as_ref()
else {
panic!("Cannot extract subexpression reference")
};
let a_plus_1 = &a_plus_1_vec.as_slice()[0];
let mut id_array = vec![];
optimizer.expr_to_identifier(
&expr,
&mut ExprStats::new(),
&mut id_array,
ExprMask::Normal,
)?;
fn collect_hashes(id_array: &mut IdArray) -> HashSet<u64> {
id_array
.iter_mut()
.flat_map(|(_, expr_id_option)| {
expr_id_option.as_mut().map(|expr_id| {
let hash = expr_id.hash;
expr_id.hash = 0;
hash
})
})
.collect::<HashSet<_>>()
}
let hashes = collect_hashes(&mut id_array);
assert_eq!(hashes.len(), 3);
let expected = vec![
(
8,
Some(Identifier {
hash: 0,
expr: &expr,
}),
),
(
6,
Some(Identifier {
hash: 0,
expr: sum_a_plus_1_minus_avg_c,
}),
),
(3, None),
(
2,
Some(Identifier {
hash: 0,
expr: a_plus_1,
}),
),
(0, None),
(1, None),
(5, None),
(4, None),
(7, None),
];
assert_eq!(expected, id_array);
let mut id_array = vec![];
optimizer.expr_to_identifier(
&expr,
&mut ExprStats::new(),
&mut id_array,
ExprMask::NormalAndAggregates,
)?;
let hashes = collect_hashes(&mut id_array);
assert_eq!(hashes.len(), 5);
let expected = vec![
(
8,
Some(Identifier {
hash: 0,
expr: &expr,
}),
),
(
6,
Some(Identifier {
hash: 0,
expr: sum_a_plus_1_minus_avg_c,
}),
),
(
3,
Some(Identifier {
hash: 0,
expr: sum_a_plus_1,
}),
),
(
2,
Some(Identifier {
hash: 0,
expr: a_plus_1,
}),
),
(0, None),
(1, None),
(
5,
Some(Identifier {
hash: 0,
expr: avg_c,
}),
),
(4, None),
(7, None),
];
assert_eq!(expected, id_array);
Ok(())
}
#[test]
fn tpch_q1_simplified() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
iter::empty::<Expr>(),
vec![
sum(col("a") * (lit(1) - col("b"))),
sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
],
)?
.build()?;
let expected = "Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\
\n Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
#[test]
fn nested_aliases() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
(col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
col("a") + col("b"),
])?
.build()?;
let expected = "Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b\
\n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
#[test]
fn aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
let return_type = DataType::UInt32;
let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
let udf_agg = |inner: Expr| {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
"my_agg",
Signature::exact(vec![DataType::UInt32], Volatility::Stable),
return_type.clone(),
Arc::clone(&accumulator),
vec![Field::new("value", DataType::UInt32, true)],
))),
vec![inner],
false,
None,
None,
None,
))
};
let plan = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(
iter::empty::<Expr>(),
vec![
avg(col("a")).alias("col1"),
avg(col("a")).alias("col2"),
avg(col("b")).alias("col3"),
avg(col("c")),
udf_agg(col("a")).alias("col4"),
udf_agg(col("a")).alias("col5"),
udf_agg(col("b")).alias("col6"),
udf_agg(col("c")),
],
)?
.build()?;
let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\
\n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
let plan = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(
iter::empty::<Expr>(),
vec![
lit(1) + avg(col("a")),
lit(1) - avg(col("a")),
lit(1) + udf_agg(col("a")),
lit(1) - udf_agg(col("a")),
],
)?
.build()?;
let expected = "Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\
\n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
let plan = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(
iter::empty::<Expr>(),
vec![
avg(lit(1u32) + col("a")).alias("col1"),
udf_agg(lit(1u32) + col("a")).alias("col2"),
],
)?
.build()?;
let expected ="Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\
\n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
let plan = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(
vec![lit(1u32) + col("a")],
vec![
avg(lit(1u32) + col("a")).alias("col1"),
udf_agg(lit(1u32) + col("a")).alias("col2"),
],
)?
.build()?;
let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\
\n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![lit(1u32) + col("a")],
vec![
(lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
(lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
avg(lit(1u32) + col("a")),
(lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
(lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
udf_agg(lit(1u32) + col("a")),
],
)?
.build()?;
let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\
\n Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\
\n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
#[test]
fn aggregate_with_relations_and_dots() -> Result<()> {
let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col_a.clone()],
vec![
(lit(1u32) + avg(lit(1u32) + col_a.clone())),
avg(lit(1u32) + col_a),
],
)?
.build()?;
let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)\
\n Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\
\n Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\
\n TableScan: table.test";
assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
#[test]
fn subexpr_in_same_order() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
(lit(1) + col("a")).alias("first"),
(lit(1) + col("a")).alias("second"),
])?
.build()?;
let expected = "Projection: __common_expr_1 AS first, __common_expr_1 AS second\
\n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
#[test]
fn subexpr_in_different_order() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![lit(1) + col("a"), col("a") + lit(1)])?
.build()?;
let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
\n TableScan: test";
assert_non_optimized_plan_eq(expected, plan, None);
Ok(())
}
#[test]
fn cross_plans_subexpr() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![lit(1) + col("a"), col("a")])?
.project(vec![lit(1) + col("a")])?
.build()?;
let expected = "Projection: Int32(1) + test.a\
\n Projection: Int32(1) + test.a, test.a\
\n TableScan: test";
assert_non_optimized_plan_eq(expected, plan, None);
Ok(())
}
fn test_identifier(hash: u64, expr: &Expr) -> Identifier {
Identifier { hash, expr }
}
#[test]
fn redundant_project_fields() {
let table_scan = test_table_scan().unwrap();
let c_plus_a = col("c") + col("a");
let b_plus_a = col("b") + col("a");
let common_exprs_1 = CommonExprs::from([
(
test_identifier(0, &c_plus_a),
(c_plus_a.clone(), format!("{CSE_PREFIX}_1")),
),
(
test_identifier(1, &b_plus_a),
(b_plus_a.clone(), format!("{CSE_PREFIX}_2")),
),
]);
let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
let common_exprs_2 = CommonExprs::from([
(
test_identifier(3, &c_plus_a_2),
(c_plus_a_2.clone(), format!("{CSE_PREFIX}_3")),
),
(
test_identifier(4, &b_plus_a_2),
(b_plus_a_2.clone(), format!("{CSE_PREFIX}_4")),
),
]);
let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
let mut field_set = BTreeSet::new();
for name in project_2.schema().field_names() {
assert!(field_set.insert(name));
}
}
#[test]
fn redundant_project_fields_join_input() {
let table_scan_1 = test_table_scan_with_name("test1").unwrap();
let table_scan_2 = test_table_scan_with_name("test2").unwrap();
let join = LogicalPlanBuilder::from(table_scan_1)
.join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
.unwrap()
.build()
.unwrap();
let c_plus_a = col("test1.c") + col("test1.a");
let b_plus_a = col("test1.b") + col("test1.a");
let common_exprs_1 = CommonExprs::from([
(
test_identifier(0, &c_plus_a),
(c_plus_a.clone(), format!("{CSE_PREFIX}_1")),
),
(
test_identifier(1, &b_plus_a),
(b_plus_a.clone(), format!("{CSE_PREFIX}_2")),
),
]);
let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
let common_exprs_2 = CommonExprs::from([
(
test_identifier(3, &c_plus_a_2),
(c_plus_a_2.clone(), format!("{CSE_PREFIX}_3")),
),
(
test_identifier(4, &b_plus_a_2),
(b_plus_a_2.clone(), format!("{CSE_PREFIX}_4")),
),
]);
let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
let mut field_set = BTreeSet::new();
for name in project_2.schema().field_names() {
assert!(field_set.insert(name));
}
}
#[test]
fn eliminated_subexpr_datatype() {
use datafusion_expr::cast;
let schema = Schema::new(vec![
Field::new("a", DataType::UInt64, false),
Field::new("b", DataType::UInt64, false),
Field::new("c", DataType::UInt64, false),
]);
let plan = table_scan(Some("table"), &schema, None)
.unwrap()
.filter(
cast(col("a"), DataType::Int64)
.lt(lit(1_i64))
.and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
)
.unwrap()
.build()
.unwrap();
let rule = CommonSubexprEliminate::new();
let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
assert!(optimized_plan.transformed);
let optimized_plan = optimized_plan.data;
let schema = optimized_plan.schema();
let fields_with_datatypes: Vec<_> = schema
.fields()
.iter()
.map(|field| (field.name(), field.data_type()))
.collect();
let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
let expected = r#"[
(
"a",
UInt64,
),
(
"b",
UInt64,
),
(
"c",
UInt64,
),
]"#;
assert_eq!(expected, formatted_fields_with_datatype);
}
#[test]
fn filter_schema_changed() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
.build()?;
let expected = "Projection: test.a, test.b, test.c\
\n Filter: __common_expr_1 - Int32(10) > __common_expr_1\
\n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
#[test]
fn test_extract_expressions_from_grouping_set() -> Result<()> {
let mut result = Vec::with_capacity(3);
let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
let schema = DFSchema::from_unqualified_fields(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]
.into(),
HashMap::default(),
)?;
extract_expressions(&grouping, &schema, &mut result)?;
assert!(result.len() == 3);
Ok(())
}
#[test]
fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
let mut result = Vec::with_capacity(2);
let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
let schema = DFSchema::from_unqualified_fields(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]
.into(),
HashMap::default(),
)?;
extract_expressions(&grouping, &schema, &mut result)?;
assert!(result.len() == 2);
Ok(())
}
#[test]
fn test_alias_collision() -> Result<()> {
let table_scan = test_table_scan()?;
let config = &OptimizerContext::new();
let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
(col("a") + col("b")).alias(common_expr_1.clone()),
col("c"),
])?
.project(vec![
col(common_expr_1.clone()).alias("c1"),
col(common_expr_1).alias("c2"),
(col("c") + lit(2)).alias("c3"),
(col("c") + lit(2)).alias("c4"),
])?
.build()?;
let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4\
\n Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c\
\n Projection: test.a + test.b AS __common_expr_1, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, Some(config));
let config = &OptimizerContext::new();
let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
(col("a") + col("b")).alias(common_expr_2.clone()),
col("c"),
])?
.project(vec![
col(common_expr_2.clone()).alias("c1"),
col(common_expr_2).alias("c2"),
(col("c") + lit(2)).alias("c3"),
(col("c") + lit(2)).alias("c4"),
])?
.build()?;
let expected = "Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4\
\n Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c\
\n Projection: test.a + test.b AS __common_expr_2, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, Some(config));
Ok(())
}
#[test]
fn test_extract_expressions_from_col() -> Result<()> {
let mut result = Vec::with_capacity(1);
let schema = DFSchema::from_unqualified_fields(
vec![Field::new("a", DataType::Int32, false)].into(),
HashMap::default(),
)?;
extract_expressions(&col("a"), &schema, &mut result)?;
assert!(result.len() == 1);
Ok(())
}
#[test]
fn test_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;
let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
extracted_short_circuit.clone().alias("c1"),
extracted_short_circuit.alias("c2"),
extracted_short_circuit_leg_1
.clone()
.or(not_extracted_short_circuit_leg_2.clone())
.alias("c3"),
extracted_short_circuit_leg_1
.and(not_extracted_short_circuit_leg_2)
.alias("c4"),
extracted_short_circuit_leg_3
.clone()
.or(extracted_short_circuit_leg_3.clone())
.alias("c5"),
])?
.build()?;
let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5\
\n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
#[test]
fn test_volatile() -> Result<()> {
let table_scan = test_table_scan()?;
let extracted_child = col("a") + col("b");
let rand = rand_func().call(vec![]);
let not_extracted_volatile = extracted_child + rand;
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
not_extracted_volatile.clone().alias("c1"),
not_extracted_volatile.alias("c2"),
])?
.build()?;
let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\
\n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
#[test]
fn test_volatile_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;
let rand = rand_func().call(vec![]);
let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
let not_extracted_volatile_short_circuit_1 =
extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
let not_extracted_volatile_short_circuit_2 =
rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
not_extracted_volatile_short_circuit_1.clone().alias("c1"),
not_extracted_volatile_short_circuit_1.alias("c2"),
not_extracted_volatile_short_circuit_2.clone().alias("c3"),
not_extracted_volatile_short_circuit_2.alias("c4"),
])?
.build()?;
let expected = "Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\
\n Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
fn rand_func() -> ScalarUDF {
ScalarUDF::new_from_impl(RandomStub::new())
}
#[derive(Debug)]
struct RandomStub {
signature: Signature,
}
impl RandomStub {
fn new() -> Self {
Self {
signature: Signature::exact(vec![], Volatility::Volatile),
}
}
}
impl ScalarUDFImpl for RandomStub {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"random"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!()
}
}
}