datafusion_physical_expr/window/
window_expr.rs1use std::any::Any;
19use std::fmt::Debug;
20use std::ops::Range;
21use std::sync::Arc;
22
23use crate::{LexOrdering, PhysicalExpr};
24
25use arrow::array::{new_empty_array, Array, ArrayRef};
26use arrow::compute::kernels::sort::SortColumn;
27use arrow::compute::SortOptions;
28use arrow::datatypes::Field;
29use arrow::record_batch::RecordBatch;
30use datafusion_common::utils::compare_rows;
31use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
32use datafusion_expr::window_state::{
33 PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups,
34};
35use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound};
36
37use indexmap::IndexMap;
38
39pub trait WindowExpr: Send + Sync + Debug {
65 fn as_any(&self) -> &dyn Any;
68
69 fn field(&self) -> Result<Field>;
71
72 fn name(&self) -> &str {
75 "WindowExpr: default name"
76 }
77
78 fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
82
83 fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
86 self.expressions()
87 .iter()
88 .map(|e| {
89 e.evaluate(batch)
90 .and_then(|v| v.into_array(batch.num_rows()))
91 })
92 .collect()
93 }
94
95 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
97
98 fn evaluate_stateful(
101 &self,
102 _partition_batches: &PartitionBatches,
103 _window_agg_state: &mut PartitionWindowAggStates,
104 ) -> Result<()> {
105 internal_err!("evaluate_stateful is not implemented for {}", self.name())
106 }
107
108 fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
110
111 fn order_by(&self) -> &LexOrdering;
113
114 fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
116 self.order_by()
117 .iter()
118 .map(|e| e.evaluate_to_sort_column(batch))
119 .collect::<Result<Vec<SortColumn>>>()
120 }
121
122 fn get_window_frame(&self) -> &Arc<WindowFrame>;
124
125 fn uses_bounded_memory(&self) -> bool;
128
129 fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
131
132 fn all_expressions(&self) -> WindowPhysicalExpressions {
135 let args = self.expressions();
136 let partition_by_exprs = self.partition_by().to_vec();
137 let order_by_exprs = self
138 .order_by()
139 .iter()
140 .map(|sort_expr| Arc::clone(&sort_expr.expr))
141 .collect::<Vec<_>>();
142 WindowPhysicalExpressions {
143 args,
144 partition_by_exprs,
145 order_by_exprs,
146 }
147 }
148
149 fn with_new_expressions(
153 &self,
154 _args: Vec<Arc<dyn PhysicalExpr>>,
155 _partition_bys: Vec<Arc<dyn PhysicalExpr>>,
156 _order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
157 ) -> Option<Arc<dyn WindowExpr>> {
158 None
159 }
160}
161
162pub struct WindowPhysicalExpressions {
164 pub args: Vec<Arc<dyn PhysicalExpr>>,
166 pub partition_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
168 pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
170}
171
172pub trait AggregateWindowExpr: WindowExpr {
174 fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>;
178
179 fn get_aggregate_result_inside_range(
182 &self,
183 last_range: &Range<usize>,
184 cur_range: &Range<usize>,
185 value_slice: &[ArrayRef],
186 accumulator: &mut Box<dyn Accumulator>,
187 ) -> Result<ScalarValue>;
188
189 fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
191 let mut accumulator = self.get_accumulator()?;
192 let mut last_range = Range { start: 0, end: 0 };
193 let sort_options: Vec<SortOptions> =
194 self.order_by().iter().map(|o| o.options).collect();
195 let mut window_frame_ctx =
196 WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options);
197 self.get_result_column(
198 &mut accumulator,
199 batch,
200 None,
201 &mut last_range,
202 &mut window_frame_ctx,
203 0,
204 false,
205 )
206 }
207
208 fn aggregate_evaluate_stateful(
211 &self,
212 partition_batches: &PartitionBatches,
213 window_agg_state: &mut PartitionWindowAggStates,
214 ) -> Result<()> {
215 let field = self.field()?;
216 let out_type = field.data_type();
217 for (partition_row, partition_batch_state) in partition_batches.iter() {
218 if !window_agg_state.contains_key(partition_row) {
219 let accumulator = self.get_accumulator()?;
220 window_agg_state.insert(
221 partition_row.clone(),
222 WindowState {
223 state: WindowAggState::new(out_type)?,
224 window_fn: WindowFn::Aggregate(accumulator),
225 },
226 );
227 };
228 let window_state =
229 window_agg_state.get_mut(partition_row).ok_or_else(|| {
230 DataFusionError::Execution("Cannot find state".to_string())
231 })?;
232 let accumulator = match &mut window_state.window_fn {
233 WindowFn::Aggregate(accumulator) => accumulator,
234 _ => unreachable!(),
235 };
236 let state = &mut window_state.state;
237 let record_batch = &partition_batch_state.record_batch;
238 let most_recent_row = partition_batch_state.most_recent_row.as_ref();
239
240 let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| {
242 let sort_options: Vec<SortOptions> =
243 self.order_by().iter().map(|o| o.options).collect();
244 WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options)
245 });
246 let out_col = self.get_result_column(
247 accumulator,
248 record_batch,
249 most_recent_row,
250 &mut state.window_frame_range,
252 window_frame_ctx,
253 state.last_calculated_index,
254 !partition_batch_state.is_end,
255 )?;
256 state.update(&out_col, partition_batch_state)?;
257 }
258 Ok(())
259 }
260
261 #[allow(clippy::too_many_arguments)]
264 fn get_result_column(
265 &self,
266 accumulator: &mut Box<dyn Accumulator>,
267 record_batch: &RecordBatch,
268 most_recent_row: Option<&RecordBatch>,
269 last_range: &mut Range<usize>,
270 window_frame_ctx: &mut WindowFrameContext,
271 mut idx: usize,
272 not_end: bool,
273 ) -> Result<ArrayRef> {
274 let values = self.evaluate_args(record_batch)?;
275 let order_bys = get_orderby_values(self.order_by_columns(record_batch)?);
276
277 let most_recent_row_order_bys = most_recent_row
278 .map(|batch| self.order_by_columns(batch))
279 .transpose()?
280 .map(get_orderby_values);
281
282 let length = values[0].len();
284 let mut row_wise_results: Vec<ScalarValue> = vec![];
285 let is_causal = self.get_window_frame().is_causal();
286 while idx < length {
287 let cur_range =
289 window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?;
290 if cur_range.end == length
292 && !is_causal
293 && not_end
294 && !is_end_bound_safe(
295 window_frame_ctx,
296 &order_bys,
297 most_recent_row_order_bys.as_deref(),
298 self.order_by(),
299 idx,
300 )?
301 {
302 break;
303 }
304 let value = self.get_aggregate_result_inside_range(
305 last_range,
306 &cur_range,
307 &values,
308 accumulator,
309 )?;
310 *last_range = cur_range;
312 row_wise_results.push(value);
313 idx += 1;
314 }
315
316 if row_wise_results.is_empty() {
317 let field = self.field()?;
318 let out_type = field.data_type();
319 Ok(new_empty_array(out_type))
320 } else {
321 ScalarValue::iter_to_array(row_wise_results)
322 }
323 }
324}
325
326pub(crate) fn is_end_bound_safe(
344 window_frame_ctx: &WindowFrameContext,
345 order_bys: &[ArrayRef],
346 most_recent_order_bys: Option<&[ArrayRef]>,
347 sort_exprs: &LexOrdering,
348 idx: usize,
349) -> Result<bool> {
350 if sort_exprs.is_empty() {
351 return Ok(false);
353 }
354
355 match window_frame_ctx {
356 WindowFrameContext::Rows(window_frame) => {
357 is_end_bound_safe_for_rows(&window_frame.end_bound)
358 }
359 WindowFrameContext::Range { window_frame, .. } => is_end_bound_safe_for_range(
360 &window_frame.end_bound,
361 &order_bys[0],
362 most_recent_order_bys.map(|items| &items[0]),
363 &sort_exprs[0].options,
364 idx,
365 ),
366 WindowFrameContext::Groups {
367 window_frame,
368 state,
369 } => is_end_bound_safe_for_groups(
370 &window_frame.end_bound,
371 state,
372 &order_bys[0],
373 most_recent_order_bys.map(|items| &items[0]),
374 &sort_exprs[0].options,
375 ),
376 }
377}
378
379fn is_end_bound_safe_for_rows(end_bound: &WindowFrameBound) -> Result<bool> {
392 if let WindowFrameBound::Following(value) = end_bound {
393 let zero = ScalarValue::new_zero(&value.data_type());
394 Ok(zero.map(|zero| value.eq(&zero)).unwrap_or(false))
395 } else {
396 Ok(true)
397 }
398}
399
400fn is_end_bound_safe_for_range(
417 end_bound: &WindowFrameBound,
418 orderby_col: &ArrayRef,
419 most_recent_ob_col: Option<&ArrayRef>,
420 sort_options: &SortOptions,
421 idx: usize,
422) -> Result<bool> {
423 match end_bound {
424 WindowFrameBound::Preceding(value) => {
425 let zero = ScalarValue::new_zero(&value.data_type())?;
426 if value.eq(&zero) {
427 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
428 } else {
429 Ok(true)
430 }
431 }
432 WindowFrameBound::CurrentRow => {
433 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
434 }
435 WindowFrameBound::Following(delta) => {
436 let Some(most_recent_ob_col) = most_recent_ob_col else {
437 return Ok(false);
438 };
439 let most_recent_row_value =
440 ScalarValue::try_from_array(most_recent_ob_col, 0)?;
441 let current_row_value = ScalarValue::try_from_array(orderby_col, idx)?;
442
443 if sort_options.descending {
444 current_row_value
445 .sub(delta)
446 .map(|value| value > most_recent_row_value)
447 } else {
448 current_row_value
449 .add(delta)
450 .map(|value| most_recent_row_value > value)
451 }
452 }
453 }
454}
455
456fn is_end_bound_safe_for_groups(
473 end_bound: &WindowFrameBound,
474 state: &WindowFrameStateGroups,
475 orderby_col: &ArrayRef,
476 most_recent_ob_col: Option<&ArrayRef>,
477 sort_options: &SortOptions,
478) -> Result<bool> {
479 match end_bound {
480 WindowFrameBound::Preceding(value) => {
481 let zero = ScalarValue::new_zero(&value.data_type())?;
482 if value.eq(&zero) {
483 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
484 } else {
485 Ok(true)
486 }
487 }
488 WindowFrameBound::CurrentRow => {
489 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
490 }
491 WindowFrameBound::Following(ScalarValue::UInt64(Some(offset))) => {
492 let delta = state.group_end_indices.len() - state.current_group_idx;
493 if delta == (*offset as usize) + 1 {
494 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
495 } else {
496 Ok(false)
497 }
498 }
499 _ => Ok(false),
500 }
501}
502
503fn is_row_ahead(
506 old_col: &ArrayRef,
507 current_col: Option<&ArrayRef>,
508 sort_options: &SortOptions,
509) -> Result<bool> {
510 let Some(current_col) = current_col else {
511 return Ok(false);
512 };
513 if old_col.is_empty() || current_col.is_empty() {
514 return Ok(false);
515 }
516 let last_value = ScalarValue::try_from_array(old_col, old_col.len() - 1)?;
517 let current_value = ScalarValue::try_from_array(current_col, 0)?;
518 let cmp = compare_rows(&[current_value], &[last_value], &[*sort_options])?;
519 Ok(cmp.is_gt())
520}
521
522pub(crate) fn get_orderby_values(order_by_columns: Vec<SortColumn>) -> Vec<ArrayRef> {
524 order_by_columns.into_iter().map(|s| s.values).collect()
525}
526
527#[derive(Debug)]
528pub enum WindowFn {
529 Builtin(Box<dyn PartitionEvaluator>),
530 Aggregate(Box<dyn Accumulator>),
531}
532
533pub type PartitionKey = Vec<ScalarValue>;
538
539#[derive(Debug)]
540pub struct WindowState {
541 pub state: WindowAggState,
542 pub window_fn: WindowFn,
543}
544pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
545
546pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
548
549#[cfg(test)]
550mod tests {
551 use std::sync::Arc;
552
553 use crate::window::window_expr::is_row_ahead;
554
555 use arrow::array::{ArrayRef, Float64Array};
556 use arrow::compute::SortOptions;
557 use datafusion_common::Result;
558
559 #[test]
560 fn test_is_row_ahead() -> Result<()> {
561 let old_values: ArrayRef =
562 Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.]));
563
564 let new_values1: ArrayRef = Arc::new(Float64Array::from(vec![11.0]));
565 let new_values2: ArrayRef = Arc::new(Float64Array::from(vec![10.0]));
566
567 assert!(is_row_ahead(
568 &old_values,
569 Some(&new_values1),
570 &SortOptions {
571 descending: false,
572 nulls_first: false
573 }
574 )?);
575 assert!(!is_row_ahead(
576 &old_values,
577 Some(&new_values2),
578 &SortOptions {
579 descending: false,
580 nulls_first: false
581 }
582 )?);
583
584 Ok(())
585 }
586}