datafusion_physical_expr/window/
aggregate.rs1use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::standard::add_new_ordering_expr_with_partition_by;
26use crate::window::window_expr::AggregateWindowExpr;
27use crate::window::{
28 PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
29};
30use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr};
31
32use arrow::array::Array;
33use arrow::record_batch::RecordBatch;
34use arrow::{array::ArrayRef, datatypes::Field};
35use datafusion_common::{DataFusionError, Result, ScalarValue};
36use datafusion_expr::{Accumulator, WindowFrame};
37use datafusion_physical_expr_common::sort_expr::LexOrdering;
38
39#[derive(Debug)]
43pub struct PlainAggregateWindowExpr {
44 aggregate: Arc<AggregateFunctionExpr>,
45 partition_by: Vec<Arc<dyn PhysicalExpr>>,
46 order_by: LexOrdering,
47 window_frame: Arc<WindowFrame>,
48}
49
50impl PlainAggregateWindowExpr {
51 pub fn new(
53 aggregate: Arc<AggregateFunctionExpr>,
54 partition_by: &[Arc<dyn PhysicalExpr>],
55 order_by: &LexOrdering,
56 window_frame: Arc<WindowFrame>,
57 ) -> Self {
58 Self {
59 aggregate,
60 partition_by: partition_by.to_vec(),
61 order_by: order_by.clone(),
62 window_frame,
63 }
64 }
65
66 pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
68 &self.aggregate
69 }
70
71 pub fn add_equal_orderings(
72 &self,
73 eq_properties: &mut EquivalenceProperties,
74 window_expr_index: usize,
75 ) {
76 if let Some(expr) = self
77 .get_aggregate_expr()
78 .get_result_ordering(window_expr_index)
79 {
80 add_new_ordering_expr_with_partition_by(
81 eq_properties,
82 expr,
83 &self.partition_by,
84 );
85 }
86 }
87}
88
89impl WindowExpr for PlainAggregateWindowExpr {
93 fn as_any(&self) -> &dyn Any {
95 self
96 }
97
98 fn field(&self) -> Result<Field> {
99 Ok(self.aggregate.field())
100 }
101
102 fn name(&self) -> &str {
103 self.aggregate.name()
104 }
105
106 fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
107 self.aggregate.expressions()
108 }
109
110 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
111 self.aggregate_evaluate(batch)
112 }
113
114 fn evaluate_stateful(
115 &self,
116 partition_batches: &PartitionBatches,
117 window_agg_state: &mut PartitionWindowAggStates,
118 ) -> Result<()> {
119 self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
120
121 for partition_row in partition_batches.keys() {
127 let window_state =
128 window_agg_state.get_mut(partition_row).ok_or_else(|| {
129 DataFusionError::Execution("Cannot find state".to_string())
130 })?;
131 let state = &mut window_state.state;
132 if self.window_frame.start_bound.is_unbounded() {
133 state.window_frame_range.start =
134 state.window_frame_range.end.saturating_sub(1);
135 }
136 }
137 Ok(())
138 }
139
140 fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
141 &self.partition_by
142 }
143
144 fn order_by(&self) -> &LexOrdering {
145 self.order_by.as_ref()
146 }
147
148 fn get_window_frame(&self) -> &Arc<WindowFrame> {
149 &self.window_frame
150 }
151
152 fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
153 self.aggregate.reverse_expr().map(|reverse_expr| {
154 let reverse_window_frame = self.window_frame.reverse();
155 if reverse_window_frame.is_ever_expanding() {
156 Arc::new(PlainAggregateWindowExpr::new(
157 Arc::new(reverse_expr),
158 &self.partition_by.clone(),
159 reverse_order_bys(self.order_by.as_ref()).as_ref(),
160 Arc::new(self.window_frame.reverse()),
161 )) as _
162 } else {
163 Arc::new(SlidingAggregateWindowExpr::new(
164 Arc::new(reverse_expr),
165 &self.partition_by.clone(),
166 reverse_order_bys(self.order_by.as_ref()).as_ref(),
167 Arc::new(self.window_frame.reverse()),
168 )) as _
169 }
170 })
171 }
172
173 fn uses_bounded_memory(&self) -> bool {
174 !self.window_frame.end_bound.is_unbounded()
175 }
176}
177
178impl AggregateWindowExpr for PlainAggregateWindowExpr {
179 fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
180 self.aggregate.create_accumulator()
181 }
182
183 fn get_aggregate_result_inside_range(
189 &self,
190 last_range: &Range<usize>,
191 cur_range: &Range<usize>,
192 value_slice: &[ArrayRef],
193 accumulator: &mut Box<dyn Accumulator>,
194 ) -> Result<ScalarValue> {
195 if cur_range.start == cur_range.end {
196 self.aggregate
197 .default_value(self.aggregate.field().data_type())
198 } else {
199 let update_bound = cur_range.end - last_range.end;
201 if update_bound > 0 {
206 let update: Vec<ArrayRef> = value_slice
207 .iter()
208 .map(|v| v.slice(last_range.end, update_bound))
209 .collect();
210 accumulator.update_batch(&update)?
211 }
212 accumulator.evaluate()
213 }
214 }
215}