polars_arrow/array/binview/
mod.rs

1//! See thread: https://lists.apache.org/thread/w88tpz76ox8h3rxkjl4so6rg3f1rv7wt
2mod ffi;
3pub(super) mod fmt;
4mod iterator;
5mod mutable;
6mod view;
7
8use std::any::Any;
9use std::fmt::Debug;
10use std::marker::PhantomData;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13
14use polars_error::*;
15
16use crate::array::Array;
17use crate::bitmap::Bitmap;
18use crate::buffer::Buffer;
19use crate::datatypes::ArrowDataType;
20
21mod private {
22    pub trait Sealed: Send + Sync {}
23
24    impl Sealed for str {}
25    impl Sealed for [u8] {}
26}
27pub use iterator::BinaryViewValueIter;
28pub use mutable::MutableBinaryViewArray;
29use polars_utils::aliases::{InitHashMaps, PlHashMap};
30use private::Sealed;
31
32use crate::array::binview::view::{validate_binary_view, validate_utf8_only};
33use crate::array::iterator::NonNullValuesIter;
34use crate::bitmap::utils::{BitmapIter, ZipValidity};
35pub type BinaryViewArray = BinaryViewArrayGeneric<[u8]>;
36pub type Utf8ViewArray = BinaryViewArrayGeneric<str>;
37pub use view::{validate_utf8_view, View};
38
39use super::Splitable;
40
41pub type MutablePlString = MutableBinaryViewArray<str>;
42pub type MutablePlBinary = MutableBinaryViewArray<[u8]>;
43
44static BIN_VIEW_TYPE: ArrowDataType = ArrowDataType::BinaryView;
45static UTF8_VIEW_TYPE: ArrowDataType = ArrowDataType::Utf8View;
46
47pub trait ViewType: Sealed + 'static + PartialEq + AsRef<Self> {
48    const IS_UTF8: bool;
49    const DATA_TYPE: ArrowDataType;
50    type Owned: Debug + Clone + Sync + Send + AsRef<Self>;
51
52    /// # Safety
53    /// The caller must ensure that `slice` is a valid view.
54    unsafe fn from_bytes_unchecked(slice: &[u8]) -> &Self;
55    fn from_bytes(slice: &[u8]) -> Option<&Self>;
56
57    fn to_bytes(&self) -> &[u8];
58
59    #[allow(clippy::wrong_self_convention)]
60    fn into_owned(&self) -> Self::Owned;
61
62    fn dtype() -> &'static ArrowDataType;
63}
64
65impl ViewType for str {
66    const IS_UTF8: bool = true;
67    const DATA_TYPE: ArrowDataType = ArrowDataType::Utf8View;
68    type Owned = String;
69
70    #[inline(always)]
71    unsafe fn from_bytes_unchecked(slice: &[u8]) -> &Self {
72        std::str::from_utf8_unchecked(slice)
73    }
74    #[inline(always)]
75    fn from_bytes(slice: &[u8]) -> Option<&Self> {
76        std::str::from_utf8(slice).ok()
77    }
78
79    #[inline(always)]
80    fn to_bytes(&self) -> &[u8] {
81        self.as_bytes()
82    }
83
84    fn into_owned(&self) -> Self::Owned {
85        self.to_string()
86    }
87    fn dtype() -> &'static ArrowDataType {
88        &UTF8_VIEW_TYPE
89    }
90}
91
92impl ViewType for [u8] {
93    const IS_UTF8: bool = false;
94    const DATA_TYPE: ArrowDataType = ArrowDataType::BinaryView;
95    type Owned = Vec<u8>;
96
97    #[inline(always)]
98    unsafe fn from_bytes_unchecked(slice: &[u8]) -> &Self {
99        slice
100    }
101    #[inline(always)]
102    fn from_bytes(slice: &[u8]) -> Option<&Self> {
103        Some(slice)
104    }
105
106    #[inline(always)]
107    fn to_bytes(&self) -> &[u8] {
108        self
109    }
110
111    fn into_owned(&self) -> Self::Owned {
112        self.to_vec()
113    }
114
115    fn dtype() -> &'static ArrowDataType {
116        &BIN_VIEW_TYPE
117    }
118}
119
120pub struct BinaryViewArrayGeneric<T: ViewType + ?Sized> {
121    dtype: ArrowDataType,
122    views: Buffer<View>,
123    buffers: Arc<[Buffer<u8>]>,
124    validity: Option<Bitmap>,
125    phantom: PhantomData<T>,
126    /// Total bytes length if we would concatenate them all.
127    total_bytes_len: AtomicU64,
128    /// Total bytes in the buffer (excluding remaining capacity)
129    total_buffer_len: usize,
130}
131
132impl<T: ViewType + ?Sized> PartialEq for BinaryViewArrayGeneric<T> {
133    fn eq(&self, other: &Self) -> bool {
134        self.len() == other.len() && self.into_iter().zip(other).all(|(l, r)| l == r)
135    }
136}
137
138impl<T: ViewType + ?Sized> Clone for BinaryViewArrayGeneric<T> {
139    fn clone(&self) -> Self {
140        Self {
141            dtype: self.dtype.clone(),
142            views: self.views.clone(),
143            buffers: self.buffers.clone(),
144            validity: self.validity.clone(),
145            phantom: Default::default(),
146            total_bytes_len: AtomicU64::new(self.total_bytes_len.load(Ordering::Relaxed)),
147            total_buffer_len: self.total_buffer_len,
148        }
149    }
150}
151
152unsafe impl<T: ViewType + ?Sized> Send for BinaryViewArrayGeneric<T> {}
153unsafe impl<T: ViewType + ?Sized> Sync for BinaryViewArrayGeneric<T> {}
154
155const UNKNOWN_LEN: u64 = u64::MAX;
156
157impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
158    /// # Safety
159    /// The caller must ensure
160    /// - the data is valid utf8 (if required)
161    /// - The offsets match the buffers.
162    pub unsafe fn new_unchecked(
163        dtype: ArrowDataType,
164        views: Buffer<View>,
165        buffers: Arc<[Buffer<u8>]>,
166        validity: Option<Bitmap>,
167        total_bytes_len: usize,
168        total_buffer_len: usize,
169    ) -> Self {
170        // Verify the invariants
171        #[cfg(debug_assertions)]
172        {
173            if let Some(validity) = validity.as_ref() {
174                assert_eq!(validity.len(), views.len());
175            }
176
177            // @TODO: Enable this. This is currently bugged with concatenate.
178            // let mut actual_total_buffer_len = 0;
179            // let mut actual_total_bytes_len = 0;
180            //
181            // for buffer in buffers.iter() {
182            //     actual_total_buffer_len += buffer.len();
183            // }
184
185            for (i, view) in views.iter().enumerate() {
186                let is_valid = validity.as_ref().is_none_or(|v| v.get_bit(i));
187
188                if !is_valid {
189                    continue;
190                }
191
192                // actual_total_bytes_len += view.length as usize;
193                if view.length > View::MAX_INLINE_SIZE {
194                    assert!((view.buffer_idx as usize) < (buffers.len()));
195                    assert!(
196                        view.offset as usize + view.length as usize
197                            <= buffers[view.buffer_idx as usize].len()
198                    );
199                }
200            }
201
202            // assert_eq!(actual_total_buffer_len, total_buffer_len);
203            // if (total_bytes_len as u64) != UNKNOWN_LEN {
204            //     assert_eq!(actual_total_bytes_len, total_bytes_len);
205            // }
206        }
207
208        Self {
209            dtype,
210            views,
211            buffers,
212            validity,
213            phantom: Default::default(),
214            total_bytes_len: AtomicU64::new(total_bytes_len as u64),
215            total_buffer_len,
216        }
217    }
218
219    /// Create a new BinaryViewArray but initialize a statistics compute.
220    ///
221    /// # Safety
222    /// The caller must ensure the invariants
223    pub unsafe fn new_unchecked_unknown_md(
224        dtype: ArrowDataType,
225        views: Buffer<View>,
226        buffers: Arc<[Buffer<u8>]>,
227        validity: Option<Bitmap>,
228        total_buffer_len: Option<usize>,
229    ) -> Self {
230        let total_bytes_len = UNKNOWN_LEN as usize;
231        let total_buffer_len =
232            total_buffer_len.unwrap_or_else(|| buffers.iter().map(|b| b.len()).sum());
233        Self::new_unchecked(
234            dtype,
235            views,
236            buffers,
237            validity,
238            total_bytes_len,
239            total_buffer_len,
240        )
241    }
242
243    pub fn data_buffers(&self) -> &Arc<[Buffer<u8>]> {
244        &self.buffers
245    }
246
247    pub fn variadic_buffer_lengths(&self) -> Vec<i64> {
248        self.buffers.iter().map(|buf| buf.len() as i64).collect()
249    }
250
251    pub fn views(&self) -> &Buffer<View> {
252        &self.views
253    }
254
255    pub fn into_views(self) -> Vec<View> {
256        self.views.make_mut()
257    }
258
259    pub fn into_inner(
260        self,
261    ) -> (
262        Buffer<View>,
263        Arc<[Buffer<u8>]>,
264        Option<Bitmap>,
265        usize,
266        usize,
267    ) {
268        let views = self.views;
269        let buffers = self.buffers;
270        let validity = self.validity;
271
272        (
273            views,
274            buffers,
275            validity,
276            self.total_bytes_len.load(Ordering::Relaxed) as usize,
277            self.total_buffer_len,
278        )
279    }
280
281    /// Apply a function over the views. This can be used to update views in operations like slicing.
282    ///
283    /// # Safety
284    /// Update the views. All invariants of the views apply.
285    pub unsafe fn apply_views<F: FnMut(View, &T) -> View>(&self, mut update_view: F) -> Self {
286        let arr = self.clone();
287        let (views, buffers, validity, total_bytes_len, total_buffer_len) = arr.into_inner();
288
289        let mut views = views.make_mut();
290        for v in views.iter_mut() {
291            let str_slice = T::from_bytes_unchecked(v.get_slice_unchecked(&buffers));
292            *v = update_view(*v, str_slice);
293        }
294        Self::new_unchecked(
295            self.dtype.clone(),
296            views.into(),
297            buffers,
298            validity,
299            total_bytes_len,
300            total_buffer_len,
301        )
302    }
303
304    pub fn try_new(
305        dtype: ArrowDataType,
306        views: Buffer<View>,
307        buffers: Arc<[Buffer<u8>]>,
308        validity: Option<Bitmap>,
309    ) -> PolarsResult<Self> {
310        if T::IS_UTF8 {
311            validate_utf8_view(views.as_ref(), buffers.as_ref())?;
312        } else {
313            validate_binary_view(views.as_ref(), buffers.as_ref())?;
314        }
315
316        if let Some(validity) = &validity {
317            polars_ensure!(validity.len()== views.len(), ComputeError: "validity mask length must match the number of values" )
318        }
319
320        unsafe {
321            Ok(Self::new_unchecked_unknown_md(
322                dtype, views, buffers, validity, None,
323            ))
324        }
325    }
326
327    /// Creates an empty [`BinaryViewArrayGeneric`], i.e. whose `.len` is zero.
328    #[inline]
329    pub fn new_empty(dtype: ArrowDataType) -> Self {
330        unsafe { Self::new_unchecked(dtype, Buffer::new(), Arc::from([]), None, 0, 0) }
331    }
332
333    /// Returns a new null [`BinaryViewArrayGeneric`] of `length`.
334    #[inline]
335    pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {
336        let validity = Some(Bitmap::new_zeroed(length));
337        unsafe { Self::new_unchecked(dtype, Buffer::zeroed(length), Arc::from([]), validity, 0, 0) }
338    }
339
340    /// Returns the element at index `i`
341    /// # Panics
342    /// iff `i >= self.len()`
343    #[inline]
344    pub fn value(&self, i: usize) -> &T {
345        assert!(i < self.len());
346        unsafe { self.value_unchecked(i) }
347    }
348
349    /// Returns the element at index `i`
350    ///
351    /// # Safety
352    /// Assumes that the `i < self.len`.
353    #[inline]
354    pub unsafe fn value_unchecked(&self, i: usize) -> &T {
355        let v = self.views.get_unchecked(i);
356        T::from_bytes_unchecked(v.get_slice_unchecked(&self.buffers))
357    }
358
359    /// Returns an iterator of `Option<&T>` over every element of this array.
360    pub fn iter(&self) -> ZipValidity<&T, BinaryViewValueIter<T>, BitmapIter> {
361        ZipValidity::new_with_validity(self.values_iter(), self.validity.as_ref())
362    }
363
364    /// Returns an iterator of `&[u8]` over every element of this array, ignoring the validity
365    pub fn values_iter(&self) -> BinaryViewValueIter<T> {
366        BinaryViewValueIter::new(self)
367    }
368
369    pub fn len_iter(&self) -> impl Iterator<Item = u32> + '_ {
370        self.views.iter().map(|v| v.length)
371    }
372
373    /// Returns an iterator of the non-null values.
374    pub fn non_null_values_iter(&self) -> NonNullValuesIter<'_, BinaryViewArrayGeneric<T>> {
375        NonNullValuesIter::new(self, self.validity())
376    }
377
378    /// Returns an iterator of the non-null values.
379    pub fn non_null_views_iter(&self) -> NonNullValuesIter<'_, Buffer<View>> {
380        NonNullValuesIter::new(self.views(), self.validity())
381    }
382
383    impl_sliced!();
384    impl_mut_validity!();
385    impl_into_array!();
386
387    pub fn from_slice<S: AsRef<T>, P: AsRef<[Option<S>]>>(slice: P) -> Self {
388        let mutable = MutableBinaryViewArray::from_iterator(
389            slice.as_ref().iter().map(|opt_v| opt_v.as_ref()),
390        );
391        mutable.into()
392    }
393
394    pub fn from_slice_values<S: AsRef<T>, P: AsRef<[S]>>(slice: P) -> Self {
395        let mutable =
396            MutableBinaryViewArray::from_values_iter(slice.as_ref().iter().map(|v| v.as_ref()));
397        mutable.into()
398    }
399
400    /// Get the total length of bytes that it would take to concatenate all binary/str values in this array.
401    pub fn total_bytes_len(&self) -> usize {
402        let total = self.total_bytes_len.load(Ordering::Relaxed);
403        if total == UNKNOWN_LEN {
404            let total = self.len_iter().map(|v| v as usize).sum::<usize>();
405            self.total_bytes_len.store(total as u64, Ordering::Relaxed);
406            total
407        } else {
408            total as usize
409        }
410    }
411
412    /// Get the length of bytes that are stored in the variadic buffers.
413    pub fn total_buffer_len(&self) -> usize {
414        self.total_buffer_len
415    }
416
417    fn total_unshared_buffer_len(&self) -> usize {
418        // XXX: it is O(n), not O(1).
419        // Given this function is only called in `maybe_gc()`,
420        // it may not be worthy to add an extra field for this.
421        self.buffers
422            .iter()
423            .map(|buf| {
424                if buf.storage_refcount() > 1 {
425                    0
426                } else {
427                    buf.len()
428                }
429            })
430            .sum()
431    }
432
433    #[inline(always)]
434    pub fn len(&self) -> usize {
435        self.views.len()
436    }
437
438    /// Garbage collect
439    pub fn gc(self) -> Self {
440        if self.buffers.is_empty() {
441            return self;
442        }
443        let mut mutable = MutableBinaryViewArray::with_capacity(self.len());
444        let buffers = self.buffers.as_ref();
445
446        for view in self.views.as_ref() {
447            unsafe { mutable.push_view_unchecked(*view, buffers) }
448        }
449        mutable.freeze().with_validity(self.validity)
450    }
451
452    pub fn is_sliced(&self) -> bool {
453        self.views.as_ptr() != self.views.storage_ptr()
454    }
455
456    pub fn maybe_gc(self) -> Self {
457        const GC_MINIMUM_SAVINGS: usize = 16 * 1024; // At least 16 KiB.
458
459        if self.total_buffer_len <= GC_MINIMUM_SAVINGS {
460            return self;
461        }
462
463        if Arc::strong_count(&self.buffers) != 1 {
464            // There are multiple holders of this `buffers`.
465            // If we allow gc in this case,
466            // it may end up copying the same content multiple times.
467            return self;
468        }
469
470        // Subtract the maximum amount of inlined strings to get a lower bound
471        // on the number of buffer bytes needed (assuming no dedup).
472        let total_bytes_len = self.total_bytes_len();
473        let buffer_req_lower_bound = total_bytes_len.saturating_sub(self.len() * 12);
474
475        let lower_bound_mem_usage_post_gc = self.len() * 16 + buffer_req_lower_bound;
476        // Use unshared buffer len. Shared buffer won't be freed; no savings.
477        let cur_mem_usage = self.len() * 16 + self.total_unshared_buffer_len();
478        let savings_upper_bound = cur_mem_usage.saturating_sub(lower_bound_mem_usage_post_gc);
479
480        if savings_upper_bound >= GC_MINIMUM_SAVINGS
481            && cur_mem_usage >= 4 * lower_bound_mem_usage_post_gc
482        {
483            self.gc()
484        } else {
485            self
486        }
487    }
488
489    pub fn make_mut(self) -> MutableBinaryViewArray<T> {
490        let views = self.views.make_mut();
491        let completed_buffers = self.buffers.to_vec();
492        let validity = self.validity.map(|bitmap| bitmap.make_mut());
493
494        // We need to know the total_bytes_len if we are going to mutate it.
495        let mut total_bytes_len = self.total_bytes_len.load(Ordering::Relaxed);
496        if total_bytes_len == UNKNOWN_LEN {
497            total_bytes_len = views.iter().map(|view| view.length as u64).sum();
498        }
499        let total_bytes_len = total_bytes_len as usize;
500
501        MutableBinaryViewArray {
502            views,
503            completed_buffers,
504            in_progress_buffer: vec![],
505            validity,
506            phantom: Default::default(),
507            total_bytes_len,
508            total_buffer_len: self.total_buffer_len,
509            stolen_buffers: PlHashMap::new(),
510        }
511    }
512}
513
514impl BinaryViewArray {
515    /// Validate the underlying bytes on UTF-8.
516    pub fn validate_utf8(&self) -> PolarsResult<()> {
517        // SAFETY: views are correct
518        unsafe { validate_utf8_only(&self.views, &self.buffers, &self.buffers) }
519    }
520
521    /// Convert [`BinaryViewArray`] to [`Utf8ViewArray`].
522    pub fn to_utf8view(&self) -> PolarsResult<Utf8ViewArray> {
523        self.validate_utf8()?;
524        unsafe { Ok(self.to_utf8view_unchecked()) }
525    }
526
527    /// Convert [`BinaryViewArray`] to [`Utf8ViewArray`] without checking UTF-8.
528    ///
529    /// # Safety
530    /// The caller must ensure the underlying data is valid UTF-8.
531    pub unsafe fn to_utf8view_unchecked(&self) -> Utf8ViewArray {
532        Utf8ViewArray::new_unchecked(
533            ArrowDataType::Utf8View,
534            self.views.clone(),
535            self.buffers.clone(),
536            self.validity.clone(),
537            self.total_bytes_len.load(Ordering::Relaxed) as usize,
538            self.total_buffer_len,
539        )
540    }
541}
542
543impl Utf8ViewArray {
544    pub fn to_binview(&self) -> BinaryViewArray {
545        // SAFETY: same invariants.
546        unsafe {
547            BinaryViewArray::new_unchecked(
548                ArrowDataType::BinaryView,
549                self.views.clone(),
550                self.buffers.clone(),
551                self.validity.clone(),
552                self.total_bytes_len.load(Ordering::Relaxed) as usize,
553                self.total_buffer_len,
554            )
555        }
556    }
557}
558
559impl<T: ViewType + ?Sized> Array for BinaryViewArrayGeneric<T> {
560    fn as_any(&self) -> &dyn Any {
561        self
562    }
563
564    fn as_any_mut(&mut self) -> &mut dyn Any {
565        self
566    }
567
568    #[inline(always)]
569    fn len(&self) -> usize {
570        BinaryViewArrayGeneric::len(self)
571    }
572
573    fn dtype(&self) -> &ArrowDataType {
574        T::dtype()
575    }
576
577    fn validity(&self) -> Option<&Bitmap> {
578        self.validity.as_ref()
579    }
580
581    fn split_at_boxed(&self, offset: usize) -> (Box<dyn Array>, Box<dyn Array>) {
582        let (lhs, rhs) = Splitable::split_at(self, offset);
583        (Box::new(lhs), Box::new(rhs))
584    }
585
586    unsafe fn split_at_boxed_unchecked(&self, offset: usize) -> (Box<dyn Array>, Box<dyn Array>) {
587        let (lhs, rhs) = unsafe { Splitable::split_at_unchecked(self, offset) };
588        (Box::new(lhs), Box::new(rhs))
589    }
590
591    fn slice(&mut self, offset: usize, length: usize) {
592        assert!(
593            offset + length <= self.len(),
594            "the offset of the new Buffer cannot exceed the existing length"
595        );
596        unsafe { self.slice_unchecked(offset, length) }
597    }
598
599    unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
600        debug_assert!(offset + length <= self.len());
601        self.validity = self
602            .validity
603            .take()
604            .map(|bitmap| bitmap.sliced_unchecked(offset, length))
605            .filter(|bitmap| bitmap.unset_bits() > 0);
606        self.views.slice_unchecked(offset, length);
607        self.total_bytes_len.store(UNKNOWN_LEN, Ordering::Relaxed)
608    }
609
610    fn with_validity(&self, validity: Option<Bitmap>) -> Box<dyn Array> {
611        debug_assert!(
612            validity.as_ref().is_none_or(|v| v.len() == self.len()),
613            "{} != {}",
614            validity.as_ref().unwrap().len(),
615            self.len()
616        );
617
618        let mut new = self.clone();
619        new.validity = validity;
620        Box::new(new)
621    }
622
623    fn to_boxed(&self) -> Box<dyn Array> {
624        Box::new(self.clone())
625    }
626}
627
628impl<T: ViewType + ?Sized> Splitable for BinaryViewArrayGeneric<T> {
629    fn check_bound(&self, offset: usize) -> bool {
630        offset <= self.len()
631    }
632
633    unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {
634        let (lhs_views, rhs_views) = unsafe { self.views.split_at_unchecked(offset) };
635        let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) };
636
637        unsafe {
638            (
639                Self::new_unchecked(
640                    self.dtype.clone(),
641                    lhs_views,
642                    self.buffers.clone(),
643                    lhs_validity,
644                    if offset == 0 { 0 } else { UNKNOWN_LEN as _ },
645                    self.total_buffer_len(),
646                ),
647                Self::new_unchecked(
648                    self.dtype.clone(),
649                    rhs_views,
650                    self.buffers.clone(),
651                    rhs_validity,
652                    if offset == self.len() {
653                        0
654                    } else {
655                        UNKNOWN_LEN as _
656                    },
657                    self.total_buffer_len(),
658                ),
659            )
660        }
661    }
662}