use crate::expr::{
AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
Placeholder, TryCast, Unnest, WindowFunction,
};
use crate::function::{
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
StateFieldsArgs,
};
use crate::{
conditional_expressions::CaseBuilder, logical_plan::Subquery, AggregateUDF, Expr,
LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature,
Volatility,
};
use crate::{
AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
};
use arrow::compute::kernels::cast_utils::{
parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_err, Column, Result, ScalarValue};
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::Debug;
use std::ops::Not;
use std::sync::Arc;
pub fn col(ident: impl Into<Column>) -> Expr {
Expr::Column(ident.into())
}
pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
Expr::OuterReferenceColumn(dt, ident.into())
}
pub fn ident(name: impl Into<String>) -> Expr {
Expr::Column(Column::from_name(name))
}
pub fn placeholder(id: impl Into<String>) -> Expr {
Expr::Placeholder(Placeholder {
id: id.into(),
data_type: None,
})
}
pub fn wildcard() -> Expr {
Expr::Wildcard { qualifier: None }
}
pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
}
pub fn and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::And,
Box::new(right),
))
}
pub fn or(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::Or,
Box::new(right),
))
}
pub fn not(expr: Expr) -> Expr {
expr.not()
}
pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseAnd,
Box::new(right),
))
}
pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseOr,
Box::new(right),
))
}
pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseXor,
Box::new(right),
))
}
pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseShiftRight,
Box::new(right),
))
}
pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseShiftLeft,
Box::new(right),
))
}
pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
Expr::InList(InList::new(Box::new(expr), list, negated))
}
pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::Exists(Exists {
subquery: Subquery {
subquery,
outer_ref_columns,
},
negated: false,
})
}
pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::Exists(Exists {
subquery: Subquery {
subquery,
outer_ref_columns,
},
negated: true,
})
}
pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::InSubquery(InSubquery::new(
Box::new(expr),
Subquery {
subquery,
outer_ref_columns,
},
false,
))
}
pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::InSubquery(InSubquery::new(
Box::new(expr),
Subquery {
subquery,
outer_ref_columns,
},
true,
))
}
pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::ScalarSubquery(Subquery {
subquery,
outer_ref_columns,
})
}
pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
}
pub fn cube(exprs: Vec<Expr>) -> Expr {
Expr::GroupingSet(GroupingSet::Cube(exprs))
}
pub fn rollup(exprs: Vec<Expr>) -> Expr {
Expr::GroupingSet(GroupingSet::Rollup(exprs))
}
pub fn cast(expr: Expr, data_type: DataType) -> Expr {
Expr::Cast(Cast::new(Box::new(expr), data_type))
}
pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
Expr::TryCast(TryCast::new(Box::new(expr), data_type))
}
pub fn is_null(expr: Expr) -> Expr {
Expr::IsNull(Box::new(expr))
}
pub fn is_true(expr: Expr) -> Expr {
Expr::IsTrue(Box::new(expr))
}
pub fn is_not_true(expr: Expr) -> Expr {
Expr::IsNotTrue(Box::new(expr))
}
pub fn is_false(expr: Expr) -> Expr {
Expr::IsFalse(Box::new(expr))
}
pub fn is_not_false(expr: Expr) -> Expr {
Expr::IsNotFalse(Box::new(expr))
}
pub fn is_unknown(expr: Expr) -> Expr {
Expr::IsUnknown(Box::new(expr))
}
pub fn is_not_unknown(expr: Expr) -> Expr {
Expr::IsNotUnknown(Box::new(expr))
}
pub fn case(expr: Expr) -> CaseBuilder {
CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
}
pub fn when(when: Expr, then: Expr) -> CaseBuilder {
CaseBuilder::new(None, vec![when], vec![then], None)
}
pub fn unnest(expr: Expr) -> Expr {
Expr::Unnest(Unnest {
expr: Box::new(expr),
})
}
pub fn create_udf(
name: &str,
input_types: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUDF {
let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone());
ScalarUDF::from(SimpleScalarUDF::new(
name,
input_types,
return_type,
volatility,
fun,
))
}
pub struct SimpleScalarUDF {
name: String,
signature: Signature,
return_type: DataType,
fun: ScalarFunctionImplementation,
}
impl Debug for SimpleScalarUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}
impl SimpleScalarUDF {
pub fn new(
name: impl Into<String>,
input_types: Vec<DataType>,
return_type: DataType,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> Self {
let name = name.into();
let signature = Signature::exact(input_types, volatility);
Self {
name,
signature,
return_type,
fun,
}
}
}
impl ScalarUDFImpl for SimpleScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
(self.fun)(args)
}
}
pub fn create_udaf(
name: &str,
input_type: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
accumulator: AccumulatorFactoryFunction,
state_type: Arc<Vec<DataType>>,
) -> AggregateUDF {
let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone());
let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone());
let state_fields = state_type
.into_iter()
.enumerate()
.map(|(i, t)| Field::new(format!("{i}"), t, true))
.collect::<Vec<_>>();
AggregateUDF::from(SimpleAggregateUDF::new(
name,
input_type,
return_type,
volatility,
accumulator,
state_fields,
))
}
pub struct SimpleAggregateUDF {
name: String,
signature: Signature,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
state_fields: Vec<Field>,
}
impl Debug for SimpleAggregateUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("AggregateUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}
impl SimpleAggregateUDF {
pub fn new(
name: impl Into<String>,
input_type: Vec<DataType>,
return_type: DataType,
volatility: Volatility,
accumulator: AccumulatorFactoryFunction,
state_fields: Vec<Field>,
) -> Self {
let name = name.into();
let signature = Signature::exact(input_type, volatility);
Self {
name,
signature,
return_type,
accumulator,
state_fields,
}
}
pub fn new_with_signature(
name: impl Into<String>,
signature: Signature,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
state_fields: Vec<Field>,
) -> Self {
let name = name.into();
Self {
name,
signature,
return_type,
accumulator,
state_fields,
}
}
}
impl AggregateUDFImpl for SimpleAggregateUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}
fn accumulator(
&self,
acc_args: AccumulatorArgs,
) -> Result<Box<dyn crate::Accumulator>> {
(self.accumulator)(acc_args)
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(self.state_fields.clone())
}
}
pub fn create_udwf(
name: &str,
input_type: DataType,
return_type: Arc<DataType>,
volatility: Volatility,
partition_evaluator_factory: PartitionEvaluatorFactory,
) -> WindowUDF {
let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone());
WindowUDF::from(SimpleWindowUDF::new(
name,
input_type,
return_type,
volatility,
partition_evaluator_factory,
))
}
pub struct SimpleWindowUDF {
name: String,
signature: Signature,
return_type: DataType,
partition_evaluator_factory: PartitionEvaluatorFactory,
}
impl Debug for SimpleWindowUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("WindowUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("return_type", &"<func>")
.field("partition_evaluator_factory", &"<FUNC>")
.finish()
}
}
impl SimpleWindowUDF {
pub fn new(
name: impl Into<String>,
input_type: DataType,
return_type: DataType,
volatility: Volatility,
partition_evaluator_factory: PartitionEvaluatorFactory,
) -> Self {
let name = name.into();
let signature = Signature::exact([input_type].to_vec(), volatility);
Self {
name,
signature,
return_type,
partition_evaluator_factory,
}
}
}
impl WindowUDFImpl for SimpleWindowUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}
fn partition_evaluator(&self) -> Result<Box<dyn crate::PartitionEvaluator>> {
(self.partition_evaluator_factory)()
}
}
pub fn interval_year_month_lit(value: &str) -> Expr {
let interval = parse_interval_year_month(value).ok();
Expr::Literal(ScalarValue::IntervalYearMonth(interval))
}
pub fn interval_datetime_lit(value: &str) -> Expr {
let interval = parse_interval_day_time(value).ok();
Expr::Literal(ScalarValue::IntervalDayTime(interval))
}
pub fn interval_month_day_nano_lit(value: &str) -> Expr {
let interval = parse_interval_month_day_nano(value).ok();
Expr::Literal(ScalarValue::IntervalMonthDayNano(interval))
}
pub trait ExprFunctionExt {
fn order_by(self, order_by: Vec<Expr>) -> ExprFuncBuilder;
fn filter(self, filter: Expr) -> ExprFuncBuilder;
fn distinct(self) -> ExprFuncBuilder;
fn null_treatment(
self,
null_treatment: impl Into<Option<NullTreatment>>,
) -> ExprFuncBuilder;
fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
}
#[derive(Debug, Clone)]
pub enum ExprFuncKind {
Aggregate(AggregateFunction),
Window(WindowFunction),
}
#[derive(Debug, Clone)]
pub struct ExprFuncBuilder {
fun: Option<ExprFuncKind>,
order_by: Option<Vec<Expr>>,
filter: Option<Expr>,
distinct: bool,
null_treatment: Option<NullTreatment>,
partition_by: Option<Vec<Expr>>,
window_frame: Option<WindowFrame>,
}
impl ExprFuncBuilder {
fn new(fun: Option<ExprFuncKind>) -> Self {
Self {
fun,
order_by: None,
filter: None,
distinct: false,
null_treatment: None,
partition_by: None,
window_frame: None,
}
}
pub fn build(self) -> Result<Expr> {
let Self {
fun,
order_by,
filter,
distinct,
null_treatment,
partition_by,
window_frame,
} = self;
let Some(fun) = fun else {
return plan_err!(
"ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
);
};
if let Some(order_by) = &order_by {
for expr in order_by.iter() {
if !matches!(expr, Expr::Sort(_)) {
return plan_err!(
"ORDER BY expressions must be Expr::Sort, found {expr:?}"
);
}
}
}
let fun_expr = match fun {
ExprFuncKind::Aggregate(mut udaf) => {
udaf.order_by = order_by;
udaf.filter = filter.map(Box::new);
udaf.distinct = distinct;
udaf.null_treatment = null_treatment;
Expr::AggregateFunction(udaf)
}
ExprFuncKind::Window(mut udwf) => {
let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
udwf.order_by = order_by.unwrap_or_default();
udwf.partition_by = partition_by.unwrap_or_default();
udwf.window_frame =
window_frame.unwrap_or(WindowFrame::new(has_order_by));
udwf.null_treatment = null_treatment;
Expr::WindowFunction(udwf)
}
};
Ok(fun_expr)
}
}
impl ExprFunctionExt for ExprFuncBuilder {
fn order_by(mut self, order_by: Vec<Expr>) -> ExprFuncBuilder {
self.order_by = Some(order_by);
self
}
fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
self.filter = Some(filter);
self
}
fn distinct(mut self) -> ExprFuncBuilder {
self.distinct = true;
self
}
fn null_treatment(
mut self,
null_treatment: impl Into<Option<NullTreatment>>,
) -> ExprFuncBuilder {
self.null_treatment = null_treatment.into();
self
}
fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
self.partition_by = Some(partition_by);
self
}
fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
self.window_frame = Some(window_frame);
self
}
}
impl ExprFunctionExt for Expr {
fn order_by(self, order_by: Vec<Expr>) -> ExprFuncBuilder {
let mut builder = match self {
Expr::AggregateFunction(udaf) => {
ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
}
Expr::WindowFunction(udwf) => {
ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
}
_ => ExprFuncBuilder::new(None),
};
if builder.fun.is_some() {
builder.order_by = Some(order_by);
}
builder
}
fn filter(self, filter: Expr) -> ExprFuncBuilder {
match self {
Expr::AggregateFunction(udaf) => {
let mut builder =
ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
builder.filter = Some(filter);
builder
}
_ => ExprFuncBuilder::new(None),
}
}
fn distinct(self) -> ExprFuncBuilder {
match self {
Expr::AggregateFunction(udaf) => {
let mut builder =
ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
builder.distinct = true;
builder
}
_ => ExprFuncBuilder::new(None),
}
}
fn null_treatment(
self,
null_treatment: impl Into<Option<NullTreatment>>,
) -> ExprFuncBuilder {
let mut builder = match self {
Expr::AggregateFunction(udaf) => {
ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
}
Expr::WindowFunction(udwf) => {
ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
}
_ => ExprFuncBuilder::new(None),
};
if builder.fun.is_some() {
builder.null_treatment = null_treatment.into();
}
builder
}
fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
match self {
Expr::WindowFunction(udwf) => {
let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
builder.partition_by = Some(partition_by);
builder
}
_ => ExprFuncBuilder::new(None),
}
}
fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
match self {
Expr::WindowFunction(udwf) => {
let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
builder.window_frame = Some(window_frame);
builder
}
_ => ExprFuncBuilder::new(None),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn filter_is_null_and_is_not_null() {
let col_null = col("col1");
let col_not_null = ident("col2");
assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
assert_eq!(
format!("{}", col_not_null.is_not_null()),
"col2 IS NOT NULL"
);
}
}