datafusion_physical_expr/window/
aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical exec for aggregate window function expressions.
19
20use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::standard::add_new_ordering_expr_with_partition_by;
26use crate::window::window_expr::AggregateWindowExpr;
27use crate::window::{
28    PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
29};
30use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr};
31
32use arrow::array::Array;
33use arrow::record_batch::RecordBatch;
34use arrow::{array::ArrayRef, datatypes::Field};
35use datafusion_common::{DataFusionError, Result, ScalarValue};
36use datafusion_expr::{Accumulator, WindowFrame};
37use datafusion_physical_expr_common::sort_expr::LexOrdering;
38
39/// A window expr that takes the form of an aggregate function.
40///
41/// See comments on [`WindowExpr`] for more details.
42#[derive(Debug)]
43pub struct PlainAggregateWindowExpr {
44    aggregate: Arc<AggregateFunctionExpr>,
45    partition_by: Vec<Arc<dyn PhysicalExpr>>,
46    order_by: LexOrdering,
47    window_frame: Arc<WindowFrame>,
48}
49
50impl PlainAggregateWindowExpr {
51    /// Create a new aggregate window function expression
52    pub fn new(
53        aggregate: Arc<AggregateFunctionExpr>,
54        partition_by: &[Arc<dyn PhysicalExpr>],
55        order_by: &LexOrdering,
56        window_frame: Arc<WindowFrame>,
57    ) -> Self {
58        Self {
59            aggregate,
60            partition_by: partition_by.to_vec(),
61            order_by: order_by.clone(),
62            window_frame,
63        }
64    }
65
66    /// Get aggregate expr of AggregateWindowExpr
67    pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
68        &self.aggregate
69    }
70
71    pub fn add_equal_orderings(
72        &self,
73        eq_properties: &mut EquivalenceProperties,
74        window_expr_index: usize,
75    ) {
76        if let Some(expr) = self
77            .get_aggregate_expr()
78            .get_result_ordering(window_expr_index)
79        {
80            add_new_ordering_expr_with_partition_by(
81                eq_properties,
82                expr,
83                &self.partition_by,
84            );
85        }
86    }
87}
88
89/// peer based evaluation based on the fact that batch is pre-sorted given the sort columns
90/// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same
91/// results for peers) and concatenate the results.
92impl WindowExpr for PlainAggregateWindowExpr {
93    /// Return a reference to Any that can be used for downcasting
94    fn as_any(&self) -> &dyn Any {
95        self
96    }
97
98    fn field(&self) -> Result<Field> {
99        Ok(self.aggregate.field())
100    }
101
102    fn name(&self) -> &str {
103        self.aggregate.name()
104    }
105
106    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
107        self.aggregate.expressions()
108    }
109
110    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
111        self.aggregate_evaluate(batch)
112    }
113
114    fn evaluate_stateful(
115        &self,
116        partition_batches: &PartitionBatches,
117        window_agg_state: &mut PartitionWindowAggStates,
118    ) -> Result<()> {
119        self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
120
121        // Update window frame range for each partition. As we know that
122        // non-sliding aggregations will never call `retract_batch`, this value
123        // can safely increase, and we can remove "old" parts of the state.
124        // This enables us to run queries involving UNBOUNDED PRECEDING frames
125        // using bounded memory for suitable aggregations.
126        for partition_row in partition_batches.keys() {
127            let window_state =
128                window_agg_state.get_mut(partition_row).ok_or_else(|| {
129                    DataFusionError::Execution("Cannot find state".to_string())
130                })?;
131            let state = &mut window_state.state;
132            if self.window_frame.start_bound.is_unbounded() {
133                state.window_frame_range.start =
134                    state.window_frame_range.end.saturating_sub(1);
135            }
136        }
137        Ok(())
138    }
139
140    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
141        &self.partition_by
142    }
143
144    fn order_by(&self) -> &LexOrdering {
145        self.order_by.as_ref()
146    }
147
148    fn get_window_frame(&self) -> &Arc<WindowFrame> {
149        &self.window_frame
150    }
151
152    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
153        self.aggregate.reverse_expr().map(|reverse_expr| {
154            let reverse_window_frame = self.window_frame.reverse();
155            if reverse_window_frame.is_ever_expanding() {
156                Arc::new(PlainAggregateWindowExpr::new(
157                    Arc::new(reverse_expr),
158                    &self.partition_by.clone(),
159                    reverse_order_bys(self.order_by.as_ref()).as_ref(),
160                    Arc::new(self.window_frame.reverse()),
161                )) as _
162            } else {
163                Arc::new(SlidingAggregateWindowExpr::new(
164                    Arc::new(reverse_expr),
165                    &self.partition_by.clone(),
166                    reverse_order_bys(self.order_by.as_ref()).as_ref(),
167                    Arc::new(self.window_frame.reverse()),
168                )) as _
169            }
170        })
171    }
172
173    fn uses_bounded_memory(&self) -> bool {
174        !self.window_frame.end_bound.is_unbounded()
175    }
176}
177
178impl AggregateWindowExpr for PlainAggregateWindowExpr {
179    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
180        self.aggregate.create_accumulator()
181    }
182
183    /// For a given range, calculate accumulation result inside the range on
184    /// `value_slice` and update accumulator state.
185    // We assume that `cur_range` contains `last_range` and their start points
186    // are same. In summary if `last_range` is `Range{start: a,end: b}` and
187    // `cur_range` is `Range{start: a1, end: b1}`, it is guaranteed that a1=a and b1>=b.
188    fn get_aggregate_result_inside_range(
189        &self,
190        last_range: &Range<usize>,
191        cur_range: &Range<usize>,
192        value_slice: &[ArrayRef],
193        accumulator: &mut Box<dyn Accumulator>,
194    ) -> Result<ScalarValue> {
195        if cur_range.start == cur_range.end {
196            self.aggregate
197                .default_value(self.aggregate.field().data_type())
198        } else {
199            // Accumulate any new rows that have entered the window:
200            let update_bound = cur_range.end - last_range.end;
201            // A non-sliding aggregation only processes new data, it never
202            // deals with expiring data as its starting point is always the
203            // same point (i.e. the beginning of the table/frame). Hence, we
204            // do not call `retract_batch`.
205            if update_bound > 0 {
206                let update: Vec<ArrayRef> = value_slice
207                    .iter()
208                    .map(|v| v.slice(last_range.end, update_bound))
209                    .collect();
210                accumulator.update_batch(&update)?
211            }
212            accumulator.evaluate()
213        }
214    }
215}