1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
//! Chained samplers for generating arbitrary `Box<dyn Array>` arrow arrays.

use std::ops::Range;

use arrow2::{array::Array, datatypes::DataType};
use sample_std::{valid_f32, valid_f64, Chained, Sample};

use crate::{
    datatypes::DataTypeSampler,
    list::ListSampler,
    primitive::{arbitrary_boxed_primitive, boxed_primitive},
    struct_::StructSampler,
};

pub type ArraySampler = Box<dyn Sample<Output = Box<dyn Array>> + Send + Sync>;

pub type ChainedArraySampler =
    Box<dyn Sample<Output = Chained<DataType, Box<dyn Array>>> + Send + Sync>;

#[derive(Clone, Debug)]
pub struct ArbitraryArray<N, V> {
    pub names: N,
    pub branch: Range<usize>,
    pub len: Range<usize>,
    pub null: V,
    pub is_nullable: bool,
}

impl<N, V> ArbitraryArray<N, V>
where
    N: Sample<Output = String> + Send + Sync + Clone + 'static,
    V: Sample<Output = bool> + Send + Sync + Clone + 'static,
{
    pub fn with_len(&self, len: usize) -> Self {
        Self {
            len: len..(len + 1),
            ..self.clone()
        }
    }

    pub fn arbitrary_array(self, data_type_sampler: DataTypeSampler) -> ChainedArraySampler {
        Box::new(data_type_sampler.chain_resample(
            move |data_type| self.sampler_from_data_type(&data_type),
            100,
        ))
    }

    pub fn sampler_from_data_type(&self, data_type: &DataType) -> ArraySampler {
        let current_null = if self.is_nullable {
            Some(self.null.clone())
        } else {
            None
        };
        let len = self.len.clone();

        match data_type {
            DataType::Float32 => boxed_primitive(valid_f32(), len, current_null),
            DataType::Float64 => boxed_primitive(valid_f64(), len, current_null),
            DataType::Int8 => arbitrary_boxed_primitive::<i8, _>(len, current_null),
            DataType::Int16 => arbitrary_boxed_primitive::<i16, _>(len, current_null),
            DataType::Int32 => arbitrary_boxed_primitive::<i32, _>(len, current_null),
            DataType::Int64 => arbitrary_boxed_primitive::<i64, _>(len, current_null),
            DataType::UInt8 => arbitrary_boxed_primitive::<u8, _>(len, current_null),
            DataType::UInt16 => arbitrary_boxed_primitive::<u16, _>(len, current_null),
            DataType::UInt32 => arbitrary_boxed_primitive::<u32, _>(len, current_null),
            DataType::UInt64 => arbitrary_boxed_primitive::<u64, _>(len, current_null),
            DataType::Struct(fields) => Box::new(StructSampler {
                data_type: data_type.clone(),
                null: current_null,
                values: fields
                    .iter()
                    .map(|f| {
                        ArbitraryArray {
                            len: (len.end.saturating_sub(1))..len.end,
                            is_nullable: f.is_nullable,
                            ..self.clone()
                        }
                        .sampler_from_data_type(f.data_type())
                    })
                    .collect(),
            }),
            DataType::List(field) => Box::new(ListSampler {
                data_type: data_type.clone(),
                len: len.clone(),
                null: current_null,
                inner: ArbitraryArray {
                    branch: (self.branch.start * self.len.start)..(self.branch.end * self.len.end),
                    is_nullable: field.is_nullable,
                    ..self.clone()
                }
                .sampler_from_data_type(field.data_type()),
            }),
            dt => panic!("not implemented: {:?}", dt),
        }
    }
}