datafusion_physical_expr/window/
sliding_aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical exec for aggregate window function expressions.
19
20use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::window_expr::AggregateWindowExpr;
26use crate::window::{
27    PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr,
28};
29use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr};
30
31use arrow::array::{Array, ArrayRef};
32use arrow::datatypes::Field;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::{Result, ScalarValue};
35use datafusion_expr::{Accumulator, WindowFrame};
36use datafusion_physical_expr_common::sort_expr::LexOrdering;
37
38/// A window expr that takes the form of an aggregate function that
39/// can be incrementally computed over sliding windows.
40///
41/// See comments on [`WindowExpr`] for more details.
42#[derive(Debug)]
43pub struct SlidingAggregateWindowExpr {
44    aggregate: Arc<AggregateFunctionExpr>,
45    partition_by: Vec<Arc<dyn PhysicalExpr>>,
46    order_by: LexOrdering,
47    window_frame: Arc<WindowFrame>,
48}
49
50impl SlidingAggregateWindowExpr {
51    /// Create a new (sliding) aggregate window function expression.
52    pub fn new(
53        aggregate: Arc<AggregateFunctionExpr>,
54        partition_by: &[Arc<dyn PhysicalExpr>],
55        order_by: &LexOrdering,
56        window_frame: Arc<WindowFrame>,
57    ) -> Self {
58        Self {
59            aggregate,
60            partition_by: partition_by.to_vec(),
61            order_by: order_by.clone(),
62            window_frame,
63        }
64    }
65
66    /// Get the [AggregateFunctionExpr] of this object.
67    pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
68        &self.aggregate
69    }
70}
71
72/// Incrementally update window function using the fact that batch is
73/// pre-sorted given the sort columns and then per partition point.
74///
75/// Evaluates the peer group (e.g. `SUM` or `MAX` gives the same results
76/// for peers) and concatenate the results.
77impl WindowExpr for SlidingAggregateWindowExpr {
78    /// Return a reference to Any that can be used for downcasting
79    fn as_any(&self) -> &dyn Any {
80        self
81    }
82
83    fn field(&self) -> Result<Field> {
84        Ok(self.aggregate.field())
85    }
86
87    fn name(&self) -> &str {
88        self.aggregate.name()
89    }
90
91    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
92        self.aggregate.expressions()
93    }
94
95    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
96        self.aggregate_evaluate(batch)
97    }
98
99    fn evaluate_stateful(
100        &self,
101        partition_batches: &PartitionBatches,
102        window_agg_state: &mut PartitionWindowAggStates,
103    ) -> Result<()> {
104        self.aggregate_evaluate_stateful(partition_batches, window_agg_state)
105    }
106
107    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
108        &self.partition_by
109    }
110
111    fn order_by(&self) -> &LexOrdering {
112        self.order_by.as_ref()
113    }
114
115    fn get_window_frame(&self) -> &Arc<WindowFrame> {
116        &self.window_frame
117    }
118
119    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
120        self.aggregate.reverse_expr().map(|reverse_expr| {
121            let reverse_window_frame = self.window_frame.reverse();
122            if reverse_window_frame.is_ever_expanding() {
123                Arc::new(PlainAggregateWindowExpr::new(
124                    Arc::new(reverse_expr),
125                    &self.partition_by.clone(),
126                    reverse_order_bys(self.order_by.as_ref()).as_ref(),
127                    Arc::new(self.window_frame.reverse()),
128                )) as _
129            } else {
130                Arc::new(SlidingAggregateWindowExpr::new(
131                    Arc::new(reverse_expr),
132                    &self.partition_by.clone(),
133                    reverse_order_bys(self.order_by.as_ref()).as_ref(),
134                    Arc::new(self.window_frame.reverse()),
135                )) as _
136            }
137        })
138    }
139
140    fn uses_bounded_memory(&self) -> bool {
141        !self.window_frame.end_bound.is_unbounded()
142    }
143
144    fn with_new_expressions(
145        &self,
146        args: Vec<Arc<dyn PhysicalExpr>>,
147        partition_bys: Vec<Arc<dyn PhysicalExpr>>,
148        order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
149    ) -> Option<Arc<dyn WindowExpr>> {
150        debug_assert_eq!(self.order_by.len(), order_by_exprs.len());
151
152        let new_order_by = self
153            .order_by
154            .iter()
155            .zip(order_by_exprs)
156            .map(|(req, new_expr)| PhysicalSortExpr {
157                expr: new_expr,
158                options: req.options,
159            })
160            .collect::<LexOrdering>();
161        Some(Arc::new(SlidingAggregateWindowExpr {
162            aggregate: self
163                .aggregate
164                .with_new_expressions(args, vec![])
165                .map(Arc::new)?,
166            partition_by: partition_bys,
167            order_by: new_order_by,
168            window_frame: Arc::clone(&self.window_frame),
169        }))
170    }
171}
172
173impl AggregateWindowExpr for SlidingAggregateWindowExpr {
174    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
175        self.aggregate.create_sliding_accumulator()
176    }
177
178    /// Given current range and the last range, calculates the accumulator
179    /// result for the range of interest.
180    fn get_aggregate_result_inside_range(
181        &self,
182        last_range: &Range<usize>,
183        cur_range: &Range<usize>,
184        value_slice: &[ArrayRef],
185        accumulator: &mut Box<dyn Accumulator>,
186    ) -> Result<ScalarValue> {
187        if cur_range.start == cur_range.end {
188            self.aggregate
189                .default_value(self.aggregate.field().data_type())
190        } else {
191            // Accumulate any new rows that have entered the window:
192            let update_bound = cur_range.end - last_range.end;
193            if update_bound > 0 {
194                let update: Vec<ArrayRef> = value_slice
195                    .iter()
196                    .map(|v| v.slice(last_range.end, update_bound))
197                    .collect();
198                accumulator.update_batch(&update)?
199            }
200
201            // Remove rows that have now left the window:
202            let retract_bound = cur_range.start - last_range.start;
203            if retract_bound > 0 {
204                let retract: Vec<ArrayRef> = value_slice
205                    .iter()
206                    .map(|v| v.slice(last_range.start, retract_bound))
207                    .collect();
208                accumulator.retract_batch(&retract)?
209            }
210            accumulator.evaluate()
211        }
212    }
213}