datafusion_physical_expr/window/
window_expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::Debug;
20use std::ops::Range;
21use std::sync::Arc;
22
23use crate::{LexOrdering, PhysicalExpr};
24
25use arrow::array::{new_empty_array, Array, ArrayRef};
26use arrow::compute::kernels::sort::SortColumn;
27use arrow::compute::SortOptions;
28use arrow::datatypes::Field;
29use arrow::record_batch::RecordBatch;
30use datafusion_common::utils::compare_rows;
31use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
32use datafusion_expr::window_state::{
33    PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups,
34};
35use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound};
36
37use indexmap::IndexMap;
38
39/// Common trait for [window function] implementations
40///
41/// # Aggregate Window Expressions
42///
43/// These expressions take the form
44///
45/// ```text
46/// OVER({ROWS | RANGE| GROUPS} BETWEEN UNBOUNDED PRECEDING AND ...)
47/// ```
48///
49/// For example, cumulative window frames uses `PlainAggregateWindowExpr`.
50///
51/// # Non Aggregate Window Expressions
52///
53/// The expressions have the form
54///
55/// ```text
56/// OVER({ROWS | RANGE| GROUPS} BETWEEN M {PRECEDING| FOLLOWING} AND ...)
57/// ```
58///
59/// For example, sliding window frames use [`SlidingAggregateWindowExpr`].
60///
61/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
62/// [`PlainAggregateWindowExpr`]: crate::window::PlainAggregateWindowExpr
63/// [`SlidingAggregateWindowExpr`]: crate::window::SlidingAggregateWindowExpr
64pub trait WindowExpr: Send + Sync + Debug {
65    /// Returns the window expression as [`Any`] so that it can be
66    /// downcast to a specific implementation.
67    fn as_any(&self) -> &dyn Any;
68
69    /// The field of the final result of this window function.
70    fn field(&self) -> Result<Field>;
71
72    /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default
73    /// implementation returns placeholder text.
74    fn name(&self) -> &str {
75        "WindowExpr: default name"
76    }
77
78    /// Expressions that are passed to the WindowAccumulator.
79    /// Functions which take a single input argument, such as `sum`, return a single [`datafusion_expr::expr::Expr`],
80    /// others (e.g. `cov`) return many.
81    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
82
83    /// Evaluate the window function arguments against the batch and return
84    /// array ref, normally the resulting `Vec` is a single element one.
85    fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
86        self.expressions()
87            .iter()
88            .map(|e| {
89                e.evaluate(batch)
90                    .and_then(|v| v.into_array(batch.num_rows()))
91            })
92            .collect()
93    }
94
95    /// Evaluate the window function values against the batch
96    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
97
98    /// Evaluate the window function against the batch. This function facilitates
99    /// stateful, bounded-memory implementations.
100    fn evaluate_stateful(
101        &self,
102        _partition_batches: &PartitionBatches,
103        _window_agg_state: &mut PartitionWindowAggStates,
104    ) -> Result<()> {
105        internal_err!("evaluate_stateful is not implemented for {}", self.name())
106    }
107
108    /// Expressions that's from the window function's partition by clause, empty if absent
109    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
110
111    /// Expressions that's from the window function's order by clause, empty if absent
112    fn order_by(&self) -> &LexOrdering;
113
114    /// Get order by columns, empty if absent
115    fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
116        self.order_by()
117            .iter()
118            .map(|e| e.evaluate_to_sort_column(batch))
119            .collect::<Result<Vec<SortColumn>>>()
120    }
121
122    /// Get the window frame of this [WindowExpr].
123    fn get_window_frame(&self) -> &Arc<WindowFrame>;
124
125    /// Return a flag indicating whether this [WindowExpr] can run with
126    /// bounded memory.
127    fn uses_bounded_memory(&self) -> bool;
128
129    /// Get the reverse expression of this [WindowExpr].
130    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
131
132    /// Returns all expressions used in the [`WindowExpr`].
133    /// These expressions are (1) function arguments, (2) partition by expressions, (3) order by expressions.
134    fn all_expressions(&self) -> WindowPhysicalExpressions {
135        let args = self.expressions();
136        let partition_by_exprs = self.partition_by().to_vec();
137        let order_by_exprs = self
138            .order_by()
139            .iter()
140            .map(|sort_expr| Arc::clone(&sort_expr.expr))
141            .collect::<Vec<_>>();
142        WindowPhysicalExpressions {
143            args,
144            partition_by_exprs,
145            order_by_exprs,
146        }
147    }
148
149    /// Rewrites [`WindowExpr`], with new expressions given. The argument should be consistent
150    /// with the return value of the [`WindowExpr::all_expressions`] method.
151    /// Returns `Some(Arc<dyn WindowExpr>)` if re-write is supported, otherwise returns `None`.
152    fn with_new_expressions(
153        &self,
154        _args: Vec<Arc<dyn PhysicalExpr>>,
155        _partition_bys: Vec<Arc<dyn PhysicalExpr>>,
156        _order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
157    ) -> Option<Arc<dyn WindowExpr>> {
158        None
159    }
160}
161
162/// Stores the physical expressions used inside the `WindowExpr`.
163pub struct WindowPhysicalExpressions {
164    /// Window function arguments
165    pub args: Vec<Arc<dyn PhysicalExpr>>,
166    /// PARTITION BY expressions
167    pub partition_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
168    /// ORDER BY expressions
169    pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
170}
171
172/// Extension trait that adds common functionality to [`AggregateWindowExpr`]s
173pub trait AggregateWindowExpr: WindowExpr {
174    /// Get the accumulator for the window expression. Note that distinct
175    /// window expressions may return distinct accumulators; e.g. sliding
176    /// (non-sliding) expressions will return sliding (normal) accumulators.
177    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>;
178
179    /// Given current range and the last range, calculates the accumulator
180    /// result for the range of interest.
181    fn get_aggregate_result_inside_range(
182        &self,
183        last_range: &Range<usize>,
184        cur_range: &Range<usize>,
185        value_slice: &[ArrayRef],
186        accumulator: &mut Box<dyn Accumulator>,
187    ) -> Result<ScalarValue>;
188
189    /// Evaluates the window function against the batch.
190    fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
191        let mut accumulator = self.get_accumulator()?;
192        let mut last_range = Range { start: 0, end: 0 };
193        let sort_options: Vec<SortOptions> =
194            self.order_by().iter().map(|o| o.options).collect();
195        let mut window_frame_ctx =
196            WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options);
197        self.get_result_column(
198            &mut accumulator,
199            batch,
200            None,
201            &mut last_range,
202            &mut window_frame_ctx,
203            0,
204            false,
205        )
206    }
207
208    /// Statefully evaluates the window function against the batch. Maintains
209    /// state so that it can work incrementally over multiple chunks.
210    fn aggregate_evaluate_stateful(
211        &self,
212        partition_batches: &PartitionBatches,
213        window_agg_state: &mut PartitionWindowAggStates,
214    ) -> Result<()> {
215        let field = self.field()?;
216        let out_type = field.data_type();
217        for (partition_row, partition_batch_state) in partition_batches.iter() {
218            if !window_agg_state.contains_key(partition_row) {
219                let accumulator = self.get_accumulator()?;
220                window_agg_state.insert(
221                    partition_row.clone(),
222                    WindowState {
223                        state: WindowAggState::new(out_type)?,
224                        window_fn: WindowFn::Aggregate(accumulator),
225                    },
226                );
227            };
228            let window_state =
229                window_agg_state.get_mut(partition_row).ok_or_else(|| {
230                    DataFusionError::Execution("Cannot find state".to_string())
231                })?;
232            let accumulator = match &mut window_state.window_fn {
233                WindowFn::Aggregate(accumulator) => accumulator,
234                _ => unreachable!(),
235            };
236            let state = &mut window_state.state;
237            let record_batch = &partition_batch_state.record_batch;
238            let most_recent_row = partition_batch_state.most_recent_row.as_ref();
239
240            // If there is no window state context, initialize it.
241            let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| {
242                let sort_options: Vec<SortOptions> =
243                    self.order_by().iter().map(|o| o.options).collect();
244                WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options)
245            });
246            let out_col = self.get_result_column(
247                accumulator,
248                record_batch,
249                most_recent_row,
250                // Start search from the last range
251                &mut state.window_frame_range,
252                window_frame_ctx,
253                state.last_calculated_index,
254                !partition_batch_state.is_end,
255            )?;
256            state.update(&out_col, partition_batch_state)?;
257        }
258        Ok(())
259    }
260
261    /// Calculates the window expression result for the given record batch.
262    /// Assumes that `record_batch` belongs to a single partition.
263    #[allow(clippy::too_many_arguments)]
264    fn get_result_column(
265        &self,
266        accumulator: &mut Box<dyn Accumulator>,
267        record_batch: &RecordBatch,
268        most_recent_row: Option<&RecordBatch>,
269        last_range: &mut Range<usize>,
270        window_frame_ctx: &mut WindowFrameContext,
271        mut idx: usize,
272        not_end: bool,
273    ) -> Result<ArrayRef> {
274        let values = self.evaluate_args(record_batch)?;
275        let order_bys = get_orderby_values(self.order_by_columns(record_batch)?);
276
277        let most_recent_row_order_bys = most_recent_row
278            .map(|batch| self.order_by_columns(batch))
279            .transpose()?
280            .map(get_orderby_values);
281
282        // We iterate on each row to perform a running calculation.
283        let length = values[0].len();
284        let mut row_wise_results: Vec<ScalarValue> = vec![];
285        let is_causal = self.get_window_frame().is_causal();
286        while idx < length {
287            // Start search from the last_range. This squeezes searched range.
288            let cur_range =
289                window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?;
290            // Exit if the range is non-causal and extends all the way:
291            if cur_range.end == length
292                && !is_causal
293                && not_end
294                && !is_end_bound_safe(
295                    window_frame_ctx,
296                    &order_bys,
297                    most_recent_row_order_bys.as_deref(),
298                    self.order_by(),
299                    idx,
300                )?
301            {
302                break;
303            }
304            let value = self.get_aggregate_result_inside_range(
305                last_range,
306                &cur_range,
307                &values,
308                accumulator,
309            )?;
310            // Update last range
311            *last_range = cur_range;
312            row_wise_results.push(value);
313            idx += 1;
314        }
315
316        if row_wise_results.is_empty() {
317            let field = self.field()?;
318            let out_type = field.data_type();
319            Ok(new_empty_array(out_type))
320        } else {
321            ScalarValue::iter_to_array(row_wise_results)
322        }
323    }
324}
325
326/// Determines whether the end bound calculation for a window frame context is
327/// safe, meaning that the end bound stays the same, regardless of future data,
328/// based on the current sort expressions and ORDER BY columns. This function
329/// delegates work to specific functions for each frame type.
330///
331/// # Parameters
332///
333/// * `window_frame_ctx`: The context of the window frame being evaluated.
334/// * `order_bys`: A slice of `ArrayRef` representing the ORDER BY columns.
335/// * `most_recent_order_bys`: An optional reference to the most recent ORDER BY
336///   columns.
337/// * `sort_exprs`: Defines the lexicographical ordering in question.
338/// * `idx`: The current index in the window frame.
339///
340/// # Returns
341///
342/// A `Result` which is `Ok(true)` if the end bound is safe, `Ok(false)` otherwise.
343pub(crate) fn is_end_bound_safe(
344    window_frame_ctx: &WindowFrameContext,
345    order_bys: &[ArrayRef],
346    most_recent_order_bys: Option<&[ArrayRef]>,
347    sort_exprs: &LexOrdering,
348    idx: usize,
349) -> Result<bool> {
350    if sort_exprs.is_empty() {
351        // Early return if no sort expressions are present:
352        return Ok(false);
353    }
354
355    match window_frame_ctx {
356        WindowFrameContext::Rows(window_frame) => {
357            is_end_bound_safe_for_rows(&window_frame.end_bound)
358        }
359        WindowFrameContext::Range { window_frame, .. } => is_end_bound_safe_for_range(
360            &window_frame.end_bound,
361            &order_bys[0],
362            most_recent_order_bys.map(|items| &items[0]),
363            &sort_exprs[0].options,
364            idx,
365        ),
366        WindowFrameContext::Groups {
367            window_frame,
368            state,
369        } => is_end_bound_safe_for_groups(
370            &window_frame.end_bound,
371            state,
372            &order_bys[0],
373            most_recent_order_bys.map(|items| &items[0]),
374            &sort_exprs[0].options,
375        ),
376    }
377}
378
379/// For row-based window frames, determines whether the end bound calculation
380/// is safe, which is trivially the case for `Preceding` and `CurrentRow` bounds.
381/// For 'Following' bounds, it compares the bound value to zero to ensure that
382/// it doesn't extend beyond the current row.
383///
384/// # Parameters
385///
386/// * `end_bound`: Reference to the window frame bound in question.
387///
388/// # Returns
389///
390/// A `Result` indicating whether the end bound is safe for row-based window frames.
391fn is_end_bound_safe_for_rows(end_bound: &WindowFrameBound) -> Result<bool> {
392    if let WindowFrameBound::Following(value) = end_bound {
393        let zero = ScalarValue::new_zero(&value.data_type());
394        Ok(zero.map(|zero| value.eq(&zero)).unwrap_or(false))
395    } else {
396        Ok(true)
397    }
398}
399
400/// For row-based window frames, determines whether the end bound calculation
401/// is safe by comparing it against specific values (zero, current row). It uses
402/// the `is_row_ahead` helper function to determine if the current row is ahead
403/// of the most recent row based on the ORDER BY column and sorting options.
404///
405/// # Parameters
406///
407/// * `end_bound`: Reference to the window frame bound in question.
408/// * `orderby_col`: Reference to the column used for ordering.
409/// * `most_recent_ob_col`: Optional reference to the most recent order-by column.
410/// * `sort_options`: The sorting options used in the window frame.
411/// * `idx`: The current index in the window frame.
412///
413/// # Returns
414///
415/// A `Result` indicating whether the end bound is safe for range-based window frames.
416fn is_end_bound_safe_for_range(
417    end_bound: &WindowFrameBound,
418    orderby_col: &ArrayRef,
419    most_recent_ob_col: Option<&ArrayRef>,
420    sort_options: &SortOptions,
421    idx: usize,
422) -> Result<bool> {
423    match end_bound {
424        WindowFrameBound::Preceding(value) => {
425            let zero = ScalarValue::new_zero(&value.data_type())?;
426            if value.eq(&zero) {
427                is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
428            } else {
429                Ok(true)
430            }
431        }
432        WindowFrameBound::CurrentRow => {
433            is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
434        }
435        WindowFrameBound::Following(delta) => {
436            let Some(most_recent_ob_col) = most_recent_ob_col else {
437                return Ok(false);
438            };
439            let most_recent_row_value =
440                ScalarValue::try_from_array(most_recent_ob_col, 0)?;
441            let current_row_value = ScalarValue::try_from_array(orderby_col, idx)?;
442
443            if sort_options.descending {
444                current_row_value
445                    .sub(delta)
446                    .map(|value| value > most_recent_row_value)
447            } else {
448                current_row_value
449                    .add(delta)
450                    .map(|value| most_recent_row_value > value)
451            }
452        }
453    }
454}
455
456/// For group-based window frames, determines whether the end bound calculation
457/// is safe by considering the group offset and whether the current row is ahead
458/// of the most recent row in terms of sorting. It checks if the end bound is
459/// within the bounds of the current group based on group end indices.
460///
461/// # Parameters
462///
463/// * `end_bound`: Reference to the window frame bound in question.
464/// * `state`: The state of the window frame for group calculations.
465/// * `orderby_col`: Reference to the column used for ordering.
466/// * `most_recent_ob_col`: Optional reference to the most recent order-by column.
467/// * `sort_options`: The sorting options used in the window frame.
468///
469/// # Returns
470///
471/// A `Result` indicating whether the end bound is safe for group-based window frames.
472fn is_end_bound_safe_for_groups(
473    end_bound: &WindowFrameBound,
474    state: &WindowFrameStateGroups,
475    orderby_col: &ArrayRef,
476    most_recent_ob_col: Option<&ArrayRef>,
477    sort_options: &SortOptions,
478) -> Result<bool> {
479    match end_bound {
480        WindowFrameBound::Preceding(value) => {
481            let zero = ScalarValue::new_zero(&value.data_type())?;
482            if value.eq(&zero) {
483                is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
484            } else {
485                Ok(true)
486            }
487        }
488        WindowFrameBound::CurrentRow => {
489            is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
490        }
491        WindowFrameBound::Following(ScalarValue::UInt64(Some(offset))) => {
492            let delta = state.group_end_indices.len() - state.current_group_idx;
493            if delta == (*offset as usize) + 1 {
494                is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
495            } else {
496                Ok(false)
497            }
498        }
499        _ => Ok(false),
500    }
501}
502
503/// This utility function checks whether `current_cols` is ahead of the `old_cols`
504/// in terms of `sort_options`.
505fn is_row_ahead(
506    old_col: &ArrayRef,
507    current_col: Option<&ArrayRef>,
508    sort_options: &SortOptions,
509) -> Result<bool> {
510    let Some(current_col) = current_col else {
511        return Ok(false);
512    };
513    if old_col.is_empty() || current_col.is_empty() {
514        return Ok(false);
515    }
516    let last_value = ScalarValue::try_from_array(old_col, old_col.len() - 1)?;
517    let current_value = ScalarValue::try_from_array(current_col, 0)?;
518    let cmp = compare_rows(&[current_value], &[last_value], &[*sort_options])?;
519    Ok(cmp.is_gt())
520}
521
522/// Get order by expression results inside `order_by_columns`.
523pub(crate) fn get_orderby_values(order_by_columns: Vec<SortColumn>) -> Vec<ArrayRef> {
524    order_by_columns.into_iter().map(|s| s.values).collect()
525}
526
527#[derive(Debug)]
528pub enum WindowFn {
529    Builtin(Box<dyn PartitionEvaluator>),
530    Aggregate(Box<dyn Accumulator>),
531}
532
533/// Key for IndexMap for each unique partition
534///
535/// For instance, if window frame is `OVER(PARTITION BY a,b)`,
536/// PartitionKey would consist of unique `[a,b]` pairs
537pub type PartitionKey = Vec<ScalarValue>;
538
539#[derive(Debug)]
540pub struct WindowState {
541    pub state: WindowAggState,
542    pub window_fn: WindowFn,
543}
544pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
545
546/// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition.
547pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
548
549#[cfg(test)]
550mod tests {
551    use std::sync::Arc;
552
553    use crate::window::window_expr::is_row_ahead;
554
555    use arrow::array::{ArrayRef, Float64Array};
556    use arrow::compute::SortOptions;
557    use datafusion_common::Result;
558
559    #[test]
560    fn test_is_row_ahead() -> Result<()> {
561        let old_values: ArrayRef =
562            Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.]));
563
564        let new_values1: ArrayRef = Arc::new(Float64Array::from(vec![11.0]));
565        let new_values2: ArrayRef = Arc::new(Float64Array::from(vec![10.0]));
566
567        assert!(is_row_ahead(
568            &old_values,
569            Some(&new_values1),
570            &SortOptions {
571                descending: false,
572                nulls_first: false
573            }
574        )?);
575        assert!(!is_row_ahead(
576            &old_values,
577            Some(&new_values2),
578            &SortOptions {
579                descending: false,
580                nulls_first: false
581            }
582        )?);
583
584        Ok(())
585    }
586}