lance_file/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4mod statistics;
5
6use std::collections::HashMap;
7use std::marker::PhantomData;
8
9use arrow_array::builder::{ArrayBuilder, PrimitiveBuilder};
10use arrow_array::cast::{as_large_list_array, as_list_array, as_struct_array};
11use arrow_array::types::{Int32Type, Int64Type};
12use arrow_array::{Array, ArrayRef, RecordBatch, StructArray};
13use arrow_buffer::ArrowNativeType;
14use arrow_data::ArrayData;
15use arrow_schema::DataType;
16use async_recursion::async_recursion;
17use async_trait::async_trait;
18use lance_arrow::*;
19use lance_core::datatypes::{Encoding, Field, NullabilityComparison, Schema, SchemaCompareOptions};
20use lance_core::{Error, Result};
21use lance_io::encodings::{
22    binary::BinaryEncoder, dictionary::DictionaryEncoder, plain::PlainEncoder, Encoder,
23};
24use lance_io::object_store::ObjectStore;
25use lance_io::object_writer::ObjectWriter;
26use lance_io::traits::{WriteExt, Writer};
27use object_store::path::Path;
28use snafu::location;
29use tokio::io::AsyncWriteExt;
30
31use crate::format::metadata::{Metadata, StatisticsMetadata};
32use crate::format::{MAGIC, MAJOR_VERSION, MINOR_VERSION};
33use crate::page_table::{PageInfo, PageTable};
34
35/// The file format currently includes a "manifest" where it stores the schema for
36/// self-describing files.  Historically this has been a table format manifest that
37/// is empty except for the schema field.
38///
39/// Since this crate is not aware of the table format we need this to be provided
40/// externally.  You should always use lance_table::io::manifest::ManifestDescribing
41/// for this today.
42#[async_trait]
43pub trait ManifestProvider {
44    /// Store the schema in the file
45    ///
46    /// This should just require writing the schema (or a manifest wrapper) as a proto struct
47    ///
48    /// Note: the dictionaries have already been written by this point and the schema should
49    /// be populated with the dictionary lengths/offsets
50    async fn store_schema(
51        object_writer: &mut ObjectWriter,
52        schema: &Schema,
53    ) -> Result<Option<usize>>;
54}
55
56/// Implementation of ManifestProvider that does not store the schema
57#[cfg(test)]
58pub(crate) struct NotSelfDescribing {}
59
60#[cfg(test)]
61#[async_trait]
62impl ManifestProvider for NotSelfDescribing {
63    async fn store_schema(_: &mut ObjectWriter, _: &Schema) -> Result<Option<usize>> {
64        Ok(None)
65    }
66}
67
68/// [FileWriter] writes Arrow [RecordBatch] to one Lance file.
69///
70/// ```ignored
71/// use lance::io::FileWriter;
72/// use futures::stream::Stream;
73///
74/// let mut file_writer = FileWriter::new(object_store, &path, &schema);
75/// while let Ok(batch) = stream.next().await {
76///     file_writer.write(&batch).unwrap();
77/// }
78/// // Need to close file writer to flush buffer and footer.
79/// file_writer.shutdown();
80/// ```
81pub struct FileWriter<M: ManifestProvider + Send + Sync> {
82    pub object_writer: ObjectWriter,
83    schema: Schema,
84    batch_id: i32,
85    page_table: PageTable,
86    metadata: Metadata,
87    stats_collector: Option<statistics::StatisticsCollector>,
88    manifest_provider: PhantomData<M>,
89}
90
91#[derive(Debug, Clone, Default)]
92pub struct FileWriterOptions {
93    /// The field ids to collect statistics for.
94    ///
95    /// If None, will collect for all fields in the schema (that support stats).
96    /// If an empty vector, will not collect any statistics.
97    pub collect_stats_for_fields: Option<Vec<i32>>,
98}
99
100impl<M: ManifestProvider + Send + Sync> FileWriter<M> {
101    pub async fn try_new(
102        object_store: &ObjectStore,
103        path: &Path,
104        schema: Schema,
105        options: &FileWriterOptions,
106    ) -> Result<Self> {
107        let object_writer = object_store.create(path).await?;
108        Self::with_object_writer(object_writer, schema, options)
109    }
110
111    pub fn with_object_writer(
112        object_writer: ObjectWriter,
113        schema: Schema,
114        options: &FileWriterOptions,
115    ) -> Result<Self> {
116        let collect_stats_for_fields = if let Some(stats_fields) = &options.collect_stats_for_fields
117        {
118            stats_fields.clone()
119        } else {
120            schema.field_ids()
121        };
122
123        let stats_collector = if !collect_stats_for_fields.is_empty() {
124            let stats_schema = schema.project_by_ids(&collect_stats_for_fields, true);
125            statistics::StatisticsCollector::try_new(&stats_schema)
126        } else {
127            None
128        };
129
130        Ok(Self {
131            object_writer,
132            schema,
133            batch_id: 0,
134            page_table: PageTable::default(),
135            metadata: Metadata::default(),
136            stats_collector,
137            manifest_provider: PhantomData,
138        })
139    }
140
141    /// Return the schema of the file writer.
142    pub fn schema(&self) -> &Schema {
143        &self.schema
144    }
145
146    fn verify_field_nullability(arr: &ArrayData, field: &Field) -> Result<()> {
147        if !field.nullable && arr.null_count() > 0 {
148            return Err(Error::invalid_input(format!("The field `{}` contained null values even though the field is marked non-null in the schema", field.name), location!()));
149        }
150
151        for (child_field, child_arr) in field.children.iter().zip(arr.child_data()) {
152            Self::verify_field_nullability(child_arr, child_field)?;
153        }
154
155        Ok(())
156    }
157
158    fn verify_nullability_constraints(&self, batch: &RecordBatch) -> Result<()> {
159        for (col, field) in batch.columns().iter().zip(self.schema.fields.iter()) {
160            Self::verify_field_nullability(&col.to_data(), field)?;
161        }
162        Ok(())
163    }
164
165    /// Write a [RecordBatch] to the open file.
166    /// All RecordBatch will be treated as one RecordBatch on disk
167    ///
168    /// Returns [Err] if the schema does not match with the batch.
169    pub async fn write(&mut self, batches: &[RecordBatch]) -> Result<()> {
170        if batches.is_empty() {
171            return Ok(());
172        }
173
174        for batch in batches {
175            // Compare, ignore metadata and dictionary
176            //   dictionary should have been checked earlier and could be an expensive check
177            let schema = Schema::try_from(batch.schema().as_ref())?;
178            schema.check_compatible(
179                &self.schema,
180                &SchemaCompareOptions {
181                    compare_nullability: NullabilityComparison::Ignore,
182                    ..Default::default()
183                },
184            )?;
185            self.verify_nullability_constraints(batch)?;
186        }
187
188        // If we are collecting stats for this column, collect them.
189        // Statistics need to traverse nested arrays, so it's a separate loop
190        // from writing which is done on top-level arrays.
191        if let Some(stats_collector) = &mut self.stats_collector {
192            for (field, arrays) in fields_in_batches(batches, &self.schema) {
193                if let Some(stats_builder) = stats_collector.get_builder(field.id) {
194                    let stats_row = statistics::collect_statistics(&arrays);
195                    stats_builder.append(stats_row);
196                }
197            }
198        }
199
200        // Copy a list of fields to avoid borrow checker error.
201        let fields = self.schema.fields.clone();
202        for field in fields.iter() {
203            let arrs = batches
204                .iter()
205                .map(|batch| {
206                    batch.column_by_name(&field.name).ok_or_else(|| {
207                        Error::io(
208                            format!("FileWriter::write: Field '{}' not found", field.name),
209                            location!(),
210                        )
211                    })
212                })
213                .collect::<Result<Vec<_>>>()?;
214
215            Self::write_array(
216                &mut self.object_writer,
217                field,
218                &arrs,
219                self.batch_id,
220                &mut self.page_table,
221            )
222            .await?;
223        }
224        let batch_length = batches.iter().map(|b| b.num_rows() as i32).sum();
225        self.metadata.push_batch_length(batch_length);
226
227        // It's imperative we complete any in-flight requests, since we are
228        // returning control to the caller. If the caller takes a long time to
229        // write the next batch, the in-flight requests will not be polled and
230        // may time out.
231        self.object_writer.flush().await?;
232
233        self.batch_id += 1;
234        Ok(())
235    }
236
237    /// Add schema metadata, as (key, value) pair to the file.
238    pub fn add_metadata(&mut self, key: &str, value: &str) {
239        self.schema
240            .metadata
241            .insert(key.to_string(), value.to_string());
242    }
243
244    pub async fn finish_with_metadata(
245        &mut self,
246        metadata: &HashMap<String, String>,
247    ) -> Result<usize> {
248        self.schema
249            .metadata
250            .extend(metadata.iter().map(|(k, y)| (k.clone(), y.clone())));
251        self.finish().await
252    }
253
254    pub async fn finish(&mut self) -> Result<usize> {
255        self.write_footer().await?;
256        self.object_writer.shutdown().await?;
257        let num_rows = self
258            .metadata
259            .batch_offsets
260            .last()
261            .cloned()
262            .unwrap_or_default();
263        Ok(num_rows as usize)
264    }
265
266    /// Total records written in this file.
267    pub fn len(&self) -> usize {
268        self.metadata.len()
269    }
270
271    /// Total bytes written so far
272    pub async fn tell(&mut self) -> Result<usize> {
273        self.object_writer.tell().await
274    }
275
276    /// Return the id of the next batch to be written.
277    pub fn next_batch_id(&self) -> i32 {
278        self.batch_id
279    }
280
281    pub fn is_empty(&self) -> bool {
282        self.len() == 0
283    }
284
285    #[async_recursion]
286    async fn write_array(
287        object_writer: &mut ObjectWriter,
288        field: &Field,
289        arrs: &[&ArrayRef],
290        batch_id: i32,
291        page_table: &mut PageTable,
292    ) -> Result<()> {
293        assert!(!arrs.is_empty());
294        let data_type = arrs[0].data_type();
295        let arrs_ref = arrs.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
296
297        match data_type {
298            DataType::Null => {
299                Self::write_null_array(
300                    object_writer,
301                    field,
302                    arrs_ref.as_slice(),
303                    batch_id,
304                    page_table,
305                )
306                .await
307            }
308            dt if dt.is_fixed_stride() => {
309                Self::write_fixed_stride_array(
310                    object_writer,
311                    field,
312                    arrs_ref.as_slice(),
313                    batch_id,
314                    page_table,
315                )
316                .await
317            }
318            dt if dt.is_binary_like() => {
319                Self::write_binary_array(
320                    object_writer,
321                    field,
322                    arrs_ref.as_slice(),
323                    batch_id,
324                    page_table,
325                )
326                .await
327            }
328            DataType::Dictionary(key_type, _) => {
329                Self::write_dictionary_arr(
330                    object_writer,
331                    field,
332                    arrs_ref.as_slice(),
333                    key_type,
334                    batch_id,
335                    page_table,
336                )
337                .await
338            }
339            dt if dt.is_struct() => {
340                let struct_arrays = arrs.iter().map(|a| as_struct_array(a)).collect::<Vec<_>>();
341                Self::write_struct_array(
342                    object_writer,
343                    field,
344                    struct_arrays.as_slice(),
345                    batch_id,
346                    page_table,
347                )
348                .await
349            }
350            DataType::FixedSizeList(_, _) | DataType::FixedSizeBinary(_) => {
351                Self::write_fixed_stride_array(
352                    object_writer,
353                    field,
354                    arrs_ref.as_slice(),
355                    batch_id,
356                    page_table,
357                )
358                .await
359            }
360            DataType::List(_) => {
361                Self::write_list_array(
362                    object_writer,
363                    field,
364                    arrs_ref.as_slice(),
365                    batch_id,
366                    page_table,
367                )
368                .await
369            }
370            DataType::LargeList(_) => {
371                Self::write_large_list_array(
372                    object_writer,
373                    field,
374                    arrs_ref.as_slice(),
375                    batch_id,
376                    page_table,
377                )
378                .await
379            }
380            _ => Err(Error::Schema {
381                message: format!("FileWriter::write: unsupported data type: {data_type}"),
382                location: location!(),
383            }),
384        }
385    }
386
387    async fn write_null_array(
388        object_writer: &mut ObjectWriter,
389        field: &Field,
390        arrs: &[&dyn Array],
391        batch_id: i32,
392        page_table: &mut PageTable,
393    ) -> Result<()> {
394        let arrs_length: i32 = arrs.iter().map(|a| a.len() as i32).sum();
395        let page_info = PageInfo::new(object_writer.tell().await?, arrs_length as usize);
396        page_table.set(field.id, batch_id, page_info);
397        Ok(())
398    }
399
400    /// Write fixed size array, including, primtiives, fixed size binary, and fixed size list.
401    async fn write_fixed_stride_array(
402        object_writer: &mut ObjectWriter,
403        field: &Field,
404        arrs: &[&dyn Array],
405        batch_id: i32,
406        page_table: &mut PageTable,
407    ) -> Result<()> {
408        assert_eq!(field.encoding, Some(Encoding::Plain));
409        assert!(!arrs.is_empty());
410        let data_type = arrs[0].data_type();
411
412        let mut encoder = PlainEncoder::new(object_writer, data_type);
413        let pos = encoder.encode(arrs).await?;
414        let arrs_length: i32 = arrs.iter().map(|a| a.len() as i32).sum();
415        let page_info = PageInfo::new(pos, arrs_length as usize);
416        page_table.set(field.id, batch_id, page_info);
417        Ok(())
418    }
419
420    /// Write var-length binary arrays.
421    async fn write_binary_array(
422        object_writer: &mut ObjectWriter,
423        field: &Field,
424        arrs: &[&dyn Array],
425        batch_id: i32,
426        page_table: &mut PageTable,
427    ) -> Result<()> {
428        assert_eq!(field.encoding, Some(Encoding::VarBinary));
429        let mut encoder = BinaryEncoder::new(object_writer);
430        let pos = encoder.encode(arrs).await?;
431        let arrs_length: i32 = arrs.iter().map(|a| a.len() as i32).sum();
432        let page_info = PageInfo::new(pos, arrs_length as usize);
433        page_table.set(field.id, batch_id, page_info);
434        Ok(())
435    }
436
437    async fn write_dictionary_arr(
438        object_writer: &mut ObjectWriter,
439        field: &Field,
440        arrs: &[&dyn Array],
441        key_type: &DataType,
442        batch_id: i32,
443        page_table: &mut PageTable,
444    ) -> Result<()> {
445        assert_eq!(field.encoding, Some(Encoding::Dictionary));
446
447        // Write the dictionary keys.
448        let mut encoder = DictionaryEncoder::new(object_writer, key_type);
449        let pos = encoder.encode(arrs).await?;
450        let arrs_length: i32 = arrs.iter().map(|a| a.len() as i32).sum();
451        let page_info = PageInfo::new(pos, arrs_length as usize);
452        page_table.set(field.id, batch_id, page_info);
453        Ok(())
454    }
455
456    #[async_recursion]
457    async fn write_struct_array(
458        object_writer: &mut ObjectWriter,
459        field: &Field,
460        arrays: &[&StructArray],
461        batch_id: i32,
462        page_table: &mut PageTable,
463    ) -> Result<()> {
464        arrays
465            .iter()
466            .for_each(|a| assert_eq!(a.num_columns(), field.children.len()));
467
468        for child in &field.children {
469            let mut arrs: Vec<&ArrayRef> = Vec::new();
470            for struct_array in arrays {
471                let arr = struct_array
472                    .column_by_name(&child.name)
473                    .ok_or(Error::Schema {
474                        message: format!(
475                            "FileWriter: schema mismatch: column {} does not exist in array: {:?}",
476                            child.name,
477                            struct_array.data_type()
478                        ),
479                        location: location!(),
480                    })?;
481                arrs.push(arr);
482            }
483            Self::write_array(object_writer, child, arrs.as_slice(), batch_id, page_table).await?;
484        }
485        Ok(())
486    }
487
488    async fn write_list_array(
489        object_writer: &mut ObjectWriter,
490        field: &Field,
491        arrs: &[&dyn Array],
492        batch_id: i32,
493        page_table: &mut PageTable,
494    ) -> Result<()> {
495        let capacity: usize = arrs.iter().map(|a| a.len()).sum();
496        let mut list_arrs: Vec<ArrayRef> = Vec::new();
497        let mut pos_builder: PrimitiveBuilder<Int32Type> =
498            PrimitiveBuilder::with_capacity(capacity);
499
500        let mut last_offset: usize = 0;
501        pos_builder.append_value(last_offset as i32);
502        for array in arrs.iter() {
503            let list_arr = as_list_array(*array);
504            let offsets = list_arr.value_offsets();
505
506            assert!(!offsets.is_empty());
507            let start_offset = offsets[0].as_usize();
508            let end_offset = offsets[offsets.len() - 1].as_usize();
509
510            let list_values = list_arr.values();
511            let sliced_values = list_values.slice(start_offset, end_offset - start_offset);
512            list_arrs.push(sliced_values);
513
514            offsets
515                .iter()
516                .skip(1)
517                .map(|b| b.as_usize() - start_offset + last_offset)
518                .for_each(|o| pos_builder.append_value(o as i32));
519            last_offset = pos_builder.values_slice()[pos_builder.len() - 1_usize] as usize;
520        }
521
522        let positions: &dyn Array = &pos_builder.finish();
523        Self::write_fixed_stride_array(object_writer, field, &[positions], batch_id, page_table)
524            .await?;
525        let arrs = list_arrs.iter().collect::<Vec<_>>();
526        Self::write_array(
527            object_writer,
528            &field.children[0],
529            arrs.as_slice(),
530            batch_id,
531            page_table,
532        )
533        .await
534    }
535
536    async fn write_large_list_array(
537        object_writer: &mut ObjectWriter,
538        field: &Field,
539        arrs: &[&dyn Array],
540        batch_id: i32,
541        page_table: &mut PageTable,
542    ) -> Result<()> {
543        let capacity: usize = arrs.iter().map(|a| a.len()).sum();
544        let mut list_arrs: Vec<ArrayRef> = Vec::new();
545        let mut pos_builder: PrimitiveBuilder<Int64Type> =
546            PrimitiveBuilder::with_capacity(capacity);
547
548        let mut last_offset: usize = 0;
549        pos_builder.append_value(last_offset as i64);
550        for array in arrs.iter() {
551            let list_arr = as_large_list_array(*array);
552            let offsets = list_arr.value_offsets();
553
554            assert!(!offsets.is_empty());
555            let start_offset = offsets[0].as_usize();
556            let end_offset = offsets[offsets.len() - 1].as_usize();
557
558            let sliced_values = list_arr
559                .values()
560                .slice(start_offset, end_offset - start_offset);
561            list_arrs.push(sliced_values);
562
563            offsets
564                .iter()
565                .skip(1)
566                .map(|b| b.as_usize() - start_offset + last_offset)
567                .for_each(|o| pos_builder.append_value(o as i64));
568            last_offset = pos_builder.values_slice()[pos_builder.len() - 1_usize] as usize;
569        }
570
571        let positions: &dyn Array = &pos_builder.finish();
572        Self::write_fixed_stride_array(object_writer, field, &[positions], batch_id, page_table)
573            .await?;
574        let arrs = list_arrs.iter().collect::<Vec<_>>();
575        Self::write_array(
576            object_writer,
577            &field.children[0],
578            arrs.as_slice(),
579            batch_id,
580            page_table,
581        )
582        .await
583    }
584
585    async fn write_statistics(&mut self) -> Result<Option<StatisticsMetadata>> {
586        let statistics = self
587            .stats_collector
588            .as_mut()
589            .map(|collector| collector.finish());
590
591        match statistics {
592            Some(Ok(stats_batch)) if stats_batch.num_rows() > 0 => {
593                debug_assert_eq!(self.next_batch_id() as usize, stats_batch.num_rows());
594                let schema = Schema::try_from(stats_batch.schema().as_ref())?;
595                let leaf_field_ids = schema.field_ids();
596
597                let mut stats_page_table = PageTable::default();
598                for (i, field) in schema.fields.iter().enumerate() {
599                    Self::write_array(
600                        &mut self.object_writer,
601                        field,
602                        &[stats_batch.column(i)],
603                        0, // Only one batch for statistics.
604                        &mut stats_page_table,
605                    )
606                    .await?;
607                }
608
609                let page_table_position =
610                    stats_page_table.write(&mut self.object_writer, 0).await?;
611
612                Ok(Some(StatisticsMetadata {
613                    schema,
614                    leaf_field_ids,
615                    page_table_position,
616                }))
617            }
618            Some(Err(e)) => Err(e),
619            _ => Ok(None),
620        }
621    }
622
623    /// Writes the dictionaries (using plain/binary encoding) into the file
624    ///
625    /// The offsets and lengths of the written buffers are stored in the given
626    /// schema so that the dictionaries can be loaded in the future.
627    async fn write_dictionaries(writer: &mut ObjectWriter, schema: &mut Schema) -> Result<()> {
628        // Write dictionary values.
629        let max_field_id = schema.max_field_id().unwrap_or(-1);
630        for field_id in 0..max_field_id + 1 {
631            if let Some(field) = schema.mut_field_by_id(field_id) {
632                if field.data_type().is_dictionary() {
633                    let dict_info = field.dictionary.as_mut().ok_or_else(|| {
634                        Error::io(
635                            format!("Lance field {} misses dictionary info", field.name),
636                            // and wrap it in here.
637                            location!(),
638                        )
639                    })?;
640
641                    let value_arr = dict_info.values.as_ref().ok_or_else(|| {
642                        Error::io(
643                            format!(
644                        "Lance field {} is dictionary type, but misses the dictionary value array", 
645                        field.name),
646                            location!(),
647                        )
648                    })?;
649
650                    let data_type = value_arr.data_type();
651                    let pos = match data_type {
652                        dt if dt.is_numeric() => {
653                            let mut encoder = PlainEncoder::new(writer, dt);
654                            encoder.encode(&[value_arr]).await?
655                        }
656                        dt if dt.is_binary_like() => {
657                            let mut encoder = BinaryEncoder::new(writer);
658                            encoder.encode(&[value_arr]).await?
659                        }
660                        _ => {
661                            return Err(Error::io(
662                                format!(
663                                    "Does not support {} as dictionary value type",
664                                    value_arr.data_type()
665                                ),
666                                location!(),
667                            ));
668                        }
669                    };
670                    dict_info.offset = pos;
671                    dict_info.length = value_arr.len();
672                }
673            }
674        }
675        Ok(())
676    }
677
678    async fn write_footer(&mut self) -> Result<()> {
679        // Step 1. Write page table.
680        let field_id_offset = *self.schema.field_ids().iter().min().unwrap();
681        let pos = self
682            .page_table
683            .write(&mut self.object_writer, field_id_offset)
684            .await?;
685        self.metadata.page_table_position = pos;
686
687        // Step 2. Write statistics.
688        self.metadata.stats_metadata = self.write_statistics().await?;
689
690        // Step 3. Write manifest and dictionary values.
691        Self::write_dictionaries(&mut self.object_writer, &mut self.schema).await?;
692        let pos = M::store_schema(&mut self.object_writer, &self.schema).await?;
693
694        // Step 4. Write metadata.
695        self.metadata.manifest_position = pos;
696        let pos = self.object_writer.write_struct(&self.metadata).await?;
697
698        // Step 5. Write magics.
699        self.object_writer
700            .write_magics(pos, MAJOR_VERSION, MINOR_VERSION, MAGIC)
701            .await
702    }
703}
704
705/// Walk through the schema and return arrays with their Lance field.
706///
707/// This skips over nested arrays and fields within list arrays. It does walk
708/// over the children of structs.
709fn fields_in_batches<'a>(
710    batches: &'a [RecordBatch],
711    schema: &'a Schema,
712) -> impl Iterator<Item = (&'a Field, Vec<&'a ArrayRef>)> {
713    let num_columns = batches[0].num_columns();
714    let array_iters = (0..num_columns).map(|col_i| {
715        batches
716            .iter()
717            .map(|batch| batch.column(col_i))
718            .collect::<Vec<_>>()
719    });
720    let mut to_visit: Vec<(&'a Field, Vec<&'a ArrayRef>)> =
721        schema.fields.iter().zip(array_iters).collect();
722
723    std::iter::from_fn(move || {
724        loop {
725            let (field, arrays): (_, Vec<&'a ArrayRef>) = to_visit.pop()?;
726            match field.data_type() {
727                DataType::Struct(_) => {
728                    for (i, child_field) in field.children.iter().enumerate() {
729                        let child_arrays = arrays
730                            .iter()
731                            .map(|arr| as_struct_array(*arr).column(i))
732                            .collect::<Vec<&'a ArrayRef>>();
733                        to_visit.push((child_field, child_arrays));
734                    }
735                    continue;
736                }
737                // We only walk structs right now.
738                _ if field.data_type().is_nested() => continue,
739                _ => return Some((field, arrays)),
740            }
741        }
742    })
743}
744
745#[cfg(test)]
746mod tests {
747    use super::*;
748
749    use std::sync::Arc;
750
751    use arrow_array::{
752        types::UInt32Type, BooleanArray, Decimal128Array, Decimal256Array, DictionaryArray,
753        DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray,
754        DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, Float32Array, Int32Array,
755        Int64Array, ListArray, NullArray, StringArray, TimestampMicrosecondArray,
756        TimestampSecondArray, UInt8Array,
757    };
758    use arrow_buffer::i256;
759    use arrow_schema::{
760        Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema, TimeUnit,
761    };
762    use arrow_select::concat::concat_batches;
763
764    use crate::reader::FileReader;
765
766    #[tokio::test]
767    async fn test_write_file() {
768        let arrow_schema = ArrowSchema::new(vec![
769            ArrowField::new("null", DataType::Null, true),
770            ArrowField::new("bool", DataType::Boolean, true),
771            ArrowField::new("i", DataType::Int64, true),
772            ArrowField::new("f", DataType::Float32, false),
773            ArrowField::new("b", DataType::Utf8, true),
774            ArrowField::new("decimal128", DataType::Decimal128(7, 3), false),
775            ArrowField::new("decimal256", DataType::Decimal256(7, 3), false),
776            ArrowField::new("duration_sec", DataType::Duration(TimeUnit::Second), false),
777            ArrowField::new(
778                "duration_msec",
779                DataType::Duration(TimeUnit::Millisecond),
780                false,
781            ),
782            ArrowField::new(
783                "duration_usec",
784                DataType::Duration(TimeUnit::Microsecond),
785                false,
786            ),
787            ArrowField::new(
788                "duration_nsec",
789                DataType::Duration(TimeUnit::Nanosecond),
790                false,
791            ),
792            ArrowField::new(
793                "d",
794                DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
795                true,
796            ),
797            ArrowField::new(
798                "fixed_size_list",
799                DataType::FixedSizeList(
800                    Arc::new(ArrowField::new("item", DataType::Float32, true)),
801                    16,
802                ),
803                true,
804            ),
805            ArrowField::new("fixed_size_binary", DataType::FixedSizeBinary(8), true),
806            ArrowField::new(
807                "l",
808                DataType::List(Arc::new(ArrowField::new("item", DataType::Utf8, true))),
809                true,
810            ),
811            ArrowField::new(
812                "large_l",
813                DataType::LargeList(Arc::new(ArrowField::new("item", DataType::Utf8, true))),
814                true,
815            ),
816            ArrowField::new(
817                "l_dict",
818                DataType::List(Arc::new(ArrowField::new(
819                    "item",
820                    DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
821                    true,
822                ))),
823                true,
824            ),
825            ArrowField::new(
826                "large_l_dict",
827                DataType::LargeList(Arc::new(ArrowField::new(
828                    "item",
829                    DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
830                    true,
831                ))),
832                true,
833            ),
834            ArrowField::new(
835                "s",
836                DataType::Struct(ArrowFields::from(vec![
837                    ArrowField::new("si", DataType::Int64, true),
838                    ArrowField::new("sb", DataType::Utf8, true),
839                ])),
840                true,
841            ),
842        ]);
843        let mut schema = Schema::try_from(&arrow_schema).unwrap();
844
845        let dict_vec = (0..100).map(|n| ["a", "b", "c"][n % 3]).collect::<Vec<_>>();
846        let dict_arr: DictionaryArray<UInt32Type> = dict_vec.into_iter().collect();
847
848        let fixed_size_list_arr = FixedSizeListArray::try_new_from_values(
849            Float32Array::from_iter((0..1600).map(|n| n as f32).collect::<Vec<_>>()),
850            16,
851        )
852        .unwrap();
853
854        let binary_data: [u8; 800] = [123; 800];
855        let fixed_size_binary_arr =
856            FixedSizeBinaryArray::try_new_from_values(&UInt8Array::from_iter(binary_data), 8)
857                .unwrap();
858
859        let list_offsets = (0..202).step_by(2).collect();
860        let list_values =
861            StringArray::from((0..200).map(|n| format!("str-{}", n)).collect::<Vec<_>>());
862        let list_arr: arrow_array::GenericListArray<i32> =
863            try_new_generic_list_array(list_values, &list_offsets).unwrap();
864
865        let large_list_offsets: Int64Array = (0..202).step_by(2).collect();
866        let large_list_values =
867            StringArray::from((0..200).map(|n| format!("str-{}", n)).collect::<Vec<_>>());
868        let large_list_arr: arrow_array::GenericListArray<i64> =
869            try_new_generic_list_array(large_list_values, &large_list_offsets).unwrap();
870
871        let list_dict_offsets = (0..202).step_by(2).collect();
872        let list_dict_vec = (0..200).map(|n| ["a", "b", "c"][n % 3]).collect::<Vec<_>>();
873        let list_dict_arr: DictionaryArray<UInt32Type> = list_dict_vec.into_iter().collect();
874        let list_dict_arr: arrow_array::GenericListArray<i32> =
875            try_new_generic_list_array(list_dict_arr, &list_dict_offsets).unwrap();
876
877        let large_list_dict_offsets: Int64Array = (0..202).step_by(2).collect();
878        let large_list_dict_vec = (0..200).map(|n| ["a", "b", "c"][n % 3]).collect::<Vec<_>>();
879        let large_list_dict_arr: DictionaryArray<UInt32Type> =
880            large_list_dict_vec.into_iter().collect();
881        let large_list_dict_arr: arrow_array::GenericListArray<i64> =
882            try_new_generic_list_array(large_list_dict_arr, &large_list_dict_offsets).unwrap();
883
884        let columns: Vec<ArrayRef> = vec![
885            Arc::new(NullArray::new(100)),
886            Arc::new(BooleanArray::from_iter(
887                (0..100).map(|f| Some(f % 3 == 0)).collect::<Vec<_>>(),
888            )),
889            Arc::new(Int64Array::from_iter((0..100).collect::<Vec<_>>())),
890            Arc::new(Float32Array::from_iter(
891                (0..100).map(|n| n as f32).collect::<Vec<_>>(),
892            )),
893            Arc::new(StringArray::from(
894                (0..100).map(|n| n.to_string()).collect::<Vec<_>>(),
895            )),
896            Arc::new(
897                Decimal128Array::from_iter_values(0..100)
898                    .with_precision_and_scale(7, 3)
899                    .unwrap(),
900            ),
901            Arc::new(
902                Decimal256Array::from_iter_values((0..100).map(|v| i256::from_i128(v as i128)))
903                    .with_precision_and_scale(7, 3)
904                    .unwrap(),
905            ),
906            Arc::new(DurationSecondArray::from_iter_values(0..100)),
907            Arc::new(DurationMillisecondArray::from_iter_values(0..100)),
908            Arc::new(DurationMicrosecondArray::from_iter_values(0..100)),
909            Arc::new(DurationNanosecondArray::from_iter_values(0..100)),
910            Arc::new(dict_arr),
911            Arc::new(fixed_size_list_arr),
912            Arc::new(fixed_size_binary_arr),
913            Arc::new(list_arr),
914            Arc::new(large_list_arr),
915            Arc::new(list_dict_arr),
916            Arc::new(large_list_dict_arr),
917            Arc::new(StructArray::from(vec![
918                (
919                    Arc::new(ArrowField::new("si", DataType::Int64, true)),
920                    Arc::new(Int64Array::from_iter((100..200).collect::<Vec<_>>())) as ArrayRef,
921                ),
922                (
923                    Arc::new(ArrowField::new("sb", DataType::Utf8, true)),
924                    Arc::new(StringArray::from(
925                        (0..100).map(|n| n.to_string()).collect::<Vec<_>>(),
926                    )) as ArrayRef,
927                ),
928            ])),
929        ];
930        let batch = RecordBatch::try_new(Arc::new(arrow_schema), columns).unwrap();
931        schema.set_dictionary(&batch).unwrap();
932
933        let store = ObjectStore::memory();
934        let path = Path::from("/foo");
935        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
936            &store,
937            &path,
938            schema.clone(),
939            &Default::default(),
940        )
941        .await
942        .unwrap();
943        file_writer.write(&[batch.clone()]).await.unwrap();
944        file_writer.finish().await.unwrap();
945
946        let reader = FileReader::try_new(&store, &path, schema).await.unwrap();
947        let actual = reader.read_batch(0, .., reader.schema()).await.unwrap();
948        assert_eq!(actual, batch);
949    }
950
951    #[tokio::test]
952    async fn test_dictionary_first_element_file() {
953        let arrow_schema = ArrowSchema::new(vec![ArrowField::new(
954            "d",
955            DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
956            true,
957        )]);
958        let mut schema = Schema::try_from(&arrow_schema).unwrap();
959
960        let dict_vec = (0..100).map(|n| ["a", "b", "c"][n % 3]).collect::<Vec<_>>();
961        let dict_arr: DictionaryArray<UInt32Type> = dict_vec.into_iter().collect();
962
963        let columns: Vec<ArrayRef> = vec![Arc::new(dict_arr)];
964        let batch = RecordBatch::try_new(Arc::new(arrow_schema), columns).unwrap();
965        schema.set_dictionary(&batch).unwrap();
966
967        let store = ObjectStore::memory();
968        let path = Path::from("/foo");
969        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
970            &store,
971            &path,
972            schema.clone(),
973            &Default::default(),
974        )
975        .await
976        .unwrap();
977        file_writer.write(&[batch.clone()]).await.unwrap();
978        file_writer.finish().await.unwrap();
979
980        let reader = FileReader::try_new(&store, &path, schema).await.unwrap();
981        let actual = reader.read_batch(0, .., reader.schema()).await.unwrap();
982        assert_eq!(actual, batch);
983    }
984
985    #[tokio::test]
986    async fn test_write_temporal_types() {
987        let arrow_schema = Arc::new(ArrowSchema::new(vec![
988            ArrowField::new(
989                "ts_notz",
990                DataType::Timestamp(TimeUnit::Second, None),
991                false,
992            ),
993            ArrowField::new(
994                "ts_tz",
995                DataType::Timestamp(TimeUnit::Microsecond, Some("America/Los_Angeles".into())),
996                false,
997            ),
998        ]));
999        let columns: Vec<ArrayRef> = vec![
1000            Arc::new(TimestampSecondArray::from(vec![11111111, 22222222])),
1001            Arc::new(
1002                TimestampMicrosecondArray::from(vec![3333333, 4444444])
1003                    .with_timezone("America/Los_Angeles"),
1004            ),
1005        ];
1006        let batch = RecordBatch::try_new(arrow_schema.clone(), columns).unwrap();
1007
1008        let schema = Schema::try_from(arrow_schema.as_ref()).unwrap();
1009        let store = ObjectStore::memory();
1010        let path = Path::from("/foo");
1011        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
1012            &store,
1013            &path,
1014            schema.clone(),
1015            &Default::default(),
1016        )
1017        .await
1018        .unwrap();
1019        file_writer.write(&[batch.clone()]).await.unwrap();
1020        file_writer.finish().await.unwrap();
1021
1022        let reader = FileReader::try_new(&store, &path, schema).await.unwrap();
1023        let actual = reader.read_batch(0, .., reader.schema()).await.unwrap();
1024        assert_eq!(actual, batch);
1025    }
1026
1027    #[tokio::test]
1028    async fn test_collect_stats() {
1029        // Validate:
1030        // Only collects stats for requested columns
1031        // Can collect stats in nested structs
1032        // Won't collect stats for list columns (for now)
1033
1034        let arrow_schema = ArrowSchema::new(vec![
1035            ArrowField::new("i", DataType::Int64, true),
1036            ArrowField::new("i2", DataType::Int64, true),
1037            ArrowField::new(
1038                "l",
1039                DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))),
1040                true,
1041            ),
1042            ArrowField::new(
1043                "s",
1044                DataType::Struct(ArrowFields::from(vec![
1045                    ArrowField::new("si", DataType::Int64, true),
1046                    ArrowField::new("sb", DataType::Utf8, true),
1047                ])),
1048                true,
1049            ),
1050        ]);
1051
1052        let schema = Schema::try_from(&arrow_schema).unwrap();
1053
1054        let store = ObjectStore::memory();
1055        let path = Path::from("/foo");
1056
1057        let options = FileWriterOptions {
1058            collect_stats_for_fields: Some(vec![0, 1, 5, 6]),
1059        };
1060        let mut file_writer =
1061            FileWriter::<NotSelfDescribing>::try_new(&store, &path, schema.clone(), &options)
1062                .await
1063                .unwrap();
1064
1065        let batch1 = RecordBatch::try_new(
1066            Arc::new(arrow_schema.clone()),
1067            vec![
1068                Arc::new(Int64Array::from(vec![1, 2, 3])),
1069                Arc::new(Int64Array::from(vec![4, 5, 6])),
1070                Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1071                    Some(vec![Some(1i32), Some(2), Some(3)]),
1072                    Some(vec![Some(4), Some(5)]),
1073                    Some(vec![]),
1074                ])),
1075                Arc::new(StructArray::from(vec![
1076                    (
1077                        Arc::new(ArrowField::new("si", DataType::Int64, true)),
1078                        Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef,
1079                    ),
1080                    (
1081                        Arc::new(ArrowField::new("sb", DataType::Utf8, true)),
1082                        Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef,
1083                    ),
1084                ])),
1085            ],
1086        )
1087        .unwrap();
1088        file_writer.write(&[batch1]).await.unwrap();
1089
1090        let batch2 = RecordBatch::try_new(
1091            Arc::new(arrow_schema.clone()),
1092            vec![
1093                Arc::new(Int64Array::from(vec![5, 6])),
1094                Arc::new(Int64Array::from(vec![10, 11])),
1095                Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1096                    Some(vec![Some(1i32), Some(2), Some(3)]),
1097                    Some(vec![]),
1098                ])),
1099                Arc::new(StructArray::from(vec![
1100                    (
1101                        Arc::new(ArrowField::new("si", DataType::Int64, true)),
1102                        Arc::new(Int64Array::from(vec![4, 5])) as ArrayRef,
1103                    ),
1104                    (
1105                        Arc::new(ArrowField::new("sb", DataType::Utf8, true)),
1106                        Arc::new(StringArray::from(vec!["d", "e"])) as ArrayRef,
1107                    ),
1108                ])),
1109            ],
1110        )
1111        .unwrap();
1112        file_writer.write(&[batch2]).await.unwrap();
1113
1114        file_writer.finish().await.unwrap();
1115
1116        let reader = FileReader::try_new(&store, &path, schema).await.unwrap();
1117
1118        let read_stats = reader.read_page_stats(&[0, 1, 5, 6]).await.unwrap();
1119        assert!(read_stats.is_some());
1120        let read_stats = read_stats.unwrap();
1121
1122        let expected_stats_schema = stats_schema([
1123            (0, DataType::Int64),
1124            (1, DataType::Int64),
1125            (5, DataType::Int64),
1126            (6, DataType::Utf8),
1127        ]);
1128
1129        assert_eq!(read_stats.schema().as_ref(), &expected_stats_schema);
1130
1131        let expected_stats = stats_batch(&[
1132            Stats {
1133                field_id: 0,
1134                null_counts: vec![0, 0],
1135                min_values: Arc::new(Int64Array::from(vec![1, 5])),
1136                max_values: Arc::new(Int64Array::from(vec![3, 6])),
1137            },
1138            Stats {
1139                field_id: 1,
1140                null_counts: vec![0, 0],
1141                min_values: Arc::new(Int64Array::from(vec![4, 10])),
1142                max_values: Arc::new(Int64Array::from(vec![6, 11])),
1143            },
1144            Stats {
1145                field_id: 5,
1146                null_counts: vec![0, 0],
1147                min_values: Arc::new(Int64Array::from(vec![1, 4])),
1148                max_values: Arc::new(Int64Array::from(vec![3, 5])),
1149            },
1150            // FIXME: these max values shouldn't be incremented
1151            // https://github.com/lancedb/lance/issues/1517
1152            Stats {
1153                field_id: 6,
1154                null_counts: vec![0, 0],
1155                min_values: Arc::new(StringArray::from(vec!["a", "d"])),
1156                max_values: Arc::new(StringArray::from(vec!["c", "e"])),
1157            },
1158        ]);
1159
1160        assert_eq!(read_stats, expected_stats);
1161    }
1162
1163    fn stats_schema(data_fields: impl IntoIterator<Item = (i32, DataType)>) -> ArrowSchema {
1164        let fields = data_fields
1165            .into_iter()
1166            .map(|(field_id, data_type)| {
1167                Arc::new(ArrowField::new(
1168                    format!("{}", field_id),
1169                    DataType::Struct(
1170                        vec![
1171                            Arc::new(ArrowField::new("null_count", DataType::Int64, false)),
1172                            Arc::new(ArrowField::new("min_value", data_type.clone(), true)),
1173                            Arc::new(ArrowField::new("max_value", data_type, true)),
1174                        ]
1175                        .into(),
1176                    ),
1177                    false,
1178                ))
1179            })
1180            .collect::<Vec<_>>();
1181        ArrowSchema::new(fields)
1182    }
1183
1184    struct Stats {
1185        field_id: i32,
1186        null_counts: Vec<i64>,
1187        min_values: ArrayRef,
1188        max_values: ArrayRef,
1189    }
1190
1191    fn stats_batch(stats: &[Stats]) -> RecordBatch {
1192        let schema = stats_schema(
1193            stats
1194                .iter()
1195                .map(|s| (s.field_id, s.min_values.data_type().clone())),
1196        );
1197
1198        let columns = stats
1199            .iter()
1200            .map(|s| {
1201                let data_type = s.min_values.data_type().clone();
1202                let fields = vec![
1203                    Arc::new(ArrowField::new("null_count", DataType::Int64, false)),
1204                    Arc::new(ArrowField::new("min_value", data_type.clone(), true)),
1205                    Arc::new(ArrowField::new("max_value", data_type, true)),
1206                ];
1207                let arrays = vec![
1208                    Arc::new(Int64Array::from(s.null_counts.clone())),
1209                    s.min_values.clone(),
1210                    s.max_values.clone(),
1211                ];
1212                Arc::new(StructArray::new(fields.into(), arrays, None)) as ArrayRef
1213            })
1214            .collect();
1215
1216        RecordBatch::try_new(Arc::new(schema), columns).unwrap()
1217    }
1218
1219    async fn read_file_as_one_batch(
1220        object_store: &ObjectStore,
1221        path: &Path,
1222        schema: Schema,
1223    ) -> RecordBatch {
1224        let reader = FileReader::try_new(object_store, path, schema)
1225            .await
1226            .unwrap();
1227        let mut batches = vec![];
1228        for i in 0..reader.num_batches() {
1229            batches.push(
1230                reader
1231                    .read_batch(i as i32, .., reader.schema())
1232                    .await
1233                    .unwrap(),
1234            );
1235        }
1236        let arrow_schema = Arc::new(reader.schema().into());
1237        concat_batches(&arrow_schema, &batches).unwrap()
1238    }
1239
1240    /// Test encoding arrays that share the same underneath buffer.
1241    #[tokio::test]
1242    async fn test_encode_slice() {
1243        let store = ObjectStore::memory();
1244        let path = Path::from("/shared_slice");
1245
1246        let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
1247            "i",
1248            DataType::Int32,
1249            false,
1250        )]));
1251        let schema = Schema::try_from(arrow_schema.as_ref()).unwrap();
1252        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
1253            &store,
1254            &path,
1255            schema.clone(),
1256            &Default::default(),
1257        )
1258        .await
1259        .unwrap();
1260
1261        let array = Int32Array::from_iter_values(0..1000);
1262
1263        for i in (0..1000).step_by(4) {
1264            let data = array.slice(i, 4);
1265            file_writer
1266                .write(&[RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(data)]).unwrap()])
1267                .await
1268                .unwrap();
1269        }
1270        file_writer.finish().await.unwrap();
1271        assert!(store.size(&path).await.unwrap() < 2 * 8 * 1000);
1272
1273        let batch = read_file_as_one_batch(&store, &path, schema).await;
1274        assert_eq!(batch.column_by_name("i").unwrap().as_ref(), &array);
1275    }
1276
1277    #[tokio::test]
1278    async fn test_write_schema_with_holes() {
1279        let store = ObjectStore::memory();
1280        let path = Path::from("test");
1281
1282        let mut field0 = Field::try_from(&ArrowField::new("a", DataType::Int32, true)).unwrap();
1283        field0.set_id(-1, &mut 0);
1284        assert_eq!(field0.id, 0);
1285        let mut field2 = Field::try_from(&ArrowField::new("b", DataType::Int32, true)).unwrap();
1286        field2.set_id(-1, &mut 2);
1287        assert_eq!(field2.id, 2);
1288        // There is a hole at field id 1.
1289        let schema = Schema {
1290            fields: vec![field0, field2],
1291            metadata: Default::default(),
1292        };
1293
1294        let arrow_schema = Arc::new(ArrowSchema::new(vec![
1295            ArrowField::new("a", DataType::Int32, true),
1296            ArrowField::new("b", DataType::Int32, true),
1297        ]));
1298        let data = RecordBatch::try_new(
1299            arrow_schema.clone(),
1300            vec![
1301                Arc::new(Int32Array::from_iter_values(0..10)),
1302                Arc::new(Int32Array::from_iter_values(10..20)),
1303            ],
1304        )
1305        .unwrap();
1306
1307        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
1308            &store,
1309            &path,
1310            schema.clone(),
1311            &Default::default(),
1312        )
1313        .await
1314        .unwrap();
1315        file_writer.write(&[data]).await.unwrap();
1316        file_writer.finish().await.unwrap();
1317
1318        let page_table = file_writer.page_table;
1319        assert!(page_table.get(0, 0).is_some());
1320        assert!(page_table.get(2, 0).is_some());
1321    }
1322}