use crate::expr::{Alias, Sort, WindowFunction};
use crate::logical_plan::Aggregate;
use crate::signature::{Signature, TypeSignature};
use crate::{Cast, Expr, ExprSchemable, GroupingSet, LogicalPlan, TryCast};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::{
internal_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError,
Result, ScalarValue, TableReference,
};
use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions};
use std::cmp::Ordering;
use std::collections::HashSet;
pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::UInt8(Some(1));
pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
for e in expr {
expr_to_columns(e, accum)?;
}
Ok(())
}
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
Ok(grouping_set.distinct_expr().len())
} else {
Ok(group_expr.len())
}
}
fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
if slice.len() >= 64 {
return Err("The size of the set must be less than 64.".into());
}
let mut v = Vec::new();
for mask in 0..(1 << slice.len()) {
let mut ss = vec![];
let mut bitset = mask;
while bitset > 0 {
let rightmost: u64 = bitset & !(bitset - 1);
let idx = rightmost.trailing_zeros();
let item = slice.get(idx as usize).unwrap();
ss.push(item);
bitset &= bitset - 1;
}
v.push(ss);
}
Ok(v)
}
fn check_grouping_set_size_limit(size: usize) -> Result<()> {
let max_grouping_set_size = 65535;
if size > max_grouping_set_size {
return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
}
Ok(())
}
fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
let max_grouping_sets_size = 4096;
if size > max_grouping_sets_size {
return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
}
Ok(())
}
fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
check_grouping_set_size_limit(left.len() + right.len())?;
Ok(left.iter().chain(right.iter()).cloned().collect())
}
fn cross_join_grouping_sets<T: Clone>(
left: &[Vec<T>],
right: &[Vec<T>],
) -> Result<Vec<Vec<T>>> {
let grouping_sets_size = left.len() * right.len();
check_grouping_sets_size_limit(grouping_sets_size)?;
let mut result = Vec::with_capacity(grouping_sets_size);
for le in left {
for re in right {
result.push(merge_grouping_set(le, re)?);
}
}
Ok(result)
}
pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
let has_grouping_set = group_expr
.iter()
.any(|expr| matches!(expr, Expr::GroupingSet(_)));
if !has_grouping_set || group_expr.len() == 1 {
return Ok(group_expr);
}
let partial_sets = group_expr
.iter()
.map(|expr| {
let exprs = match expr {
Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
check_grouping_sets_size_limit(grouping_sets.len())?;
grouping_sets.iter().map(|e| e.iter().collect()).collect()
}
Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
let grouping_sets =
powerset(group_exprs).map_err(DataFusionError::Plan)?;
check_grouping_sets_size_limit(grouping_sets.len())?;
grouping_sets
}
Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
let size = group_exprs.len();
let slice = group_exprs.as_slice();
check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
(0..(size + 1))
.map(|i| slice[0..i].iter().collect())
.collect()
}
expr => vec![vec![expr]],
};
Ok(exprs)
})
.collect::<Result<Vec<_>>>()?;
let grouping_sets = partial_sets
.into_iter()
.map(Ok)
.reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
.transpose()?
.map(|e| {
e.into_iter()
.map(|e| e.into_iter().cloned().collect())
.collect()
})
.unwrap_or_default();
Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
grouping_sets,
))])
}
pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
Ok(grouping_set.distinct_expr())
} else {
Ok(group_expr.to_vec())
}
}
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
inspect_expr_pre(expr, |expr| {
match expr {
Expr::Column(qc) => {
accum.insert(qc.clone());
}
Expr::ScalarVariable(_, _)
| Expr::Alias(_)
| Expr::Literal(_)
| Expr::BinaryExpr { .. }
| Expr::Like { .. }
| Expr::SimilarTo { .. }
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Negative(_)
| Expr::Between { .. }
| Expr::Case { .. }
| Expr::Cast { .. }
| Expr::TryCast { .. }
| Expr::Sort { .. }
| Expr::ScalarFunction(..)
| Expr::ScalarUDF(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
| Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::Wildcard
| Expr::QualifiedWildcard { .. }
| Expr::GetIndexedField { .. }
| Expr::Placeholder(_)
| Expr::OuterReferenceColumn { .. } => {}
}
Ok(())
})
}
fn get_excluded_columns(
opt_exclude: Option<&ExcludeSelectItem>,
opt_except: Option<&ExceptSelectItem>,
schema: &DFSchema,
qualifier: &Option<TableReference>,
) -> Result<Vec<Column>> {
let mut idents = vec![];
if let Some(excepts) = opt_except {
idents.push(&excepts.first_element);
idents.extend(&excepts.additional_elements);
}
if let Some(exclude) = opt_exclude {
match exclude {
ExcludeSelectItem::Single(ident) => idents.push(ident),
ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
}
}
let n_elem = idents.len();
let unique_idents = idents.into_iter().collect::<HashSet<_>>();
if n_elem != unique_idents.len() {
return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
}
let mut result = vec![];
for ident in unique_idents.into_iter() {
let col_name = ident.value.as_str();
let field = if let Some(qualifier) = qualifier {
schema.field_with_qualified_name(qualifier, col_name)?
} else {
schema.field_with_unqualified_name(col_name)?
};
result.push(field.qualified_column())
}
Ok(result)
}
fn get_exprs_except_skipped(
schema: &DFSchema,
columns_to_skip: HashSet<Column>,
) -> Vec<Expr> {
if columns_to_skip.is_empty() {
schema
.fields()
.iter()
.map(|f| Expr::Column(f.qualified_column()))
.collect::<Vec<Expr>>()
} else {
schema
.fields()
.iter()
.filter_map(|f| {
let col = f.qualified_column();
if !columns_to_skip.contains(&col) {
Some(Expr::Column(col))
} else {
None
}
})
.collect::<Vec<Expr>>()
}
}
pub fn expand_wildcard(
schema: &DFSchema,
plan: &LogicalPlan,
wildcard_options: Option<&WildcardAdditionalOptions>,
) -> Result<Vec<Expr>> {
let using_columns = plan.using_columns()?;
let mut columns_to_skip = using_columns
.into_iter()
.flat_map(|cols| {
let mut cols = cols.into_iter().collect::<Vec<_>>();
cols.sort();
let mut out_column_names: HashSet<String> = HashSet::new();
cols.into_iter()
.filter_map(|c| {
if out_column_names.contains(&c.name) {
Some(c)
} else {
out_column_names.insert(c.name);
None
}
})
.collect::<Vec<_>>()
})
.collect::<HashSet<_>>();
let excluded_columns = if let Some(WildcardAdditionalOptions {
opt_exclude,
opt_except,
..
}) = wildcard_options
{
get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, &None)?
} else {
vec![]
};
columns_to_skip.extend(excluded_columns);
Ok(get_exprs_except_skipped(schema, columns_to_skip))
}
pub fn expand_qualified_wildcard(
qualifier: &str,
schema: &DFSchema,
wildcard_options: Option<&WildcardAdditionalOptions>,
) -> Result<Vec<Expr>> {
let qualifier = TableReference::from(qualifier);
let qualified_fields: Vec<DFField> = schema
.fields_with_qualified(&qualifier)
.into_iter()
.cloned()
.collect();
if qualified_fields.is_empty() {
return plan_err!("Invalid qualifier {qualifier}");
}
let qualified_schema =
DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())?
.with_functional_dependencies(schema.functional_dependencies().clone());
let excluded_columns = if let Some(WildcardAdditionalOptions {
opt_exclude,
opt_except,
..
}) = wildcard_options
{
get_excluded_columns(
opt_exclude.as_ref(),
opt_except.as_ref(),
schema,
&Some(qualifier),
)?
} else {
vec![]
};
let mut columns_to_skip = HashSet::new();
columns_to_skip.extend(excluded_columns);
Ok(get_exprs_except_skipped(&qualified_schema, columns_to_skip))
}
type WindowSortKey = Vec<(Expr, bool)>;
pub fn generate_sort_key(
partition_by: &[Expr],
order_by: &[Expr],
) -> Result<WindowSortKey> {
let normalized_order_by_keys = order_by
.iter()
.map(|e| match e {
Expr::Sort(Sort { expr, .. }) => {
Ok(Expr::Sort(Sort::new(expr.clone(), true, false)))
}
_ => plan_err!("Order by only accepts sort expressions"),
})
.collect::<Result<Vec<_>>>()?;
let mut final_sort_keys = vec![];
let mut is_partition_flag = vec![];
partition_by.iter().for_each(|e| {
let e = e.clone().sort(true, false);
if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
let order_by_key = &order_by[pos];
if !final_sort_keys.contains(order_by_key) {
final_sort_keys.push(order_by_key.clone());
is_partition_flag.push(true);
}
} else if !final_sort_keys.contains(&e) {
final_sort_keys.push(e);
is_partition_flag.push(true);
}
});
order_by.iter().for_each(|e| {
if !final_sort_keys.contains(e) {
final_sort_keys.push(e.clone());
is_partition_flag.push(false);
}
});
let res = final_sort_keys
.into_iter()
.zip(is_partition_flag)
.map(|(lhs, rhs)| (lhs, rhs))
.collect::<Vec<_>>();
Ok(res)
}
pub fn compare_sort_expr(
sort_expr_a: &Expr,
sort_expr_b: &Expr,
schema: &DFSchemaRef,
) -> Ordering {
match (sort_expr_a, sort_expr_b) {
(
Expr::Sort(Sort {
expr: expr_a,
asc: asc_a,
nulls_first: nulls_first_a,
}),
Expr::Sort(Sort {
expr: expr_b,
asc: asc_b,
nulls_first: nulls_first_b,
}),
) => {
let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
match idx_a.cmp(idx_b) {
Ordering::Less => {
return Ordering::Less;
}
Ordering::Greater => {
return Ordering::Greater;
}
Ordering::Equal => {}
}
}
match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
Ordering::Less => return Ordering::Greater,
Ordering::Greater => {
return Ordering::Less;
}
Ordering::Equal => {}
}
match (asc_a, asc_b) {
(true, false) => {
return Ordering::Greater;
}
(false, true) => {
return Ordering::Less;
}
_ => {}
}
match (nulls_first_a, nulls_first_b) {
(true, false) => {
return Ordering::Less;
}
(false, true) => {
return Ordering::Greater;
}
_ => {}
}
Ordering::Equal
}
_ => Ordering::Equal,
}
}
pub fn group_window_expr_by_sort_keys(
window_expr: &[Expr],
) -> Result<Vec<(WindowSortKey, Vec<&Expr>)>> {
let mut result = vec![];
window_expr.iter().try_for_each(|expr| match expr {
Expr::WindowFunction(WindowFunction{ partition_by, order_by, .. }) => {
let sort_key = generate_sort_key(partition_by, order_by)?;
if let Some((_, values)) = result.iter_mut().find(
|group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key),
) {
values.push(expr);
} else {
result.push((sort_key, vec![expr]))
}
Ok(())
}
other => internal_err!(
"Impossibly got non-window expr {other:?}"
),
})?;
Ok(result)
}
pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(
nested_expr,
Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. }
)
})
}
pub fn find_sort_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(nested_expr, Expr::Sort { .. })
})
}
pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(nested_expr, Expr::WindowFunction { .. })
})
}
pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
find_exprs_in_expr(expr, &|nested_expr| {
matches!(nested_expr, Expr::OuterReferenceColumn { .. })
})
}
fn find_exprs_in_exprs<F>(exprs: &[Expr], test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
exprs
.iter()
.flat_map(|expr| find_exprs_in_expr(expr, test_fn))
.fold(vec![], |mut acc, expr| {
if !acc.contains(&expr) {
acc.push(expr)
}
acc
})
}
fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
let mut exprs = vec![];
expr.apply(&mut |expr| {
if test_fn(expr) {
if !(exprs.contains(expr)) {
exprs.push(expr.clone())
}
return Ok(VisitRecursion::Skip);
}
Ok(VisitRecursion::Continue)
})
.expect("no way to return error during recursion");
exprs
}
pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
where
F: FnMut(&Expr) -> Result<(), E>,
{
let mut err = Ok(());
expr.apply(&mut |expr| {
if let Err(e) = f(expr) {
err = Err(e);
Ok(VisitRecursion::Stop)
} else {
Ok(VisitRecursion::Continue)
}
})
.expect("no way to return error during recursion");
err
}
#[deprecated(since = "31.0.0", note = "use LogicalPlan::with_new_exprs instead")]
pub fn from_plan(
plan: &LogicalPlan,
expr: &[Expr],
inputs: &[LogicalPlan],
) -> Result<LogicalPlan> {
plan.with_new_exprs(expr.to_vec(), inputs)
}
fn agg_cols(agg: &Aggregate) -> Vec<Column> {
agg.aggr_expr
.iter()
.chain(&agg.group_expr)
.flat_map(find_columns_referenced_by_expr)
.collect()
}
fn exprlist_to_fields_aggregate(
exprs: &[Expr],
plan: &LogicalPlan,
agg: &Aggregate,
) -> Result<Vec<DFField>> {
let agg_cols = agg_cols(agg);
let mut fields = vec![];
for expr in exprs {
match expr {
Expr::Column(c) if agg_cols.iter().any(|x| x == c) => {
fields.push(expr.to_field(agg.input.schema())?);
}
_ => fields.push(expr.to_field(plan.schema())?),
}
}
Ok(fields)
}
pub fn exprlist_to_fields<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
plan: &LogicalPlan,
) -> Result<Vec<DFField>> {
let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
let fields = match plan {
LogicalPlan::Aggregate(agg) => {
Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
}
LogicalPlan::Window(window) => match window.input.as_ref() {
LogicalPlan::Aggregate(agg) => {
Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
}
_ => None,
},
_ => None,
};
if let Some(fields) = fields {
fields
} else {
let input_schema = &plan.schema();
exprs.iter().map(|e| e.to_field(input_schema)).collect()
}
}
pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr {
match e {
Expr::Column(_) => e,
Expr::OuterReferenceColumn(_, _) => e,
Expr::Alias(Alias { expr, name, .. }) => {
columnize_expr(*expr, input_schema).alias(name)
}
Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast {
expr: Box::new(columnize_expr(*expr, input_schema)),
data_type,
}),
Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new(
Box::new(columnize_expr(*expr, input_schema)),
data_type,
)),
Expr::ScalarSubquery(_) => e.clone(),
_ => match e.display_name() {
Ok(name) => match input_schema.field_with_unqualified_name(&name) {
Ok(field) => Expr::Column(field.qualified_column()),
Err(_) => e,
},
Err(_) => e,
},
}
}
pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
exprs
.iter()
.flat_map(find_columns_referenced_by_expr)
.map(Expr::Column)
.collect()
}
pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
let mut exprs = vec![];
inspect_expr_pre(e, |expr| {
if let Expr::Column(c) = expr {
exprs.push(c.clone())
}
Ok(()) as Result<()>
})
.expect("Unexpected error");
exprs
}
pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
match expr {
Expr::Column(col) => {
let field = plan.schema().field_from_column(col)?;
Ok(Expr::Column(field.qualified_column()))
}
_ => Ok(Expr::Column(Column::from_name(expr.display_name()?))),
}
}
pub(crate) fn find_column_indexes_referenced_by_expr(
e: &Expr,
schema: &DFSchemaRef,
) -> Vec<usize> {
let mut indexes = vec![];
inspect_expr_pre(e, |expr| {
match expr {
Expr::Column(qc) => {
if let Ok(idx) = schema.index_of_column(qc) {
indexes.push(idx);
}
}
Expr::Literal(_) => {
indexes.push(std::usize::MAX);
}
_ => {}
}
Ok(()) as Result<()>
})
.unwrap();
indexes
}
pub fn can_hash(data_type: &DataType) -> bool {
match data_type {
DataType::Null => true,
DataType::Boolean => true,
DataType::Int8 => true,
DataType::Int16 => true,
DataType::Int32 => true,
DataType::Int64 => true,
DataType::UInt8 => true,
DataType::UInt16 => true,
DataType::UInt32 => true,
DataType::UInt64 => true,
DataType::Float32 => true,
DataType::Float64 => true,
DataType::Timestamp(time_unit, None) => match time_unit {
TimeUnit::Second => true,
TimeUnit::Millisecond => true,
TimeUnit::Microsecond => true,
TimeUnit::Nanosecond => true,
},
DataType::Utf8 => true,
DataType::LargeUtf8 => true,
DataType::Decimal128(_, _) => true,
DataType::Date32 => true,
DataType::Date64 => true,
DataType::FixedSizeBinary(_) => true,
DataType::Dictionary(key_type, value_type)
if *value_type.as_ref() == DataType::Utf8 =>
{
DataType::is_dictionary_key_type(key_type)
}
_ => false,
}
}
pub fn check_all_columns_from_schema(
columns: &HashSet<Column>,
schema: DFSchemaRef,
) -> Result<bool> {
for col in columns.iter() {
let exist = schema.is_column_from_schema(col)?;
if !exist {
return Ok(false);
}
}
Ok(true)
}
pub fn find_valid_equijoin_key_pair(
left_key: &Expr,
right_key: &Expr,
left_schema: DFSchemaRef,
right_schema: DFSchemaRef,
) -> Result<Option<(Expr, Expr)>> {
let left_using_columns = left_key.to_columns()?;
let right_using_columns = right_key.to_columns()?;
if left_using_columns.is_empty() || right_using_columns.is_empty() {
return Ok(None);
}
let l_is_left =
check_all_columns_from_schema(&left_using_columns, left_schema.clone())?;
let r_is_right =
check_all_columns_from_schema(&right_using_columns, right_schema.clone())?;
let r_is_left_and_l_is_right = || {
let result =
check_all_columns_from_schema(&right_using_columns, left_schema.clone())?
&& check_all_columns_from_schema(
&left_using_columns,
right_schema.clone(),
)?;
Result::<_>::Ok(result)
};
let join_key_pair = match (l_is_left, r_is_right) {
(true, true) => Some((left_key.clone(), right_key.clone())),
(_, _) if r_is_left_and_l_is_right()? => {
Some((right_key.clone(), left_key.clone()))
}
_ => None,
};
Ok(join_key_pair)
}
pub fn generate_signature_error_msg(
func_name: &str,
func_signature: Signature,
input_expr_types: &[DataType],
) -> String {
let candidate_signatures = func_signature
.type_signature
.to_string_repr()
.iter()
.map(|args_str| format!("\t{func_name}({args_str})"))
.collect::<Vec<String>>()
.join("\n");
format!(
"No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr_vec_fmt;
use crate::{
col, cube, expr, grouping_set, rollup, AggregateFunction, WindowFrame,
WindowFunction,
};
#[test]
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
let result = group_window_expr_by_sort_keys(&[])?;
let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![],
WindowFrame::new(false),
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs)?;
let key = vec![];
let expected: Vec<(WindowSortKey, Vec<&Expr>)> =
vec![(key, vec![&max1, &max2, &min3, &sum4])];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_group_window_expr_by_sort_keys() -> Result<()> {
let age_asc = Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true));
let name_desc = Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true));
let created_at_desc =
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true));
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
WindowFrame::new(true),
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs)?;
let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
let key2 = vec![];
let key3 = vec![
(name_desc, false),
(age_asc, false),
(created_at_desc, false),
];
let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![
(key1, vec![&max1, &min3]),
(key2, vec![&max2]),
(key3, vec![&sum4]),
];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_find_sort_exprs() -> Result<()> {
let exprs = &[
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
],
WindowFrame::new(true),
)),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
],
WindowFrame::new(true),
)),
];
let expected = vec![
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
];
let result = find_sort_exprs(exprs);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn avoid_generate_duplicate_sort_keys() -> Result<()> {
let asc_or_desc = [true, false];
let nulls_first_or_last = [true, false];
let partition_by = &[col("age"), col("name"), col("created_at")];
for asc_ in asc_or_desc {
for nulls_first_ in nulls_first_or_last {
let order_by = &[
Expr::Sort(Sort {
expr: Box::new(col("age")),
asc: asc_,
nulls_first: nulls_first_,
}),
Expr::Sort(Sort {
expr: Box::new(col("name")),
asc: asc_,
nulls_first: nulls_first_,
}),
];
let expected = vec![
(
Expr::Sort(Sort {
expr: Box::new(col("age")),
asc: asc_,
nulls_first: nulls_first_,
}),
true,
),
(
Expr::Sort(Sort {
expr: Box::new(col("name")),
asc: asc_,
nulls_first: nulls_first_,
}),
true,
),
(
Expr::Sort(Sort {
expr: Box::new(col("created_at")),
asc: true,
nulls_first: false,
}),
true,
),
];
let result = generate_sort_key(partition_by, order_by)?;
assert_eq!(expected, result);
}
}
Ok(())
}
#[test]
fn test_enumerate_grouping_sets() -> Result<()> {
let multi_cols = vec![col("col1"), col("col2"), col("col3")];
let simple_col = col("simple_col");
let cube = cube(multi_cols.clone());
let rollup = rollup(multi_cols.clone());
let grouping_set = grouping_set(vec![multi_cols]);
let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!("[simple_col]", &result);
let sets = enumerate_grouping_sets(vec![cube.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!("[CUBE (col1, col2, col3)]", &result);
let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col), \
(simple_col, col1), \
(simple_col, col2), \
(simple_col, col1, col2), \
(simple_col, col3), \
(simple_col, col1, col3), \
(simple_col, col2, col3), \
(simple_col, col1, col2, col3))]",
&result
);
let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col), \
(simple_col, col1), \
(simple_col, col1, col2), \
(simple_col, col1, col2, col3))]",
&result
);
let sets =
enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col, col1, col2, col3))]",
&result
);
let sets = enumerate_grouping_sets(vec![
simple_col.clone(),
grouping_set,
rollup.clone(),
])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col, col1, col2, col3), \
(simple_col, col1, col2, col3, col1), \
(simple_col, col1, col2, col3, col1, col2), \
(simple_col, col1, col2, col3, col1, col2, col3))]",
&result
);
let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col), \
(simple_col, col1), \
(simple_col, col1, col2), \
(simple_col, col1, col2, col3), \
(simple_col, col1), \
(simple_col, col1, col1), \
(simple_col, col1, col1, col2), \
(simple_col, col1, col1, col2, col3), \
(simple_col, col2), \
(simple_col, col2, col1), \
(simple_col, col2, col1, col2), \
(simple_col, col2, col1, col2, col3), \
(simple_col, col1, col2), \
(simple_col, col1, col2, col1), \
(simple_col, col1, col2, col1, col2), \
(simple_col, col1, col2, col1, col2, col3), \
(simple_col, col3), \
(simple_col, col3, col1), \
(simple_col, col3, col1, col2), \
(simple_col, col3, col1, col2, col3), \
(simple_col, col1, col3), \
(simple_col, col1, col3, col1), \
(simple_col, col1, col3, col1, col2), \
(simple_col, col1, col3, col1, col2, col3), \
(simple_col, col2, col3), \
(simple_col, col2, col3, col1), \
(simple_col, col2, col3, col1, col2), \
(simple_col, col2, col3, col1, col2, col3), \
(simple_col, col1, col2, col3), \
(simple_col, col1, col2, col3, col1), \
(simple_col, col1, col2, col3, col1, col2), \
(simple_col, col1, col2, col3, col1, col2, col3))]",
&result
);
Ok(())
}
}