polars_arrow/array/struct_/
mod.rs1use super::{new_empty_array, new_null_array, Array, Splitable};
2use crate::bitmap::Bitmap;
3use crate::datatypes::{ArrowDataType, Field};
4
5mod ffi;
6pub(super) mod fmt;
7mod iterator;
8use polars_error::{polars_bail, polars_ensure, PolarsResult};
9
10use crate::compute::utils::combine_validities_and;
11
12#[derive(Clone)]
29pub struct StructArray {
30 dtype: ArrowDataType,
31 values: Vec<Box<dyn Array>>,
33 length: usize,
35 validity: Option<Bitmap>,
36}
37
38impl StructArray {
39 pub fn try_new(
49 dtype: ArrowDataType,
50 length: usize,
51 values: Vec<Box<dyn Array>>,
52 validity: Option<Bitmap>,
53 ) -> PolarsResult<Self> {
54 let fields = Self::try_get_fields(&dtype)?;
55
56 polars_ensure!(
57 fields.len() == values.len(),
58 ComputeError:
59 "a StructArray must have a number of fields in its DataType equal to the number of child values"
60 );
61
62 fields
63 .iter().map(|a| &a.dtype)
64 .zip(values.iter().map(|a| a.dtype()))
65 .enumerate()
66 .try_for_each(|(index, (dtype, child))| {
67 if dtype != child {
68 polars_bail!(ComputeError:
69 "The children DataTypes of a StructArray must equal the children data types.
70 However, the field {index} has data type {dtype:?} but the value has data type {child:?}"
71 )
72 } else {
73 Ok(())
74 }
75 })?;
76
77 values
78 .iter()
79 .map(|f| f.len())
80 .enumerate()
81 .try_for_each(|(index, f_length)| {
82 if f_length != length {
83 polars_bail!(ComputeError: "The children must have the given number of values.
84 However, the values at index {index} have a length of {f_length}, which is different from given length {length}.")
85 } else {
86 Ok(())
87 }
88 })?;
89
90 if validity
91 .as_ref()
92 .is_some_and(|validity| validity.len() != length)
93 {
94 polars_bail!(ComputeError:"The validity length of a StructArray must match its number of elements")
95 }
96
97 Ok(Self {
98 dtype,
99 length,
100 values,
101 validity,
102 })
103 }
104
105 pub fn new(
115 dtype: ArrowDataType,
116 length: usize,
117 values: Vec<Box<dyn Array>>,
118 validity: Option<Bitmap>,
119 ) -> Self {
120 Self::try_new(dtype, length, values, validity).unwrap()
121 }
122
123 pub fn new_empty(dtype: ArrowDataType) -> Self {
125 if let ArrowDataType::Struct(fields) = &dtype.to_logical_type() {
126 let values = fields
127 .iter()
128 .map(|field| new_empty_array(field.dtype().clone()))
129 .collect();
130 Self::new(dtype, 0, values, None)
131 } else {
132 panic!("StructArray must be initialized with DataType::Struct");
133 }
134 }
135
136 pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {
138 if let ArrowDataType::Struct(fields) = &dtype {
139 let values = fields
140 .iter()
141 .map(|field| new_null_array(field.dtype().clone(), length))
142 .collect();
143 Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length)))
144 } else {
145 panic!("StructArray must be initialized with DataType::Struct");
146 }
147 }
148}
149
150impl StructArray {
152 #[must_use]
154 pub fn into_data(self) -> (Vec<Field>, usize, Vec<Box<dyn Array>>, Option<Bitmap>) {
155 let Self {
156 dtype,
157 length,
158 values,
159 validity,
160 } = self;
161 let fields = if let ArrowDataType::Struct(fields) = dtype {
162 fields
163 } else {
164 unreachable!()
165 };
166 (fields, length, values, validity)
167 }
168
169 pub fn slice(&mut self, offset: usize, length: usize) {
175 assert!(
176 offset + length <= self.len(),
177 "offset + length may not exceed length of array"
178 );
179 unsafe { self.slice_unchecked(offset, length) }
180 }
181
182 pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
189 self.validity = self
190 .validity
191 .take()
192 .map(|bitmap| bitmap.sliced_unchecked(offset, length))
193 .filter(|bitmap| bitmap.unset_bits() > 0);
194 self.values
195 .iter_mut()
196 .for_each(|x| x.slice_unchecked(offset, length));
197 self.length = length;
198 }
199
200 pub fn propagate_nulls(&self) -> StructArray {
202 let has_nulls = self.null_count() > 0;
203 let mut out = self.clone();
204 if !has_nulls {
205 return out;
206 };
207
208 for value_arr in &mut out.values {
209 let new_validity = combine_validities_and(self.validity(), value_arr.validity());
210 *value_arr = value_arr.with_validity(new_validity);
211 }
212 out
213 }
214
215 impl_sliced!();
216
217 impl_mut_validity!();
218
219 impl_into_array!();
220}
221
222impl StructArray {
224 #[inline]
225 fn len(&self) -> usize {
226 if cfg!(debug_assertions) {
227 for arr in self.values.iter() {
228 assert_eq!(
229 arr.len(),
230 self.length,
231 "StructArray invariant: each array has same length"
232 );
233 }
234 }
235
236 self.length
237 }
238
239 #[inline]
241 pub fn validity(&self) -> Option<&Bitmap> {
242 self.validity.as_ref()
243 }
244
245 pub fn values(&self) -> &[Box<dyn Array>] {
247 &self.values
248 }
249
250 pub fn fields(&self) -> &[Field] {
252 let fields = Self::get_fields(&self.dtype);
253 debug_assert_eq!(self.values().len(), fields.len());
254 fields
255 }
256}
257
258impl StructArray {
259 pub(crate) fn try_get_fields(dtype: &ArrowDataType) -> PolarsResult<&[Field]> {
261 match dtype.to_logical_type() {
262 ArrowDataType::Struct(fields) => Ok(fields),
263 _ => {
264 polars_bail!(ComputeError: "Struct array must be created with a DataType whose physical type is Struct")
265 },
266 }
267 }
268
269 pub fn get_fields(dtype: &ArrowDataType) -> &[Field] {
271 Self::try_get_fields(dtype).unwrap()
272 }
273}
274
275impl Array for StructArray {
276 impl_common_array!();
277
278 fn validity(&self) -> Option<&Bitmap> {
279 self.validity.as_ref()
280 }
281
282 #[inline]
283 fn with_validity(&self, validity: Option<Bitmap>) -> Box<dyn Array> {
284 Box::new(self.clone().with_validity(validity))
285 }
286}
287
288impl Splitable for StructArray {
289 fn check_bound(&self, offset: usize) -> bool {
290 offset <= self.len()
291 }
292
293 unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {
294 let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) };
295
296 let mut lhs_values = Vec::with_capacity(self.values.len());
297 let mut rhs_values = Vec::with_capacity(self.values.len());
298
299 for v in self.values.iter() {
300 let (lhs, rhs) = unsafe { v.split_at_boxed_unchecked(offset) };
301 lhs_values.push(lhs);
302 rhs_values.push(rhs);
303 }
304
305 (
306 Self {
307 dtype: self.dtype.clone(),
308 length: offset,
309 values: lhs_values,
310 validity: lhs_validity,
311 },
312 Self {
313 dtype: self.dtype.clone(),
314 length: self.length - offset,
315 values: rhs_values,
316 validity: rhs_validity,
317 },
318 )
319 }
320}