lance_index/vector/v3/
shuffler.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! Shuffler is a component that takes a stream of record batches and shuffles them into
//! the corresponding IVF partitions.

use std::sync::Arc;

use arrow::{array::AsArray, compute::sort_to_indices};
use arrow_array::{RecordBatch, UInt32Array};
use future::join_all;
use futures::prelude::*;
use lance_arrow::RecordBatchExt;
use lance_core::{
    cache::FileMetadataCache,
    utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
    Error, Result,
};
use lance_encoding::decoder::{DecoderPlugins, FilterExpression};
use lance_file::v2::{
    reader::{FileReader, FileReaderOptions},
    writer::FileWriter,
};
use lance_io::{
    object_store::ObjectStore,
    scheduler::{ScanScheduler, SchedulerConfig},
    stream::{RecordBatchStream, RecordBatchStreamAdapter},
};
use object_store::path::Path;

use crate::vector::PART_ID_COLUMN;

#[async_trait::async_trait]
/// A reader that can read the shuffled partitions.
pub trait ShuffleReader: Send + Sync {
    /// Read a partition by partition_id
    /// will return Ok(None) if partition_size is 0
    /// check reader.partition_size(partition_id) before calling this function
    async fn read_partition(
        &self,
        partition_id: usize,
    ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>>;

    /// Get the size of the partition by partition_id
    fn partition_size(&self, partition_id: usize) -> Result<usize>;
}

#[async_trait::async_trait]
/// A shuffler that can shuffle the incoming stream of record batches into IVF partitions.
/// Returns a IvfShuffleReader that can be used to read the shuffled partitions.
pub trait Shuffler: Send + Sync {
    /// Shuffle the incoming stream of record batches into IVF partitions.
    /// Returns a IvfShuffleReader that can be used to read the shuffled partitions.
    async fn shuffle(
        &self,
        data: Box<dyn RecordBatchStream + Unpin + 'static>,
    ) -> Result<Box<dyn ShuffleReader>>;
}

pub struct IvfShuffler {
    object_store: Arc<ObjectStore>,
    output_dir: Path,
    num_partitions: usize,

    // options
    buffer_size: usize,
}

impl IvfShuffler {
    pub fn new(output_dir: Path, num_partitions: usize) -> Self {
        Self {
            object_store: Arc::new(ObjectStore::local()),
            output_dir,
            num_partitions,
            buffer_size: 4096,
        }
    }

    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
        self.buffer_size = buffer_size;
        self
    }
}

