lance_datagen/
generator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, iter, marker::PhantomData, sync::Arc};
5
6use arrow::{
7    array::{ArrayData, AsArray},
8    buffer::{BooleanBuffer, Buffer, OffsetBuffer, ScalarBuffer},
9    datatypes::{ArrowPrimitiveType, Int32Type, Int64Type, IntervalDayTime, IntervalMonthDayNano},
10};
11use arrow_array::{
12    make_array,
13    types::{ArrowDictionaryKeyType, BinaryType, ByteArrayType, Utf8Type},
14    Array, BinaryArray, FixedSizeBinaryArray, FixedSizeListArray, LargeListArray, ListArray,
15    NullArray, PrimitiveArray, RecordBatch, RecordBatchOptions, RecordBatchReader, StringArray,
16    StructArray,
17};
18use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema, SchemaRef};
19use futures::{stream::BoxStream, StreamExt};
20use rand::{distributions::Uniform, Rng, RngCore, SeedableRng};
21
22use self::array::rand_with_distribution;
23
24#[derive(Copy, Clone, Debug, Default)]
25pub struct RowCount(u64);
26#[derive(Copy, Clone, Debug, Default)]
27pub struct BatchCount(u32);
28#[derive(Copy, Clone, Debug, Default)]
29pub struct ByteCount(u64);
30#[derive(Copy, Clone, Debug, Default)]
31pub struct Dimension(u32);
32
33impl From<u32> for BatchCount {
34    fn from(n: u32) -> Self {
35        Self(n)
36    }
37}
38
39impl From<u64> for RowCount {
40    fn from(n: u64) -> Self {
41        Self(n)
42    }
43}
44
45impl From<u64> for ByteCount {
46    fn from(n: u64) -> Self {
47        Self(n)
48    }
49}
50
51impl From<u32> for Dimension {
52    fn from(n: u32) -> Self {
53        Self(n)
54    }
55}
56
57/// A trait for anything that can generate arrays of data
58pub trait ArrayGenerator: Send + Sync + std::fmt::Debug {
59    /// Generate an array of the given length
60    ///
61    /// # Arguments
62    ///
63    /// * `length` - The number of elements to generate
64    /// * `rng` - The random number generator to use
65    ///
66    /// # Returns
67    ///
68    /// An array of the given length
69    ///
70    /// Note: Not every generator needs an rng.  However, it is passed here because many do and this
71    /// lets us manage RNGs at the batch level instead of the array level.
72    fn generate(
73        &mut self,
74        length: RowCount,
75        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
76    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError>;
77    /// Get the data type of the array that this generator produces
78    ///
79    /// # Returns
80    ///
81    /// The data type of the array that this generator produces
82    fn data_type(&self) -> &DataType;
83    /// Gets metadata that should be associated with the field generated by this generator
84    fn metadata(&self) -> Option<HashMap<String, String>> {
85        None
86    }
87    /// Get the size of each element in bytes
88    ///
89    /// # Returns
90    ///
91    /// The size of each element in bytes.  Will be None if the size varies by element.
92    fn element_size_bytes(&self) -> Option<ByteCount>;
93}
94
95#[derive(Debug)]
96pub struct CycleNullGenerator {
97    generator: Box<dyn ArrayGenerator>,
98    validity: Vec<bool>,
99    idx: usize,
100}
101
102impl ArrayGenerator for CycleNullGenerator {
103    fn generate(
104        &mut self,
105        length: RowCount,
106        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
107    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
108        let array = self.generator.generate(length, rng)?;
109        let data = array.to_data();
110        let validity_itr = self
111            .validity
112            .iter()
113            .cycle()
114            .skip(self.idx)
115            .take(length.0 as usize)
116            .copied();
117        let validity_bitmap = BooleanBuffer::from_iter(validity_itr);
118
119        self.idx = (self.idx + (length.0 as usize)) % self.validity.len();
120        unsafe {
121            let new_data = ArrayData::new_unchecked(
122                data.data_type().clone(),
123                data.len(),
124                None,
125                Some(validity_bitmap.into_inner()),
126                data.offset(),
127                data.buffers().to_vec(),
128                data.child_data().into(),
129            );
130            Ok(make_array(new_data))
131        }
132    }
133
134    fn data_type(&self) -> &DataType {
135        self.generator.data_type()
136    }
137
138    fn element_size_bytes(&self) -> Option<ByteCount> {
139        self.generator.element_size_bytes()
140    }
141}
142
143#[derive(Debug)]
144pub struct MetadataGenerator {
145    generator: Box<dyn ArrayGenerator>,
146    metadata: HashMap<String, String>,
147}
148
149impl ArrayGenerator for MetadataGenerator {
150    fn generate(
151        &mut self,
152        length: RowCount,
153        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
154    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
155        self.generator.generate(length, rng)
156    }
157
158    fn metadata(&self) -> Option<HashMap<String, String>> {
159        Some(self.metadata.clone())
160    }
161
162    fn data_type(&self) -> &DataType {
163        self.generator.data_type()
164    }
165
166    fn element_size_bytes(&self) -> Option<ByteCount> {
167        self.generator.element_size_bytes()
168    }
169}
170
171#[derive(Debug)]
172pub struct NullGenerator {
173    generator: Box<dyn ArrayGenerator>,
174    null_probability: f64,
175}
176
177impl ArrayGenerator for NullGenerator {
178    fn generate(
179        &mut self,
180        length: RowCount,
181        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
182    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
183        let array = self.generator.generate(length, rng)?;
184        let data = array.to_data();
185
186        if self.null_probability < 0.0 || self.null_probability > 1.0 {
187            return Err(ArrowError::InvalidArgumentError(format!(
188                "null_probability must be between 0 and 1, got {}",
189                self.null_probability
190            )));
191        }
192
193        let (null_count, new_validity) = if self.null_probability == 0.0 {
194            if data.null_count() == 0 {
195                return Ok(array);
196            } else {
197                (0_usize, None)
198            }
199        } else if self.null_probability == 1.0 {
200            if data.null_count() == data.len() {
201                return Ok(array);
202            } else {
203                let all_nulls = BooleanBuffer::new_unset(array.len());
204                (array.len(), Some(all_nulls.into_inner()))
205            }
206        } else {
207            let array_len = array.len();
208            let num_validity_bytes = (array_len + 7) / 8;
209            let mut null_count = 0;
210            // Sampling the RNG once per bit is kind of slow so we do this to sample once
211            // per byte.  We only get 8 bits of RNG resolution but that should be good enough.
212            let threshold = (self.null_probability * u8::MAX as f64) as u8;
213            let bytes = (0..num_validity_bytes)
214                .map(|byte_idx| {
215                    let mut sample = rng.gen::<u64>();
216                    let mut byte: u8 = 0;
217                    for bit_idx in 0..8 {
218                        // We could probably overshoot and fill in extra bits with random data but
219                        // this is cleaner and that would mess up the null count
220                        byte <<= 1;
221                        let pos = byte_idx * 8 + (7 - bit_idx);
222                        if pos < array_len {
223                            let sample_piece = sample & 0xFF;
224                            let is_null = (sample_piece as u8) < threshold;
225                            byte |= (!is_null) as u8;
226                            null_count += is_null as usize;
227                        }
228                        sample >>= 8;
229                    }
230                    byte
231                })
232                .collect::<Vec<_>>();
233            let new_validity = Buffer::from_iter(bytes);
234            (null_count, Some(new_validity))
235        };
236
237        unsafe {
238            let new_data = ArrayData::new_unchecked(
239                data.data_type().clone(),
240                data.len(),
241                Some(null_count),
242                new_validity,
243                data.offset(),
244                data.buffers().to_vec(),
245                data.child_data().into(),
246            );
247            Ok(make_array(new_data))
248        }
249    }
250
251    fn metadata(&self) -> Option<HashMap<String, String>> {
252        self.generator.metadata()
253    }
254
255    fn data_type(&self) -> &DataType {
256        self.generator.data_type()
257    }
258
259    fn element_size_bytes(&self) -> Option<ByteCount> {
260        self.generator.element_size_bytes()
261    }
262}
263
264pub trait ArrayGeneratorExt {
265    /// Replaces the validity bitmap of generated arrays, inserting nulls with a given probability
266    fn with_random_nulls(self, null_probability: f64) -> Box<dyn ArrayGenerator>;
267    /// Replaces the validity bitmap of generated arrays with the inverse of `nulls`, cycling if needed
268    fn with_nulls(self, nulls: &[bool]) -> Box<dyn ArrayGenerator>;
269    /// Replaces the validity bitmap of generated arrays with `validity`, cycling if needed
270    fn with_validity(self, nulls: &[bool]) -> Box<dyn ArrayGenerator>;
271    fn with_metadata(self, metadata: HashMap<String, String>) -> Box<dyn ArrayGenerator>;
272}
273
274impl ArrayGeneratorExt for Box<dyn ArrayGenerator> {
275    fn with_random_nulls(self, null_probability: f64) -> Box<dyn ArrayGenerator> {
276        Box::new(NullGenerator {
277            generator: self,
278            null_probability,
279        })
280    }
281
282    fn with_nulls(self, nulls: &[bool]) -> Box<dyn ArrayGenerator> {
283        Box::new(CycleNullGenerator {
284            generator: self,
285            validity: nulls.iter().map(|v| !*v).collect(),
286            idx: 0,
287        })
288    }
289
290    fn with_validity(self, validity: &[bool]) -> Box<dyn ArrayGenerator> {
291        Box::new(CycleNullGenerator {
292            generator: self,
293            validity: validity.to_vec(),
294            idx: 0,
295        })
296    }
297
298    fn with_metadata(self, metadata: HashMap<String, String>) -> Box<dyn ArrayGenerator> {
299        Box::new(MetadataGenerator {
300            generator: self,
301            metadata,
302        })
303    }
304}
305
306pub struct NTimesIter<I: Iterator>
307where
308    I::Item: Copy,
309{
310    iter: I,
311    n: u32,
312    cur: I::Item,
313    count: u32,
314}
315
316// Note: if this is used then there is a performance hit as the
317// inner loop cannot experience vectorization
318//
319// TODO: maybe faster to build the vec and then repeat it into
320// the destination array?
321impl<I: Iterator> Iterator for NTimesIter<I>
322where
323    I::Item: Copy,
324{
325    type Item = I::Item;
326
327    fn next(&mut self) -> Option<Self::Item> {
328        if self.count == 0 {
329            self.count = self.n - 1;
330            self.cur = self.iter.next()?;
331        } else {
332            self.count -= 1;
333        }
334        Some(self.cur)
335    }
336
337    fn size_hint(&self) -> (usize, Option<usize>) {
338        let (lower, upper) = self.iter.size_hint();
339        let lower = lower * self.n as usize;
340        let upper = upper.map(|u| u * self.n as usize);
341        (lower, upper)
342    }
343}
344
345pub struct FnGen<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T>
346where
347    T: Copy + Default,
348    ArrayType: arrow_array::Array + From<Vec<T>>,
349{
350    data_type: DataType,
351    generator: F,
352    array_type: PhantomData<ArrayType>,
353    repeat: u32,
354    leftover: T,
355    leftover_count: u32,
356    element_size_bytes: Option<ByteCount>,
357}
358
359impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> std::fmt::Debug
360    for FnGen<T, ArrayType, F>
361where
362    T: Copy + Default,
363    ArrayType: arrow_array::Array + From<Vec<T>>,
364{
365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        f.debug_struct("FnGen")
367            .field("data_type", &self.data_type)
368            .field("array_type", &self.array_type)
369            .field("repeat", &self.repeat)
370            .field("leftover_count", &self.leftover_count)
371            .field("element_size_bytes", &self.element_size_bytes)
372            .finish()
373    }
374}
375
376impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> FnGen<T, ArrayType, F>
377where
378    T: Copy + Default,
379    ArrayType: arrow_array::Array + From<Vec<T>>,
380{
381    fn new_known_size(
382        data_type: DataType,
383        generator: F,
384        repeat: u32,
385        element_size_bytes: ByteCount,
386    ) -> Self {
387        Self {
388            data_type,
389            generator,
390            array_type: PhantomData,
391            repeat,
392            leftover: T::default(),
393            leftover_count: 0,
394            element_size_bytes: Some(element_size_bytes),
395        }
396    }
397}
398
399impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> ArrayGenerator
400    for FnGen<T, ArrayType, F>
401where
402    T: Copy + Default + Send + Sync,
403    ArrayType: arrow_array::Array + From<Vec<T>> + 'static,
404    F: Send + Sync,
405{
406    fn generate(
407        &mut self,
408        length: RowCount,
409        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
410    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
411        let iter = (0..length.0).map(|_| (self.generator)(rng));
412        let values = if self.repeat > 1 {
413            Vec::from_iter(
414                NTimesIter {
415                    iter,
416                    n: self.repeat,
417                    cur: self.leftover,
418                    count: self.leftover_count,
419                }
420                .take(length.0 as usize),
421            )
422        } else {
423            Vec::from_iter(iter)
424        };
425        self.leftover_count = ((self.leftover_count as u64 + length.0) % self.repeat as u64) as u32;
426        self.leftover = values.last().copied().unwrap_or(T::default());
427        Ok(Arc::new(ArrayType::from(values)))
428    }
429
430    fn data_type(&self) -> &DataType {
431        &self.data_type
432    }
433
434    fn element_size_bytes(&self) -> Option<ByteCount> {
435        self.element_size_bytes
436    }
437}
438
439#[derive(Copy, Clone, Debug)]
440pub struct Seed(pub u64);
441pub const DEFAULT_SEED: Seed = Seed(42);
442
443impl From<u64> for Seed {
444    fn from(n: u64) -> Self {
445        Self(n)
446    }
447}
448
449#[derive(Debug)]
450pub struct CycleVectorGenerator {
451    underlying_gen: Box<dyn ArrayGenerator>,
452    dimension: Dimension,
453    data_type: DataType,
454}
455
456impl CycleVectorGenerator {
457    pub fn new(underlying_gen: Box<dyn ArrayGenerator>, dimension: Dimension) -> Self {
458        let data_type = DataType::FixedSizeList(
459            Arc::new(Field::new("item", underlying_gen.data_type().clone(), true)),
460            dimension.0 as i32,
461        );
462        Self {
463            underlying_gen,
464            dimension,
465            data_type,
466        }
467    }
468}
469
470impl ArrayGenerator for CycleVectorGenerator {
471    fn generate(
472        &mut self,
473        length: RowCount,
474        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
475    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
476        let values = self
477            .underlying_gen
478            .generate(RowCount::from(length.0 * self.dimension.0 as u64), rng)?;
479        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
480        let values = Arc::new(values);
481
482        let array = FixedSizeListArray::try_new(field, self.dimension.0 as i32, values, None)?;
483
484        Ok(Arc::new(array))
485    }
486
487    fn data_type(&self) -> &DataType {
488        &self.data_type
489    }
490
491    fn element_size_bytes(&self) -> Option<ByteCount> {
492        self.underlying_gen
493            .element_size_bytes()
494            .map(|byte_count| ByteCount::from(byte_count.0 * self.dimension.0 as u64))
495    }
496}
497
498#[derive(Debug, Default)]
499pub struct PseudoUuidGenerator {}
500
501impl ArrayGenerator for PseudoUuidGenerator {
502    fn generate(
503        &mut self,
504        length: RowCount,
505        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
506    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
507        Ok(Arc::new(FixedSizeBinaryArray::try_from_iter(
508            (0..length.0).map(|_| {
509                let mut data = vec![0; 16];
510                rng.fill_bytes(&mut data);
511                data
512            }),
513        )?))
514    }
515
516    fn data_type(&self) -> &DataType {
517        &DataType::FixedSizeBinary(16)
518    }
519
520    fn element_size_bytes(&self) -> Option<ByteCount> {
521        Some(ByteCount::from(16))
522    }
523}
524
525#[derive(Debug, Default)]
526pub struct PseudoUuidHexGenerator {}
527
528impl ArrayGenerator for PseudoUuidHexGenerator {
529    fn generate(
530        &mut self,
531        length: RowCount,
532        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
533    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
534        let mut data = vec![0; 16 * length.0 as usize];
535        rng.fill_bytes(&mut data);
536        let data_hex = hex::encode(data);
537
538        Ok(Arc::new(StringArray::from_iter_values(
539            (0..length.0 as usize).map(|i| data_hex.get(i * 32..(i + 1) * 32).unwrap()),
540        )))
541    }
542
543    fn data_type(&self) -> &DataType {
544        &DataType::Utf8
545    }
546
547    fn element_size_bytes(&self) -> Option<ByteCount> {
548        Some(ByteCount::from(16))
549    }
550}
551
552#[derive(Debug, Default)]
553pub struct RandomBooleanGenerator {}
554
555impl ArrayGenerator for RandomBooleanGenerator {
556    fn generate(
557        &mut self,
558        length: RowCount,
559        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
560    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
561        let num_bytes = (length.0 + 7) / 8;
562        let mut bytes = vec![0; num_bytes as usize];
563        rng.fill_bytes(&mut bytes);
564        let bytes = BooleanBuffer::new(Buffer::from(bytes), 0, length.0 as usize);
565        Ok(Arc::new(arrow_array::BooleanArray::new(bytes, None)))
566    }
567
568    fn data_type(&self) -> &DataType {
569        &DataType::Boolean
570    }
571
572    fn element_size_bytes(&self) -> Option<ByteCount> {
573        // We can't say 1/8th of a byte and 1 byte would be a pretty extreme over-count so let's leave
574        // it at None until someone needs this.  Then we can probably special case this (e.g. make a ByteCount::ONE_BIT)
575        None
576    }
577}
578
579// Instead of using the "standard distribution" and generating values there are some cases (e.g. f16 / decimal)
580// where we just generate random bytes because there is no rand support
581pub struct RandomBytesGenerator<T: ArrowPrimitiveType + Send + Sync> {
582    phantom: PhantomData<T>,
583    data_type: DataType,
584}
585
586impl<T: ArrowPrimitiveType + Send + Sync> std::fmt::Debug for RandomBytesGenerator<T> {
587    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588        f.debug_struct("RandomBytesGenerator")
589            .field("data_type", &self.data_type)
590            .finish()
591    }
592}
593
594impl<T: ArrowPrimitiveType + Send + Sync> RandomBytesGenerator<T> {
595    fn new(data_type: DataType) -> Self {
596        Self {
597            phantom: Default::default(),
598            data_type,
599        }
600    }
601
602    fn byte_width() -> Result<u64, ArrowError> {
603        T::DATA_TYPE.primitive_width().ok_or_else(|| ArrowError::InvalidArgumentError(format!("Cannot generate the data type {} with the RandomBytesGenerator because it is not a fixed-width bytes type", T::DATA_TYPE))).map(|val| val as u64)
604    }
605}
606
607impl<T: ArrowPrimitiveType + Send + Sync> ArrayGenerator for RandomBytesGenerator<T> {
608    fn generate(
609        &mut self,
610        length: RowCount,
611        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
612    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
613        let num_bytes = length.0 * Self::byte_width()?;
614        let mut bytes = vec![0; num_bytes as usize];
615        rng.fill_bytes(&mut bytes);
616        let bytes = ScalarBuffer::new(Buffer::from(bytes), 0, length.0 as usize);
617        Ok(Arc::new(
618            PrimitiveArray::<T>::new(bytes, None).with_data_type(self.data_type.clone()),
619        ))
620    }
621
622    fn data_type(&self) -> &DataType {
623        &self.data_type
624    }
625
626    fn element_size_bytes(&self) -> Option<ByteCount> {
627        Self::byte_width().map(ByteCount::from).ok()
628    }
629}
630
631// This is pretty much the same thing as RandomBinaryGenerator but we can't use that
632// because there is no ArrowPrimitiveType for FixedSizeBinary
633#[derive(Debug)]
634pub struct RandomFixedSizeBinaryGenerator {
635    data_type: DataType,
636    size: i32,
637}
638
639impl RandomFixedSizeBinaryGenerator {
640    fn new(size: i32) -> Self {
641        Self {
642            size,
643            data_type: DataType::FixedSizeBinary(size),
644        }
645    }
646}
647
648impl ArrayGenerator for RandomFixedSizeBinaryGenerator {
649    fn generate(
650        &mut self,
651        length: RowCount,
652        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
653    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
654        let num_bytes = length.0 * self.size as u64;
655        let mut bytes = vec![0; num_bytes as usize];
656        rng.fill_bytes(&mut bytes);
657        Ok(Arc::new(FixedSizeBinaryArray::new(
658            self.size,
659            Buffer::from(bytes),
660            None,
661        )))
662    }
663
664    fn data_type(&self) -> &DataType {
665        &self.data_type
666    }
667
668    fn element_size_bytes(&self) -> Option<ByteCount> {
669        Some(ByteCount::from(self.size as u64))
670    }
671}
672
673#[derive(Debug)]
674pub struct RandomIntervalGenerator {
675    unit: IntervalUnit,
676    data_type: DataType,
677}
678
679impl RandomIntervalGenerator {
680    pub fn new(unit: IntervalUnit) -> Self {
681        Self {
682            unit,
683            data_type: DataType::Interval(unit),
684        }
685    }
686}
687
688impl ArrayGenerator for RandomIntervalGenerator {
689    fn generate(
690        &mut self,
691        length: RowCount,
692        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
693    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
694        match self.unit {
695            IntervalUnit::YearMonth => {
696                let months = (0..length.0).map(|_| rng.gen::<i32>()).collect::<Vec<_>>();
697                Ok(Arc::new(arrow_array::IntervalYearMonthArray::from(months)))
698            }
699            IntervalUnit::MonthDayNano => {
700                let day_time_array = (0..length.0)
701                    .map(|_| IntervalMonthDayNano::new(rng.gen(), rng.gen(), rng.gen()))
702                    .collect::<Vec<_>>();
703                Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from(
704                    day_time_array,
705                )))
706            }
707            IntervalUnit::DayTime => {
708                let day_time_array = (0..length.0)
709                    .map(|_| IntervalDayTime::new(rng.gen(), rng.gen()))
710                    .collect::<Vec<_>>();
711                Ok(Arc::new(arrow_array::IntervalDayTimeArray::from(
712                    day_time_array,
713                )))
714            }
715        }
716    }
717
718    fn data_type(&self) -> &DataType {
719        &self.data_type
720    }
721
722    fn element_size_bytes(&self) -> Option<ByteCount> {
723        Some(ByteCount::from(12))
724    }
725}
726#[derive(Debug)]
727pub struct RandomBinaryGenerator {
728    bytes_per_element: ByteCount,
729    scale_to_utf8: bool,
730    is_large: bool,
731    data_type: DataType,
732}
733
734impl RandomBinaryGenerator {
735    pub fn new(bytes_per_element: ByteCount, scale_to_utf8: bool, is_large: bool) -> Self {
736        Self {
737            bytes_per_element,
738            scale_to_utf8,
739            is_large,
740            data_type: match (scale_to_utf8, is_large) {
741                (false, false) => DataType::Binary,
742                (false, true) => DataType::LargeBinary,
743                (true, false) => DataType::Utf8,
744                (true, true) => DataType::LargeUtf8,
745            },
746        }
747    }
748}
749
750impl ArrayGenerator for RandomBinaryGenerator {
751    fn generate(
752        &mut self,
753        length: RowCount,
754        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
755    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
756        let mut bytes = vec![0; (self.bytes_per_element.0 * length.0) as usize];
757        rng.fill_bytes(&mut bytes);
758        if self.scale_to_utf8 {
759            // This doesn't give us the full UTF-8 range and it isn't statistically correct but
760            // it's fast and probably good enough for most cases
761            bytes = bytes.into_iter().map(|val| (val % 95) + 32).collect();
762        }
763        let bytes = Buffer::from(bytes);
764        if self.is_large {
765            let offsets = OffsetBuffer::from_lengths(
766                iter::repeat(self.bytes_per_element.0 as usize).take(length.0 as usize),
767            );
768            if self.scale_to_utf8 {
769                // This is safe because we are only using printable characters
770                unsafe {
771                    Ok(Arc::new(arrow_array::LargeStringArray::new_unchecked(
772                        offsets, bytes, None,
773                    )))
774                }
775            } else {
776                unsafe {
777                    Ok(Arc::new(arrow_array::LargeBinaryArray::new_unchecked(
778                        offsets, bytes, None,
779                    )))
780                }
781            }
782        } else {
783            let offsets = OffsetBuffer::from_lengths(
784                iter::repeat(self.bytes_per_element.0 as usize).take(length.0 as usize),
785            );
786            if self.scale_to_utf8 {
787                // This is safe because we are only using printable characters
788                unsafe {
789                    Ok(Arc::new(arrow_array::StringArray::new_unchecked(
790                        offsets, bytes, None,
791                    )))
792                }
793            } else {
794                unsafe {
795                    Ok(Arc::new(arrow_array::BinaryArray::new_unchecked(
796                        offsets, bytes, None,
797                    )))
798                }
799            }
800        }
801    }
802
803    fn data_type(&self) -> &DataType {
804        &self.data_type
805    }
806
807    fn element_size_bytes(&self) -> Option<ByteCount> {
808        // Not exactly correct since there are N + 1 4-byte offsets and this only counts N
809        Some(ByteCount::from(
810            self.bytes_per_element.0 + std::mem::size_of::<i32>() as u64,
811        ))
812    }
813}
814
815#[derive(Debug)]
816pub struct VariableRandomBinaryGenerator {
817    lengths_gen: Box<dyn ArrayGenerator>,
818    data_type: DataType,
819}
820
821impl VariableRandomBinaryGenerator {
822    pub fn new(min_bytes_per_element: ByteCount, max_bytes_per_element: ByteCount) -> Self {
823        let lengths_dist = Uniform::new_inclusive(
824            min_bytes_per_element.0 as i32,
825            max_bytes_per_element.0 as i32,
826        );
827        let lengths_gen = rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist);
828
829        Self {
830            lengths_gen,
831            data_type: DataType::Binary,
832        }
833    }
834}
835
836impl ArrayGenerator for VariableRandomBinaryGenerator {
837    fn generate(
838        &mut self,
839        length: RowCount,
840        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
841    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
842        let lengths = self.lengths_gen.generate(length, rng)?;
843        let lengths = lengths.as_primitive::<Int32Type>();
844        let total_length = lengths.values().iter().map(|i| *i as usize).sum::<usize>();
845        let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
846        let mut bytes = vec![0; total_length];
847        rng.fill_bytes(&mut bytes);
848        let bytes = Buffer::from(bytes);
849        Ok(Arc::new(BinaryArray::try_new(offsets, bytes, None)?))
850    }
851
852    fn data_type(&self) -> &DataType {
853        &self.data_type
854    }
855
856    fn element_size_bytes(&self) -> Option<ByteCount> {
857        None
858    }
859}
860
861pub struct CycleBinaryGenerator<T: ByteArrayType> {
862    values: Vec<u8>,
863    lengths: Vec<usize>,
864    data_type: DataType,
865    array_type: PhantomData<T>,
866    width: Option<ByteCount>,
867    idx: usize,
868}
869
870impl<T: ByteArrayType> std::fmt::Debug for CycleBinaryGenerator<T> {
871    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
872        f.debug_struct("CycleBinaryGenerator")
873            .field("values", &self.values)
874            .field("lengths", &self.lengths)
875            .field("data_type", &self.data_type)
876            .field("width", &self.width)
877            .field("idx", &self.idx)
878            .finish()
879    }
880}
881
882impl<T: ByteArrayType> CycleBinaryGenerator<T> {
883    pub fn from_strings(values: &[&str]) -> Self {
884        if values.is_empty() {
885            panic!("Attempt to create a cycle generator with no values");
886        }
887        let lengths = values.iter().map(|s| s.len()).collect::<Vec<_>>();
888        let typical_length = lengths[0];
889        let width = if lengths.iter().all(|item| *item == typical_length) {
890            Some(ByteCount::from(
891                typical_length as u64 + std::mem::size_of::<i32>() as u64,
892            ))
893        } else {
894            None
895        };
896        let values = values
897            .iter()
898            .flat_map(|s| s.as_bytes().iter().copied())
899            .collect::<Vec<_>>();
900        Self {
901            values,
902            lengths,
903            data_type: T::DATA_TYPE,
904            array_type: PhantomData,
905            width,
906            idx: 0,
907        }
908    }
909}
910
911impl<T: ByteArrayType> ArrayGenerator for CycleBinaryGenerator<T> {
912    fn generate(
913        &mut self,
914        length: RowCount,
915        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
916    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
917        let lengths = self
918            .lengths
919            .iter()
920            .copied()
921            .cycle()
922            .skip(self.idx)
923            .take(length.0 as usize);
924        let num_bytes = lengths.clone().sum();
925        let byte_offset = self.lengths[0..self.idx].iter().sum();
926        let bytes = self
927            .values
928            .iter()
929            .cycle()
930            .skip(byte_offset)
931            .copied()
932            .take(num_bytes)
933            .collect::<Vec<_>>();
934        let bytes = Buffer::from(bytes);
935        let offsets = OffsetBuffer::from_lengths(lengths);
936        self.idx = (self.idx + length.0 as usize) % self.lengths.len();
937        Ok(Arc::new(arrow_array::GenericByteArray::<T>::new(
938            offsets, bytes, None,
939        )))
940    }
941
942    fn data_type(&self) -> &DataType {
943        &self.data_type
944    }
945
946    fn element_size_bytes(&self) -> Option<ByteCount> {
947        self.width
948    }
949}
950
951pub struct FixedBinaryGenerator<T: ByteArrayType> {
952    value: Vec<u8>,
953    data_type: DataType,
954    array_type: PhantomData<T>,
955}
956
957impl<T: ByteArrayType> std::fmt::Debug for FixedBinaryGenerator<T> {
958    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
959        f.debug_struct("FixedBinaryGenerator")
960            .field("value", &self.value)
961            .field("data_type", &self.data_type)
962            .finish()
963    }
964}
965
966impl<T: ByteArrayType> FixedBinaryGenerator<T> {
967    pub fn new(value: Vec<u8>) -> Self {
968        Self {
969            value,
970            data_type: T::DATA_TYPE,
971            array_type: PhantomData,
972        }
973    }
974}
975
976impl<T: ByteArrayType> ArrayGenerator for FixedBinaryGenerator<T> {
977    fn generate(
978        &mut self,
979        length: RowCount,
980        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
981    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
982        let bytes = Buffer::from(Vec::from_iter(
983            self.value
984                .iter()
985                .cycle()
986                .take((length.0 * self.value.len() as u64) as usize)
987                .copied(),
988        ));
989        let offsets =
990            OffsetBuffer::from_lengths(iter::repeat(self.value.len()).take(length.0 as usize));
991        Ok(Arc::new(arrow_array::GenericByteArray::<T>::new(
992            offsets, bytes, None,
993        )))
994    }
995
996    fn data_type(&self) -> &DataType {
997        &self.data_type
998    }
999
1000    fn element_size_bytes(&self) -> Option<ByteCount> {
1001        // Not exactly correct since there are N + 1 4-byte offsets and this only counts N
1002        Some(ByteCount::from(
1003            self.value.len() as u64 + std::mem::size_of::<i32>() as u64,
1004        ))
1005    }
1006}
1007
1008pub struct DictionaryGenerator<K: ArrowDictionaryKeyType> {
1009    generator: Box<dyn ArrayGenerator>,
1010    data_type: DataType,
1011    key_type: PhantomData<K>,
1012    key_width: u64,
1013}
1014
1015impl<K: ArrowDictionaryKeyType> std::fmt::Debug for DictionaryGenerator<K> {
1016    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017        f.debug_struct("DictionaryGenerator")
1018            .field("generator", &self.generator)
1019            .field("data_type", &self.data_type)
1020            .field("key_width", &self.key_width)
1021            .finish()
1022    }
1023}
1024
1025impl<K: ArrowDictionaryKeyType> DictionaryGenerator<K> {
1026    fn new(generator: Box<dyn ArrayGenerator>) -> Self {
1027        let key_type = Box::new(K::DATA_TYPE);
1028        let key_width = key_type
1029            .primitive_width()
1030            .expect("dictionary key types should have a known width")
1031            as u64;
1032        let val_type = Box::new(generator.data_type().clone());
1033        let dict_type = DataType::Dictionary(key_type, val_type);
1034        Self {
1035            generator,
1036            data_type: dict_type,
1037            key_type: PhantomData,
1038            key_width,
1039        }
1040    }
1041}
1042
1043impl<K: ArrowDictionaryKeyType + Send + Sync> ArrayGenerator for DictionaryGenerator<K> {
1044    fn generate(
1045        &mut self,
1046        length: RowCount,
1047        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1048    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1049        let underlying = self.generator.generate(length, rng)?;
1050        arrow_cast::cast::cast(&underlying, &self.data_type)
1051    }
1052
1053    fn data_type(&self) -> &DataType {
1054        &self.data_type
1055    }
1056
1057    fn element_size_bytes(&self) -> Option<ByteCount> {
1058        self.generator
1059            .element_size_bytes()
1060            .map(|size_bytes| ByteCount::from(size_bytes.0 + self.key_width))
1061    }
1062}
1063
1064#[derive(Debug)]
1065struct RandomListGenerator {
1066    field: Arc<Field>,
1067    child_field: Arc<Field>,
1068    items_gen: Box<dyn ArrayGenerator>,
1069    lengths_gen: Box<dyn ArrayGenerator>,
1070    is_large: bool,
1071}
1072
1073impl RandomListGenerator {
1074    // Creates a list generator that generates random lists with lengths between 0 and 10 (inclusive)
1075    fn new(items_gen: Box<dyn ArrayGenerator>, is_large: bool) -> Self {
1076        let child_field = Arc::new(Field::new("item", items_gen.data_type().clone(), true));
1077        let list_type = if is_large {
1078            DataType::LargeList(child_field.clone())
1079        } else {
1080            DataType::List(child_field.clone())
1081        };
1082        let field = Field::new("", list_type, true);
1083        let lengths_gen = if is_large {
1084            let lengths_dist = Uniform::new_inclusive(0, 10);
1085            rand_with_distribution::<Int64Type, Uniform<i64>>(lengths_dist)
1086        } else {
1087            let lengths_dist = Uniform::new_inclusive(0, 10);
1088            rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist)
1089        };
1090        Self {
1091            field: Arc::new(field),
1092            child_field,
1093            items_gen,
1094            lengths_gen,
1095            is_large,
1096        }
1097    }
1098}
1099
1100impl ArrayGenerator for RandomListGenerator {
1101    fn generate(
1102        &mut self,
1103        length: RowCount,
1104        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1105    ) -> Result<Arc<dyn Array>, ArrowError> {
1106        let lengths = self.lengths_gen.generate(length, rng)?;
1107        if self.is_large {
1108            let lengths = lengths.as_primitive::<Int64Type>();
1109            let total_length = lengths.values().iter().sum::<i64>() as u64;
1110            let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1111            let items = self.items_gen.generate(RowCount::from(total_length), rng)?;
1112            Ok(Arc::new(LargeListArray::try_new(
1113                self.child_field.clone(),
1114                offsets,
1115                items,
1116                None,
1117            )?))
1118        } else {
1119            let lengths = lengths.as_primitive::<Int32Type>();
1120            let total_length = lengths.values().iter().sum::<i32>() as u64;
1121            let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1122            let items = self.items_gen.generate(RowCount::from(total_length), rng)?;
1123            Ok(Arc::new(ListArray::try_new(
1124                self.child_field.clone(),
1125                offsets,
1126                items,
1127                None,
1128            )?))
1129        }
1130    }
1131
1132    fn data_type(&self) -> &DataType {
1133        self.field.data_type()
1134    }
1135
1136    fn element_size_bytes(&self) -> Option<ByteCount> {
1137        None
1138    }
1139}
1140
1141#[derive(Debug)]
1142struct NullArrayGenerator {}
1143
1144impl ArrayGenerator for NullArrayGenerator {
1145    fn generate(
1146        &mut self,
1147        length: RowCount,
1148        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
1149    ) -> Result<Arc<dyn Array>, ArrowError> {
1150        Ok(Arc::new(NullArray::new(length.0 as usize)))
1151    }
1152
1153    fn data_type(&self) -> &DataType {
1154        &DataType::Null
1155    }
1156
1157    fn element_size_bytes(&self) -> Option<ByteCount> {
1158        None
1159    }
1160}
1161
1162#[derive(Debug)]
1163struct RandomStructGenerator {
1164    fields: Fields,
1165    data_type: DataType,
1166    child_gens: Vec<Box<dyn ArrayGenerator>>,
1167}
1168
1169impl RandomStructGenerator {
1170    fn new(fields: Fields, child_gens: Vec<Box<dyn ArrayGenerator>>) -> Self {
1171        let data_type = DataType::Struct(fields.clone());
1172        Self {
1173            fields,
1174            data_type,
1175            child_gens,
1176        }
1177    }
1178}
1179
1180impl ArrayGenerator for RandomStructGenerator {
1181    fn generate(
1182        &mut self,
1183        length: RowCount,
1184        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1185    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1186        if self.child_gens.is_empty() {
1187            // Have to create empty struct arrays specially to ensure they have the correct
1188            // row count
1189            let struct_arr = StructArray::new_empty_fields(length.0 as usize, None);
1190            return Ok(Arc::new(struct_arr));
1191        }
1192        let child_arrays = self
1193            .child_gens
1194            .iter_mut()
1195            .map(|gen| gen.generate(length, rng))
1196            .collect::<Result<Vec<_>, ArrowError>>()?;
1197        let struct_arr = StructArray::new(self.fields.clone(), child_arrays, None);
1198        Ok(Arc::new(struct_arr))
1199    }
1200
1201    fn data_type(&self) -> &DataType {
1202        &self.data_type
1203    }
1204
1205    fn element_size_bytes(&self) -> Option<ByteCount> {
1206        let mut sum = 0;
1207        for child_gen in &self.child_gens {
1208            sum += child_gen.element_size_bytes()?.0;
1209        }
1210        Some(ByteCount::from(sum))
1211    }
1212}
1213
1214/// A RecordBatchReader that generates batches of the given size from the given array generators
1215pub struct FixedSizeBatchGenerator {
1216    rng: rand_xoshiro::Xoshiro256PlusPlus,
1217    generators: Vec<Box<dyn ArrayGenerator>>,
1218    batch_size: RowCount,
1219    num_batches: BatchCount,
1220    schema: SchemaRef,
1221}
1222
1223impl FixedSizeBatchGenerator {
1224    fn new(
1225        generators: Vec<(Option<String>, Box<dyn ArrayGenerator>)>,
1226        batch_size: RowCount,
1227        num_batches: BatchCount,
1228        seed: Option<Seed>,
1229        default_null_probability: Option<f64>,
1230    ) -> Self {
1231        let mut fields = Vec::with_capacity(generators.len());
1232        for (field_index, field_gen) in generators.iter().enumerate() {
1233            let (name, gen) = field_gen;
1234            let default_name = format!("field_{}", field_index);
1235            let name = name.clone().unwrap_or(default_name);
1236            let mut field = Field::new(name, gen.data_type().clone(), true);
1237            if let Some(metadata) = gen.metadata() {
1238                field = field.with_metadata(metadata);
1239            }
1240            fields.push(field);
1241        }
1242        let mut generators = generators
1243            .into_iter()
1244            .map(|(_, gen)| gen)
1245            .collect::<Vec<_>>();
1246        if let Some(null_probability) = default_null_probability {
1247            generators = generators
1248                .into_iter()
1249                .map(|gen| gen.with_random_nulls(null_probability))
1250                .collect();
1251        }
1252        let schema = Arc::new(Schema::new(fields));
1253        Self {
1254            rng: rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(
1255                seed.map(|s| s.0).unwrap_or(DEFAULT_SEED.0),
1256            ),
1257            generators,
1258            batch_size,
1259            num_batches,
1260            schema,
1261        }
1262    }
1263
1264    fn gen_next(&mut self) -> Result<RecordBatch, ArrowError> {
1265        let mut arrays = Vec::with_capacity(self.generators.len());
1266        for gen in self.generators.iter_mut() {
1267            let arr = gen.generate(self.batch_size, &mut self.rng)?;
1268            arrays.push(arr);
1269        }
1270        self.num_batches.0 -= 1;
1271        Ok(RecordBatch::try_new_with_options(
1272            self.schema.clone(),
1273            arrays,
1274            &RecordBatchOptions::new().with_row_count(Some(self.batch_size.0 as usize)),
1275        )
1276        .unwrap())
1277    }
1278}
1279
1280impl Iterator for FixedSizeBatchGenerator {
1281    type Item = Result<RecordBatch, ArrowError>;
1282
1283    fn next(&mut self) -> Option<Self::Item> {
1284        if self.num_batches.0 == 0 {
1285            return None;
1286        }
1287        Some(self.gen_next())
1288    }
1289}
1290
1291impl RecordBatchReader for FixedSizeBatchGenerator {
1292    fn schema(&self) -> SchemaRef {
1293        self.schema.clone()
1294    }
1295}
1296
1297/// A builder to create a record batch reader with generated data
1298///
1299/// This type is meant to be used in a fluent builder style to define the schema and generators
1300/// for a record batch reader.
1301#[derive(Default)]
1302pub struct BatchGeneratorBuilder {
1303    generators: Vec<(Option<String>, Box<dyn ArrayGenerator>)>,
1304    default_null_probability: Option<f64>,
1305    seed: Option<Seed>,
1306}
1307
1308pub enum RoundingBehavior {
1309    ExactOrErr,
1310    RoundUp,
1311    RoundDown,
1312}
1313
1314impl BatchGeneratorBuilder {
1315    /// Create a new BatchGeneratorBuilder with a default random seed
1316    pub fn new() -> Self {
1317        Default::default()
1318    }
1319
1320    /// Create a new BatchGeneratorBuilder with the given seed
1321    pub fn new_with_seed(seed: Seed) -> Self {
1322        Self {
1323            seed: Some(seed),
1324            ..Default::default()
1325        }
1326    }
1327
1328    /// Adds a new column to the generator
1329    ///
1330    /// See [`crate::generator::array`] for methods to create generators
1331    pub fn col(mut self, name: impl Into<String>, gen: Box<dyn ArrayGenerator>) -> Self {
1332        self.generators.push((Some(name.into()), gen));
1333        self
1334    }
1335
1336    /// Adds a new column to the generator with a generated unique name
1337    ///
1338    /// See [`crate::generator::array`] for methods to create generators
1339    pub fn anon_col(mut self, gen: Box<dyn ArrayGenerator>) -> Self {
1340        self.generators.push((None, gen));
1341        self
1342    }
1343
1344    pub fn into_batch_rows(self, batch_size: RowCount) -> Result<RecordBatch, ArrowError> {
1345        let mut reader = self.into_reader_rows(batch_size, BatchCount::from(1));
1346        reader
1347            .next()
1348            .expect("Asked for 1 batch but reader was empty")
1349    }
1350
1351    pub fn into_batch_bytes(
1352        self,
1353        batch_size: ByteCount,
1354        rounding: RoundingBehavior,
1355    ) -> Result<RecordBatch, ArrowError> {
1356        let mut reader = self.into_reader_bytes(batch_size, BatchCount::from(1), rounding)?;
1357        reader
1358            .next()
1359            .expect("Asked for 1 batch but reader was empty")
1360    }
1361
1362    /// Create a RecordBatchReader that generates batches of the given size (in rows)
1363    pub fn into_reader_rows(
1364        self,
1365        batch_size: RowCount,
1366        num_batches: BatchCount,
1367    ) -> impl RecordBatchReader {
1368        FixedSizeBatchGenerator::new(
1369            self.generators,
1370            batch_size,
1371            num_batches,
1372            self.seed,
1373            self.default_null_probability,
1374        )
1375    }
1376
1377    pub fn into_reader_stream(
1378        self,
1379        batch_size: RowCount,
1380        num_batches: BatchCount,
1381    ) -> BoxStream<'static, Result<RecordBatch, ArrowError>> {
1382        // TODO: this is pretty lazy and could be optimized
1383        let batches = self
1384            .into_reader_rows(batch_size, num_batches)
1385            .collect::<Vec<_>>();
1386        futures::stream::iter(batches).boxed()
1387    }
1388
1389    /// Create a RecordBatchReader that generates batches of the given size (in bytes)
1390    pub fn into_reader_bytes(
1391        self,
1392        batch_size_bytes: ByteCount,
1393        num_batches: BatchCount,
1394        rounding: RoundingBehavior,
1395    ) -> Result<impl RecordBatchReader, ArrowError> {
1396        let bytes_per_row = self
1397            .generators
1398            .iter()
1399            .map(|gen| gen.1.element_size_bytes().map(|byte_count| byte_count.0).ok_or(
1400                        ArrowError::NotYetImplemented("The function into_reader_bytes currently requires each array generator to have a fixed element size".to_string())
1401                )
1402            )
1403            .sum::<Result<u64, ArrowError>>()?;
1404        let mut num_rows = RowCount::from(batch_size_bytes.0 / bytes_per_row);
1405        if batch_size_bytes.0 % bytes_per_row != 0 {
1406            match rounding {
1407                RoundingBehavior::ExactOrErr => {
1408                    return Err(ArrowError::NotYetImplemented(
1409                        format!("Exact rounding requested but not possible.  Batch size requested {}, row size: {}", batch_size_bytes.0, bytes_per_row))
1410                    );
1411                }
1412                RoundingBehavior::RoundUp => {
1413                    num_rows = RowCount::from(num_rows.0 + 1);
1414                }
1415                RoundingBehavior::RoundDown => (),
1416            }
1417        }
1418        Ok(self.into_reader_rows(num_rows, num_batches))
1419    }
1420
1421    /// Set the seed for the generator
1422    pub fn with_seed(mut self, seed: Seed) -> Self {
1423        self.seed = Some(seed);
1424        self
1425    }
1426
1427    /// Adds nulls (with the given probability) to all columns
1428    pub fn with_random_nulls(&mut self, default_null_probability: f64) {
1429        self.default_null_probability = Some(default_null_probability);
1430    }
1431}
1432
1433/// Factory for creating a single random array
1434pub struct ArrayGeneratorBuilder {
1435    generator: Box<dyn ArrayGenerator>,
1436    seed: Option<Seed>,
1437}
1438
1439impl ArrayGeneratorBuilder {
1440    fn new(generator: Box<dyn ArrayGenerator>) -> Self {
1441        Self {
1442            generator,
1443            seed: None,
1444        }
1445    }
1446
1447    /// Use the given seed for the generator
1448    pub fn with_seed(mut self, seed: Seed) -> Self {
1449        self.seed = Some(seed);
1450        self
1451    }
1452
1453    /// Generate a single array with the given length
1454    pub fn into_array_rows(
1455        mut self,
1456        length: RowCount,
1457    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1458        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(
1459            self.seed.map(|s| s.0).unwrap_or(DEFAULT_SEED.0),
1460        );
1461        self.generator.generate(length, &mut rng)
1462    }
1463}
1464
1465const MS_PER_DAY: i64 = 86400000;
1466
1467pub mod array {
1468
1469    use arrow::datatypes::{Int16Type, Int64Type, Int8Type};
1470    use arrow_array::types::{
1471        Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType,
1472        DurationNanosecondType, DurationSecondType, Float16Type, Float32Type, Float64Type,
1473        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
1474    };
1475    use arrow_array::{
1476        ArrowNativeTypeOp, Date32Array, Date64Array, Time32MillisecondArray, Time32SecondArray,
1477        Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
1478        TimestampNanosecondArray, TimestampSecondArray,
1479    };
1480    use arrow_schema::{IntervalUnit, TimeUnit};
1481    use chrono::Utc;
1482    use rand::prelude::Distribution;
1483
1484    use super::*;
1485
1486    /// Create a generator of vectors by continuously calling the given generator
1487    ///
1488    /// For example, given a step generator and a dimension of 3 this will generate vectors like
1489    /// [0, 1, 2], [3, 4, 5], [6, 7, 8], ...
1490    pub fn cycle_vec(
1491        generator: Box<dyn ArrayGenerator>,
1492        dimension: Dimension,
1493    ) -> Box<dyn ArrayGenerator> {
1494        Box::new(CycleVectorGenerator::new(generator, dimension))
1495    }
1496
1497    /// Create a generator from a vector of values
1498    ///
1499    /// If more rows are requested than the length of values then it will restart
1500    /// from the beginning of the vector.
1501    pub fn cycle<DataType>(values: Vec<DataType::Native>) -> Box<dyn ArrayGenerator>
1502    where
1503        DataType::Native: Copy + 'static,
1504        DataType: ArrowPrimitiveType,
1505        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1506    {
1507        let mut values_idx = 0;
1508        Box::new(
1509            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1510                DataType::DATA_TYPE,
1511                move |_| {
1512                    let y = values[values_idx];
1513                    values_idx = (values_idx + 1) % values.len();
1514                    y
1515                },
1516                1,
1517                DataType::DATA_TYPE
1518                    .primitive_width()
1519                    .map(|width| ByteCount::from(width as u64))
1520                    .expect("Primitive types should have a fixed width"),
1521            ),
1522        )
1523    }
1524
1525    /// Create a generator that starts at 0 and increments by 1 for each element
1526    pub fn step<DataType>() -> Box<dyn ArrayGenerator>
1527    where
1528        DataType::Native: Copy + Default + std::ops::AddAssign<DataType::Native> + 'static,
1529        DataType: ArrowPrimitiveType,
1530        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1531    {
1532        let mut x = DataType::Native::default();
1533        Box::new(
1534            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1535                DataType::DATA_TYPE,
1536                move |_| {
1537                    let y = x;
1538                    x += DataType::Native::ONE;
1539                    y
1540                },
1541                1,
1542                DataType::DATA_TYPE
1543                    .primitive_width()
1544                    .map(|width| ByteCount::from(width as u64))
1545                    .expect("Primitive types should have a fixed width"),
1546            ),
1547        )
1548    }
1549
1550    pub fn blob() -> Box<dyn ArrayGenerator> {
1551        let mut blob_meta = HashMap::new();
1552        blob_meta.insert("lance-encoding:blob".to_string(), "true".to_string());
1553        rand_fixedbin(ByteCount::from(4 * 1024 * 1024), true).with_metadata(blob_meta)
1554    }
1555
1556    /// Create a generator that starts at a given value and increments by a given step for each element
1557    pub fn step_custom<DataType>(
1558        start: DataType::Native,
1559        step: DataType::Native,
1560    ) -> Box<dyn ArrayGenerator>
1561    where
1562        DataType::Native: Copy + Default + std::ops::AddAssign<DataType::Native> + 'static,
1563        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1564        DataType: ArrowPrimitiveType,
1565    {
1566        let mut x = start;
1567        Box::new(
1568            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1569                DataType::DATA_TYPE,
1570                move |_| {
1571                    let y = x;
1572                    x += step;
1573                    y
1574                },
1575                1,
1576                DataType::DATA_TYPE
1577                    .primitive_width()
1578                    .map(|width| ByteCount::from(width as u64))
1579                    .expect("Primitive types should have a fixed width"),
1580            ),
1581        )
1582    }
1583
1584    /// Create a generator that fills each element with the given primitive value
1585    pub fn fill<DataType>(value: DataType::Native) -> Box<dyn ArrayGenerator>
1586    where
1587        DataType::Native: Copy + 'static,
1588        DataType: ArrowPrimitiveType,
1589        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1590    {
1591        Box::new(
1592            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1593                DataType::DATA_TYPE,
1594                move |_| value,
1595                1,
1596                DataType::DATA_TYPE
1597                    .primitive_width()
1598                    .map(|width| ByteCount::from(width as u64))
1599                    .expect("Primitive types should have a fixed width"),
1600            ),
1601        )
1602    }
1603
1604    /// Create a generator that fills each element with the given binary value
1605    pub fn fill_varbin(value: Vec<u8>) -> Box<dyn ArrayGenerator> {
1606        Box::new(FixedBinaryGenerator::<BinaryType>::new(value))
1607    }
1608
1609    /// Create a generator that fills each element with the given string value
1610    pub fn fill_utf8(value: String) -> Box<dyn ArrayGenerator> {
1611        Box::new(FixedBinaryGenerator::<Utf8Type>::new(value.into_bytes()))
1612    }
1613
1614    pub fn cycle_utf8_literals(values: &[&'static str]) -> Box<dyn ArrayGenerator> {
1615        Box::new(CycleBinaryGenerator::<Utf8Type>::from_strings(values))
1616    }
1617
1618    /// Create a generator of primitive values that are randomly sampled from the entire range available for the value
1619    pub fn rand<DataType>() -> Box<dyn ArrayGenerator>
1620    where
1621        DataType::Native: Copy + 'static,
1622        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1623        DataType: ArrowPrimitiveType,
1624        rand::distributions::Standard: rand::distributions::Distribution<DataType::Native>,
1625    {
1626        Box::new(
1627            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1628                DataType::DATA_TYPE,
1629                move |rng| rng.gen(),
1630                1,
1631                DataType::DATA_TYPE
1632                    .primitive_width()
1633                    .map(|width| ByteCount::from(width as u64))
1634                    .expect("Primitive types should have a fixed width"),
1635            ),
1636        )
1637    }
1638
1639    /// Create a generator of primitive values that are randomly sampled from the entire range available for the value
1640    pub fn rand_with_distribution<
1641        DataType,
1642        Dist: rand::distributions::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
1643    >(
1644        dist: Dist,
1645    ) -> Box<dyn ArrayGenerator>
1646    where
1647        DataType::Native: Copy + 'static,
1648        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1649        DataType: ArrowPrimitiveType,
1650    {
1651        Box::new(
1652            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1653                DataType::DATA_TYPE,
1654                move |rng| rng.sample(dist.clone()),
1655                1,
1656                DataType::DATA_TYPE
1657                    .primitive_width()
1658                    .map(|width| ByteCount::from(width as u64))
1659                    .expect("Primitive types should have a fixed width"),
1660            ),
1661        )
1662    }
1663
1664    /// Create a generator of 1d vectors (of a primitive type) consisting of randomly sampled primitive values
1665    pub fn rand_vec<DataType>(dimension: Dimension) -> Box<dyn ArrayGenerator>
1666    where
1667        DataType::Native: Copy + 'static,
1668        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1669        DataType: ArrowPrimitiveType,
1670        rand::distributions::Standard: rand::distributions::Distribution<DataType::Native>,
1671    {
1672        let underlying = rand::<DataType>();
1673        cycle_vec(underlying, dimension)
1674    }
1675
1676    /// Create a generator of randomly sampled time32 values covering the entire
1677    /// range of 1 day
1678    pub fn rand_time32(resolution: &TimeUnit) -> Box<dyn ArrayGenerator> {
1679        let start = 0;
1680        let end = match resolution {
1681            TimeUnit::Second => 86_400,
1682            TimeUnit::Millisecond => 86_400_000,
1683            _ => panic!(),
1684        };
1685
1686        let data_type = DataType::Time32(*resolution);
1687        let size = ByteCount::from(data_type.primitive_width().unwrap() as u64);
1688        let dist = Uniform::new(start, end);
1689        let sample_fn = move |rng: &mut _| dist.sample(rng);
1690
1691        match resolution {
1692            TimeUnit::Second => Box::new(FnGen::<i32, Time32SecondArray, _>::new_known_size(
1693                data_type, sample_fn, 1, size,
1694            )),
1695            TimeUnit::Millisecond => {
1696                Box::new(FnGen::<i32, Time32MillisecondArray, _>::new_known_size(
1697                    data_type, sample_fn, 1, size,
1698                ))
1699            }
1700            _ => panic!(),
1701        }
1702    }
1703
1704    /// Create a generator of randomly sampled time64 values covering the entire
1705    /// range of 1 day
1706    pub fn rand_time64(resolution: &TimeUnit) -> Box<dyn ArrayGenerator> {
1707        let start = 0_i64;
1708        let end: i64 = match resolution {
1709            TimeUnit::Microsecond => 86_400_000,
1710            TimeUnit::Nanosecond => 86_400_000_000,
1711            _ => panic!(),
1712        };
1713
1714        let data_type = DataType::Time64(*resolution);
1715        let size = ByteCount::from(data_type.primitive_width().unwrap() as u64);
1716        let dist = Uniform::new(start, end);
1717        let sample_fn = move |rng: &mut _| dist.sample(rng);
1718
1719        match resolution {
1720            TimeUnit::Microsecond => {
1721                Box::new(FnGen::<i64, Time64MicrosecondArray, _>::new_known_size(
1722                    data_type, sample_fn, 1, size,
1723                ))
1724            }
1725            TimeUnit::Nanosecond => {
1726                Box::new(FnGen::<i64, Time64NanosecondArray, _>::new_known_size(
1727                    data_type, sample_fn, 1, size,
1728                ))
1729            }
1730            _ => panic!(),
1731        }
1732    }
1733
1734    /// Create a generator of random UUIDs, stored as fixed size binary values
1735    ///
1736    /// Note, these are "pseudo UUIDs".  They are 16-byte randomish values but they
1737    /// are not guaranteed to be unique.  We use a simplistic RNG that trades uniqueness
1738    /// for speed.
1739    pub fn rand_pseudo_uuid() -> Box<dyn ArrayGenerator> {
1740        Box::<PseudoUuidGenerator>::default()
1741    }
1742
1743    /// Create a generator of random UUIDs, stored as 32-character strings (hex encoding
1744    /// of the 16-byte binary value)
1745    ///
1746    /// Note, these are "pseudo UUIDs".  They are 16-byte randomish values but they
1747    /// are not guaranteed to be unique.  We use a simplistic RNG that trades uniqueness
1748    /// for speed.
1749    pub fn rand_pseudo_uuid_hex() -> Box<dyn ArrayGenerator> {
1750        Box::<PseudoUuidHexGenerator>::default()
1751    }
1752
1753    pub fn rand_primitive<T: ArrowPrimitiveType + Send + Sync>(
1754        data_type: DataType,
1755    ) -> Box<dyn ArrayGenerator> {
1756        Box::new(RandomBytesGenerator::<T>::new(data_type))
1757    }
1758
1759    pub fn rand_fsb(size: i32) -> Box<dyn ArrayGenerator> {
1760        Box::new(RandomFixedSizeBinaryGenerator::new(size))
1761    }
1762
1763    pub fn rand_interval(unit: IntervalUnit) -> Box<dyn ArrayGenerator> {
1764        Box::new(RandomIntervalGenerator::new(unit))
1765    }
1766
1767    /// Create a generator of randomly sampled date32 values
1768    ///
1769    /// Instead of sampling the entire range, all values will be drawn from the last year as this
1770    /// is a more common use pattern
1771    pub fn rand_date32() -> Box<dyn ArrayGenerator> {
1772        let now = chrono::Utc::now();
1773        let one_year_ago = now - chrono::TimeDelta::try_days(365).expect("TimeDelta try days");
1774        rand_date32_in_range(one_year_ago, now)
1775    }
1776
1777    /// Create a generator of randomly sampled date32 values in the given range
1778    pub fn rand_date32_in_range(
1779        start: chrono::DateTime<Utc>,
1780        end: chrono::DateTime<Utc>,
1781    ) -> Box<dyn ArrayGenerator> {
1782        let data_type = DataType::Date32;
1783        let end_ms = end.timestamp_millis();
1784        let end_days = (end_ms / MS_PER_DAY) as i32;
1785        let start_ms = start.timestamp_millis();
1786        let start_days = (start_ms / MS_PER_DAY) as i32;
1787        let dist = Uniform::new(start_days, end_days);
1788
1789        Box::new(FnGen::<i32, Date32Array, _>::new_known_size(
1790            data_type,
1791            move |rng| dist.sample(rng),
1792            1,
1793            DataType::Date32
1794                .primitive_width()
1795                .map(|width| ByteCount::from(width as u64))
1796                .expect("Date32 should have a fixed width"),
1797        ))
1798    }
1799
1800    /// Create a generator of randomly sampled date64 values
1801    ///
1802    /// Instead of sampling the entire range, all values will be drawn from the last year as this
1803    /// is a more common use pattern
1804    pub fn rand_date64() -> Box<dyn ArrayGenerator> {
1805        let now = chrono::Utc::now();
1806        let one_year_ago = now - chrono::TimeDelta::try_days(365).expect("TimeDelta try_days");
1807        rand_date64_in_range(one_year_ago, now)
1808    }
1809
1810    /// Create a generator of randomly sampled timestamp values in the given range
1811    ///
1812    /// Currently just samples the entire range of u64 values and casts to timestamp
1813    pub fn rand_timestamp_in_range(
1814        start: chrono::DateTime<Utc>,
1815        end: chrono::DateTime<Utc>,
1816        data_type: &DataType,
1817    ) -> Box<dyn ArrayGenerator> {
1818        let end_ms = end.timestamp_millis();
1819        let start_ms = start.timestamp_millis();
1820        let (start_ticks, end_ticks) = match data_type {
1821            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1822                (start_ms * 1000 * 1000, end_ms * 1000 * 1000)
1823            }
1824            DataType::Timestamp(TimeUnit::Microsecond, _) => (start_ms * 1000, end_ms * 1000),
1825            DataType::Timestamp(TimeUnit::Millisecond, _) => (start_ms, end_ms),
1826            DataType::Timestamp(TimeUnit::Second, _) => (start.timestamp(), end.timestamp()),
1827            _ => panic!(),
1828        };
1829        let dist = Uniform::new(start_ticks, end_ticks);
1830
1831        let data_type = data_type.clone();
1832        let sample_fn = move |rng: &mut _| (dist.sample(rng));
1833        let width = data_type
1834            .primitive_width()
1835            .map(|width| ByteCount::from(width as u64))
1836            .unwrap();
1837
1838        match data_type {
1839            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1840                Box::new(FnGen::<i64, TimestampNanosecondArray, _>::new_known_size(
1841                    data_type, sample_fn, 1, width,
1842                ))
1843            }
1844            DataType::Timestamp(TimeUnit::Microsecond, _) => {
1845                Box::new(FnGen::<i64, TimestampMicrosecondArray, _>::new_known_size(
1846                    data_type, sample_fn, 1, width,
1847                ))
1848            }
1849            DataType::Timestamp(TimeUnit::Millisecond, _) => {
1850                Box::new(FnGen::<i64, TimestampMicrosecondArray, _>::new_known_size(
1851                    data_type, sample_fn, 1, width,
1852                ))
1853            }
1854            DataType::Timestamp(TimeUnit::Second, _) => {
1855                Box::new(FnGen::<i64, TimestampSecondArray, _>::new_known_size(
1856                    data_type, sample_fn, 1, width,
1857                ))
1858            }
1859            _ => panic!(),
1860        }
1861    }
1862
1863    pub fn rand_timestamp(data_type: &DataType) -> Box<dyn ArrayGenerator> {
1864        let now = chrono::Utc::now();
1865        let one_year_ago = now - chrono::Duration::try_days(365).unwrap();
1866        rand_timestamp_in_range(one_year_ago, now, data_type)
1867    }
1868
1869    /// Create a generator of randomly sampled date64 values
1870    ///
1871    /// Instead of sampling the entire range, all values will be drawn from the last year as this
1872    /// is a more common use pattern
1873    pub fn rand_date64_in_range(
1874        start: chrono::DateTime<Utc>,
1875        end: chrono::DateTime<Utc>,
1876    ) -> Box<dyn ArrayGenerator> {
1877        let data_type = DataType::Date64;
1878        let end_ms = end.timestamp_millis();
1879        let end_days = end_ms / MS_PER_DAY;
1880        let start_ms = start.timestamp_millis();
1881        let start_days = start_ms / MS_PER_DAY;
1882        let dist = Uniform::new(start_days, end_days);
1883
1884        Box::new(FnGen::<i64, Date64Array, _>::new_known_size(
1885            data_type,
1886            move |rng| (dist.sample(rng)) * MS_PER_DAY,
1887            1,
1888            DataType::Date64
1889                .primitive_width()
1890                .map(|width| ByteCount::from(width as u64))
1891                .expect("Date64 should have a fixed width"),
1892        ))
1893    }
1894
1895    /// Create a generator of random binary values where each value has a fixed number of bytes
1896    pub fn rand_fixedbin(bytes_per_element: ByteCount, is_large: bool) -> Box<dyn ArrayGenerator> {
1897        Box::new(RandomBinaryGenerator::new(
1898            bytes_per_element,
1899            false,
1900            is_large,
1901        ))
1902    }
1903
1904    /// Create a generator of random binary values where each value has a variable number of bytes
1905    ///
1906    /// The number of bytes per element will be randomly sampled from the given (inclusive) range
1907    pub fn rand_varbin(
1908        min_bytes_per_element: ByteCount,
1909        max_bytes_per_element: ByteCount,
1910    ) -> Box<dyn ArrayGenerator> {
1911        Box::new(VariableRandomBinaryGenerator::new(
1912            min_bytes_per_element,
1913            max_bytes_per_element,
1914        ))
1915    }
1916
1917    /// Create a generator of random strings
1918    ///
1919    /// All strings will consist entirely of printable ASCII characters
1920    pub fn rand_utf8(bytes_per_element: ByteCount, is_large: bool) -> Box<dyn ArrayGenerator> {
1921        Box::new(RandomBinaryGenerator::new(
1922            bytes_per_element,
1923            true,
1924            is_large,
1925        ))
1926    }
1927
1928    /// Create a random generator of boolean values
1929    pub fn rand_boolean() -> Box<dyn ArrayGenerator> {
1930        Box::<RandomBooleanGenerator>::default()
1931    }
1932
1933    pub fn rand_list(item_type: &DataType, is_large: bool) -> Box<dyn ArrayGenerator> {
1934        let child_gen = rand_type(item_type);
1935        Box::new(RandomListGenerator::new(child_gen, is_large))
1936    }
1937
1938    pub fn rand_list_any(
1939        item_gen: Box<dyn ArrayGenerator>,
1940        is_large: bool,
1941    ) -> Box<dyn ArrayGenerator> {
1942        Box::new(RandomListGenerator::new(item_gen, is_large))
1943    }
1944
1945    pub fn rand_struct(fields: Fields) -> Box<dyn ArrayGenerator> {
1946        let child_gens = fields
1947            .iter()
1948            .map(|f| rand_type(f.data_type()))
1949            .collect::<Vec<_>>();
1950        Box::new(RandomStructGenerator::new(fields, child_gens))
1951    }
1952
1953    pub fn null_type() -> Box<dyn ArrayGenerator> {
1954        Box::new(NullArrayGenerator {})
1955    }
1956
1957    /// Create a generator of random values
1958    pub fn rand_type(data_type: &DataType) -> Box<dyn ArrayGenerator> {
1959        match data_type {
1960            DataType::Boolean => rand_boolean(),
1961            DataType::Int8 => rand::<Int8Type>(),
1962            DataType::Int16 => rand::<Int16Type>(),
1963            DataType::Int32 => rand::<Int32Type>(),
1964            DataType::Int64 => rand::<Int64Type>(),
1965            DataType::UInt8 => rand::<UInt8Type>(),
1966            DataType::UInt16 => rand::<UInt16Type>(),
1967            DataType::UInt32 => rand::<UInt32Type>(),
1968            DataType::UInt64 => rand::<UInt64Type>(),
1969            DataType::Float16 => rand_primitive::<Float16Type>(data_type.clone()),
1970            DataType::Float32 => rand::<Float32Type>(),
1971            DataType::Float64 => rand::<Float64Type>(),
1972            DataType::Decimal128(_, _) => rand_primitive::<Decimal128Type>(data_type.clone()),
1973            DataType::Decimal256(_, _) => rand_primitive::<Decimal256Type>(data_type.clone()),
1974            DataType::Utf8 => rand_utf8(ByteCount::from(12), false),
1975            DataType::LargeUtf8 => rand_utf8(ByteCount::from(12), true),
1976            DataType::Binary => rand_fixedbin(ByteCount::from(12), false),
1977            DataType::LargeBinary => rand_fixedbin(ByteCount::from(12), true),
1978            DataType::Dictionary(key_type, value_type) => {
1979                dict_type(rand_type(value_type), key_type)
1980            }
1981            DataType::FixedSizeList(child, dimension) => cycle_vec(
1982                rand_type(child.data_type()),
1983                Dimension::from(*dimension as u32),
1984            ),
1985            DataType::FixedSizeBinary(size) => rand_fsb(*size),
1986            DataType::List(child) => rand_list(child.data_type(), false),
1987            DataType::LargeList(child) => rand_list(child.data_type(), true),
1988            DataType::Duration(unit) => match unit {
1989                TimeUnit::Second => rand::<DurationSecondType>(),
1990                TimeUnit::Millisecond => rand::<DurationMillisecondType>(),
1991                TimeUnit::Microsecond => rand::<DurationMicrosecondType>(),
1992                TimeUnit::Nanosecond => rand::<DurationNanosecondType>(),
1993            },
1994            DataType::Interval(unit) => rand_interval(*unit),
1995            DataType::Date32 => rand_date32(),
1996            DataType::Date64 => rand_date64(),
1997            DataType::Time32(resolution) => rand_time32(resolution),
1998            DataType::Time64(resolution) => rand_time64(resolution),
1999            DataType::Timestamp(_, _) => rand_timestamp(data_type),
2000            DataType::Struct(fields) => rand_struct(fields.clone()),
2001            DataType::Null => null_type(),
2002            _ => unimplemented!("random generation of {}", data_type),
2003        }
2004    }
2005
2006    /// Encodes arrays generated by the underlying generator as dictionaries with the given key type
2007    ///
2008    /// Note that this may not be very realistic if the underlying generator is something like a random
2009    /// generator since most of the underlying values will be unique and the common case for dictionary
2010    /// encoding is when there is a small set of possible values.
2011    pub fn dict<K: ArrowDictionaryKeyType + Send + Sync>(
2012        generator: Box<dyn ArrayGenerator>,
2013    ) -> Box<dyn ArrayGenerator> {
2014        Box::new(DictionaryGenerator::<K>::new(generator))
2015    }
2016
2017    /// Encodes arrays generated by the underlying generator as dictionaries with the given key type
2018    pub fn dict_type(
2019        generator: Box<dyn ArrayGenerator>,
2020        key_type: &DataType,
2021    ) -> Box<dyn ArrayGenerator> {
2022        match key_type {
2023            DataType::Int8 => dict::<Int8Type>(generator),
2024            DataType::Int16 => dict::<Int16Type>(generator),
2025            DataType::Int32 => dict::<Int32Type>(generator),
2026            DataType::Int64 => dict::<Int64Type>(generator),
2027            DataType::UInt8 => dict::<UInt8Type>(generator),
2028            DataType::UInt16 => dict::<UInt16Type>(generator),
2029            DataType::UInt32 => dict::<UInt32Type>(generator),
2030            DataType::UInt64 => dict::<UInt64Type>(generator),
2031            _ => unimplemented!(),
2032        }
2033    }
2034}
2035
2036/// Create a BatchGeneratorBuilder to start generating batch data
2037pub fn gen() -> BatchGeneratorBuilder {
2038    BatchGeneratorBuilder::default()
2039}
2040
2041/// Create an ArrayGeneratorBuilder to start generating array data
2042pub fn gen_array(gen: Box<dyn ArrayGenerator>) -> ArrayGeneratorBuilder {
2043    ArrayGeneratorBuilder::new(gen)
2044}
2045
2046/// Create a BatchGeneratorBuilder with the given schema
2047///
2048/// You can add more columns or convert this into a reader immediately
2049pub fn rand(schema: &Schema) -> BatchGeneratorBuilder {
2050    let mut builder = BatchGeneratorBuilder::default();
2051    for field in schema.fields() {
2052        builder = builder.col(field.name(), array::rand_type(field.data_type()));
2053    }
2054    builder
2055}
2056
2057#[cfg(test)]
2058mod tests {
2059
2060    use arrow::datatypes::{Float32Type, Int16Type, Int8Type, UInt32Type};
2061    use arrow_array::{BooleanArray, Float32Array, Int16Array, Int32Array, Int8Array, UInt32Array};
2062
2063    use super::*;
2064
2065    #[test]
2066    fn test_step() {
2067        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2068        let mut gen = array::step::<Int32Type>();
2069        assert_eq!(
2070            *gen.generate(RowCount::from(5), &mut rng).unwrap(),
2071            Int32Array::from_iter([0, 1, 2, 3, 4])
2072        );
2073        assert_eq!(
2074            *gen.generate(RowCount::from(5), &mut rng).unwrap(),
2075            Int32Array::from_iter([5, 6, 7, 8, 9])
2076        );
2077
2078        let mut gen = array::step::<Int8Type>();
2079        assert_eq!(
2080            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2081            Int8Array::from_iter([0, 1, 2])
2082        );
2083
2084        let mut gen = array::step::<Float32Type>();
2085        assert_eq!(
2086            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2087            Float32Array::from_iter([0.0, 1.0, 2.0])
2088        );
2089
2090        let mut gen = array::step_custom::<Int16Type>(4, 8);
2091        assert_eq!(
2092            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2093            Int16Array::from_iter([4, 12, 20])
2094        );
2095        assert_eq!(
2096            *gen.generate(RowCount::from(2), &mut rng).unwrap(),
2097            Int16Array::from_iter([28, 36])
2098        );
2099    }
2100
2101    #[test]
2102    fn test_cycle() {
2103        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2104        let mut gen = array::cycle::<Int32Type>(vec![1, 2, 3]);
2105        assert_eq!(
2106            *gen.generate(RowCount::from(5), &mut rng).unwrap(),
2107            Int32Array::from_iter([1, 2, 3, 1, 2])
2108        );
2109
2110        let mut gen = array::cycle_utf8_literals(&["abc", "def", "xyz"]);
2111        assert_eq!(
2112            *gen.generate(RowCount::from(5), &mut rng).unwrap(),
2113            StringArray::from_iter_values(["abc", "def", "xyz", "abc", "def"])
2114        );
2115        assert_eq!(
2116            *gen.generate(RowCount::from(1), &mut rng).unwrap(),
2117            StringArray::from_iter_values(["xyz"])
2118        );
2119    }
2120
2121    #[test]
2122    fn test_fill() {
2123        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2124        let mut gen = array::fill::<Int32Type>(42);
2125        assert_eq!(
2126            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2127            Int32Array::from_iter([42, 42, 42])
2128        );
2129        assert_eq!(
2130            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2131            Int32Array::from_iter([42, 42, 42])
2132        );
2133
2134        let mut gen = array::fill_varbin(vec![0, 1, 2]);
2135        assert_eq!(
2136            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2137            arrow_array::BinaryArray::from_iter_values([
2138                "\x00\x01\x02",
2139                "\x00\x01\x02",
2140                "\x00\x01\x02"
2141            ])
2142        );
2143
2144        let mut gen = array::fill_utf8("xyz".to_string());
2145        assert_eq!(
2146            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2147            arrow_array::StringArray::from_iter_values(["xyz", "xyz", "xyz"])
2148        );
2149    }
2150
2151    #[test]
2152    fn test_rng() {
2153        // Note: these tests are heavily dependent on the default seed.
2154        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2155        let mut gen = array::rand::<Int32Type>();
2156        assert_eq!(
2157            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2158            Int32Array::from_iter([-797553329, 1369325940, -69174021])
2159        );
2160
2161        let mut gen = array::rand_fixedbin(ByteCount::from(3), false);
2162        assert_eq!(
2163            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2164            arrow_array::BinaryArray::from_iter_values([
2165                [184, 53, 216],
2166                [12, 96, 159],
2167                [125, 179, 56]
2168            ])
2169        );
2170
2171        let mut gen = array::rand_utf8(ByteCount::from(3), false);
2172        assert_eq!(
2173            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2174            arrow_array::StringArray::from_iter_values([">@p", "n `", "NWa"])
2175        );
2176
2177        let mut gen = array::rand_date32();
2178        let days_32 = gen.generate(RowCount::from(3), &mut rng).unwrap();
2179        assert_eq!(days_32.data_type(), &DataType::Date32);
2180
2181        let mut gen = array::rand_date64();
2182        let days_64 = gen.generate(RowCount::from(3), &mut rng).unwrap();
2183        assert_eq!(days_64.data_type(), &DataType::Date64);
2184
2185        let mut gen = array::rand_boolean();
2186        let bools = gen.generate(RowCount::from(1024), &mut rng).unwrap();
2187        assert_eq!(bools.data_type(), &DataType::Boolean);
2188        let bools = bools.as_any().downcast_ref::<BooleanArray>().unwrap();
2189        // Sanity check to ensure we're getting at least some rng
2190        assert!(bools.false_count() > 100);
2191        assert!(bools.true_count() > 100);
2192
2193        let mut gen = array::rand_varbin(ByteCount::from(2), ByteCount::from(4));
2194        assert_eq!(
2195            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2196            arrow_array::BinaryArray::from_iter_values([
2197                vec![56, 122, 157, 34],
2198                vec![58, 51],
2199                vec![41, 184, 125]
2200            ])
2201        );
2202    }
2203
2204    #[test]
2205    fn test_rng_list() {
2206        // Note: these tests are heavily dependent on the default seed.
2207        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2208        let mut gen = array::rand_list(&DataType::Int32, false);
2209        let arr = gen.generate(RowCount::from(100), &mut rng).unwrap();
2210        // Make sure we can generate empty lists (note, test is dependent on seed)
2211        let arr = arr.as_list::<i32>();
2212        assert!(arr.iter().any(|l| l.unwrap().is_empty()));
2213        // Shouldn't generate any giant lists (don't kill performance in normal datagen)
2214        assert!(arr.iter().any(|l| l.unwrap().len() < 11));
2215    }
2216
2217    #[test]
2218    fn test_rng_distribution() {
2219        // Sanity test to make sure we our RNG is giving us well distributed values
2220        // We generates some 4-byte integers, histogram them into 8 buckets, and make
2221        // sure each bucket has a good # of values
2222        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2223        let mut gen = array::rand::<UInt32Type>();
2224        for _ in 0..10 {
2225            let arr = gen.generate(RowCount::from(10000), &mut rng).unwrap();
2226            let int_arr = arr.as_any().downcast_ref::<UInt32Array>().unwrap();
2227            let mut buckets = vec![0_u32; 256];
2228            for val in int_arr.values() {
2229                buckets[(*val >> 24) as usize] += 1;
2230            }
2231            for bucket in buckets {
2232                // Perfectly even distribution would have 10000 / 256 values (~40) per bucket
2233                // We test for 15 which should be "good enough" and statistically unlikely to fail
2234                assert!(bucket > 15);
2235            }
2236        }
2237    }
2238
2239    #[test]
2240    fn test_nulls() {
2241        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2242        let mut gen = array::rand::<Int32Type>().with_random_nulls(0.3);
2243
2244        let arr = gen.generate(RowCount::from(1000), &mut rng).unwrap();
2245
2246        // This assert depends on the default seed
2247        assert_eq!(arr.null_count(), 297);
2248
2249        for len in 0..100 {
2250            let arr = gen.generate(RowCount::from(len), &mut rng).unwrap();
2251            // Make sure the null count we came up with matches the actual # of unset bits
2252            assert_eq!(
2253                arr.null_count(),
2254                arr.nulls()
2255                    .map(|nulls| (len as usize)
2256                        - nulls.buffer().count_set_bits_offset(0, len as usize))
2257                    .unwrap_or(0)
2258            );
2259        }
2260
2261        let mut gen = array::rand::<Int32Type>().with_random_nulls(0.0);
2262        let arr = gen.generate(RowCount::from(10), &mut rng).unwrap();
2263
2264        assert_eq!(arr.null_count(), 0);
2265
2266        let mut gen = array::rand::<Int32Type>().with_random_nulls(1.0);
2267        let arr = gen.generate(RowCount::from(10), &mut rng).unwrap();
2268
2269        assert_eq!(arr.null_count(), 10);
2270        assert!((0..10).all(|idx| arr.is_null(idx)));
2271
2272        let mut gen = array::rand::<Int32Type>().with_nulls(&[false, false, true]);
2273        let arr = gen.generate(RowCount::from(7), &mut rng).unwrap();
2274        assert!((0..2).all(|idx| arr.is_valid(idx)));
2275        assert!(arr.is_null(2));
2276        assert!((3..5).all(|idx| arr.is_valid(idx)));
2277        assert!(arr.is_null(5));
2278        assert!(arr.is_valid(6));
2279    }
2280
2281    #[test]
2282    fn test_rand_schema() {
2283        let schema = Schema::new(vec![
2284            Field::new("a", DataType::Int32, true),
2285            Field::new("b", DataType::Utf8, true),
2286            Field::new("c", DataType::Float32, true),
2287            Field::new("d", DataType::Int32, true),
2288            Field::new("e", DataType::Int32, true),
2289        ]);
2290        let rbr = rand(&schema)
2291            .into_reader_bytes(
2292                ByteCount::from(1024 * 1024),
2293                BatchCount::from(8),
2294                RoundingBehavior::ExactOrErr,
2295            )
2296            .unwrap();
2297        assert_eq!(*rbr.schema(), schema);
2298
2299        let batches = rbr.map(|val| val.unwrap()).collect::<Vec<_>>();
2300        assert_eq!(batches.len(), 8);
2301
2302        for batch in batches {
2303            assert_eq!(batch.num_rows(), 1024 * 1024 / 32);
2304            assert_eq!(batch.num_columns(), 5);
2305        }
2306    }
2307}