1use std::{collections::VecDeque, ops::Range, sync::Arc};
21
22use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
23
24use arrow::{
25 array::ArrayRef,
26 compute::{concat, concat_batches, SortOptions},
27 datatypes::{DataType, SchemaRef},
28 record_batch::RecordBatch,
29};
30use datafusion_common::{
31 internal_err,
32 utils::{compare_rows, get_row_at_idx, search_in_slice},
33 DataFusionError, Result, ScalarValue,
34};
35
36#[derive(Debug)]
38pub struct WindowAggState {
39 pub window_frame_range: Range<usize>,
41 pub window_frame_ctx: Option<WindowFrameContext>,
42 pub last_calculated_index: usize,
44 pub offset_pruned_rows: usize,
46 pub out_col: ArrayRef,
48 pub n_row_result_missing: usize,
51 pub is_end: bool,
53}
54
55impl WindowAggState {
56 pub fn prune_state(&mut self, n_prune: usize) {
57 self.window_frame_range = Range {
58 start: self.window_frame_range.start - n_prune,
59 end: self.window_frame_range.end - n_prune,
60 };
61 self.last_calculated_index -= n_prune;
62 self.offset_pruned_rows += n_prune;
63
64 match self.window_frame_ctx.as_mut() {
65 Some(WindowFrameContext::Rows(_)) => {}
67 Some(WindowFrameContext::Range { .. }) => {}
68 Some(WindowFrameContext::Groups { state, .. }) => {
69 let mut n_group_to_del = 0;
70 for (_, end_idx) in &state.group_end_indices {
71 if n_prune < *end_idx {
72 break;
73 }
74 n_group_to_del += 1;
75 }
76 state.group_end_indices.drain(0..n_group_to_del);
77 state
78 .group_end_indices
79 .iter_mut()
80 .for_each(|(_, start_idx)| *start_idx -= n_prune);
81 state.current_group_idx -= n_group_to_del;
82 }
83 None => {}
84 };
85 }
86
87 pub fn update(
88 &mut self,
89 out_col: &ArrayRef,
90 partition_batch_state: &PartitionBatchState,
91 ) -> Result<()> {
92 self.last_calculated_index += out_col.len();
93 self.out_col = concat(&[&self.out_col, &out_col])?;
94 self.n_row_result_missing =
95 partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
96 self.is_end = partition_batch_state.is_end;
97 Ok(())
98 }
99
100 pub fn new(out_type: &DataType) -> Result<Self> {
101 let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0)?;
102 Ok(Self {
103 window_frame_range: Range { start: 0, end: 0 },
104 window_frame_ctx: None,
105 last_calculated_index: 0,
106 offset_pruned_rows: 0,
107 out_col: empty_out_col,
108 n_row_result_missing: 0,
109 is_end: false,
110 })
111 }
112}
113
114#[derive(Debug)]
116pub enum WindowFrameContext {
117 Rows(Arc<WindowFrame>),
119 Range {
123 window_frame: Arc<WindowFrame>,
124 state: WindowFrameStateRange,
125 },
126 Groups {
130 window_frame: Arc<WindowFrame>,
131 state: WindowFrameStateGroups,
132 },
133}
134
135impl WindowFrameContext {
136 pub fn new(window_frame: Arc<WindowFrame>, sort_options: Vec<SortOptions>) -> Self {
138 match window_frame.units {
139 WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame),
140 WindowFrameUnits::Range => WindowFrameContext::Range {
141 window_frame,
142 state: WindowFrameStateRange::new(sort_options),
143 },
144 WindowFrameUnits::Groups => WindowFrameContext::Groups {
145 window_frame,
146 state: WindowFrameStateGroups::default(),
147 },
148 }
149 }
150
151 pub fn calculate_range(
153 &mut self,
154 range_columns: &[ArrayRef],
155 last_range: &Range<usize>,
156 length: usize,
157 idx: usize,
158 ) -> Result<Range<usize>> {
159 match self {
160 WindowFrameContext::Rows(window_frame) => {
161 Self::calculate_range_rows(window_frame, length, idx)
162 }
163 WindowFrameContext::Range {
167 window_frame,
168 ref mut state,
169 } => state.calculate_range(
170 window_frame,
171 last_range,
172 range_columns,
173 length,
174 idx,
175 ),
176 WindowFrameContext::Groups {
180 window_frame,
181 ref mut state,
182 } => state.calculate_range(window_frame, range_columns, length, idx),
183 }
184 }
185
186 fn calculate_range_rows(
188 window_frame: &Arc<WindowFrame>,
189 length: usize,
190 idx: usize,
191 ) -> Result<Range<usize>> {
192 let start = match window_frame.start_bound {
193 WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0,
195 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
196 if idx >= n as usize {
197 idx - n as usize
198 } else {
199 0
200 }
201 }
202 WindowFrameBound::CurrentRow => idx,
203 WindowFrameBound::Following(ScalarValue::UInt64(None)) => {
205 return internal_err!(
206 "Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'"
207 )
208 }
209 WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
210 std::cmp::min(idx + n as usize, length)
211 }
212 WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
214 return internal_err!("Rows should be Uint")
215 }
216 };
217 let end = match window_frame.end_bound {
218 WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => {
220 return internal_err!(
221 "Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'"
222 )
223 }
224 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
225 if idx >= n as usize {
226 idx - n as usize + 1
227 } else {
228 0
229 }
230 }
231 WindowFrameBound::CurrentRow => idx + 1,
232 WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
234 WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
235 std::cmp::min(idx + n as usize + 1, length)
236 }
237 WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
239 return internal_err!("Rows should be Uint")
240 }
241 };
242 Ok(Range { start, end })
243 }
244}
245
246#[derive(Debug)]
248pub struct PartitionBatchState {
249 pub record_batch: RecordBatch,
251 pub most_recent_row: Option<RecordBatch>,
256 pub is_end: bool,
258 pub n_out_row: usize,
260}
261
262impl PartitionBatchState {
263 pub fn new(schema: SchemaRef) -> Self {
264 Self {
265 record_batch: RecordBatch::new_empty(schema),
266 most_recent_row: None,
267 is_end: false,
268 n_out_row: 0,
269 }
270 }
271
272 pub fn extend(&mut self, batch: &RecordBatch) -> Result<()> {
273 self.record_batch =
274 concat_batches(&self.record_batch.schema(), [&self.record_batch, batch])?;
275 Ok(())
276 }
277
278 pub fn set_most_recent_row(&mut self, batch: RecordBatch) {
279 self.most_recent_row = Some(batch);
282 }
283}
284
285#[derive(Debug, Default)]
290pub struct WindowFrameStateRange {
291 sort_options: Vec<SortOptions>,
292}
293
294impl WindowFrameStateRange {
295 fn new(sort_options: Vec<SortOptions>) -> Self {
297 Self { sort_options }
298 }
299
300 fn calculate_range(
306 &mut self,
307 window_frame: &Arc<WindowFrame>,
308 last_range: &Range<usize>,
309 range_columns: &[ArrayRef],
310 length: usize,
311 idx: usize,
312 ) -> Result<Range<usize>> {
313 let start = match window_frame.start_bound {
314 WindowFrameBound::Preceding(ref n) => {
315 if n.is_null() {
316 0
318 } else {
319 self.calculate_index_of_row::<true, true>(
320 range_columns,
321 last_range,
322 idx,
323 Some(n),
324 length,
325 )?
326 }
327 }
328 WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
329 range_columns,
330 last_range,
331 idx,
332 None,
333 length,
334 )?,
335 WindowFrameBound::Following(ref n) => self
336 .calculate_index_of_row::<true, false>(
337 range_columns,
338 last_range,
339 idx,
340 Some(n),
341 length,
342 )?,
343 };
344 let end = match window_frame.end_bound {
345 WindowFrameBound::Preceding(ref n) => self
346 .calculate_index_of_row::<false, true>(
347 range_columns,
348 last_range,
349 idx,
350 Some(n),
351 length,
352 )?,
353 WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
354 range_columns,
355 last_range,
356 idx,
357 None,
358 length,
359 )?,
360 WindowFrameBound::Following(ref n) => {
361 if n.is_null() {
362 length
364 } else {
365 self.calculate_index_of_row::<false, false>(
366 range_columns,
367 last_range,
368 idx,
369 Some(n),
370 length,
371 )?
372 }
373 }
374 };
375 Ok(Range { start, end })
376 }
377
378 fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
382 &mut self,
383 range_columns: &[ArrayRef],
384 last_range: &Range<usize>,
385 idx: usize,
386 delta: Option<&ScalarValue>,
387 length: usize,
388 ) -> Result<usize> {
389 let current_row_values = get_row_at_idx(range_columns, idx)?;
390 let end_range = if let Some(delta) = delta {
391 let is_descending: bool = self
392 .sort_options
393 .first()
394 .ok_or_else(|| {
395 DataFusionError::Internal(
396 "Sort options unexpectedly absent in a window frame".to_string(),
397 )
398 })?
399 .descending;
400
401 current_row_values
402 .iter()
403 .map(|value| {
404 if value.is_null() {
405 return Ok(value.clone());
406 }
407 if SEARCH_SIDE == is_descending {
408 value.add(delta)
410 } else if value.is_unsigned() && value < delta {
411 value.sub(value)
415 } else {
416 value.sub(delta)
418 }
419 })
420 .collect::<Result<Vec<ScalarValue>>>()?
421 } else {
422 current_row_values
423 };
424 let search_start = if SIDE {
425 last_range.start
426 } else {
427 last_range.end
428 };
429 let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
430 let cmp = compare_rows(current, target, &self.sort_options)?;
431 Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
432 };
433 search_in_slice(range_columns, &end_range, compare_fn, search_start, length)
434 }
435}
436
437#[derive(Debug, Default)]
462pub struct WindowFrameStateGroups {
463 pub group_end_indices: VecDeque<(Vec<ScalarValue>, usize)>,
467 pub current_group_idx: usize,
469}
470
471impl WindowFrameStateGroups {
472 fn calculate_range(
473 &mut self,
474 window_frame: &Arc<WindowFrame>,
475 range_columns: &[ArrayRef],
476 length: usize,
477 idx: usize,
478 ) -> Result<Range<usize>> {
479 let start = match window_frame.start_bound {
480 WindowFrameBound::Preceding(ref n) => {
481 if n.is_null() {
482 0
484 } else {
485 self.calculate_index_of_row::<true, true>(
486 range_columns,
487 idx,
488 Some(n),
489 length,
490 )?
491 }
492 }
493 WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
494 range_columns,
495 idx,
496 None,
497 length,
498 )?,
499 WindowFrameBound::Following(ref n) => self
500 .calculate_index_of_row::<true, false>(
501 range_columns,
502 idx,
503 Some(n),
504 length,
505 )?,
506 };
507 let end = match window_frame.end_bound {
508 WindowFrameBound::Preceding(ref n) => self
509 .calculate_index_of_row::<false, true>(
510 range_columns,
511 idx,
512 Some(n),
513 length,
514 )?,
515 WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
516 range_columns,
517 idx,
518 None,
519 length,
520 )?,
521 WindowFrameBound::Following(ref n) => {
522 if n.is_null() {
523 length
525 } else {
526 self.calculate_index_of_row::<false, false>(
527 range_columns,
528 idx,
529 Some(n),
530 length,
531 )?
532 }
533 }
534 };
535 Ok(Range { start, end })
536 }
537
538 fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
543 &mut self,
544 range_columns: &[ArrayRef],
545 idx: usize,
546 delta: Option<&ScalarValue>,
547 length: usize,
548 ) -> Result<usize> {
549 let delta = if let Some(delta) = delta {
550 if let ScalarValue::UInt64(Some(value)) = delta {
551 *value as usize
552 } else {
553 return internal_err!(
554 "Unexpectedly got a non-UInt64 value in a GROUPS mode window frame"
555 );
556 }
557 } else {
558 0
559 };
560 let mut group_start = 0;
561 let last_group = self.group_end_indices.back_mut();
562 if let Some((group_row, group_end)) = last_group {
563 if *group_end < length {
564 let new_group_row = get_row_at_idx(range_columns, *group_end)?;
565 if new_group_row.eq(group_row) {
567 *group_end = search_in_slice(
569 range_columns,
570 group_row,
571 check_equality,
572 *group_end,
573 length,
574 )?;
575 }
576 }
577 group_start = *group_end;
579 }
580
581 while idx >= group_start {
583 let group_row = get_row_at_idx(range_columns, group_start)?;
584 let group_end = search_in_slice(
586 range_columns,
587 &group_row,
588 check_equality,
589 group_start,
590 length,
591 )?;
592 self.group_end_indices.push_back((group_row, group_end));
593 group_start = group_end;
594 }
595
596 while self.current_group_idx < self.group_end_indices.len()
598 && idx >= self.group_end_indices[self.current_group_idx].1
599 {
600 self.current_group_idx += 1;
601 }
602
603 let group_idx = if SEARCH_SIDE {
605 if self.current_group_idx > delta {
606 self.current_group_idx - delta
607 } else {
608 0
609 }
610 } else {
611 self.current_group_idx + delta
612 };
613
614 while self.group_end_indices.len() <= group_idx && group_start < length {
616 let group_row = get_row_at_idx(range_columns, group_start)?;
617 let group_end = search_in_slice(
619 range_columns,
620 &group_row,
621 check_equality,
622 group_start,
623 length,
624 )?;
625 self.group_end_indices.push_back((group_row, group_end));
626 group_start = group_end;
627 }
628
629 Ok(match (SIDE, SEARCH_SIDE) {
631 (true, _) => {
633 let group_idx = std::cmp::min(group_idx, self.group_end_indices.len());
634 if group_idx > 0 {
635 self.group_end_indices[group_idx - 1].1
637 } else {
638 0
640 }
641 }
642 (false, true) => {
644 if self.current_group_idx >= delta {
645 let group_idx = self.current_group_idx - delta;
646 self.group_end_indices[group_idx].1
647 } else {
648 0
650 }
651 }
652 (false, false) => {
654 let group_idx = std::cmp::min(
655 self.current_group_idx + delta,
656 self.group_end_indices.len() - 1,
657 );
658 self.group_end_indices[group_idx].1
659 }
660 })
661 }
662}
663
664fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<bool> {
665 Ok(current == target)
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671
672 use arrow::array::Float64Array;
673
674 fn get_test_data() -> (Vec<ArrayRef>, Vec<SortOptions>) {
675 let range_columns: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
676 5.0, 7.0, 8.0, 8.0, 9., 10., 10., 10., 11.,
677 ]))];
678 let sort_options = vec![SortOptions {
679 descending: false,
680 nulls_first: false,
681 }];
682
683 (range_columns, sort_options)
684 }
685
686 fn assert_expected(
687 expected_results: Vec<(Range<usize>, usize)>,
688 window_frame: &Arc<WindowFrame>,
689 ) -> Result<()> {
690 let mut window_frame_groups = WindowFrameStateGroups::default();
691 let (range_columns, _) = get_test_data();
692 let n_row = range_columns[0].len();
693 for (idx, (expected_range, expected_group_idx)) in
694 expected_results.into_iter().enumerate()
695 {
696 let range = window_frame_groups.calculate_range(
697 window_frame,
698 &range_columns,
699 n_row,
700 idx,
701 )?;
702 assert_eq!(range, expected_range);
703 assert_eq!(window_frame_groups.current_group_idx, expected_group_idx);
704 }
705 Ok(())
706 }
707
708 #[test]
709 fn test_window_frame_group_boundaries() -> Result<()> {
710 let window_frame = Arc::new(WindowFrame::new_bounds(
711 WindowFrameUnits::Groups,
712 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
713 WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
714 ));
715 let expected_results = vec![
716 (Range { start: 0, end: 2 }, 0),
717 (Range { start: 0, end: 4 }, 1),
718 (Range { start: 1, end: 5 }, 2),
719 (Range { start: 1, end: 5 }, 2),
720 (Range { start: 2, end: 8 }, 3),
721 (Range { start: 4, end: 9 }, 4),
722 (Range { start: 4, end: 9 }, 4),
723 (Range { start: 4, end: 9 }, 4),
724 (Range { start: 5, end: 9 }, 5),
725 ];
726 assert_expected(expected_results, &window_frame)
727 }
728
729 #[test]
730 fn test_window_frame_group_boundaries_both_following() -> Result<()> {
731 let window_frame = Arc::new(WindowFrame::new_bounds(
732 WindowFrameUnits::Groups,
733 WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
734 WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
735 ));
736 let expected_results = vec![
737 (Range::<usize> { start: 1, end: 4 }, 0),
738 (Range::<usize> { start: 2, end: 5 }, 1),
739 (Range::<usize> { start: 4, end: 8 }, 2),
740 (Range::<usize> { start: 4, end: 8 }, 2),
741 (Range::<usize> { start: 5, end: 9 }, 3),
742 (Range::<usize> { start: 8, end: 9 }, 4),
743 (Range::<usize> { start: 8, end: 9 }, 4),
744 (Range::<usize> { start: 8, end: 9 }, 4),
745 (Range::<usize> { start: 9, end: 9 }, 5),
746 ];
747 assert_expected(expected_results, &window_frame)
748 }
749
750 #[test]
751 fn test_window_frame_group_boundaries_both_preceding() -> Result<()> {
752 let window_frame = Arc::new(WindowFrame::new_bounds(
753 WindowFrameUnits::Groups,
754 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
755 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
756 ));
757 let expected_results = vec![
758 (Range::<usize> { start: 0, end: 0 }, 0),
759 (Range::<usize> { start: 0, end: 1 }, 1),
760 (Range::<usize> { start: 0, end: 2 }, 2),
761 (Range::<usize> { start: 0, end: 2 }, 2),
762 (Range::<usize> { start: 1, end: 4 }, 3),
763 (Range::<usize> { start: 2, end: 5 }, 4),
764 (Range::<usize> { start: 2, end: 5 }, 4),
765 (Range::<usize> { start: 2, end: 5 }, 4),
766 (Range::<usize> { start: 4, end: 8 }, 5),
767 ];
768 assert_expected(expected_results, &window_frame)
769 }
770}