use std::collections::HashMap;
use arrow_schema::{
DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{
exec_err, internal_err, plan_err, Column, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::builder::get_unnested_columns;
use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction};
use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
use datafusion_expr::{expr_vec_fmt, Expr, ExprSchemable, LogicalPlan};
use sqlparser::ast::{Ident, Value};
pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
expr.clone()
.transform_up(|nested_expr| {
match nested_expr {
Expr::Column(col) => {
let (qualifier, field) =
plan.schema().qualified_field_from_column(&col)?;
Ok(Transformed::yes(Expr::Column(Column::from((
qualifier, field,
)))))
}
_ => {
Ok(Transformed::no(nested_expr))
}
}
})
.data()
}
pub(crate) fn rebase_expr(
expr: &Expr,
base_exprs: &[Expr],
plan: &LogicalPlan,
) -> Result<Expr> {
expr.clone()
.transform_down(|nested_expr| {
if base_exprs.contains(&nested_expr) {
Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
} else {
Ok(Transformed::no(nested_expr))
}
})
.data()
}
pub(crate) fn check_columns_satisfy_exprs(
columns: &[Expr],
exprs: &[Expr],
message_prefix: &str,
) -> Result<()> {
columns.iter().try_for_each(|c| match c {
Expr::Column(_) => Ok(()),
_ => internal_err!("Expr::Column are required"),
})?;
let column_exprs = find_column_exprs(exprs);
for e in &column_exprs {
match e {
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
for e in exprs {
check_column_satisfies_expr(columns, e, message_prefix)?;
}
}
Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
for e in exprs {
check_column_satisfies_expr(columns, e, message_prefix)?;
}
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
for exprs in lists_of_exprs {
for e in exprs {
check_column_satisfies_expr(columns, e, message_prefix)?;
}
}
}
_ => check_column_satisfies_expr(columns, e, message_prefix)?,
}
}
Ok(())
}
fn check_column_satisfies_expr(
columns: &[Expr],
expr: &Expr,
message_prefix: &str,
) -> Result<()> {
if !columns.contains(expr) {
return plan_err!(
"{}: Expression {} could not be resolved from available columns: {}",
message_prefix,
expr,
expr_vec_fmt!(columns)
);
}
Ok(())
}
pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
exprs
.iter()
.filter_map(|expr| match expr {
Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
_ => None,
})
.collect::<HashMap<String, Expr>>()
}
pub(crate) fn resolve_positions_to_exprs(
expr: Expr,
select_exprs: &[Expr],
) -> Result<Expr> {
match expr {
Expr::Literal(ScalarValue::Int64(Some(position)))
if position > 0_i64 && position <= select_exprs.len() as i64 =>
{
let index = (position - 1) as usize;
let select_expr = &select_exprs[index];
Ok(match select_expr {
Expr::Alias(Alias { expr, .. }) => *expr.clone(),
_ => select_expr.clone(),
})
}
Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!(
"Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
position, select_exprs.len()
),
_ => Ok(expr),
}
}
pub(crate) fn resolve_aliases_to_exprs(
expr: Expr,
aliases: &HashMap<String, Expr>,
) -> Result<Expr> {
expr.transform_up(|nested_expr| match nested_expr {
Expr::Column(c) if c.relation.is_none() => {
if let Some(aliased_expr) = aliases.get(&c.name) {
Ok(Transformed::yes(aliased_expr.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
}
}
_ => Ok(Transformed::no(nested_expr)),
})
.data()
}
pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
let all_partition_keys = window_exprs
.iter()
.map(|expr| match expr {
Expr::WindowFunction(WindowFunction { partition_by, .. }) => Ok(partition_by),
Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
Expr::WindowFunction(WindowFunction { partition_by, .. }) => {
Ok(partition_by)
}
expr => exec_err!("Impossibly got non-window expr {expr:?}"),
},
expr => exec_err!("Impossibly got non-window expr {expr:?}"),
})
.collect::<Result<Vec<_>>>()?;
let result = all_partition_keys
.iter()
.min_by_key(|s| s.len())
.ok_or_else(|| {
DataFusionError::Execution("No window expressions found".to_owned())
})?;
Ok(result)
}
pub(crate) fn make_decimal_type(
precision: Option<u64>,
scale: Option<u64>,
) -> Result<DataType> {
let (precision, scale) = match (precision, scale) {
(Some(p), Some(s)) => (p as u8, s as i8),
(Some(p), None) => (p as u8, 0),
(None, Some(_)) => {
return plan_err!("Cannot specify only scale for decimal data type")
}
(None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
};
if precision == 0
|| precision > DECIMAL256_MAX_PRECISION
|| scale.unsigned_abs() > precision
{
plan_err!(
"Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
)
} else if precision > DECIMAL128_MAX_PRECISION
&& precision <= DECIMAL256_MAX_PRECISION
{
Ok(DataType::Decimal256(precision, scale))
} else {
Ok(DataType::Decimal128(precision, scale))
}
}
pub(crate) fn normalize_ident(id: Ident) -> String {
match id.quote_style {
Some(_) => id.value,
None => id.value.to_ascii_lowercase(),
}
}
pub(crate) fn value_to_string(value: &Value) -> Option<String> {
match value {
Value::SingleQuotedString(s) => Some(s.to_string()),
Value::DollarQuotedString(s) => Some(s.to_string()),
Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
Value::DoubleQuotedString(_)
| Value::EscapedStringLiteral(_)
| Value::NationalStringLiteral(_)
| Value::SingleQuotedByteStringLiteral(_)
| Value::DoubleQuotedByteStringLiteral(_)
| Value::TripleSingleQuotedString(_)
| Value::TripleDoubleQuotedString(_)
| Value::TripleSingleQuotedByteStringLiteral(_)
| Value::TripleDoubleQuotedByteStringLiteral(_)
| Value::SingleQuotedRawStringLiteral(_)
| Value::DoubleQuotedRawStringLiteral(_)
| Value::TripleSingleQuotedRawStringLiteral(_)
| Value::TripleDoubleQuotedRawStringLiteral(_)
| Value::HexStringLiteral(_)
| Value::Null
| Value::Placeholder(_) => None,
}
}
pub(crate) fn transform_bottom_unnests(
input: &LogicalPlan,
unnest_placeholder_columns: &mut Vec<String>,
inner_projection_exprs: &mut Vec<Expr>,
original_exprs: &[Expr],
) -> Result<Vec<Expr>> {
Ok(original_exprs
.iter()
.map(|expr| {
transform_bottom_unnest(
input,
unnest_placeholder_columns,
inner_projection_exprs,
expr,
)
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>())
}
pub(crate) fn transform_bottom_unnest(
input: &LogicalPlan,
unnest_placeholder_columns: &mut Vec<String>,
inner_projection_exprs: &mut Vec<Expr>,
original_expr: &Expr,
) -> Result<Vec<Expr>> {
let mut transform =
|unnest_expr: &Expr, expr_in_unnest: &Expr| -> Result<Vec<Expr>> {
let placeholder_name = unnest_expr.display_name()?;
unnest_placeholder_columns.push(placeholder_name.clone());
inner_projection_exprs
.push(expr_in_unnest.clone().alias(placeholder_name.clone()));
let schema = input.schema();
let (data_type, _) = expr_in_unnest.data_type_and_nullable(schema)?;
let outer_projection_columns =
get_unnested_columns(&placeholder_name, &data_type)?;
let expr = outer_projection_columns
.iter()
.map(|col| Expr::Column(col.0.clone()))
.collect::<Vec<_>>();
Ok(expr)
};
let Transformed {
data: transformed_expr,
transformed,
tnr: _,
} = original_expr.clone().transform_up(|expr: Expr| {
let is_root_expr = &expr == original_expr;
if is_root_expr {
return Ok(Transformed::no(expr));
}
if let Expr::Unnest(Unnest { expr: ref arg }) = expr {
let (data_type, _) = arg.data_type_and_nullable(input.schema())?;
if let DataType::Struct(_) = data_type {
return internal_err!("unnest on struct can only be applied at the root level of select expression");
}
let mut transformed_exprs = transform(&expr, arg)?;
Ok(Transformed::new(
transformed_exprs.swap_remove(0),
true,
TreeNodeRecursion::Stop,
))
} else {
Ok(Transformed::no(expr))
}
})?;
if !transformed {
if let Expr::Unnest(Unnest { expr: ref arg }) = transformed_expr {
return transform(&transformed_expr, arg);
}
if matches!(&transformed_expr, Expr::Column(_)) {
inner_projection_exprs.push(transformed_expr.clone());
Ok(vec![transformed_expr])
} else {
let column_name = transformed_expr.display_name()?;
inner_projection_exprs.push(transformed_expr);
Ok(vec![Expr::Column(Column::from_name(column_name))])
}
} else {
Ok(vec![transformed_expr])
}
}
#[cfg(test)]
mod tests {
use std::{ops::Add, sync::Arc};
use arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
use arrow_schema::Fields;
use datafusion_common::{DFSchema, Result};
use datafusion_expr::{col, lit, unnest, EmptyRelation, LogicalPlan};
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions_aggregate::expr_fn::count;
use crate::utils::{resolve_positions_to_exprs, transform_bottom_unnest};
#[test]
fn test_transform_bottom_unnest() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
"struct_col",
ArrowDataType::Struct(Fields::from(vec![
Field::new("field1", ArrowDataType::Int32, false),
Field::new("field2", ArrowDataType::Int32, false),
])),
false,
),
Field::new(
"array_col",
ArrowDataType::List(Arc::new(Field::new(
"item",
ArrowDataType::Int64,
true,
))),
true,
),
Field::new("int_col", ArrowDataType::Int32, false),
]);
let dfschema = DFSchema::try_from(schema)?;
let input = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(dfschema),
});
let mut unnest_placeholder_columns = vec![];
let mut inner_projection_exprs = vec![];
let original_expr = unnest(col("struct_col"));
let transformed_exprs = transform_bottom_unnest(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr,
)?;
assert_eq!(
transformed_exprs,
vec![
col("unnest(struct_col).field1"),
col("unnest(struct_col).field2"),
]
);
assert_eq!(unnest_placeholder_columns, vec!["unnest(struct_col)"]);
assert_eq!(
inner_projection_exprs,
vec![col("struct_col").alias("unnest(struct_col)"),]
);
let original_expr = unnest(col("array_col")).add(lit(1i64));
let transformed_exprs = transform_bottom_unnest(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr,
)?;
assert_eq!(
unnest_placeholder_columns,
vec!["unnest(struct_col)", "unnest(array_col)"]
);
assert_eq!(
transformed_exprs,
vec![col("unnest(array_col)").add(lit(1i64))]
);
assert_eq!(
inner_projection_exprs,
vec![
col("struct_col").alias("unnest(struct_col)"),
col("array_col").alias("unnest(array_col)")
]
);
let schema = Schema::new(vec![
Field::new(
"struct_col", ArrowDataType::Struct(Fields::from(vec![Field::new(
"matrix",
ArrowDataType::List(Arc::new(Field::new(
"matrix_row",
ArrowDataType::List(Arc::new(Field::new(
"item",
ArrowDataType::Int64,
true,
))),
true,
))),
true,
)])),
false,
),
Field::new("int_col", ArrowDataType::Int32, false),
]);
let dfschema = DFSchema::try_from(schema)?;
let input = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(dfschema),
});
let mut unnest_placeholder_columns = vec![];
let mut inner_projection_exprs = vec![];
let original_expr = unnest(unnest(col("struct_col").field("matrix")));
let transformed_exprs = transform_bottom_unnest(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr,
)?;
assert_eq!(
transformed_exprs,
vec![unnest(col("unnest(struct_col[matrix])"))]
);
assert_eq!(
unnest_placeholder_columns,
vec!["unnest(struct_col[matrix])"]
);
assert_eq!(
inner_projection_exprs,
vec![col("struct_col")
.field("matrix")
.alias("unnest(struct_col[matrix])"),]
);
Ok(())
}
#[test]
fn test_resolve_positions_to_exprs() -> Result<()> {
let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
assert_eq!(resolved, col("c1"));
let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
assert!(resolved.is_err_and(|e| e.message().contains(
"Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
)));
let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
assert!(resolved.is_err_and(|e| e.message().contains(
"Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
)));
let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
assert_eq!(resolved, lit("text"));
let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
assert_eq!(resolved, col("fake"));
Ok(())
}
}