#[async_trait::async_trait]
impl Shuffler for IvfShuffler {
    async fn shuffle(
        &self,
        data: Box<dyn RecordBatchStream + Unpin + 'static>,
    ) -> Result<Box<dyn ShuffleReader>> {
        let mut writers: Vec<FileWriter> = vec![];
        let mut partition_sizes = vec![0; self.num_partitions];
        let mut first_pass = true;

        let num_partitions = self.num_partitions;
        let mut parallel_sort_stream = data
            .map(|batch| {
                spawn_cpu(move || {
                    let batch = batch?;

                    let part_ids: &UInt32Array = batch
                        .column_by_name(PART_ID_COLUMN)
                        .expect("Partition ID column not found")
                        .as_primitive();

                    let indices = sort_to_indices(&part_ids, None, None)?;
                    let batch = batch.take(&indices)?;

                    let part_ids: &UInt32Array = batch
                        .column_by_name(PART_ID_COLUMN)
                        .expect("Partition ID column not found")
                        .as_primitive();

                    let mut partition_buffers =
                        (0..num_partitions).map(|_| Vec::new()).collect::<Vec<_>>();

                    let mut start = 0;
                    while start < batch.num_rows() {
                        let part_id: u32 = part_ids.value(start);
                        let mut end = start + 1;
                        while end < batch.num_rows() && part_ids.value(end) == part_id {
                            end += 1;
                        }

                        let part_batches = &mut partition_buffers[part_id as usize];
                        part_batches.push(batch.slice(start, end - start));
                        start = end;
                    }

                    Ok::<Vec<Vec<RecordBatch>>, Error>(partition_buffers)
                })
            })
            .buffered(get_num_compute_intensive_cpus());

        // part_id:           |       0        |       1        |       3        |
        // partition_buffers: |[batch,batch,..]|[batch,batch,..]|[batch,batch,..]|
        let mut partition_buffers = (0..self.num_partitions)
            .map(|_| Vec::new())
            .collect::<Vec<_>>();

        let mut counter = 0;
        while let Some(shuffled) = parallel_sort_stream.next().await {
            let shuffled = shuffled?;

            for (part_id, batches) in shuffled.into_iter().enumerate() {
                let part_batches = &mut partition_buffers[part_id];
                part_batches.extend(batches);
            }

            counter += 1;

            if first_pass {
                let schema = partition_buffers
                    .iter()
                    .flatten()
                    .find(|_| true)
                    .map(|batch| batch.schema())
                    .expect("there should be at least one batch");
                writers = stream::iter(0..self.num_partitions)
                    .map(|partition_id| {
                        let part_path =
                            self.output_dir.child(format!("ivf_{}.lance", partition_id));
                        let object_store = self.object_store.clone();
                        let schema = schema.clone();
                        async move {
                            let writer = object_store.create(&part_path).await?;
                            FileWriter::try_new(
                                writer,
                                lance_core::datatypes::Schema::try_from(schema.as_ref())?,
                                Default::default(),
                            )
                        }
                    })
                    .buffered(10)
                    .try_collect::<Vec<_>>()
                    .await?;

                first_pass = false;
            }

            // do flush
            if counter % self.buffer_size == 0 {
                log::info!("shuffle {} batches, flushing", counter);
                let mut futs = vec![];
                for (part_id, writer) in writers.iter_mut().enumerate() {
                    let batches = &partition_buffers[part_id];
                    partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
                    futs.push(writer.write_batches(batches.iter()));
                }
                join_all(futs)
                    .await
                    .into_iter()
                    .collect::<Result<Vec<_>>>()?;

                partition_buffers.iter_mut().for_each(|b| b.clear());
            }
        }

        // final flush
        for (part_id, batches) in partition_buffers.into_iter().enumerate() {
            let writer = &mut writers[part_id];
            partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
            for batch in batches.iter() {
                writer.write_batch(batch).await?;
            }
        }

        // finish all writers
        for writer in writers.iter_mut() {
            writer.finish().await?;
        }

        Ok(Box::new(IvfShufflerReader::new(
            self.object_store.clone(),
            self.output_dir.clone(),
            partition_sizes,
        )))
    }
}

pub struct IvfShufflerReader {
    scheduler: Arc<ScanScheduler>,
    output_dir: Path,
    partition_sizes: Vec<usize>,
}

impl IvfShufflerReader {
    pub fn new(
        object_store: Arc<ObjectStore>,
        output_dir: Path,
        partition_sizes: Vec<usize>,
    ) -> Self {
        let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
        let scheduler = ScanScheduler::new(object_store, scheduler_config);
        Self {
            scheduler,
            output_dir,
            partition_sizes,
        }
    }
}

#[async_trait::async_trait]
impl ShuffleReader for IvfShufflerReader {
    async fn read_partition(
        &self,
        partition_id: usize,
    ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
        let partition_path = self.output_dir.child(format!("ivf_{}.lance", partition_id));

        let reader = FileReader::try_open(
            self.scheduler.open_file(&partition_path).await?,
            None,
            Arc::<DecoderPlugins>::default(),
            &FileMetadataCache::no_cache(),
            FileReaderOptions::default(),
        )
        .await?;
        let schema = reader.schema().as_ref().into();

        Ok(Some(Box::new(RecordBatchStreamAdapter::new(
            Arc::new(schema),
            reader.read_stream(
                lance_io::ReadBatchParams::RangeFull,
                4096,
                16,
                FilterExpression::no_filter(),
            )?,
        ))))
    }

    fn partition_size(&self, partition_id: usize) -> Result<usize> {
        Ok(self.partition_sizes[partition_id])
    }
}