1use std::ops::Range;
4
5use arrow2::{
6 array::{Array, FixedSizeListArray, ListArray},
7 bitmap::Bitmap,
8 datatypes::DataType,
9};
10use sample_std::{valid_f32, valid_f64, Always, Chained, Sample};
11
12use crate::{
13 datatypes::DataTypeSampler,
14 fixed_size_list::FixedSizeListWithLen,
15 list::{ListSampler, ListWithLen},
16 primitive::{
17 arbitrary_boxed_primitive, arbitrary_len_sampler, boxed_primitive, primitive_len_sampler,
18 },
19 struct_::StructSampler,
20 AlwaysValid, ArrowLenSampler, SetLen,
21};
22
23pub fn sampler_from_example(array: &dyn Array) -> ArrowLenSampler {
24 match array.data_type() {
25 DataType::Float32 => primitive_len_sampler(valid_f32(), AlwaysValid),
26 DataType::Float64 => primitive_len_sampler(valid_f64(), AlwaysValid),
27 DataType::Int8 => arbitrary_len_sampler::<i8, _>(AlwaysValid),
28 DataType::Int16 => arbitrary_len_sampler::<i16, _>(AlwaysValid),
29 DataType::Int32 => arbitrary_len_sampler::<i32, _>(AlwaysValid),
30 DataType::Int64 => arbitrary_len_sampler::<i64, _>(AlwaysValid),
31 DataType::UInt8 => arbitrary_len_sampler::<u8, _>(AlwaysValid),
32 DataType::UInt16 => arbitrary_len_sampler::<u16, _>(AlwaysValid),
33 DataType::UInt32 => arbitrary_len_sampler::<u32, _>(AlwaysValid),
34 DataType::UInt64 => arbitrary_len_sampler::<u64, _>(AlwaysValid),
35 DataType::List(f) => {
36 let list = array.as_any().downcast_ref::<ListArray<i32>>().unwrap();
37 let min = list.offsets().lengths().min().unwrap_or(0) as i32;
38 let max = list.offsets().lengths().max().unwrap_or(0) as i32 + 1;
39 Box::new(ListWithLen {
40 len: array.len(),
41 validity: AlwaysValid,
42 count: min..max,
43 inner_name: Always(f.name.clone()),
44 inner: sampler_from_example(list.values().as_ref()),
45 })
46 }
47 DataType::FixedSizeList(f, count) => Box::new(FixedSizeListWithLen {
48 len: array.len(),
49 validity: AlwaysValid,
50 count: Always(*count),
51 inner_name: Always(f.name.clone()),
52 inner: sampler_from_example(
53 array
54 .as_any()
55 .downcast_ref::<FixedSizeListArray>()
56 .unwrap()
57 .values()
58 .as_ref(),
59 ),
60 }),
61 dt => panic!("not implemented: {:?}", dt),
62 }
63}
64
65pub struct FromDataType<V, B> {
66 pub validity: V,
67
68 pub branch: B,
69}
70
71impl<V, B> FromDataType<V, B>
72where
73 V: Sample<Output = Option<Bitmap>> + SetLen + Clone + Send + Sync + 'static,
74 B: Sample<Output = i32> + Clone + Send + Sync + 'static,
75{
76 pub fn from_data_type(&self, data_type: &DataType) -> ArrowLenSampler {
77 match data_type {
78 DataType::Float32 => primitive_len_sampler(valid_f32(), self.validity.clone()),
79 DataType::Float64 => primitive_len_sampler(valid_f64(), self.validity.clone()),
80 DataType::Int8 => arbitrary_len_sampler::<i8, _>(self.validity.clone()),
81 DataType::Int16 => arbitrary_len_sampler::<i16, _>(self.validity.clone()),
82 DataType::Int32 => arbitrary_len_sampler::<i32, _>(self.validity.clone()),
83 DataType::Int64 => arbitrary_len_sampler::<i64, _>(self.validity.clone()),
84 DataType::UInt8 => arbitrary_len_sampler::<u8, _>(self.validity.clone()),
85 DataType::UInt16 => arbitrary_len_sampler::<u16, _>(self.validity.clone()),
86 DataType::UInt32 => arbitrary_len_sampler::<u32, _>(self.validity.clone()),
87 DataType::UInt64 => arbitrary_len_sampler::<u64, _>(self.validity.clone()),
88 DataType::List(f) => Box::new(ListWithLen {
89 len: 0,
90 validity: self.validity.clone(),
91 count: self.branch.clone(),
92 inner_name: Always(f.name.clone()),
93 inner: self.from_data_type(f.data_type()),
94 }),
95 DataType::FixedSizeList(f, count) => Box::new(FixedSizeListWithLen {
96 len: 0,
97 validity: self.validity.clone(),
98 count: Always(*count),
99 inner_name: Always(f.name.clone()),
100 inner: self.from_data_type(f.data_type()),
101 }),
102 dt => panic!("not implemented: {:?}", dt),
103 }
104 }
105}
106
107pub type ArraySampler = Box<dyn Sample<Output = Box<dyn Array>> + Send + Sync>;
108
109pub type ChainedArraySampler =
110 Box<dyn Sample<Output = Chained<DataType, Box<dyn Array>>> + Send + Sync>;
111
112#[derive(Clone, Debug)]
113pub struct ArbitraryArray<N, V> {
114 pub names: N,
115 pub branch: Range<usize>,
116 pub len: Range<usize>,
117 pub null: V,
118 pub is_nullable: bool,
119}
120
121impl<N, V> ArbitraryArray<N, V>
122where
123 N: Sample<Output = String> + Send + Sync + Clone + 'static,
124 V: Sample<Output = bool> + Send + Sync + Clone + 'static,
125{
126 pub fn with_len(&self, len: usize) -> Self {
127 Self {
128 len: len..(len + 1),
129 ..self.clone()
130 }
131 }
132
133 pub fn arbitrary_array(self, data_type_sampler: DataTypeSampler) -> ChainedArraySampler {
134 Box::new(data_type_sampler.chain_resample(
135 move |data_type| self.sampler_from_data_type(&data_type),
136 100,
137 ))
138 }
139
140 pub fn sampler_from_data_type(&self, data_type: &DataType) -> ArraySampler {
141 let current_null = if self.is_nullable {
142 Some(self.null.clone())
143 } else {
144 None
145 };
146 let len = self.len.clone();
147
148 match data_type {
149 DataType::Float32 => boxed_primitive(valid_f32(), len, current_null),
150 DataType::Float64 => boxed_primitive(valid_f64(), len, current_null),
151 DataType::Int8 => arbitrary_boxed_primitive::<i8, _>(len, current_null),
152 DataType::Int16 => arbitrary_boxed_primitive::<i16, _>(len, current_null),
153 DataType::Int32 => arbitrary_boxed_primitive::<i32, _>(len, current_null),
154 DataType::Int64 => arbitrary_boxed_primitive::<i64, _>(len, current_null),
155 DataType::UInt8 => arbitrary_boxed_primitive::<u8, _>(len, current_null),
156 DataType::UInt16 => arbitrary_boxed_primitive::<u16, _>(len, current_null),
157 DataType::UInt32 => arbitrary_boxed_primitive::<u32, _>(len, current_null),
158 DataType::UInt64 => arbitrary_boxed_primitive::<u64, _>(len, current_null),
159 DataType::Struct(fields) => Box::new(StructSampler {
160 data_type: data_type.clone(),
161 null: current_null,
162 values: fields
163 .iter()
164 .map(|f| {
165 ArbitraryArray {
166 len: (len.end.saturating_sub(1))..len.end,
167 is_nullable: f.is_nullable,
168 ..self.clone()
169 }
170 .sampler_from_data_type(f.data_type())
171 })
172 .collect(),
173 }),
174 DataType::List(field) => Box::new(ListSampler {
175 data_type: data_type.clone(),
176 len: len.clone(),
177 null: current_null,
178 inner: ArbitraryArray {
179 branch: (self.branch.start * self.len.start)..(self.branch.end * self.len.end),
180 is_nullable: field.is_nullable,
181 ..self.clone()
182 }
183 .sampler_from_data_type(field.data_type()),
184 }),
185 dt => panic!("not implemented: {:?}", dt),
186 }
187 }
188}