use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::expr::{Alias, Sort, WindowFunction};
use crate::expr_rewriter::strip_outer_reference;
use crate::signature::{Signature, TypeSignature};
use crate::{
and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator,
};
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::utils::get_at_indices;
use datafusion_common::{
internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, Result,
TableReference,
};
use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions};
pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
#[deprecated(since = "40.0.0", note = "Expr::add_column_refs instead")]
pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
for e in expr {
expr_to_columns(e, accum)?;
}
Ok(())
}
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
Ok(grouping_set.distinct_expr().len())
} else {
Ok(group_expr.len())
}
}
fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
if slice.len() >= 64 {
return Err("The size of the set must be less than 64.".into());
}
let mut v = Vec::new();
for mask in 0..(1 << slice.len()) {
let mut ss = vec![];
let mut bitset = mask;
while bitset > 0 {
let rightmost: u64 = bitset & !(bitset - 1);
let idx = rightmost.trailing_zeros();
let item = slice.get(idx as usize).unwrap();
ss.push(item);
bitset &= bitset - 1;
}
v.push(ss);
}
Ok(v)
}
fn check_grouping_set_size_limit(size: usize) -> Result<()> {
let max_grouping_set_size = 65535;
if size > max_grouping_set_size {
return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
}
Ok(())
}
fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
let max_grouping_sets_size = 4096;
if size > max_grouping_sets_size {
return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
}
Ok(())
}
fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
check_grouping_set_size_limit(left.len() + right.len())?;
Ok(left.iter().chain(right.iter()).cloned().collect())
}
fn cross_join_grouping_sets<T: Clone>(
left: &[Vec<T>],
right: &[Vec<T>],
) -> Result<Vec<Vec<T>>> {
let grouping_sets_size = left.len() * right.len();
check_grouping_sets_size_limit(grouping_sets_size)?;
let mut result = Vec::with_capacity(grouping_sets_size);
for le in left {
for re in right {
result.push(merge_grouping_set(le, re)?);
}
}
Ok(result)
}
pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
let has_grouping_set = group_expr
.iter()
.any(|expr| matches!(expr, Expr::GroupingSet(_)));
if !has_grouping_set || group_expr.len() == 1 {
return Ok(group_expr);
}
let partial_sets = group_expr
.iter()
.map(|expr| {
let exprs = match expr {
Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
check_grouping_sets_size_limit(grouping_sets.len())?;
grouping_sets.iter().map(|e| e.iter().collect()).collect()
}
Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
let grouping_sets = powerset(group_exprs)
.map_err(|e| plan_datafusion_err!("{}", e))?;
check_grouping_sets_size_limit(grouping_sets.len())?;
grouping_sets
}
Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
let size = group_exprs.len();
let slice = group_exprs.as_slice();
check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
(0..(size + 1))
.map(|i| slice[0..i].iter().collect())
.collect()
}
expr => vec![vec![expr]],
};
Ok(exprs)
})
.collect::<Result<Vec<_>>>()?;
let grouping_sets = partial_sets
.into_iter()
.map(Ok)
.reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
.transpose()?
.map(|e| {
e.into_iter()
.map(|e| e.into_iter().cloned().collect())
.collect()
})
.unwrap_or_default();
Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
grouping_sets,
))])
}
pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
Ok(grouping_set.distinct_expr())
} else {
Ok(group_expr.iter().collect())
}
}
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
expr.apply(|expr| {
match expr {
Expr::Column(qc) => {
accum.insert(qc.clone());
}
Expr::Unnest(_)
| Expr::ScalarVariable(_, _)
| Expr::Alias(_)
| Expr::Literal(_)
| Expr::BinaryExpr { .. }
| Expr::Like { .. }
| Expr::SimilarTo { .. }
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Negative(_)
| Expr::Between { .. }
| Expr::Case { .. }
| Expr::Cast { .. }
| Expr::TryCast { .. }
| Expr::Sort { .. }
| Expr::ScalarFunction(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
| Expr::InList { .. }
| Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::Wildcard { .. }
| Expr::Placeholder(_)
| Expr::OuterReferenceColumn { .. } => {}
}
Ok(TreeNodeRecursion::Continue)
})
.map(|_| ())
}
fn get_excluded_columns(
opt_exclude: Option<&ExcludeSelectItem>,
opt_except: Option<&ExceptSelectItem>,
schema: &DFSchema,
qualifier: Option<&TableReference>,
) -> Result<Vec<Column>> {
let mut idents = vec![];
if let Some(excepts) = opt_except {
idents.push(&excepts.first_element);
idents.extend(&excepts.additional_elements);
}
if let Some(exclude) = opt_exclude {
match exclude {
ExcludeSelectItem::Single(ident) => idents.push(ident),
ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
}
}
let n_elem = idents.len();
let unique_idents = idents.into_iter().collect::<HashSet<_>>();
if n_elem != unique_idents.len() {
return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
}
let mut result = vec![];
for ident in unique_idents.into_iter() {
let col_name = ident.value.as_str();
let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
result.push(Column::from((qualifier, field)));
}
Ok(result)
}
fn get_exprs_except_skipped(
schema: &DFSchema,
columns_to_skip: HashSet<Column>,
) -> Vec<Expr> {
if columns_to_skip.is_empty() {
schema.iter().map(Expr::from).collect::<Vec<Expr>>()
} else {
schema
.columns()
.iter()
.filter_map(|c| {
if !columns_to_skip.contains(c) {
Some(Expr::Column(c.clone()))
} else {
None
}
})
.collect::<Vec<Expr>>()
}
}
pub fn expand_wildcard(
schema: &DFSchema,
plan: &LogicalPlan,
wildcard_options: Option<&WildcardAdditionalOptions>,
) -> Result<Vec<Expr>> {
let using_columns = plan.using_columns()?;
let mut columns_to_skip = using_columns
.into_iter()
.flat_map(|cols| {
let mut cols = cols.into_iter().collect::<Vec<_>>();
cols.sort();
let mut out_column_names: HashSet<String> = HashSet::new();
cols.into_iter()
.filter_map(|c| {
if out_column_names.contains(&c.name) {
Some(c)
} else {
out_column_names.insert(c.name);
None
}
})
.collect::<Vec<_>>()
})
.collect::<HashSet<_>>();
let excluded_columns = if let Some(WildcardAdditionalOptions {
opt_exclude,
opt_except,
..
}) = wildcard_options
{
get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
} else {
vec![]
};
columns_to_skip.extend(excluded_columns);
Ok(get_exprs_except_skipped(schema, columns_to_skip))
}
pub fn expand_qualified_wildcard(
qualifier: &TableReference,
schema: &DFSchema,
wildcard_options: Option<&WildcardAdditionalOptions>,
) -> Result<Vec<Expr>> {
let qualified_indices = schema.fields_indices_with_qualified(qualifier);
let projected_func_dependencies = schema
.functional_dependencies()
.project_functional_dependencies(&qualified_indices, qualified_indices.len());
let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
if fields_with_qualified.is_empty() {
return plan_err!("Invalid qualifier {qualifier}");
}
let qualified_schema = Arc::new(Schema::new(fields_with_qualified));
let qualified_dfschema =
DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
.with_functional_dependencies(projected_func_dependencies)?;
let excluded_columns = if let Some(WildcardAdditionalOptions {
opt_exclude,
opt_except,
..
}) = wildcard_options
{
get_excluded_columns(
opt_exclude.as_ref(),
opt_except.as_ref(),
schema,
Some(qualifier),
)?
} else {
vec![]
};
let mut columns_to_skip = HashSet::new();
columns_to_skip.extend(excluded_columns);
Ok(get_exprs_except_skipped(
&qualified_dfschema,
columns_to_skip,
))
}
type WindowSortKey = Vec<(Expr, bool)>;
pub fn generate_sort_key(
partition_by: &[Expr],
order_by: &[Expr],
) -> Result<WindowSortKey> {
let normalized_order_by_keys = order_by
.iter()
.map(|e| match e {
Expr::Sort(Sort { expr, .. }) => {
Ok(Expr::Sort(Sort::new(expr.clone(), true, false)))
}
_ => plan_err!("Order by only accepts sort expressions"),
})
.collect::<Result<Vec<_>>>()?;
let mut final_sort_keys = vec![];
let mut is_partition_flag = vec![];
partition_by.iter().for_each(|e| {
let e = e.clone().sort(true, false);
if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
let order_by_key = &order_by[pos];
if !final_sort_keys.contains(order_by_key) {
final_sort_keys.push(order_by_key.clone());
is_partition_flag.push(true);
}
} else if !final_sort_keys.contains(&e) {
final_sort_keys.push(e);
is_partition_flag.push(true);
}
});
order_by.iter().for_each(|e| {
if !final_sort_keys.contains(e) {
final_sort_keys.push(e.clone());
is_partition_flag.push(false);
}
});
let res = final_sort_keys
.into_iter()
.zip(is_partition_flag)
.collect::<Vec<_>>();
Ok(res)
}
pub fn compare_sort_expr(
sort_expr_a: &Expr,
sort_expr_b: &Expr,
schema: &DFSchemaRef,
) -> Ordering {
match (sort_expr_a, sort_expr_b) {
(
Expr::Sort(Sort {
expr: expr_a,
asc: asc_a,
nulls_first: nulls_first_a,
}),
Expr::Sort(Sort {
expr: expr_b,
asc: asc_b,
nulls_first: nulls_first_b,
}),
) => {
let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
match idx_a.cmp(idx_b) {
Ordering::Less => {
return Ordering::Less;
}
Ordering::Greater => {
return Ordering::Greater;
}
Ordering::Equal => {}
}
}
match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
Ordering::Less => return Ordering::Greater,
Ordering::Greater => {
return Ordering::Less;
}
Ordering::Equal => {}
}
match (asc_a, asc_b) {
(true, false) => {
return Ordering::Greater;
}
(false, true) => {
return Ordering::Less;
}
_ => {}
}
match (nulls_first_a, nulls_first_b) {
(true, false) => {
return Ordering::Less;
}
(false, true) => {
return Ordering::Greater;
}
_ => {}
}
Ordering::Equal
}
_ => Ordering::Equal,
}
}
pub fn group_window_expr_by_sort_keys(
window_expr: Vec<Expr>,
) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
let mut result = vec![];
window_expr.into_iter().try_for_each(|expr| match &expr {
Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => {
let sort_key = generate_sort_key(partition_by, order_by)?;
if let Some((_, values)) = result.iter_mut().find(
|group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
) {
values.push(expr);
} else {
result.push((sort_key, vec![expr]))
}
Ok(())
}
other => internal_err!(
"Impossibly got non-window expr {other:?}"
),
})?;
Ok(result)
}
pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(nested_expr, Expr::AggregateFunction { .. })
})
}
pub fn find_sort_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(nested_expr, Expr::Sort { .. })
})
}
pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(nested_expr, Expr::WindowFunction { .. })
})
}
pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
find_exprs_in_expr(expr, &|nested_expr| {
matches!(nested_expr, Expr::OuterReferenceColumn { .. })
})
}
fn find_exprs_in_exprs<F>(exprs: &[Expr], test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
exprs
.iter()
.flat_map(|expr| find_exprs_in_expr(expr, test_fn))
.fold(vec![], |mut acc, expr| {
if !acc.contains(&expr) {
acc.push(expr)
}
acc
})
}
fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
let mut exprs = vec![];
expr.apply(|expr| {
if test_fn(expr) {
if !(exprs.contains(expr)) {
exprs.push(expr.clone())
}
return Ok(TreeNodeRecursion::Jump);
}
Ok(TreeNodeRecursion::Continue)
})
.expect("no way to return error during recursion");
exprs
}
pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
where
F: FnMut(&Expr) -> Result<(), E>,
{
let mut err = Ok(());
expr.apply(|expr| {
if let Err(e) = f(expr) {
err = Err(e);
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
})
.expect("no way to return error during recursion");
err
}
#[deprecated(since = "31.0.0", note = "use LogicalPlan::with_new_exprs instead")]
pub fn from_plan(
plan: &LogicalPlan,
expr: &[Expr],
inputs: &[LogicalPlan],
) -> Result<LogicalPlan> {
plan.with_new_exprs(expr.to_vec(), inputs.to_vec())
}
pub fn exprlist_to_fields<'a>(
exprs: impl IntoIterator<Item = &'a Expr>,
plan: &LogicalPlan,
) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
let input_schema = &plan.schema();
exprs
.into_iter()
.map(|e| e.to_field(input_schema))
.collect()
}
pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
let output_exprs = match input.columnized_output_exprs() {
Ok(exprs) if !exprs.is_empty() => exprs,
_ => return Ok(e),
};
let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
e.transform_down(|node: Expr| match exprs_map.get(&node) {
Some(column) => Ok(Transformed::new(
Expr::Column(column.clone()),
true,
TreeNodeRecursion::Jump,
)),
None => Ok(Transformed::no(node)),
})
.data()
}
pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
exprs
.iter()
.flat_map(find_columns_referenced_by_expr)
.map(Expr::Column)
.collect()
}
pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
let mut exprs = vec![];
e.apply(|expr| {
if let Expr::Column(c) = expr {
exprs.push(c.clone())
}
Ok(TreeNodeRecursion::Continue)
})
.expect("Unexpected error");
exprs
}
pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
match expr {
Expr::Column(col) => {
let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
Ok(Expr::from(Column::from((qualifier, field))))
}
_ => Ok(Expr::Column(Column::from_name(expr.display_name()?))),
}
}
pub(crate) fn find_column_indexes_referenced_by_expr(
e: &Expr,
schema: &DFSchemaRef,
) -> Vec<usize> {
let mut indexes = vec![];
e.apply(|expr| {
match expr {
Expr::Column(qc) => {
if let Ok(idx) = schema.index_of_column(qc) {
indexes.push(idx);
}
}
Expr::Literal(_) => {
indexes.push(usize::MAX);
}
_ => {}
}
Ok(TreeNodeRecursion::Continue)
})
.unwrap();
indexes
}
pub fn can_hash(data_type: &DataType) -> bool {
match data_type {
DataType::Null => true,
DataType::Boolean => true,
DataType::Int8 => true,
DataType::Int16 => true,
DataType::Int32 => true,
DataType::Int64 => true,
DataType::UInt8 => true,
DataType::UInt16 => true,
DataType::UInt32 => true,
DataType::UInt64 => true,
DataType::Float32 => true,
DataType::Float64 => true,
DataType::Timestamp(time_unit, _) => match time_unit {
TimeUnit::Second => true,
TimeUnit::Millisecond => true,
TimeUnit::Microsecond => true,
TimeUnit::Nanosecond => true,
},
DataType::Utf8 => true,
DataType::LargeUtf8 => true,
DataType::Decimal128(_, _) => true,
DataType::Date32 => true,
DataType::Date64 => true,
DataType::FixedSizeBinary(_) => true,
DataType::Dictionary(key_type, value_type)
if *value_type.as_ref() == DataType::Utf8 =>
{
DataType::is_dictionary_key_type(key_type)
}
DataType::List(_) => true,
DataType::LargeList(_) => true,
DataType::FixedSizeList(_, _) => true,
DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
_ => false,
}
}
pub fn check_all_columns_from_schema(
columns: &HashSet<&Column>,
schema: &DFSchema,
) -> Result<bool> {
for col in columns.iter() {
let exist = schema.is_column_from_schema(col);
if !exist {
return Ok(false);
}
}
Ok(true)
}
pub fn find_valid_equijoin_key_pair(
left_key: &Expr,
right_key: &Expr,
left_schema: &DFSchema,
right_schema: &DFSchema,
) -> Result<Option<(Expr, Expr)>> {
let left_using_columns = left_key.column_refs();
let right_using_columns = right_key.column_refs();
if left_using_columns.is_empty() || right_using_columns.is_empty() {
return Ok(None);
}
if check_all_columns_from_schema(&left_using_columns, left_schema)?
&& check_all_columns_from_schema(&right_using_columns, right_schema)?
{
return Ok(Some((left_key.clone(), right_key.clone())));
} else if check_all_columns_from_schema(&right_using_columns, left_schema)?
&& check_all_columns_from_schema(&left_using_columns, right_schema)?
{
return Ok(Some((right_key.clone(), left_key.clone())));
}
Ok(None)
}
pub fn generate_signature_error_msg(
func_name: &str,
func_signature: Signature,
input_expr_types: &[DataType],
) -> String {
let candidate_signatures = func_signature
.type_signature
.to_string_repr()
.iter()
.map(|args_str| format!("\t{func_name}({args_str})"))
.collect::<Vec<String>>()
.join("\n");
format!(
"No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
)
}
pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
split_conjunction_impl(expr, vec![])
}
fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr {
right,
op: Operator::And,
left,
}) => {
let exprs = split_conjunction_impl(left, exprs);
split_conjunction_impl(right, exprs)
}
Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
other => {
exprs.push(other);
exprs
}
}
}
pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
split_binary_owned(expr, Operator::And)
}
pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
split_binary_owned_impl(expr, op, vec![])
}
fn split_binary_owned_impl(
expr: Expr,
operator: Operator,
mut exprs: Vec<Expr>,
) -> Vec<Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
let exprs = split_binary_owned_impl(*left, operator, exprs);
split_binary_owned_impl(*right, operator, exprs)
}
Expr::Alias(Alias { expr, .. }) => {
split_binary_owned_impl(*expr, operator, exprs)
}
other => {
exprs.push(other);
exprs
}
}
}
pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
split_binary_impl(expr, op, vec![])
}
fn split_binary_impl<'a>(
expr: &'a Expr,
operator: Operator,
mut exprs: Vec<&'a Expr>,
) -> Vec<&'a Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
let exprs = split_binary_impl(left, operator, exprs);
split_binary_impl(right, operator, exprs)
}
Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
other => {
exprs.push(other);
exprs
}
}
}
pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(Expr::and)
}
pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(Expr::or)
}
pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
let predicate = predicates
.iter()
.skip(1)
.fold(predicates[0].clone(), |acc, predicate| {
and(acc, (*predicate).to_owned())
});
Ok(LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(plan),
)?))
}
pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
let mut joins = vec![];
let mut others = vec![];
for filter in exprs.into_iter() {
if filter.contains_outer() {
if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
{
joins.push(strip_outer_reference((*filter).clone()));
}
} else {
others.push((*filter).clone());
}
}
Ok((joins, others))
}
pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
match slice {
[it] => Ok(it),
[] => plan_err!("No items found!"),
_ => plan_err!("More than one item found!"),
}
}
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
if inputs.len() == 1 {
inputs[0].schema().as_ref().clone()
} else {
inputs.iter().map(|input| input.schema()).fold(
DFSchema::empty(),
|mut lhs, rhs| {
lhs.merge(rhs);
lhs
},
)
}
}
pub fn format_state_name(name: &str, state_name: &str) -> String {
format!("{name}[{state_name}]")
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum AggregateOrderSensitivity {
Insensitive,
HardRequirement,
Beneficial,
}
impl AggregateOrderSensitivity {
pub fn is_insensitive(&self) -> bool {
self.eq(&AggregateOrderSensitivity::Insensitive)
}
pub fn is_beneficial(&self) -> bool {
self.eq(&AggregateOrderSensitivity::Beneficial)
}
pub fn hard_requires(&self) -> bool {
self.eq(&AggregateOrderSensitivity::HardRequirement)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup,
test::function_stub::max_udaf, test::function_stub::min_udaf,
test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFrame,
WindowFunctionDefinition,
};
#[test]
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
let result = group_window_expr_by_sort_keys(vec![])?;
let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(min_udaf()),
vec![col("name")],
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
vec![col("age")],
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
let key = vec![];
let expected: Vec<(WindowSortKey, Vec<Expr>)> =
vec![(key, vec![max1, max2, min3, sum4])];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_group_window_expr_by_sort_keys() -> Result<()> {
let age_asc = Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true));
let name_desc = Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true));
let created_at_desc =
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true));
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
))
.order_by(vec![age_asc.clone(), name_desc.clone()])
.build()
.unwrap();
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(min_udaf()),
vec![col("name")],
))
.order_by(vec![age_asc.clone(), name_desc.clone()])
.build()
.unwrap();
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
vec![col("age")],
))
.order_by(vec![
name_desc.clone(),
age_asc.clone(),
created_at_desc.clone(),
])
.build()
.unwrap();
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
let key2 = vec![];
let key3 = vec![
(name_desc, false),
(age_asc, false),
(created_at_desc, false),
];
let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
(key1, vec![max1, min3]),
(key2, vec![max2]),
(key3, vec![sum4]),
];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_find_sort_exprs() -> Result<()> {
let exprs = &[
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
))
.order_by(vec![
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
])
.window_frame(WindowFrame::new(Some(false)))
.build()
.unwrap(),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
vec![col("age")],
))
.order_by(vec![
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
])
.window_frame(WindowFrame::new(Some(false)))
.build()
.unwrap(),
];
let expected = vec![
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
];
let result = find_sort_exprs(exprs);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn avoid_generate_duplicate_sort_keys() -> Result<()> {
let asc_or_desc = [true, false];
let nulls_first_or_last = [true, false];
let partition_by = &[col("age"), col("name"), col("created_at")];
for asc_ in asc_or_desc {
for nulls_first_ in nulls_first_or_last {
let order_by = &[
Expr::Sort(Sort {
expr: Box::new(col("age")),
asc: asc_,
nulls_first: nulls_first_,
}),
Expr::Sort(Sort {
expr: Box::new(col("name")),
asc: asc_,
nulls_first: nulls_first_,
}),
];
let expected = vec![
(
Expr::Sort(Sort {
expr: Box::new(col("age")),
asc: asc_,
nulls_first: nulls_first_,
}),
true,
),
(
Expr::Sort(Sort {
expr: Box::new(col("name")),
asc: asc_,
nulls_first: nulls_first_,
}),
true,
),
(
Expr::Sort(Sort {
expr: Box::new(col("created_at")),
asc: true,
nulls_first: false,
}),
true,
),
];
let result = generate_sort_key(partition_by, order_by)?;
assert_eq!(expected, result);
}
}
Ok(())
}
#[test]
fn test_enumerate_grouping_sets() -> Result<()> {
let multi_cols = vec![col("col1"), col("col2"), col("col3")];
let simple_col = col("simple_col");
let cube = cube(multi_cols.clone());
let rollup = rollup(multi_cols.clone());
let grouping_set = grouping_set(vec![multi_cols]);
let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!("[simple_col]", &result);
let sets = enumerate_grouping_sets(vec![cube.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!("[CUBE (col1, col2, col3)]", &result);
let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col), \
(simple_col, col1), \
(simple_col, col2), \
(simple_col, col1, col2), \
(simple_col, col3), \
(simple_col, col1, col3), \
(simple_col, col2, col3), \
(simple_col, col1, col2, col3))]",
&result
);
let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col), \
(simple_col, col1), \
(simple_col, col1, col2), \
(simple_col, col1, col2, col3))]",
&result
);
let sets =
enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col, col1, col2, col3))]",
&result
);
let sets = enumerate_grouping_sets(vec![
simple_col.clone(),
grouping_set,
rollup.clone(),
])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col, col1, col2, col3), \
(simple_col, col1, col2, col3, col1), \
(simple_col, col1, col2, col3, col1, col2), \
(simple_col, col1, col2, col3, col1, col2, col3))]",
&result
);
let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col), \
(simple_col, col1), \
(simple_col, col1, col2), \
(simple_col, col1, col2, col3), \
(simple_col, col1), \
(simple_col, col1, col1), \
(simple_col, col1, col1, col2), \
(simple_col, col1, col1, col2, col3), \
(simple_col, col2), \
(simple_col, col2, col1), \
(simple_col, col2, col1, col2), \
(simple_col, col2, col1, col2, col3), \
(simple_col, col1, col2), \
(simple_col, col1, col2, col1), \
(simple_col, col1, col2, col1, col2), \
(simple_col, col1, col2, col1, col2, col3), \
(simple_col, col3), \
(simple_col, col3, col1), \
(simple_col, col3, col1, col2), \
(simple_col, col3, col1, col2, col3), \
(simple_col, col1, col3), \
(simple_col, col1, col3, col1), \
(simple_col, col1, col3, col1, col2), \
(simple_col, col1, col3, col1, col2, col3), \
(simple_col, col2, col3), \
(simple_col, col2, col3, col1), \
(simple_col, col2, col3, col1, col2), \
(simple_col, col2, col3, col1, col2, col3), \
(simple_col, col1, col2, col3), \
(simple_col, col1, col2, col3, col1), \
(simple_col, col1, col2, col3, col1, col2), \
(simple_col, col1, col2, col3, col1, col2, col3))]",
&result
);
Ok(())
}
#[test]
fn test_split_conjunction() {
let expr = col("a");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr]);
}
#[test]
fn test_split_conjunction_two() {
let expr = col("a").eq(lit(5)).and(col("b"));
let expr1 = col("a").eq(lit(5));
let expr2 = col("b");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr1, &expr2]);
}
#[test]
fn test_split_conjunction_alias() {
let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
let expr1 = col("a").eq(lit(5));
let expr2 = col("b"); let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr1, &expr2]);
}
#[test]
fn test_split_conjunction_or() {
let expr = col("a").eq(lit(5)).or(col("b"));
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr]);
}
#[test]
fn test_split_binary_owned() {
let expr = col("a");
assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
}
#[test]
fn test_split_binary_owned_two() {
assert_eq!(
split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
vec![col("a").eq(lit(5)), col("b")]
);
}
#[test]
fn test_split_binary_owned_different_op() {
let expr = col("a").eq(lit(5)).or(col("b"));
assert_eq!(
split_binary_owned(expr.clone(), Operator::And),
vec![expr]
);
}
#[test]
fn test_split_conjunction_owned() {
let expr = col("a");
assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
fn test_split_conjunction_owned_two() {
assert_eq!(
split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
vec![col("a").eq(lit(5)), col("b")]
);
}
#[test]
fn test_split_conjunction_owned_alias() {
assert_eq!(
split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
vec![
col("a").eq(lit(5)),
col("b"),
]
);
}
#[test]
fn test_conjunction_empty() {
assert_eq!(conjunction(vec![]), None);
}
#[test]
fn test_conjunction() {
let expr = conjunction(vec![col("a"), col("b"), col("c")]);
assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
}
#[test]
fn test_disjunction_empty() {
assert_eq!(disjunction(vec![]), None);
}
#[test]
fn test_disjunction() {
let expr = disjunction(vec![col("a"), col("b"), col("c")]);
assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
}
#[test]
fn test_split_conjunction_owned_or() {
let expr = col("a").eq(lit(5)).or(col("b"));
assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
fn test_collect_expr() -> Result<()> {
let mut accum: HashSet<Column> = HashSet::new();
expr_to_columns(
&Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
&mut accum,
)?;
expr_to_columns(
&Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
&mut accum,
)?;
assert_eq!(1, accum.len());
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}
}