datafusion_physical_plan/topk/mod.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! TopK: Combination of Sort / LIMIT
19
20use arrow::{
21 compute::interleave,
22 row::{RowConverter, Rows, SortField},
23};
24use std::mem::size_of;
25use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
26
27use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder};
28use crate::spill::get_record_batch_memory_size;
29use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
30use arrow::array::{Array, ArrayRef, RecordBatch};
31use arrow::datatypes::SchemaRef;
32use datafusion_common::HashMap;
33use datafusion_common::Result;
34use datafusion_execution::{
35 memory_pool::{MemoryConsumer, MemoryReservation},
36 runtime_env::RuntimeEnv,
37};
38use datafusion_physical_expr::PhysicalSortExpr;
39use datafusion_physical_expr_common::sort_expr::LexOrdering;
40
41/// Global TopK
42///
43/// # Background
44///
45/// "Top K" is a common query optimization used for queries such as
46/// "find the top 3 customers by revenue". The (simplified) SQL for
47/// such a query might be:
48///
49/// ```sql
50/// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3;
51/// ```
52///
53/// The simple plan would be:
54///
55/// ```sql
56/// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3;
57/// +--------------+----------------------------------------+
58/// | plan_type | plan |
59/// +--------------+----------------------------------------+
60/// | logical_plan | Limit: 3 |
61/// | | Sort: revenue DESC NULLS FIRST |
62/// | | Projection: customer_id, revenue |
63/// | | TableScan: sales |
64/// +--------------+----------------------------------------+
65/// ```
66///
67/// While this plan produces the correct answer, it will fully sorts the
68/// input before discarding everything other than the top 3 elements.
69///
70/// The same answer can be produced by simply keeping track of the top
71/// K=3 elements, reducing the total amount of required buffer memory.
72///
73/// # Structure
74///
75/// This operator tracks the top K items using a `TopKHeap`.
76pub struct TopK {
77 /// schema of the output (and the input)
78 schema: SchemaRef,
79 /// Runtime metrics
80 metrics: TopKMetrics,
81 /// Reservation
82 reservation: MemoryReservation,
83 /// The target number of rows for output batches
84 batch_size: usize,
85 /// sort expressions
86 expr: Arc<[PhysicalSortExpr]>,
87 /// row converter, for sort keys
88 row_converter: RowConverter,
89 /// scratch space for converting rows
90 scratch_rows: Rows,
91 /// stores the top k values and their sort key values, in order
92 heap: TopKHeap,
93}
94
95impl TopK {
96 /// Create a new [`TopK`] that stores the top `k` values, as
97 /// defined by the sort expressions in `expr`.
98 // TODO: make a builder or some other nicer API
99 pub fn try_new(
100 partition_id: usize,
101 schema: SchemaRef,
102 expr: LexOrdering,
103 k: usize,
104 batch_size: usize,
105 runtime: Arc<RuntimeEnv>,
106 metrics: &ExecutionPlanMetricsSet,
107 ) -> Result<Self> {
108 let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
109 .register(&runtime.memory_pool);
110
111 let expr: Arc<[PhysicalSortExpr]> = expr.into();
112
113 let sort_fields: Vec<_> = expr
114 .iter()
115 .map(|e| {
116 Ok(SortField::new_with_options(
117 e.expr.data_type(&schema)?,
118 e.options,
119 ))
120 })
121 .collect::<Result<_>>()?;
122
123 // TODO there is potential to add special cases for single column sort fields
124 // to improve performance
125 let row_converter = RowConverter::new(sort_fields)?;
126 let scratch_rows = row_converter.empty_rows(
127 batch_size,
128 20 * batch_size, // guesstimate 20 bytes per row
129 );
130
131 Ok(Self {
132 schema: Arc::clone(&schema),
133 metrics: TopKMetrics::new(metrics, partition_id),
134 reservation,
135 batch_size,
136 expr,
137 row_converter,
138 scratch_rows,
139 heap: TopKHeap::new(k, batch_size, schema),
140 })
141 }
142
143 /// Insert `batch`, remembering if any of its values are among
144 /// the top k seen so far.
145 pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
146 // Updates on drop
147 let _timer = self.metrics.baseline.elapsed_compute().timer();
148
149 let sort_keys: Vec<ArrayRef> = self
150 .expr
151 .iter()
152 .map(|expr| {
153 let value = expr.expr.evaluate(&batch)?;
154 value.into_array(batch.num_rows())
155 })
156 .collect::<Result<Vec<_>>>()?;
157
158 // reuse existing `Rows` to avoid reallocations
159 let rows = &mut self.scratch_rows;
160 rows.clear();
161 self.row_converter.append(rows, &sort_keys)?;
162
163 // TODO make this algorithmically better?:
164 // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`)
165 // this avoids some work and also might be better vectorizable.
166 let mut batch_entry = self.heap.register_batch(batch);
167 for (index, row) in rows.iter().enumerate() {
168 match self.heap.max() {
169 // heap has k items, and the new row is greater than the
170 // current max in the heap ==> it is not a new topk
171 Some(max_row) if row.as_ref() >= max_row.row() => {}
172 // don't yet have k items or new item is lower than the currently k low values
173 None | Some(_) => {
174 self.heap.add(&mut batch_entry, row, index);
175 self.metrics.row_replacements.add(1);
176 }
177 }
178 }
179 self.heap.insert_batch_entry(batch_entry);
180
181 // conserve memory
182 self.heap.maybe_compact()?;
183
184 // update memory reservation
185 self.reservation.try_resize(self.size())?;
186 Ok(())
187 }
188
189 /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap
190 pub fn emit(self) -> Result<SendableRecordBatchStream> {
191 let Self {
192 schema,
193 metrics,
194 reservation: _,
195 batch_size,
196 expr: _,
197 row_converter: _,
198 scratch_rows: _,
199 mut heap,
200 } = self;
201 let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop
202
203 // break into record batches as needed
204 let mut batches = vec![];
205 if let Some(mut batch) = heap.emit()? {
206 metrics.baseline.output_rows().add(batch.num_rows());
207
208 loop {
209 if batch.num_rows() <= batch_size {
210 batches.push(Ok(batch));
211 break;
212 } else {
213 batches.push(Ok(batch.slice(0, batch_size)));
214 let remaining_length = batch.num_rows() - batch_size;
215 batch = batch.slice(batch_size, remaining_length);
216 }
217 }
218 };
219 Ok(Box::pin(RecordBatchStreamAdapter::new(
220 schema,
221 futures::stream::iter(batches),
222 )))
223 }
224
225 /// return the size of memory used by this operator, in bytes
226 fn size(&self) -> usize {
227 size_of::<Self>()
228 + self.row_converter.size()
229 + self.scratch_rows.size()
230 + self.heap.size()
231 }
232}
233
234struct TopKMetrics {
235 /// metrics
236 pub baseline: BaselineMetrics,
237
238 /// count of how many rows were replaced in the heap
239 pub row_replacements: Count,
240}
241
242impl TopKMetrics {
243 fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
244 Self {
245 baseline: BaselineMetrics::new(metrics, partition),
246 row_replacements: MetricBuilder::new(metrics)
247 .counter("row_replacements", partition),
248 }
249 }
250}
251
252/// This structure keeps at most the *smallest* k items, using the
253/// [arrow::row] format for sort keys. While it is called "topK" for
254/// values like `1, 2, 3, 4, 5` the "top 3" really means the
255/// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`.
256///
257/// Using the `Row` format handles things such as ascending vs
258/// descending and nulls first vs nulls last.
259struct TopKHeap {
260 /// The maximum number of elements to store in this heap.
261 k: usize,
262 /// The target number of rows for output batches
263 batch_size: usize,
264 /// Storage for up at most `k` items using a BinaryHeap. Reversed
265 /// so that the smallest k so far is on the top
266 inner: BinaryHeap<TopKRow>,
267 /// Storage the original row values (TopKRow only has the sort key)
268 store: RecordBatchStore,
269 /// The size of all owned data held by this heap
270 owned_bytes: usize,
271}
272
273impl TopKHeap {
274 fn new(k: usize, batch_size: usize, schema: SchemaRef) -> Self {
275 assert!(k > 0);
276 Self {
277 k,
278 batch_size,
279 inner: BinaryHeap::new(),
280 store: RecordBatchStore::new(schema),
281 owned_bytes: 0,
282 }
283 }
284
285 /// Register a [`RecordBatch`] with the heap, returning the
286 /// appropriate entry
287 pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry {
288 self.store.register(batch)
289 }
290
291 /// Insert a [`RecordBatchEntry`] created by a previous call to
292 /// [`Self::register_batch`] into storage.
293 pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) {
294 self.store.insert(entry)
295 }
296
297 /// Returns the largest value stored by the heap if there are k
298 /// items, otherwise returns None. Remember this structure is
299 /// keeping the "smallest" k values
300 fn max(&self) -> Option<&TopKRow> {
301 if self.inner.len() < self.k {
302 None
303 } else {
304 self.inner.peek()
305 }
306 }
307
308 /// Adds `row` to this heap. If inserting this new item would
309 /// increase the size past `k`, removes the previously smallest
310 /// item.
311 fn add(
312 &mut self,
313 batch_entry: &mut RecordBatchEntry,
314 row: impl AsRef<[u8]>,
315 index: usize,
316 ) {
317 let batch_id = batch_entry.id;
318 batch_entry.uses += 1;
319
320 assert!(self.inner.len() <= self.k);
321 let row = row.as_ref();
322
323 // Reuse storage for evicted item if possible
324 let new_top_k = if self.inner.len() == self.k {
325 let prev_min = self.inner.pop().unwrap();
326
327 // Update batch use
328 if prev_min.batch_id == batch_entry.id {
329 batch_entry.uses -= 1;
330 } else {
331 self.store.unuse(prev_min.batch_id);
332 }
333
334 // update memory accounting
335 self.owned_bytes -= prev_min.owned_size();
336 prev_min.with_new_row(row, batch_id, index)
337 } else {
338 TopKRow::new(row, batch_id, index)
339 };
340
341 self.owned_bytes += new_top_k.owned_size();
342
343 // put the new row into the heap
344 self.inner.push(new_top_k)
345 }
346
347 /// Returns the values stored in this heap, from values low to
348 /// high, as a single [`RecordBatch`], resetting the inner heap
349 pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
350 Ok(self.emit_with_state()?.0)
351 }
352
353 /// Returns the values stored in this heap, from values low to
354 /// high, as a single [`RecordBatch`], and a sorted vec of the
355 /// current heap's contents
356 pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, Vec<TopKRow>)> {
357 let schema = Arc::clone(self.store.schema());
358
359 // generate sorted rows
360 let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
361
362 if self.store.is_empty() {
363 return Ok((None, topk_rows));
364 }
365
366 // Indices for each row within its respective RecordBatch
367 let indices: Vec<_> = topk_rows
368 .iter()
369 .enumerate()
370 .map(|(i, k)| (i, k.index))
371 .collect();
372
373 let num_columns = schema.fields().len();
374
375 // build the output columns one at time, using the
376 // `interleave` kernel to pick rows from different arrays
377 let output_columns: Vec<_> = (0..num_columns)
378 .map(|col| {
379 let input_arrays: Vec<_> = topk_rows
380 .iter()
381 .map(|k| {
382 let entry =
383 self.store.get(k.batch_id).expect("invalid stored batch id");
384 entry.batch.column(col) as &dyn Array
385 })
386 .collect();
387
388 // at this point `indices` contains indexes within the
389 // rows and `input_arrays` contains a reference to the
390 // relevant Array for that index. `interleave` pulls
391 // them together into a single new array
392 Ok(interleave(&input_arrays, &indices)?)
393 })
394 .collect::<Result<_>>()?;
395
396 let new_batch = RecordBatch::try_new(schema, output_columns)?;
397 Ok((Some(new_batch), topk_rows))
398 }
399
400 /// Compact this heap, rewriting all stored batches into a single
401 /// input batch
402 pub fn maybe_compact(&mut self) -> Result<()> {
403 // we compact if the number of "unused" rows in the store is
404 // past some pre-defined threshold. Target holding up to
405 // around 20 batches, but handle cases of large k where some
406 // batches might be partially full
407 let max_unused_rows = (20 * self.batch_size) + self.k;
408 let unused_rows = self.store.unused_rows();
409
410 // don't compact if the store has one extra batch or
411 // unused rows is under the threshold
412 if self.store.len() <= 2 || unused_rows < max_unused_rows {
413 return Ok(());
414 }
415 // at first, compact the entire thing always into a new batch
416 // (maybe we can get fancier in the future about ignoring
417 // batches that have a high usage ratio already
418
419 // Note: new batch is in the same order as inner
420 let num_rows = self.inner.len();
421 let (new_batch, mut topk_rows) = self.emit_with_state()?;
422 let Some(new_batch) = new_batch else {
423 return Ok(());
424 };
425
426 // clear all old entries in store (this invalidates all
427 // store_ids in `inner`)
428 self.store.clear();
429
430 let mut batch_entry = self.register_batch(new_batch);
431 batch_entry.uses = num_rows;
432
433 // rewrite all existing entries to use the new batch, and
434 // remove old entries. The sortedness and their relative
435 // position do not change
436 for (i, topk_row) in topk_rows.iter_mut().enumerate() {
437 topk_row.batch_id = batch_entry.id;
438 topk_row.index = i;
439 }
440 self.insert_batch_entry(batch_entry);
441 // restore the heap
442 self.inner = BinaryHeap::from(topk_rows);
443
444 Ok(())
445 }
446
447 /// return the size of memory used by this heap, in bytes
448 fn size(&self) -> usize {
449 size_of::<Self>()
450 + (self.inner.capacity() * size_of::<TopKRow>())
451 + self.store.size()
452 + self.owned_bytes
453 }
454}
455
456/// Represents one of the top K rows held in this heap. Orders
457/// according to memcmp of row (e.g. the arrow Row format, but could
458/// also be primitive values)
459///
460/// Reuses allocations to minimize runtime overhead of creating new Vecs
461#[derive(Debug, PartialEq)]
462struct TopKRow {
463 /// the value of the sort key for this row. This contains the
464 /// bytes that could be stored in `OwnedRow` but uses `Vec<u8>` to
465 /// reuse allocations.
466 row: Vec<u8>,
467 /// the RecordBatch this row came from: an id into a [`RecordBatchStore`]
468 batch_id: u32,
469 /// the index in this record batch the row came from
470 index: usize,
471}
472
473impl TopKRow {
474 /// Create a new TopKRow with new allocation
475 fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self {
476 Self {
477 row: row.as_ref().to_vec(),
478 batch_id,
479 index,
480 }
481 }
482
483 /// Create a new TopKRow reusing the existing allocation
484 fn with_new_row(
485 self,
486 new_row: impl AsRef<[u8]>,
487 batch_id: u32,
488 index: usize,
489 ) -> Self {
490 let Self {
491 mut row,
492 batch_id: _,
493 index: _,
494 } = self;
495 row.clear();
496 row.extend_from_slice(new_row.as_ref());
497
498 Self {
499 row,
500 batch_id,
501 index,
502 }
503 }
504
505 /// Returns the number of bytes owned by this row in the heap (not
506 /// including itself)
507 fn owned_size(&self) -> usize {
508 self.row.capacity()
509 }
510
511 /// Returns a slice to the owned row value
512 fn row(&self) -> &[u8] {
513 self.row.as_slice()
514 }
515}
516
517impl Eq for TopKRow {}
518
519impl PartialOrd for TopKRow {
520 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
521 Some(self.cmp(other))
522 }
523}
524
525impl Ord for TopKRow {
526 fn cmp(&self, other: &Self) -> Ordering {
527 self.row.cmp(&other.row)
528 }
529}
530
531#[derive(Debug)]
532struct RecordBatchEntry {
533 id: u32,
534 batch: RecordBatch,
535 // for this batch, how many times has it been used
536 uses: usize,
537}
538
539/// This structure tracks [`RecordBatch`] by an id so that:
540///
541/// 1. The baches can be tracked via an id that can be copied cheaply
542/// 2. The total memory held by all batches is tracked
543#[derive(Debug)]
544struct RecordBatchStore {
545 /// id generator
546 next_id: u32,
547 /// storage
548 batches: HashMap<u32, RecordBatchEntry>,
549 /// total size of all record batches tracked by this store
550 batches_size: usize,
551 /// schema of the batches
552 schema: SchemaRef,
553}
554
555impl RecordBatchStore {
556 fn new(schema: SchemaRef) -> Self {
557 Self {
558 next_id: 0,
559 batches: HashMap::new(),
560 batches_size: 0,
561 schema,
562 }
563 }
564
565 /// Register this batch with the store and assign an ID. No
566 /// attempt is made to compare this batch to other batches
567 pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry {
568 let id = self.next_id;
569 self.next_id += 1;
570 RecordBatchEntry { id, batch, uses: 0 }
571 }
572
573 /// Insert a record batch entry into this store, tracking its
574 /// memory use, if it has any uses
575 pub fn insert(&mut self, entry: RecordBatchEntry) {
576 // uses of 0 means that none of the rows in the batch were stored in the topk
577 if entry.uses > 0 {
578 self.batches_size += get_record_batch_memory_size(&entry.batch);
579 self.batches.insert(entry.id, entry);
580 }
581 }
582
583 /// Clear all values in this store, invalidating all previous batch ids
584 fn clear(&mut self) {
585 self.batches.clear();
586 self.batches_size = 0;
587 }
588
589 fn get(&self, id: u32) -> Option<&RecordBatchEntry> {
590 self.batches.get(&id)
591 }
592
593 /// returns the total number of batches stored in this store
594 fn len(&self) -> usize {
595 self.batches.len()
596 }
597
598 /// Returns the total number of rows in batches minus the number
599 /// which are in use
600 fn unused_rows(&self) -> usize {
601 self.batches
602 .values()
603 .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses)
604 .sum()
605 }
606
607 /// returns true if the store has nothing stored
608 fn is_empty(&self) -> bool {
609 self.batches.is_empty()
610 }
611
612 /// return the schema of batches stored
613 fn schema(&self) -> &SchemaRef {
614 &self.schema
615 }
616
617 /// remove a use from the specified batch id. If the use count
618 /// reaches zero the batch entry is removed from the store
619 ///
620 /// panics if there were no remaining uses of id
621 pub fn unuse(&mut self, id: u32) {
622 let remove = if let Some(batch_entry) = self.batches.get_mut(&id) {
623 batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow");
624 batch_entry.uses == 0
625 } else {
626 panic!("No entry for id {id}");
627 };
628
629 if remove {
630 let old_entry = self.batches.remove(&id).unwrap();
631 self.batches_size = self
632 .batches_size
633 .checked_sub(get_record_batch_memory_size(&old_entry.batch))
634 .unwrap();
635 }
636 }
637
638 /// returns the size of memory used by this store, including all
639 /// referenced `RecordBatch`es, in bytes
640 pub fn size(&self) -> usize {
641 size_of::<Self>()
642 + self.batches.capacity() * (size_of::<u32>() + size_of::<RecordBatchEntry>())
643 + self.batches_size
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use arrow::array::{Float64Array, Int32Array, RecordBatch};
651 use arrow::datatypes::{DataType, Field, Schema};
652
653 /// This test ensures the size calculation is correct for RecordBatches with multiple columns.
654 #[test]
655 fn test_record_batch_store_size() {
656 // given
657 let schema = Arc::new(Schema::new(vec![
658 Field::new("ints", DataType::Int32, true),
659 Field::new("float64", DataType::Float64, false),
660 ]));
661 let mut record_batch_store = RecordBatchStore::new(Arc::clone(&schema));
662 let int_array =
663 Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20
664 let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40
665
666 let record_batch_entry = RecordBatchEntry {
667 id: 0,
668 batch: RecordBatch::try_new(
669 schema,
670 vec![Arc::new(int_array), Arc::new(float64_array)],
671 )
672 .unwrap(),
673 uses: 1,
674 };
675
676 // when insert record batch entry
677 record_batch_store.insert(record_batch_entry);
678 assert_eq!(record_batch_store.batches_size, 60);
679
680 // when unuse record batch entry
681 record_batch_store.unuse(0);
682 assert_eq!(record_batch_store.batches_size, 0);
683 }
684}