datafusion_physical_plan/topk/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! TopK: Combination of Sort / LIMIT
19
20use arrow::{
21    compute::interleave,
22    row::{RowConverter, Rows, SortField},
23};
24use std::mem::size_of;
25use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
26
27use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder};
28use crate::spill::get_record_batch_memory_size;
29use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
30use arrow::array::{Array, ArrayRef, RecordBatch};
31use arrow::datatypes::SchemaRef;
32use datafusion_common::HashMap;
33use datafusion_common::Result;
34use datafusion_execution::{
35    memory_pool::{MemoryConsumer, MemoryReservation},
36    runtime_env::RuntimeEnv,
37};
38use datafusion_physical_expr::PhysicalSortExpr;
39use datafusion_physical_expr_common::sort_expr::LexOrdering;
40
41/// Global TopK
42///
43/// # Background
44///
45/// "Top K" is a common query optimization used for queries such as
46/// "find the top 3 customers by revenue". The (simplified) SQL for
47/// such a query might be:
48///
49/// ```sql
50/// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3;
51/// ```
52///
53/// The simple plan would be:
54///
55/// ```sql
56/// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3;
57/// +--------------+----------------------------------------+
58/// | plan_type    | plan                                   |
59/// +--------------+----------------------------------------+
60/// | logical_plan | Limit: 3                               |
61/// |              |   Sort: revenue DESC NULLS FIRST       |
62/// |              |     Projection: customer_id, revenue   |
63/// |              |       TableScan: sales                 |
64/// +--------------+----------------------------------------+
65/// ```
66///
67/// While this plan produces the correct answer, it will fully sorts the
68/// input before discarding everything other than the top 3 elements.
69///
70/// The same answer can be produced by simply keeping track of the top
71/// K=3 elements, reducing the total amount of required buffer memory.
72///
73/// # Structure
74///
75/// This operator tracks the top K items using a `TopKHeap`.
76pub struct TopK {
77    /// schema of the output (and the input)
78    schema: SchemaRef,
79    /// Runtime metrics
80    metrics: TopKMetrics,
81    /// Reservation
82    reservation: MemoryReservation,
83    /// The target number of rows for output batches
84    batch_size: usize,
85    /// sort expressions
86    expr: Arc<[PhysicalSortExpr]>,
87    /// row converter, for sort keys
88    row_converter: RowConverter,
89    /// scratch space for converting rows
90    scratch_rows: Rows,
91    /// stores the top k values and their sort key values, in order
92    heap: TopKHeap,
93}
94
95impl TopK {
96    /// Create a new [`TopK`] that stores the top `k` values, as
97    /// defined by the sort expressions in `expr`.
98    // TODO: make a builder or some other nicer API
99    pub fn try_new(
100        partition_id: usize,
101        schema: SchemaRef,
102        expr: LexOrdering,
103        k: usize,
104        batch_size: usize,
105        runtime: Arc<RuntimeEnv>,
106        metrics: &ExecutionPlanMetricsSet,
107    ) -> Result<Self> {
108        let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
109            .register(&runtime.memory_pool);
110
111        let expr: Arc<[PhysicalSortExpr]> = expr.into();
112
113        let sort_fields: Vec<_> = expr
114            .iter()
115            .map(|e| {
116                Ok(SortField::new_with_options(
117                    e.expr.data_type(&schema)?,
118                    e.options,
119                ))
120            })
121            .collect::<Result<_>>()?;
122
123        // TODO there is potential to add special cases for single column sort fields
124        // to improve performance
125        let row_converter = RowConverter::new(sort_fields)?;
126        let scratch_rows = row_converter.empty_rows(
127            batch_size,
128            20 * batch_size, // guesstimate 20 bytes per row
129        );
130
131        Ok(Self {
132            schema: Arc::clone(&schema),
133            metrics: TopKMetrics::new(metrics, partition_id),
134            reservation,
135            batch_size,
136            expr,
137            row_converter,
138            scratch_rows,
139            heap: TopKHeap::new(k, batch_size, schema),
140        })
141    }
142
143    /// Insert `batch`, remembering if any of its values are among
144    /// the top k seen so far.
145    pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
146        // Updates on drop
147        let _timer = self.metrics.baseline.elapsed_compute().timer();
148
149        let sort_keys: Vec<ArrayRef> = self
150            .expr
151            .iter()
152            .map(|expr| {
153                let value = expr.expr.evaluate(&batch)?;
154                value.into_array(batch.num_rows())
155            })
156            .collect::<Result<Vec<_>>>()?;
157
158        // reuse existing `Rows` to avoid reallocations
159        let rows = &mut self.scratch_rows;
160        rows.clear();
161        self.row_converter.append(rows, &sort_keys)?;
162
163        // TODO make this algorithmically better?:
164        // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`)
165        //       this avoids some work and also might be better vectorizable.
166        let mut batch_entry = self.heap.register_batch(batch);
167        for (index, row) in rows.iter().enumerate() {
168            match self.heap.max() {
169                // heap has k items, and the new row is greater than the
170                // current max in the heap ==> it is not a new topk
171                Some(max_row) if row.as_ref() >= max_row.row() => {}
172                // don't yet have k items or new item is lower than the currently k low values
173                None | Some(_) => {
174                    self.heap.add(&mut batch_entry, row, index);
175                    self.metrics.row_replacements.add(1);
176                }
177            }
178        }
179        self.heap.insert_batch_entry(batch_entry);
180
181        // conserve memory
182        self.heap.maybe_compact()?;
183
184        // update memory reservation
185        self.reservation.try_resize(self.size())?;
186        Ok(())
187    }
188
189    /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap
190    pub fn emit(self) -> Result<SendableRecordBatchStream> {
191        let Self {
192            schema,
193            metrics,
194            reservation: _,
195            batch_size,
196            expr: _,
197            row_converter: _,
198            scratch_rows: _,
199            mut heap,
200        } = self;
201        let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop
202
203        // break into record batches as needed
204        let mut batches = vec![];
205        if let Some(mut batch) = heap.emit()? {
206            metrics.baseline.output_rows().add(batch.num_rows());
207
208            loop {
209                if batch.num_rows() <= batch_size {
210                    batches.push(Ok(batch));
211                    break;
212                } else {
213                    batches.push(Ok(batch.slice(0, batch_size)));
214                    let remaining_length = batch.num_rows() - batch_size;
215                    batch = batch.slice(batch_size, remaining_length);
216                }
217            }
218        };
219        Ok(Box::pin(RecordBatchStreamAdapter::new(
220            schema,
221            futures::stream::iter(batches),
222        )))
223    }
224
225    /// return the size of memory used by this operator, in bytes
226    fn size(&self) -> usize {
227        size_of::<Self>()
228            + self.row_converter.size()
229            + self.scratch_rows.size()
230            + self.heap.size()
231    }
232}
233
234struct TopKMetrics {
235    /// metrics
236    pub baseline: BaselineMetrics,
237
238    /// count of how many rows were replaced in the heap
239    pub row_replacements: Count,
240}
241
242impl TopKMetrics {
243    fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
244        Self {
245            baseline: BaselineMetrics::new(metrics, partition),
246            row_replacements: MetricBuilder::new(metrics)
247                .counter("row_replacements", partition),
248        }
249    }
250}
251
252/// This structure keeps at most the *smallest* k items, using the
253/// [arrow::row] format for sort keys. While it is called "topK" for
254/// values like `1, 2, 3, 4, 5` the "top 3" really means the
255/// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`.
256///
257/// Using the `Row` format handles things such as ascending vs
258/// descending and nulls first vs nulls last.
259struct TopKHeap {
260    /// The maximum number of elements to store in this heap.
261    k: usize,
262    /// The target number of rows for output batches
263    batch_size: usize,
264    /// Storage for up at most `k` items using a BinaryHeap. Reversed
265    /// so that the smallest k so far is on the top
266    inner: BinaryHeap<TopKRow>,
267    /// Storage the original row values (TopKRow only has the sort key)
268    store: RecordBatchStore,
269    /// The size of all owned data held by this heap
270    owned_bytes: usize,
271}
272
273impl TopKHeap {
274    fn new(k: usize, batch_size: usize, schema: SchemaRef) -> Self {
275        assert!(k > 0);
276        Self {
277            k,
278            batch_size,
279            inner: BinaryHeap::new(),
280            store: RecordBatchStore::new(schema),
281            owned_bytes: 0,
282        }
283    }
284
285    /// Register a [`RecordBatch`] with the heap, returning the
286    /// appropriate entry
287    pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry {
288        self.store.register(batch)
289    }
290
291    /// Insert a [`RecordBatchEntry`] created by a previous call to
292    /// [`Self::register_batch`] into storage.
293    pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) {
294        self.store.insert(entry)
295    }
296
297    /// Returns the largest value stored by the heap if there are k
298    /// items, otherwise returns None. Remember this structure is
299    /// keeping the "smallest" k values
300    fn max(&self) -> Option<&TopKRow> {
301        if self.inner.len() < self.k {
302            None
303        } else {
304            self.inner.peek()
305        }
306    }
307
308    /// Adds `row` to this heap. If inserting this new item would
309    /// increase the size past `k`, removes the previously smallest
310    /// item.
311    fn add(
312        &mut self,
313        batch_entry: &mut RecordBatchEntry,
314        row: impl AsRef<[u8]>,
315        index: usize,
316    ) {
317        let batch_id = batch_entry.id;
318        batch_entry.uses += 1;
319
320        assert!(self.inner.len() <= self.k);
321        let row = row.as_ref();
322
323        // Reuse storage for evicted item if possible
324        let new_top_k = if self.inner.len() == self.k {
325            let prev_min = self.inner.pop().unwrap();
326
327            // Update batch use
328            if prev_min.batch_id == batch_entry.id {
329                batch_entry.uses -= 1;
330            } else {
331                self.store.unuse(prev_min.batch_id);
332            }
333
334            // update memory accounting
335            self.owned_bytes -= prev_min.owned_size();
336            prev_min.with_new_row(row, batch_id, index)
337        } else {
338            TopKRow::new(row, batch_id, index)
339        };
340
341        self.owned_bytes += new_top_k.owned_size();
342
343        // put the new row into the heap
344        self.inner.push(new_top_k)
345    }
346
347    /// Returns the values stored in this heap, from values low to
348    /// high, as a single [`RecordBatch`], resetting the inner heap
349    pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
350        Ok(self.emit_with_state()?.0)
351    }
352
353    /// Returns the values stored in this heap, from values low to
354    /// high, as a single [`RecordBatch`], and a sorted vec of the
355    /// current heap's contents
356    pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, Vec<TopKRow>)> {
357        let schema = Arc::clone(self.store.schema());
358
359        // generate sorted rows
360        let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
361
362        if self.store.is_empty() {
363            return Ok((None, topk_rows));
364        }
365
366        // Indices for each row within its respective RecordBatch
367        let indices: Vec<_> = topk_rows
368            .iter()
369            .enumerate()
370            .map(|(i, k)| (i, k.index))
371            .collect();
372
373        let num_columns = schema.fields().len();
374
375        // build the output columns one at time, using the
376        // `interleave` kernel to pick rows from different arrays
377        let output_columns: Vec<_> = (0..num_columns)
378            .map(|col| {
379                let input_arrays: Vec<_> = topk_rows
380                    .iter()
381                    .map(|k| {
382                        let entry =
383                            self.store.get(k.batch_id).expect("invalid stored batch id");
384                        entry.batch.column(col) as &dyn Array
385                    })
386                    .collect();
387
388                // at this point `indices` contains indexes within the
389                // rows and `input_arrays` contains a reference to the
390                // relevant Array for that index. `interleave` pulls
391                // them together into a single new array
392                Ok(interleave(&input_arrays, &indices)?)
393            })
394            .collect::<Result<_>>()?;
395
396        let new_batch = RecordBatch::try_new(schema, output_columns)?;
397        Ok((Some(new_batch), topk_rows))
398    }
399
400    /// Compact this heap, rewriting all stored batches into a single
401    /// input batch
402    pub fn maybe_compact(&mut self) -> Result<()> {
403        // we compact if the number of "unused" rows in the store is
404        // past some pre-defined threshold. Target holding up to
405        // around 20 batches, but handle cases of large k where some
406        // batches might be partially full
407        let max_unused_rows = (20 * self.batch_size) + self.k;
408        let unused_rows = self.store.unused_rows();
409
410        // don't compact if the store has one extra batch or
411        // unused rows is under the threshold
412        if self.store.len() <= 2 || unused_rows < max_unused_rows {
413            return Ok(());
414        }
415        // at first, compact the entire thing always into a new batch
416        // (maybe we can get fancier in the future about ignoring
417        // batches that have a high usage ratio already
418
419        // Note: new batch is in the same order as inner
420        let num_rows = self.inner.len();
421        let (new_batch, mut topk_rows) = self.emit_with_state()?;
422        let Some(new_batch) = new_batch else {
423            return Ok(());
424        };
425
426        // clear all old entries in store (this invalidates all
427        // store_ids in `inner`)
428        self.store.clear();
429
430        let mut batch_entry = self.register_batch(new_batch);
431        batch_entry.uses = num_rows;
432
433        // rewrite all existing entries to use the new batch, and
434        // remove old entries. The sortedness and their relative
435        // position do not change
436        for (i, topk_row) in topk_rows.iter_mut().enumerate() {
437            topk_row.batch_id = batch_entry.id;
438            topk_row.index = i;
439        }
440        self.insert_batch_entry(batch_entry);
441        // restore the heap
442        self.inner = BinaryHeap::from(topk_rows);
443
444        Ok(())
445    }
446
447    /// return the size of memory used by this heap, in bytes
448    fn size(&self) -> usize {
449        size_of::<Self>()
450            + (self.inner.capacity() * size_of::<TopKRow>())
451            + self.store.size()
452            + self.owned_bytes
453    }
454}
455
456/// Represents one of the top K rows held in this heap. Orders
457/// according to memcmp of row (e.g. the arrow Row format, but could
458/// also be primitive values)
459///
460/// Reuses allocations to minimize runtime overhead of creating new Vecs
461#[derive(Debug, PartialEq)]
462struct TopKRow {
463    /// the value of the sort key for this row. This contains the
464    /// bytes that could be stored in `OwnedRow` but uses `Vec<u8>` to
465    /// reuse allocations.
466    row: Vec<u8>,
467    /// the RecordBatch this row came from: an id into a [`RecordBatchStore`]
468    batch_id: u32,
469    /// the index in this record batch the row came from
470    index: usize,
471}
472
473impl TopKRow {
474    /// Create a new TopKRow with new allocation
475    fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self {
476        Self {
477            row: row.as_ref().to_vec(),
478            batch_id,
479            index,
480        }
481    }
482
483    /// Create a new  TopKRow reusing the existing allocation
484    fn with_new_row(
485        self,
486        new_row: impl AsRef<[u8]>,
487        batch_id: u32,
488        index: usize,
489    ) -> Self {
490        let Self {
491            mut row,
492            batch_id: _,
493            index: _,
494        } = self;
495        row.clear();
496        row.extend_from_slice(new_row.as_ref());
497
498        Self {
499            row,
500            batch_id,
501            index,
502        }
503    }
504
505    /// Returns the number of bytes owned by this row in the heap (not
506    /// including itself)
507    fn owned_size(&self) -> usize {
508        self.row.capacity()
509    }
510
511    /// Returns a slice to the owned row value
512    fn row(&self) -> &[u8] {
513        self.row.as_slice()
514    }
515}
516
517impl Eq for TopKRow {}
518
519impl PartialOrd for TopKRow {
520    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
521        Some(self.cmp(other))
522    }
523}
524
525impl Ord for TopKRow {
526    fn cmp(&self, other: &Self) -> Ordering {
527        self.row.cmp(&other.row)
528    }
529}
530
531#[derive(Debug)]
532struct RecordBatchEntry {
533    id: u32,
534    batch: RecordBatch,
535    // for this batch, how many times has it been used
536    uses: usize,
537}
538
539/// This structure tracks [`RecordBatch`] by an id so that:
540///
541/// 1. The baches can be tracked via an id that can be copied cheaply
542/// 2. The total memory held by all batches is tracked
543#[derive(Debug)]
544struct RecordBatchStore {
545    /// id generator
546    next_id: u32,
547    /// storage
548    batches: HashMap<u32, RecordBatchEntry>,
549    /// total size of all record batches tracked by this store
550    batches_size: usize,
551    /// schema of the batches
552    schema: SchemaRef,
553}
554
555impl RecordBatchStore {
556    fn new(schema: SchemaRef) -> Self {
557        Self {
558            next_id: 0,
559            batches: HashMap::new(),
560            batches_size: 0,
561            schema,
562        }
563    }
564
565    /// Register this batch with the store and assign an ID. No
566    /// attempt is made to compare this batch to other batches
567    pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry {
568        let id = self.next_id;
569        self.next_id += 1;
570        RecordBatchEntry { id, batch, uses: 0 }
571    }
572
573    /// Insert a record batch entry into this store, tracking its
574    /// memory use, if it has any uses
575    pub fn insert(&mut self, entry: RecordBatchEntry) {
576        // uses of 0 means that none of the rows in the batch were stored in the topk
577        if entry.uses > 0 {
578            self.batches_size += get_record_batch_memory_size(&entry.batch);
579            self.batches.insert(entry.id, entry);
580        }
581    }
582
583    /// Clear all values in this store, invalidating all previous batch ids
584    fn clear(&mut self) {
585        self.batches.clear();
586        self.batches_size = 0;
587    }
588
589    fn get(&self, id: u32) -> Option<&RecordBatchEntry> {
590        self.batches.get(&id)
591    }
592
593    /// returns the total number of batches stored in this store
594    fn len(&self) -> usize {
595        self.batches.len()
596    }
597
598    /// Returns the total number of rows in batches minus the number
599    /// which are in use
600    fn unused_rows(&self) -> usize {
601        self.batches
602            .values()
603            .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses)
604            .sum()
605    }
606
607    /// returns true if the store has nothing stored
608    fn is_empty(&self) -> bool {
609        self.batches.is_empty()
610    }
611
612    /// return the schema of batches stored
613    fn schema(&self) -> &SchemaRef {
614        &self.schema
615    }
616
617    /// remove a use from the specified batch id. If the use count
618    /// reaches zero the batch entry is removed from the store
619    ///
620    /// panics if there were no remaining uses of id
621    pub fn unuse(&mut self, id: u32) {
622        let remove = if let Some(batch_entry) = self.batches.get_mut(&id) {
623            batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow");
624            batch_entry.uses == 0
625        } else {
626            panic!("No entry for id {id}");
627        };
628
629        if remove {
630            let old_entry = self.batches.remove(&id).unwrap();
631            self.batches_size = self
632                .batches_size
633                .checked_sub(get_record_batch_memory_size(&old_entry.batch))
634                .unwrap();
635        }
636    }
637
638    /// returns the size of memory used by this store, including all
639    /// referenced `RecordBatch`es, in bytes
640    pub fn size(&self) -> usize {
641        size_of::<Self>()
642            + self.batches.capacity() * (size_of::<u32>() + size_of::<RecordBatchEntry>())
643            + self.batches_size
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use arrow::array::{Float64Array, Int32Array, RecordBatch};
651    use arrow::datatypes::{DataType, Field, Schema};
652
653    /// This test ensures the size calculation is correct for RecordBatches with multiple columns.
654    #[test]
655    fn test_record_batch_store_size() {
656        // given
657        let schema = Arc::new(Schema::new(vec![
658            Field::new("ints", DataType::Int32, true),
659            Field::new("float64", DataType::Float64, false),
660        ]));
661        let mut record_batch_store = RecordBatchStore::new(Arc::clone(&schema));
662        let int_array =
663            Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20
664        let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40
665
666        let record_batch_entry = RecordBatchEntry {
667            id: 0,
668            batch: RecordBatch::try_new(
669                schema,
670                vec![Arc::new(int_array), Arc::new(float64_array)],
671            )
672            .unwrap(),
673            uses: 1,
674        };
675
676        // when insert record batch entry
677        record_batch_store.insert(record_batch_entry);
678        assert_eq!(record_batch_store.batches_size, 60);
679
680        // when unuse record batch entry
681        record_batch_store.unuse(0);
682        assert_eq!(record_batch_store.batches_size, 0);
683    }
684}