sample_arrow2/
array.rs

1//! Chained samplers for generating arbitrary `Box<dyn Array>` arrow arrays.
2
3use std::ops::Range;
4
5use arrow2::{
6    array::{Array, FixedSizeListArray, ListArray},
7    bitmap::Bitmap,
8    datatypes::DataType,
9};
10use sample_std::{valid_f32, valid_f64, Always, Chained, Sample};
11
12use crate::{
13    datatypes::DataTypeSampler,
14    fixed_size_list::FixedSizeListWithLen,
15    list::{ListSampler, ListWithLen},
16    primitive::{
17        arbitrary_boxed_primitive, arbitrary_len_sampler, boxed_primitive, primitive_len_sampler,
18    },
19    struct_::StructSampler,
20    AlwaysValid, ArrowLenSampler, SetLen,
21};
22
23pub fn sampler_from_example(array: &dyn Array) -> ArrowLenSampler {
24    match array.data_type() {
25        DataType::Float32 => primitive_len_sampler(valid_f32(), AlwaysValid),
26        DataType::Float64 => primitive_len_sampler(valid_f64(), AlwaysValid),
27        DataType::Int8 => arbitrary_len_sampler::<i8, _>(AlwaysValid),
28        DataType::Int16 => arbitrary_len_sampler::<i16, _>(AlwaysValid),
29        DataType::Int32 => arbitrary_len_sampler::<i32, _>(AlwaysValid),
30        DataType::Int64 => arbitrary_len_sampler::<i64, _>(AlwaysValid),
31        DataType::UInt8 => arbitrary_len_sampler::<u8, _>(AlwaysValid),
32        DataType::UInt16 => arbitrary_len_sampler::<u16, _>(AlwaysValid),
33        DataType::UInt32 => arbitrary_len_sampler::<u32, _>(AlwaysValid),
34        DataType::UInt64 => arbitrary_len_sampler::<u64, _>(AlwaysValid),
35        DataType::List(f) => {
36            let list = array.as_any().downcast_ref::<ListArray<i32>>().unwrap();
37            let min = list.offsets().lengths().min().unwrap_or(0) as i32;
38            let max = list.offsets().lengths().max().unwrap_or(0) as i32 + 1;
39            Box::new(ListWithLen {
40                len: array.len(),
41                validity: AlwaysValid,
42                count: min..max,
43                inner_name: Always(f.name.clone()),
44                inner: sampler_from_example(list.values().as_ref()),
45            })
46        }
47        DataType::FixedSizeList(f, count) => Box::new(FixedSizeListWithLen {
48            len: array.len(),
49            validity: AlwaysValid,
50            count: Always(*count),
51            inner_name: Always(f.name.clone()),
52            inner: sampler_from_example(
53                array
54                    .as_any()
55                    .downcast_ref::<FixedSizeListArray>()
56                    .unwrap()
57                    .values()
58                    .as_ref(),
59            ),
60        }),
61        dt => panic!("not implemented: {:?}", dt),
62    }
63}
64
65pub struct FromDataType<V, B> {
66    pub validity: V,
67
68    pub branch: B,
69}
70
71impl<V, B> FromDataType<V, B>
72where
73    V: Sample<Output = Option<Bitmap>> + SetLen + Clone + Send + Sync + 'static,
74    B: Sample<Output = i32> + Clone + Send + Sync + 'static,
75{
76    pub fn from_data_type(&self, data_type: &DataType) -> ArrowLenSampler {
77        match data_type {
78            DataType::Float32 => primitive_len_sampler(valid_f32(), self.validity.clone()),
79            DataType::Float64 => primitive_len_sampler(valid_f64(), self.validity.clone()),
80            DataType::Int8 => arbitrary_len_sampler::<i8, _>(self.validity.clone()),
81            DataType::Int16 => arbitrary_len_sampler::<i16, _>(self.validity.clone()),
82            DataType::Int32 => arbitrary_len_sampler::<i32, _>(self.validity.clone()),
83            DataType::Int64 => arbitrary_len_sampler::<i64, _>(self.validity.clone()),
84            DataType::UInt8 => arbitrary_len_sampler::<u8, _>(self.validity.clone()),
85            DataType::UInt16 => arbitrary_len_sampler::<u16, _>(self.validity.clone()),
86            DataType::UInt32 => arbitrary_len_sampler::<u32, _>(self.validity.clone()),
87            DataType::UInt64 => arbitrary_len_sampler::<u64, _>(self.validity.clone()),
88            DataType::List(f) => Box::new(ListWithLen {
89                len: 0,
90                validity: self.validity.clone(),
91                count: self.branch.clone(),
92                inner_name: Always(f.name.clone()),
93                inner: self.from_data_type(f.data_type()),
94            }),
95            DataType::FixedSizeList(f, count) => Box::new(FixedSizeListWithLen {
96                len: 0,
97                validity: self.validity.clone(),
98                count: Always(*count),
99                inner_name: Always(f.name.clone()),
100                inner: self.from_data_type(f.data_type()),
101            }),
102            dt => panic!("not implemented: {:?}", dt),
103        }
104    }
105}
106
107pub type ArraySampler = Box<dyn Sample<Output = Box<dyn Array>> + Send + Sync>;
108
109pub type ChainedArraySampler =
110    Box<dyn Sample<Output = Chained<DataType, Box<dyn Array>>> + Send + Sync>;
111
112#[derive(Clone, Debug)]
113pub struct ArbitraryArray<N, V> {
114    pub names: N,
115    pub branch: Range<usize>,
116    pub len: Range<usize>,
117    pub null: V,
118    pub is_nullable: bool,
119}
120
121impl<N, V> ArbitraryArray<N, V>
122where
123    N: Sample<Output = String> + Send + Sync + Clone + 'static,
124    V: Sample<Output = bool> + Send + Sync + Clone + 'static,
125{
126    pub fn with_len(&self, len: usize) -> Self {
127        Self {
128            len: len..(len + 1),
129            ..self.clone()
130        }
131    }
132
133    pub fn arbitrary_array(self, data_type_sampler: DataTypeSampler) -> ChainedArraySampler {
134        Box::new(data_type_sampler.chain_resample(
135            move |data_type| self.sampler_from_data_type(&data_type),
136            100,
137        ))
138    }
139
140    pub fn sampler_from_data_type(&self, data_type: &DataType) -> ArraySampler {
141        let current_null = if self.is_nullable {
142            Some(self.null.clone())
143        } else {
144            None
145        };
146        let len = self.len.clone();
147
148        match data_type {
149            DataType::Float32 => boxed_primitive(valid_f32(), len, current_null),
150            DataType::Float64 => boxed_primitive(valid_f64(), len, current_null),
151            DataType::Int8 => arbitrary_boxed_primitive::<i8, _>(len, current_null),
152            DataType::Int16 => arbitrary_boxed_primitive::<i16, _>(len, current_null),
153            DataType::Int32 => arbitrary_boxed_primitive::<i32, _>(len, current_null),
154            DataType::Int64 => arbitrary_boxed_primitive::<i64, _>(len, current_null),
155            DataType::UInt8 => arbitrary_boxed_primitive::<u8, _>(len, current_null),
156            DataType::UInt16 => arbitrary_boxed_primitive::<u16, _>(len, current_null),
157            DataType::UInt32 => arbitrary_boxed_primitive::<u32, _>(len, current_null),
158            DataType::UInt64 => arbitrary_boxed_primitive::<u64, _>(len, current_null),
159            DataType::Struct(fields) => Box::new(StructSampler {
160                data_type: data_type.clone(),
161                null: current_null,
162                values: fields
163                    .iter()
164                    .map(|f| {
165                        ArbitraryArray {
166                            len: (len.end.saturating_sub(1))..len.end,
167                            is_nullable: f.is_nullable,
168                            ..self.clone()
169                        }
170                        .sampler_from_data_type(f.data_type())
171                    })
172                    .collect(),
173            }),
174            DataType::List(field) => Box::new(ListSampler {
175                data_type: data_type.clone(),
176                len: len.clone(),
177                null: current_null,
178                inner: ArbitraryArray {
179                    branch: (self.branch.start * self.len.start)..(self.branch.end * self.len.end),
180                    is_nullable: field.is_nullable,
181                    ..self.clone()
182                }
183                .sampler_from_data_type(field.data_type()),
184            }),
185            dt => panic!("not implemented: {:?}", dt),
186        }
187    }
188}