datafusion_expr/
expr_fn.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Functions for creating logical expressions
19
20use crate::expr::{
21    AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22    Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, WindowFunctionParams,
23};
24use crate::function::{
25    AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26    StateFieldsArgs,
27};
28use crate::{
29    conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
30    AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs,
31    ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
32};
33use crate::{
34    AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
35};
36use arrow::compute::kernels::cast_utils::{
37    parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
38};
39use arrow::datatypes::{DataType, Field};
40use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference};
41use datafusion_functions_window_common::field::WindowUDFFieldArgs;
42use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
43use sqlparser::ast::NullTreatment;
44use std::any::Any;
45use std::fmt::Debug;
46use std::ops::Not;
47use std::sync::Arc;
48
49/// Create a column expression based on a qualified or unqualified column name. Will
50/// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase).
51///
52/// For example:
53///
54/// ```rust
55/// # use datafusion_expr::col;
56/// let c1 = col("a");
57/// let c2 = col("A");
58/// assert_eq!(c1, c2);
59///
60/// // note how quoting with double quotes preserves the case
61/// let c3 = col(r#""A""#);
62/// assert_ne!(c1, c3);
63/// ```
64pub fn col(ident: impl Into<Column>) -> Expr {
65    Expr::Column(ident.into())
66}
67
68/// Create an out reference column which hold a reference that has been resolved to a field
69/// outside of the current plan.
70pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
71    Expr::OuterReferenceColumn(dt, ident.into())
72}
73
74/// Create an unqualified column expression from the provided name, without normalizing
75/// the column.
76///
77/// For example:
78///
79/// ```rust
80/// # use datafusion_expr::{col, ident};
81/// let c1 = ident("A"); // not normalized staying as column 'A'
82/// let c2 = col("A"); // normalized via SQL rules becoming column 'a'
83/// assert_ne!(c1, c2);
84///
85/// let c3 = col(r#""A""#);
86/// assert_eq!(c1, c3);
87///
88/// let c4 = col("t1.a"); // parses as relation 't1' column 'a'
89/// let c5 = ident("t1.a"); // parses as column 't1.a'
90/// assert_ne!(c4, c5);
91/// ```
92pub fn ident(name: impl Into<String>) -> Expr {
93    Expr::Column(Column::from_name(name))
94}
95
96/// Create placeholder value that will be filled in (such as `$1`)
97///
98/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
99///
100/// # Example
101///
102/// ```rust
103/// # use datafusion_expr::{placeholder};
104/// let p = placeholder("$0"); // $0, refers to parameter 1
105/// assert_eq!(p.to_string(), "$0")
106/// ```
107pub fn placeholder(id: impl Into<String>) -> Expr {
108    Expr::Placeholder(Placeholder {
109        id: id.into(),
110        data_type: None,
111    })
112}
113
114/// Create an '*' [`Expr::Wildcard`] expression that matches all columns
115///
116/// # Example
117///
118/// ```rust
119/// # use datafusion_expr::{wildcard};
120/// let p = wildcard();
121/// assert_eq!(p.to_string(), "*")
122/// ```
123pub fn wildcard() -> Expr {
124    #[expect(deprecated)]
125    Expr::Wildcard {
126        qualifier: None,
127        options: Box::new(WildcardOptions::default()),
128    }
129}
130
131/// Create an '*' [`Expr::Wildcard`] expression with the wildcard options
132pub fn wildcard_with_options(options: WildcardOptions) -> Expr {
133    #[expect(deprecated)]
134    Expr::Wildcard {
135        qualifier: None,
136        options: Box::new(options),
137    }
138}
139
140/// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table
141///
142/// # Example
143///
144/// ```rust
145/// # use datafusion_common::TableReference;
146/// # use datafusion_expr::{qualified_wildcard};
147/// let p = qualified_wildcard(TableReference::bare("t"));
148/// assert_eq!(p.to_string(), "t.*")
149/// ```
150pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> Expr {
151    #[expect(deprecated)]
152    Expr::Wildcard {
153        qualifier: Some(qualifier.into()),
154        options: Box::new(WildcardOptions::default()),
155    }
156}
157
158/// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options
159pub fn qualified_wildcard_with_options(
160    qualifier: impl Into<TableReference>,
161    options: WildcardOptions,
162) -> Expr {
163    #[expect(deprecated)]
164    Expr::Wildcard {
165        qualifier: Some(qualifier.into()),
166        options: Box::new(options),
167    }
168}
169
170/// Return a new expression `left <op> right`
171pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
172    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
173}
174
175/// Return a new expression with a logical AND
176pub fn and(left: Expr, right: Expr) -> Expr {
177    Expr::BinaryExpr(BinaryExpr::new(
178        Box::new(left),
179        Operator::And,
180        Box::new(right),
181    ))
182}
183
184/// Return a new expression with a logical OR
185pub fn or(left: Expr, right: Expr) -> Expr {
186    Expr::BinaryExpr(BinaryExpr::new(
187        Box::new(left),
188        Operator::Or,
189        Box::new(right),
190    ))
191}
192
193/// Return a new expression with a logical NOT
194pub fn not(expr: Expr) -> Expr {
195    expr.not()
196}
197
198/// Return a new expression with bitwise AND
199pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
200    Expr::BinaryExpr(BinaryExpr::new(
201        Box::new(left),
202        Operator::BitwiseAnd,
203        Box::new(right),
204    ))
205}
206
207/// Return a new expression with bitwise OR
208pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
209    Expr::BinaryExpr(BinaryExpr::new(
210        Box::new(left),
211        Operator::BitwiseOr,
212        Box::new(right),
213    ))
214}
215
216/// Return a new expression with bitwise XOR
217pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
218    Expr::BinaryExpr(BinaryExpr::new(
219        Box::new(left),
220        Operator::BitwiseXor,
221        Box::new(right),
222    ))
223}
224
225/// Return a new expression with bitwise SHIFT RIGHT
226pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
227    Expr::BinaryExpr(BinaryExpr::new(
228        Box::new(left),
229        Operator::BitwiseShiftRight,
230        Box::new(right),
231    ))
232}
233
234/// Return a new expression with bitwise SHIFT LEFT
235pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
236    Expr::BinaryExpr(BinaryExpr::new(
237        Box::new(left),
238        Operator::BitwiseShiftLeft,
239        Box::new(right),
240    ))
241}
242
243/// Create an in_list expression
244pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
245    Expr::InList(InList::new(Box::new(expr), list, negated))
246}
247
248/// Create an EXISTS subquery expression
249pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
250    let outer_ref_columns = subquery.all_out_ref_exprs();
251    Expr::Exists(Exists {
252        subquery: Subquery {
253            subquery,
254            outer_ref_columns,
255        },
256        negated: false,
257    })
258}
259
260/// Create a NOT EXISTS subquery expression
261pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
262    let outer_ref_columns = subquery.all_out_ref_exprs();
263    Expr::Exists(Exists {
264        subquery: Subquery {
265            subquery,
266            outer_ref_columns,
267        },
268        negated: true,
269    })
270}
271
272/// Create an IN subquery expression
273pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
274    let outer_ref_columns = subquery.all_out_ref_exprs();
275    Expr::InSubquery(InSubquery::new(
276        Box::new(expr),
277        Subquery {
278            subquery,
279            outer_ref_columns,
280        },
281        false,
282    ))
283}
284
285/// Create a NOT IN subquery expression
286pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
287    let outer_ref_columns = subquery.all_out_ref_exprs();
288    Expr::InSubquery(InSubquery::new(
289        Box::new(expr),
290        Subquery {
291            subquery,
292            outer_ref_columns,
293        },
294        true,
295    ))
296}
297
298/// Create a scalar subquery expression
299pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
300    let outer_ref_columns = subquery.all_out_ref_exprs();
301    Expr::ScalarSubquery(Subquery {
302        subquery,
303        outer_ref_columns,
304    })
305}
306
307/// Create a grouping set
308pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
309    Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
310}
311
312/// Create a grouping set for all combination of `exprs`
313pub fn cube(exprs: Vec<Expr>) -> Expr {
314    Expr::GroupingSet(GroupingSet::Cube(exprs))
315}
316
317/// Create a grouping set for rollup
318pub fn rollup(exprs: Vec<Expr>) -> Expr {
319    Expr::GroupingSet(GroupingSet::Rollup(exprs))
320}
321
322/// Create a cast expression
323pub fn cast(expr: Expr, data_type: DataType) -> Expr {
324    Expr::Cast(Cast::new(Box::new(expr), data_type))
325}
326
327/// Create a try cast expression
328pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
329    Expr::TryCast(TryCast::new(Box::new(expr), data_type))
330}
331
332/// Create is null expression
333pub fn is_null(expr: Expr) -> Expr {
334    Expr::IsNull(Box::new(expr))
335}
336
337/// Create is true expression
338pub fn is_true(expr: Expr) -> Expr {
339    Expr::IsTrue(Box::new(expr))
340}
341
342/// Create is not true expression
343pub fn is_not_true(expr: Expr) -> Expr {
344    Expr::IsNotTrue(Box::new(expr))
345}
346
347/// Create is false expression
348pub fn is_false(expr: Expr) -> Expr {
349    Expr::IsFalse(Box::new(expr))
350}
351
352/// Create is not false expression
353pub fn is_not_false(expr: Expr) -> Expr {
354    Expr::IsNotFalse(Box::new(expr))
355}
356
357/// Create is unknown expression
358pub fn is_unknown(expr: Expr) -> Expr {
359    Expr::IsUnknown(Box::new(expr))
360}
361
362/// Create is not unknown expression
363pub fn is_not_unknown(expr: Expr) -> Expr {
364    Expr::IsNotUnknown(Box::new(expr))
365}
366
367/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
368pub fn case(expr: Expr) -> CaseBuilder {
369    CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
370}
371
372/// Create a CASE WHEN statement with boolean WHEN expressions and no base expression.
373pub fn when(when: Expr, then: Expr) -> CaseBuilder {
374    CaseBuilder::new(None, vec![when], vec![then], None)
375}
376
377/// Create a Unnest expression
378pub fn unnest(expr: Expr) -> Expr {
379    Expr::Unnest(Unnest {
380        expr: Box::new(expr),
381    })
382}
383
384/// Convenience method to create a new user defined scalar function (UDF) with a
385/// specific signature and specific return type.
386///
387/// Note this function does not expose all available features of [`ScalarUDF`],
388/// such as
389///
390/// * computing return types based on input types
391/// * multiple [`Signature`]s
392/// * aliases
393///
394/// See [`ScalarUDF`] for details and examples on how to use the full
395/// functionality.
396pub fn create_udf(
397    name: &str,
398    input_types: Vec<DataType>,
399    return_type: DataType,
400    volatility: Volatility,
401    fun: ScalarFunctionImplementation,
402) -> ScalarUDF {
403    ScalarUDF::from(SimpleScalarUDF::new(
404        name,
405        input_types,
406        return_type,
407        volatility,
408        fun,
409    ))
410}
411
412/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
413/// return type.
414pub struct SimpleScalarUDF {
415    name: String,
416    signature: Signature,
417    return_type: DataType,
418    fun: ScalarFunctionImplementation,
419}
420
421impl Debug for SimpleScalarUDF {
422    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
423        f.debug_struct("SimpleScalarUDF")
424            .field("name", &self.name)
425            .field("signature", &self.signature)
426            .field("return_type", &self.return_type)
427            .field("fun", &"<FUNC>")
428            .finish()
429    }
430}
431
432impl SimpleScalarUDF {
433    /// Create a new `SimpleScalarUDF` from a name, input types, return type and
434    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
435    pub fn new(
436        name: impl Into<String>,
437        input_types: Vec<DataType>,
438        return_type: DataType,
439        volatility: Volatility,
440        fun: ScalarFunctionImplementation,
441    ) -> Self {
442        Self::new_with_signature(
443            name,
444            Signature::exact(input_types, volatility),
445            return_type,
446            fun,
447        )
448    }
449
450    /// Create a new `SimpleScalarUDF` from a name, signature, return type and
451    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
452    pub fn new_with_signature(
453        name: impl Into<String>,
454        signature: Signature,
455        return_type: DataType,
456        fun: ScalarFunctionImplementation,
457    ) -> Self {
458        Self {
459            name: name.into(),
460            signature,
461            return_type,
462            fun,
463        }
464    }
465}
466
467impl ScalarUDFImpl for SimpleScalarUDF {
468    fn as_any(&self) -> &dyn Any {
469        self
470    }
471
472    fn name(&self) -> &str {
473        &self.name
474    }
475
476    fn signature(&self) -> &Signature {
477        &self.signature
478    }
479
480    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
481        Ok(self.return_type.clone())
482    }
483
484    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
485        (self.fun)(&args.args)
486    }
487}
488
489/// Creates a new UDAF with a specific signature, state type and return type.
490/// The signature and state type must match the `Accumulator's implementation`.
491pub fn create_udaf(
492    name: &str,
493    input_type: Vec<DataType>,
494    return_type: Arc<DataType>,
495    volatility: Volatility,
496    accumulator: AccumulatorFactoryFunction,
497    state_type: Arc<Vec<DataType>>,
498) -> AggregateUDF {
499    let return_type = Arc::unwrap_or_clone(return_type);
500    let state_type = Arc::unwrap_or_clone(state_type);
501    let state_fields = state_type
502        .into_iter()
503        .enumerate()
504        .map(|(i, t)| Field::new(format!("{i}"), t, true))
505        .collect::<Vec<_>>();
506    AggregateUDF::from(SimpleAggregateUDF::new(
507        name,
508        input_type,
509        return_type,
510        volatility,
511        accumulator,
512        state_fields,
513    ))
514}
515
516/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
517/// return type.
518pub struct SimpleAggregateUDF {
519    name: String,
520    signature: Signature,
521    return_type: DataType,
522    accumulator: AccumulatorFactoryFunction,
523    state_fields: Vec<Field>,
524}
525
526impl Debug for SimpleAggregateUDF {
527    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
528        f.debug_struct("SimpleAggregateUDF")
529            .field("name", &self.name)
530            .field("signature", &self.signature)
531            .field("return_type", &self.return_type)
532            .field("fun", &"<FUNC>")
533            .finish()
534    }
535}
536
537impl SimpleAggregateUDF {
538    /// Create a new `SimpleAggregateUDF` from a name, input types, return type, state type and
539    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
540    pub fn new(
541        name: impl Into<String>,
542        input_type: Vec<DataType>,
543        return_type: DataType,
544        volatility: Volatility,
545        accumulator: AccumulatorFactoryFunction,
546        state_fields: Vec<Field>,
547    ) -> Self {
548        let name = name.into();
549        let signature = Signature::exact(input_type, volatility);
550        Self {
551            name,
552            signature,
553            return_type,
554            accumulator,
555            state_fields,
556        }
557    }
558
559    /// Create a new `SimpleAggregateUDF` from a name, signature, return type, state type and
560    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
561    pub fn new_with_signature(
562        name: impl Into<String>,
563        signature: Signature,
564        return_type: DataType,
565        accumulator: AccumulatorFactoryFunction,
566        state_fields: Vec<Field>,
567    ) -> Self {
568        let name = name.into();
569        Self {
570            name,
571            signature,
572            return_type,
573            accumulator,
574            state_fields,
575        }
576    }
577}
578
579impl AggregateUDFImpl for SimpleAggregateUDF {
580    fn as_any(&self) -> &dyn Any {
581        self
582    }
583
584    fn name(&self) -> &str {
585        &self.name
586    }
587
588    fn signature(&self) -> &Signature {
589        &self.signature
590    }
591
592    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
593        Ok(self.return_type.clone())
594    }
595
596    fn accumulator(
597        &self,
598        acc_args: AccumulatorArgs,
599    ) -> Result<Box<dyn crate::Accumulator>> {
600        (self.accumulator)(acc_args)
601    }
602
603    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
604        Ok(self.state_fields.clone())
605    }
606}
607
608/// Creates a new UDWF with a specific signature, state type and return type.
609///
610/// The signature and state type must match the [`PartitionEvaluator`]'s implementation`.
611///
612/// [`PartitionEvaluator`]: crate::PartitionEvaluator
613pub fn create_udwf(
614    name: &str,
615    input_type: DataType,
616    return_type: Arc<DataType>,
617    volatility: Volatility,
618    partition_evaluator_factory: PartitionEvaluatorFactory,
619) -> WindowUDF {
620    let return_type = Arc::unwrap_or_clone(return_type);
621    WindowUDF::from(SimpleWindowUDF::new(
622        name,
623        input_type,
624        return_type,
625        volatility,
626        partition_evaluator_factory,
627    ))
628}
629
630/// Implements [`WindowUDFImpl`] for functions that have a single signature and
631/// return type.
632pub struct SimpleWindowUDF {
633    name: String,
634    signature: Signature,
635    return_type: DataType,
636    partition_evaluator_factory: PartitionEvaluatorFactory,
637}
638
639impl Debug for SimpleWindowUDF {
640    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
641        f.debug_struct("WindowUDF")
642            .field("name", &self.name)
643            .field("signature", &self.signature)
644            .field("return_type", &"<func>")
645            .field("partition_evaluator_factory", &"<FUNC>")
646            .finish()
647    }
648}
649
650impl SimpleWindowUDF {
651    /// Create a new `SimpleWindowUDF` from a name, input types, return type and
652    /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility
653    pub fn new(
654        name: impl Into<String>,
655        input_type: DataType,
656        return_type: DataType,
657        volatility: Volatility,
658        partition_evaluator_factory: PartitionEvaluatorFactory,
659    ) -> Self {
660        let name = name.into();
661        let signature = Signature::exact([input_type].to_vec(), volatility);
662        Self {
663            name,
664            signature,
665            return_type,
666            partition_evaluator_factory,
667        }
668    }
669}
670
671impl WindowUDFImpl for SimpleWindowUDF {
672    fn as_any(&self) -> &dyn Any {
673        self
674    }
675
676    fn name(&self) -> &str {
677        &self.name
678    }
679
680    fn signature(&self) -> &Signature {
681        &self.signature
682    }
683
684    fn partition_evaluator(
685        &self,
686        _partition_evaluator_args: PartitionEvaluatorArgs,
687    ) -> Result<Box<dyn PartitionEvaluator>> {
688        (self.partition_evaluator_factory)()
689    }
690
691    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
692        Ok(Field::new(
693            field_args.name(),
694            self.return_type.clone(),
695            true,
696        ))
697    }
698}
699
700pub fn interval_year_month_lit(value: &str) -> Expr {
701    let interval = parse_interval_year_month(value).ok();
702    Expr::Literal(ScalarValue::IntervalYearMonth(interval))
703}
704
705pub fn interval_datetime_lit(value: &str) -> Expr {
706    let interval = parse_interval_day_time(value).ok();
707    Expr::Literal(ScalarValue::IntervalDayTime(interval))
708}
709
710pub fn interval_month_day_nano_lit(value: &str) -> Expr {
711    let interval = parse_interval_month_day_nano(value).ok();
712    Expr::Literal(ScalarValue::IntervalMonthDayNano(interval))
713}
714
715/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
716///
717/// Adds methods to [`Expr`] that make it easy to set optional options
718/// such as `ORDER BY`, `FILTER` and `DISTINCT`
719///
720/// # Example
721/// ```no_run
722/// # use datafusion_common::Result;
723/// # use datafusion_expr::test::function_stub::count;
724/// # use sqlparser::ast::NullTreatment;
725/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col};
726/// # // first_value is an aggregate function in another crate
727/// # fn first_value(_arg: Expr) -> Expr {
728/// unimplemented!() }
729/// # fn main() -> Result<()> {
730/// // Create an aggregate count, filtering on column y > 5
731/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?;
732///
733/// // Find the first value in an aggregate sorted by column y
734/// // equivalent to:
735/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)`
736/// let sort_expr = col("y").sort(true, true);
737/// let agg = first_value(col("x"))
738///     .order_by(vec![sort_expr])
739///     .null_treatment(NullTreatment::IgnoreNulls)
740///     .build()?;
741///
742/// // Create a window expression for percent rank partitioned on column a
743/// // equivalent to:
744/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)`
745/// // percent_rank is an udwf function in another crate
746/// # fn percent_rank() -> Expr {
747/// unimplemented!() }
748/// let window = percent_rank()
749///     .partition_by(vec![col("a")])
750///     .order_by(vec![col("b").sort(true, true)])
751///     .null_treatment(NullTreatment::IgnoreNulls)
752///     .build()?;
753/// #     Ok(())
754/// # }
755/// ```
756pub trait ExprFunctionExt {
757    /// Add `ORDER BY <order_by>`
758    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
759    /// Add `FILTER <filter>`
760    fn filter(self, filter: Expr) -> ExprFuncBuilder;
761    /// Add `DISTINCT`
762    fn distinct(self) -> ExprFuncBuilder;
763    /// Add `RESPECT NULLS` or `IGNORE NULLS`
764    fn null_treatment(
765        self,
766        null_treatment: impl Into<Option<NullTreatment>>,
767    ) -> ExprFuncBuilder;
768    /// Add `PARTITION BY`
769    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
770    /// Add appropriate window frame conditions
771    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
772}
773
774#[derive(Debug, Clone)]
775pub enum ExprFuncKind {
776    Aggregate(AggregateFunction),
777    Window(WindowFunction),
778}
779
780/// Implementation of [`ExprFunctionExt`].
781///
782/// See [`ExprFunctionExt`] for usage and examples
783#[derive(Debug, Clone)]
784pub struct ExprFuncBuilder {
785    fun: Option<ExprFuncKind>,
786    order_by: Option<Vec<Sort>>,
787    filter: Option<Expr>,
788    distinct: bool,
789    null_treatment: Option<NullTreatment>,
790    partition_by: Option<Vec<Expr>>,
791    window_frame: Option<WindowFrame>,
792}
793
794impl ExprFuncBuilder {
795    /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`]
796    fn new(fun: Option<ExprFuncKind>) -> Self {
797        Self {
798            fun,
799            order_by: None,
800            filter: None,
801            distinct: false,
802            null_treatment: None,
803            partition_by: None,
804            window_frame: None,
805        }
806    }
807
808    /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
809    ///
810    /// # Errors:
811    ///
812    /// Returns an error if this builder  [`ExprFunctionExt`] was used with an
813    /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
814    pub fn build(self) -> Result<Expr> {
815        let Self {
816            fun,
817            order_by,
818            filter,
819            distinct,
820            null_treatment,
821            partition_by,
822            window_frame,
823        } = self;
824
825        let Some(fun) = fun else {
826            return plan_err!(
827                "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
828            );
829        };
830
831        let fun_expr = match fun {
832            ExprFuncKind::Aggregate(mut udaf) => {
833                udaf.params.order_by = order_by;
834                udaf.params.filter = filter.map(Box::new);
835                udaf.params.distinct = distinct;
836                udaf.params.null_treatment = null_treatment;
837                Expr::AggregateFunction(udaf)
838            }
839            ExprFuncKind::Window(WindowFunction {
840                fun,
841                params: WindowFunctionParams { args, .. },
842            }) => {
843                let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
844                Expr::WindowFunction(WindowFunction {
845                    fun,
846                    params: WindowFunctionParams {
847                        args,
848                        partition_by: partition_by.unwrap_or_default(),
849                        order_by: order_by.unwrap_or_default(),
850                        window_frame: window_frame
851                            .unwrap_or(WindowFrame::new(has_order_by)),
852                        null_treatment,
853                    },
854                })
855            }
856        };
857
858        Ok(fun_expr)
859    }
860}
861
862impl ExprFunctionExt for ExprFuncBuilder {
863    /// Add `ORDER BY <order_by>`
864    fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
865        self.order_by = Some(order_by);
866        self
867    }
868
869    /// Add `FILTER <filter>`
870    fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
871        self.filter = Some(filter);
872        self
873    }
874
875    /// Add `DISTINCT`
876    fn distinct(mut self) -> ExprFuncBuilder {
877        self.distinct = true;
878        self
879    }
880
881    /// Add `RESPECT NULLS` or `IGNORE NULLS`
882    fn null_treatment(
883        mut self,
884        null_treatment: impl Into<Option<NullTreatment>>,
885    ) -> ExprFuncBuilder {
886        self.null_treatment = null_treatment.into();
887        self
888    }
889
890    fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
891        self.partition_by = Some(partition_by);
892        self
893    }
894
895    fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
896        self.window_frame = Some(window_frame);
897        self
898    }
899}
900
901impl ExprFunctionExt for Expr {
902    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
903        let mut builder = match self {
904            Expr::AggregateFunction(udaf) => {
905                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
906            }
907            Expr::WindowFunction(udwf) => {
908                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
909            }
910            _ => ExprFuncBuilder::new(None),
911        };
912        if builder.fun.is_some() {
913            builder.order_by = Some(order_by);
914        }
915        builder
916    }
917    fn filter(self, filter: Expr) -> ExprFuncBuilder {
918        match self {
919            Expr::AggregateFunction(udaf) => {
920                let mut builder =
921                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
922                builder.filter = Some(filter);
923                builder
924            }
925            _ => ExprFuncBuilder::new(None),
926        }
927    }
928    fn distinct(self) -> ExprFuncBuilder {
929        match self {
930            Expr::AggregateFunction(udaf) => {
931                let mut builder =
932                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
933                builder.distinct = true;
934                builder
935            }
936            _ => ExprFuncBuilder::new(None),
937        }
938    }
939    fn null_treatment(
940        self,
941        null_treatment: impl Into<Option<NullTreatment>>,
942    ) -> ExprFuncBuilder {
943        let mut builder = match self {
944            Expr::AggregateFunction(udaf) => {
945                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
946            }
947            Expr::WindowFunction(udwf) => {
948                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
949            }
950            _ => ExprFuncBuilder::new(None),
951        };
952        if builder.fun.is_some() {
953            builder.null_treatment = null_treatment.into();
954        }
955        builder
956    }
957
958    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
959        match self {
960            Expr::WindowFunction(udwf) => {
961                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
962                builder.partition_by = Some(partition_by);
963                builder
964            }
965            _ => ExprFuncBuilder::new(None),
966        }
967    }
968
969    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
970        match self {
971            Expr::WindowFunction(udwf) => {
972                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
973                builder.window_frame = Some(window_frame);
974                builder
975            }
976            _ => ExprFuncBuilder::new(None),
977        }
978    }
979}
980
981#[cfg(test)]
982mod test {
983    use super::*;
984
985    #[test]
986    fn filter_is_null_and_is_not_null() {
987        let col_null = col("col1");
988        let col_not_null = ident("col2");
989        assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
990        assert_eq!(
991            format!("{}", col_not_null.is_not_null()),
992            "col2 IS NOT NULL"
993        );
994    }
995}