lance_encoding/encodings/physical/
bitpack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow::datatypes::{
7    ArrowPrimitiveType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
8    UInt64Type, UInt8Type,
9};
10use arrow::util::bit_util::ceil;
11use arrow_array::{cast::AsArray, Array, PrimitiveArray};
12use arrow_schema::DataType;
13use bytes::Bytes;
14use futures::future::{BoxFuture, FutureExt};
15use log::trace;
16use num_traits::{AsPrimitive, PrimInt, ToPrimitive};
17use snafu::location;
18
19use lance_arrow::DataTypeExt;
20use lance_core::{Error, Result};
21
22use crate::buffer::LanceBuffer;
23use crate::data::{BlockInfo, DataBlock, FixedWidthDataBlock};
24use crate::decoder::{PageScheduler, PrimitivePageDecoder};
25use crate::encoder::{ArrayEncoder, EncodedArray};
26use crate::format::ProtobufUtils;
27
28#[derive(Debug)]
29pub struct BitpackParams {
30    pub num_bits: u64,
31
32    pub signed: bool,
33}
34
35// Compute the number of bits to use for each item, if this array can be encoded using
36// bitpacking encoding. Returns `None` if the type or array data is not supported.
37pub fn bitpack_params(arr: &dyn Array) -> Option<BitpackParams> {
38    match arr.data_type() {
39        DataType::UInt8 => bitpack_params_for_type::<UInt8Type>(arr.as_primitive()),
40        DataType::UInt16 => bitpack_params_for_type::<UInt16Type>(arr.as_primitive()),
41        DataType::UInt32 => bitpack_params_for_type::<UInt32Type>(arr.as_primitive()),
42        DataType::UInt64 => bitpack_params_for_type::<UInt64Type>(arr.as_primitive()),
43        DataType::Int8 => bitpack_params_for_signed_type::<Int8Type>(arr.as_primitive()),
44        DataType::Int16 => bitpack_params_for_signed_type::<Int16Type>(arr.as_primitive()),
45        DataType::Int32 => bitpack_params_for_signed_type::<Int32Type>(arr.as_primitive()),
46        DataType::Int64 => bitpack_params_for_signed_type::<Int64Type>(arr.as_primitive()),
47        // TODO -- eventually we could support temporal types as well
48        _ => None,
49    }
50}
51
52// Compute the number bits to to use for bitpacking generically.
53// returns None if the array is empty or all nulls
54fn bitpack_params_for_type<T>(arr: &PrimitiveArray<T>) -> Option<BitpackParams>
55where
56    T: ArrowPrimitiveType,
57    T::Native: PrimInt + AsPrimitive<u64>,
58{
59    let max = arrow::compute::bit_or(arr);
60    let num_bits =
61        max.map(|max| arr.data_type().byte_width() as u64 * 8 - max.leading_zeros() as u64);
62
63    // we can't bitpack into 0 bits, so the minimum is 1
64    num_bits
65        .map(|num_bits| num_bits.max(1))
66        .map(|bits| BitpackParams {
67            num_bits: bits,
68            signed: false,
69        })
70}
71
72/// determine the minimum number of bits that can be used to represent
73/// an array of signed values. It includes all the significant bits for
74/// the value + plus 1 bit to represent the sign. If there are no negative values
75/// then it will not add a signed bit
76fn bitpack_params_for_signed_type<T>(arr: &PrimitiveArray<T>) -> Option<BitpackParams>
77where
78    T: ArrowPrimitiveType,
79    T::Native: PrimInt + AsPrimitive<i64>,
80{
81    let mut add_signed_bit = false;
82    let mut min_leading_bits: Option<u64> = None;
83    for val in arr.iter() {
84        if val.is_none() {
85            continue;
86        }
87        let val = val.unwrap();
88        if min_leading_bits.is_none() {
89            min_leading_bits = Some(u64::MAX);
90        }
91
92        if val.to_i64().unwrap() < 0i64 {
93            min_leading_bits = min_leading_bits.map(|bits| bits.min(val.leading_ones() as u64));
94            add_signed_bit = true;
95        } else {
96            min_leading_bits = min_leading_bits.map(|bits| bits.min(val.leading_zeros() as u64));
97        }
98    }
99
100    let mut min_leading_bits = arr.data_type().byte_width() as u64 * 8 - min_leading_bits?;
101    if add_signed_bit {
102        // Need extra sign bit
103        min_leading_bits += 1;
104    }
105    // cannot bitpack into <1 bit
106    let num_bits = min_leading_bits.max(1);
107    Some(BitpackParams {
108        num_bits,
109        signed: add_signed_bit,
110    })
111}
112#[derive(Debug)]
113pub struct BitpackedArrayEncoder {
114    num_bits: u64,
115    signed_type: bool,
116}
117
118impl BitpackedArrayEncoder {
119    pub fn new(num_bits: u64, signed_type: bool) -> Self {
120        Self {
121            num_bits,
122            signed_type,
123        }
124    }
125}
126
127impl ArrayEncoder for BitpackedArrayEncoder {
128    fn encode(
129        &self,
130        data: DataBlock,
131        _data_type: &DataType,
132        buffer_index: &mut u32,
133    ) -> Result<EncodedArray> {
134        // calculate the total number of bytes we need to allocate for the destination.
135        // this will be the number of items in the source array times the number of bits.
136        let dst_bytes_total = ceil(data.num_values() as usize * self.num_bits as usize, 8);
137
138        let mut dst_buffer = vec![0u8; dst_bytes_total];
139        let mut dst_idx = 0;
140        let mut dst_offset = 0;
141
142        let DataBlock::FixedWidth(unpacked) = data else {
143            return Err(Error::InvalidInput {
144                source: "Bitpacking only supports fixed width data blocks".into(),
145                location: location!(),
146            });
147        };
148
149        pack_bits(
150            &unpacked.data,
151            self.num_bits,
152            &mut dst_buffer,
153            &mut dst_idx,
154            &mut dst_offset,
155        );
156
157        let packed = DataBlock::FixedWidth(FixedWidthDataBlock {
158            bits_per_value: self.num_bits,
159            data: LanceBuffer::Owned(dst_buffer),
160            num_values: unpacked.num_values,
161            block_info: BlockInfo::new(),
162        });
163
164        let bitpacked_buffer_index = *buffer_index;
165        *buffer_index += 1;
166
167        let encoding = ProtobufUtils::bitpacked_encoding(
168            self.num_bits,
169            unpacked.bits_per_value,
170            bitpacked_buffer_index,
171            self.signed_type,
172        );
173
174        Ok(EncodedArray {
175            data: packed,
176            encoding,
177        })
178    }
179}
180
181fn pack_bits(
182    src: &LanceBuffer,
183    num_bits: u64,
184    dst: &mut [u8],
185    dst_idx: &mut usize,
186    dst_offset: &mut u8,
187) {
188    let bit_len = src.len() as u64 * 8;
189
190    let mask = u64::MAX >> (64 - num_bits);
191
192    let mut src_idx = 0;
193    while src_idx < src.len() {
194        let mut curr_mask = mask;
195        let mut curr_src = src[src_idx] & curr_mask as u8;
196        let mut src_offset = 0;
197        let mut src_bits_written = 0;
198
199        while src_bits_written < num_bits {
200            dst[*dst_idx] += (curr_src >> src_offset) << *dst_offset as u64;
201            let bits_written = (num_bits - src_bits_written)
202                .min(8 - src_offset)
203                .min(8 - *dst_offset as u64);
204            src_bits_written += bits_written;
205            *dst_offset += bits_written as u8;
206            src_offset += bits_written;
207
208            if *dst_offset == 8 {
209                *dst_idx += 1;
210                *dst_offset = 0;
211            }
212
213            if src_offset == 8 {
214                src_idx += 1;
215                src_offset = 0;
216                curr_mask >>= 8;
217                if src_idx == src.len() {
218                    break;
219                }
220                curr_src = src[src_idx] & curr_mask as u8;
221            }
222        }
223
224        // advance source_offset to the next byte if we're not at the end..
225        // note that we don't need to do this if we wrote the full number of bits
226        // because source index would have been advanced by the inner loop above
227        if bit_len != num_bits {
228            let partial_bytes_written = ceil(num_bits as usize, 8);
229
230            // we also want to the next location in src, unless we wrote something
231            // byte-aligned in which case the logic above would have already advanced
232            let mut to_next_byte = 1;
233            if num_bits % 8 == 0 {
234                to_next_byte = 0;
235            }
236
237            src_idx += src.len() - partial_bytes_written + to_next_byte;
238        }
239    }
240}
241
242// A physical scheduler for bitpacked buffers
243#[derive(Debug, Clone, Copy)]
244pub struct BitpackedScheduler {
245    bits_per_value: u64,
246    uncompressed_bits_per_value: u64,
247    buffer_offset: u64,
248    signed: bool,
249}
250
251impl BitpackedScheduler {
252    pub fn new(
253        bits_per_value: u64,
254        uncompressed_bits_per_value: u64,
255        buffer_offset: u64,
256        signed: bool,
257    ) -> Self {
258        Self {
259            bits_per_value,
260            uncompressed_bits_per_value,
261            buffer_offset,
262            signed,
263        }
264    }
265}
266
267impl PageScheduler for BitpackedScheduler {
268    fn schedule_ranges(
269        &self,
270        ranges: &[std::ops::Range<u64>],
271        scheduler: &Arc<dyn crate::EncodingsIo>,
272        top_level_row: u64,
273    ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
274        let mut min = u64::MAX;
275        let mut max = 0;
276
277        let mut buffer_bit_start_offsets: Vec<u8> = vec![];
278        let mut buffer_bit_end_offsets: Vec<Option<u8>> = vec![];
279        let byte_ranges = ranges
280            .iter()
281            .map(|range| {
282                let start_byte_offset = range.start * self.bits_per_value / 8;
283                let mut end_byte_offset = range.end * self.bits_per_value / 8;
284                if range.end * self.bits_per_value % 8 != 0 {
285                    // If the end of the range is not byte-aligned, we need to read one more byte
286                    end_byte_offset += 1;
287
288                    let end_bit_offset = range.end * self.bits_per_value % 8;
289                    buffer_bit_end_offsets.push(Some(end_bit_offset as u8));
290                } else {
291                    buffer_bit_end_offsets.push(None);
292                }
293
294                let start_bit_offset = range.start * self.bits_per_value % 8;
295                buffer_bit_start_offsets.push(start_bit_offset as u8);
296
297                let start = self.buffer_offset + start_byte_offset;
298                let end = self.buffer_offset + end_byte_offset;
299                min = min.min(start);
300                max = max.max(end);
301
302                start..end
303            })
304            .collect::<Vec<_>>();
305
306        trace!(
307            "Scheduling I/O for {} ranges spread across byte range {}..{}",
308            byte_ranges.len(),
309            min,
310            max
311        );
312
313        let bytes = scheduler.submit_request(byte_ranges, top_level_row);
314
315        let bits_per_value = self.bits_per_value;
316        let uncompressed_bits_per_value = self.uncompressed_bits_per_value;
317        let signed = self.signed;
318        async move {
319            let bytes = bytes.await?;
320            Ok(Box::new(BitpackedPageDecoder {
321                buffer_bit_start_offsets,
322                buffer_bit_end_offsets,
323                bits_per_value,
324                uncompressed_bits_per_value,
325                signed,
326                data: bytes,
327            }) as Box<dyn PrimitivePageDecoder>)
328        }
329        .boxed()
330    }
331}
332
333#[derive(Debug)]
334struct BitpackedPageDecoder {
335    // bit offsets of the first value within each buffer
336    buffer_bit_start_offsets: Vec<u8>,
337
338    // bit offsets of the last value within each buffer. e.g. if there was a buffer
339    // with 2 values, packed into 5 bits, this would be [Some(3)], indicating that
340    // the bits from the 3rd->8th bit in the last byte shouldn't be decoded.
341    buffer_bit_end_offsets: Vec<Option<u8>>,
342
343    // the number of bits used to represent a compressed value. E.g. if the max value
344    // in the page was 7 (0b111), then this will be 3
345    bits_per_value: u64,
346
347    // number of bits in the uncompressed value. E.g. this will be 32 for u32
348    uncompressed_bits_per_value: u64,
349
350    // whether or not to use the msb as a sign bit during decoding
351    signed: bool,
352
353    data: Vec<Bytes>,
354}
355
356impl PrimitivePageDecoder for BitpackedPageDecoder {
357    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
358        let num_bytes = self.uncompressed_bits_per_value / 8 * num_rows;
359        let mut dest = vec![0; num_bytes as usize];
360
361        // current maximum supported bits per value = 64
362        debug_assert!(self.bits_per_value <= 64);
363
364        let mut rows_to_skip = rows_to_skip;
365        let mut rows_taken = 0;
366        let byte_len = self.uncompressed_bits_per_value / 8;
367        let mut dst_idx = 0; // index for current byte being written to destination buffer
368
369        // create bit mask for source bits
370        let mask = u64::MAX >> (64 - self.bits_per_value);
371
372        for i in 0..self.data.len() {
373            let src = &self.data[i];
374            let (mut src_idx, mut src_offset) = match compute_start_offset(
375                rows_to_skip,
376                src.len(),
377                self.bits_per_value,
378                self.buffer_bit_start_offsets[i],
379                self.buffer_bit_end_offsets[i],
380            ) {
381                StartOffset::SkipFull(rows_to_skip_here) => {
382                    rows_to_skip -= rows_to_skip_here;
383                    continue;
384                }
385                StartOffset::SkipSome(buffer_start_offset) => (
386                    buffer_start_offset.index,
387                    buffer_start_offset.bit_offset as u64,
388                ),
389            };
390
391            while src_idx < src.len() && rows_taken < num_rows {
392                rows_taken += 1;
393                let mut curr_mask = mask; // copy mask
394
395                // current source byte being written to destination
396                let mut curr_src = src[src_idx] & (curr_mask << src_offset) as u8;
397
398                // how many bits from the current source value have been written to destination
399                let mut src_bits_written = 0;
400
401                // the offset within the current destination byte to write to
402                let mut dst_offset = 0;
403
404                let is_negative = is_encoded_item_negative(
405                    src,
406                    src_idx,
407                    src_offset,
408                    self.bits_per_value as usize,
409                );
410
411                while src_bits_written < self.bits_per_value {
412                    // write bits from current source byte into destination
413                    dest[dst_idx] += (curr_src >> src_offset) << dst_offset;
414                    let bits_written = (self.bits_per_value - src_bits_written)
415                        .min(8 - src_offset)
416                        .min(8 - dst_offset);
417                    src_bits_written += bits_written;
418                    dst_offset += bits_written;
419                    src_offset += bits_written;
420                    curr_mask >>= bits_written;
421
422                    if dst_offset == 8 {
423                        dst_idx += 1;
424                        dst_offset = 0;
425                    }
426
427                    if src_offset == 8 {
428                        src_idx += 1;
429                        src_offset = 0;
430                        if src_idx == src.len() {
431                            break;
432                        }
433                        curr_src = src[src_idx] & curr_mask as u8;
434                    }
435                }
436
437                // if the type is signed, need to pad out the rest of the byte with 1s
438                let mut negative_padded_current_byte = false;
439                if self.signed && is_negative && dst_offset > 0 {
440                    negative_padded_current_byte = true;
441                    while dst_offset < 8 {
442                        dest[dst_idx] |= 1 << dst_offset;
443                        dst_offset += 1;
444                    }
445                }
446
447                // advance destination offset to the next location
448                // note that we don't need to do this if we wrote the full number of bits
449                // because source index would have been advanced by the inner loop above
450                if self.uncompressed_bits_per_value != self.bits_per_value {
451                    let partial_bytes_written = ceil(self.bits_per_value as usize, 8);
452
453                    // we also want to move one location to the next location in destination,
454                    // unless we wrote something byte-aligned in which case the logic above
455                    // would have already advanced dst_idx
456                    let mut to_next_byte = 1;
457                    if self.bits_per_value % 8 == 0 {
458                        to_next_byte = 0;
459                    }
460                    let next_dst_idx =
461                        dst_idx + byte_len as usize - partial_bytes_written + to_next_byte;
462
463                    // pad remaining bytes with 1 for negative signed numbers
464                    if self.signed && is_negative {
465                        if !negative_padded_current_byte {
466                            dest[dst_idx] = 0xFF;
467                        }
468                        for i in dest.iter_mut().take(next_dst_idx).skip(dst_idx + 1) {
469                            *i = 0xFF;
470                        }
471                    }
472
473                    dst_idx = next_dst_idx;
474                }
475
476                // If we've reached the last byte, there may be some extra bits from the
477                // next value outside the range. We don't want to be taking those.
478                if let Some(buffer_bit_end_offset) = self.buffer_bit_end_offsets[i] {
479                    if src_idx == src.len() - 1 && src_offset >= buffer_bit_end_offset as u64 {
480                        break;
481                    }
482                }
483            }
484        }
485
486        Ok(DataBlock::FixedWidth(FixedWidthDataBlock {
487            data: LanceBuffer::from(dest),
488            bits_per_value: self.uncompressed_bits_per_value,
489            num_values: num_rows,
490            block_info: BlockInfo::new(),
491        }))
492    }
493}
494
495fn is_encoded_item_negative(src: &Bytes, src_idx: usize, src_offset: u64, num_bits: usize) -> bool {
496    let mut last_byte_idx = src_idx + ((src_offset as usize + num_bits) / 8);
497    let shift_amount = (src_offset as usize + num_bits) % 8;
498    let shift_amount = if shift_amount == 0 {
499        last_byte_idx -= 1;
500        7
501    } else {
502        shift_amount - 1
503    };
504    let last_byte = src[last_byte_idx];
505    let sign_bit_mask = 1 << shift_amount;
506    let sign_bit = last_byte & sign_bit_mask;
507
508    sign_bit > 0
509}
510
511#[derive(Debug, PartialEq)]
512struct BufferStartOffset {
513    index: usize,
514    bit_offset: u8,
515}
516
517#[derive(Debug, PartialEq)]
518enum StartOffset {
519    // skip the full buffer. The value is how many rows are skipped
520    // by skipping the full buffer (e.g., # rows in buffer)
521    SkipFull(u64),
522
523    // skip to some start offset in the buffer
524    SkipSome(BufferStartOffset),
525}
526
527/// compute how far ahead in this buffer should we skip ahead and start reading
528///
529/// * `rows_to_skip` - how many rows to skip
530/// * `buffer_len` - length buf buffer (in bytes)
531/// * `bits_per_value` - number of bits used to represent a single bitpacked value
532/// * `buffer_start_bit_offset` - offset of the start of the first value within the
533///     buffer's  first byte
534/// * `buffer_end_bit_offset` - end bit of the last value within the buffer. Can be
535///     `None` if the end of the last value is byte aligned with end of buffer.
536fn compute_start_offset(
537    rows_to_skip: u64,
538    buffer_len: usize,
539    bits_per_value: u64,
540    buffer_start_bit_offset: u8,
541    buffer_end_bit_offset: Option<u8>,
542) -> StartOffset {
543    let rows_in_buffer = rows_in_buffer(
544        buffer_len,
545        bits_per_value,
546        buffer_start_bit_offset,
547        buffer_end_bit_offset,
548    );
549    if rows_to_skip >= rows_in_buffer {
550        return StartOffset::SkipFull(rows_in_buffer);
551    }
552
553    let start_bit = rows_to_skip * bits_per_value + buffer_start_bit_offset as u64;
554    let start_byte = start_bit / 8;
555
556    StartOffset::SkipSome(BufferStartOffset {
557        index: start_byte as usize,
558        bit_offset: (start_bit % 8) as u8,
559    })
560}
561
562/// calculates the number of rows in a buffer
563fn rows_in_buffer(
564    buffer_len: usize,
565    bits_per_value: u64,
566    buffer_start_bit_offset: u8,
567    buffer_end_bit_offset: Option<u8>,
568) -> u64 {
569    let mut bits_in_buffer = (buffer_len * 8) as u64 - buffer_start_bit_offset as u64;
570
571    // if the end of the last value of the buffer isn't byte aligned, subtract the
572    // end offset from the total number of bits in buffer
573    if let Some(buffer_end_bit_offset) = buffer_end_bit_offset {
574        bits_in_buffer -= (8 - buffer_end_bit_offset) as u64;
575    }
576
577    bits_in_buffer / bits_per_value
578}
579
580#[cfg(test)]
581pub mod test {
582    use crate::{
583        format::pb,
584        testing::{check_round_trip_encoding_generated, ArrayGeneratorProvider},
585        version::LanceFileVersion,
586    };
587
588    use super::*;
589    use std::{marker::PhantomData, sync::Arc};
590
591    use arrow_array::{
592        types::{UInt16Type, UInt8Type},
593        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
594        UInt16Array, UInt32Array, UInt64Array, UInt8Array,
595    };
596
597    use arrow_schema::Field;
598    use lance_datagen::{
599        array::{fill, rand_with_distribution},
600        gen, ArrayGenerator, ArrayGeneratorExt, RowCount,
601    };
602    use rand::distributions::Uniform;
603
604    #[test]
605    fn test_bitpack_params() {
606        fn gen_array(generator: Box<dyn ArrayGenerator>) -> ArrayRef {
607            let arr = gen()
608                .anon_col(generator)
609                .into_batch_rows(RowCount::from(10000))
610                .unwrap()
611                .column(0)
612                .clone();
613
614            arr
615        }
616
617        macro_rules! do_test {
618            ($num_bits:expr, $data_type:ident, $null_probability:expr) => {
619                let max = 1 << $num_bits - 1;
620                let mut arr =
621                    gen_array(fill::<$data_type>(max).with_random_nulls($null_probability));
622
623                // ensure we don't randomly generate all nulls, that won't work
624                while arr.null_count() == arr.len() {
625                    arr = gen_array(fill::<$data_type>(max).with_random_nulls($null_probability));
626                }
627                let result = bitpack_params(arr.as_ref());
628                assert!(result.is_some());
629                assert_eq!($num_bits, result.unwrap().num_bits);
630            };
631        }
632
633        let test_cases = vec![
634            (5u64, 0.0f64),
635            (5u64, 0.9f64),
636            (1u64, 0.0f64),
637            (1u64, 0.5f64),
638            (8u64, 0.0f64),
639            (8u64, 0.5f64),
640        ];
641
642        for (num_bits, null_probability) in &test_cases {
643            do_test!(*num_bits, UInt8Type, *null_probability);
644            do_test!(*num_bits, UInt16Type, *null_probability);
645            do_test!(*num_bits, UInt32Type, *null_probability);
646            do_test!(*num_bits, UInt64Type, *null_probability);
647        }
648
649        // do some test cases that that will only work on larger types
650        let test_cases = vec![
651            (13u64, 0.0f64),
652            (13u64, 0.5f64),
653            (16u64, 0.0f64),
654            (16u64, 0.5f64),
655        ];
656        for (num_bits, null_probability) in &test_cases {
657            do_test!(*num_bits, UInt16Type, *null_probability);
658            do_test!(*num_bits, UInt32Type, *null_probability);
659            do_test!(*num_bits, UInt64Type, *null_probability);
660        }
661        let test_cases = vec![
662            (25u64, 0.0f64),
663            (25u64, 0.5f64),
664            (32u64, 0.0f64),
665            (32u64, 0.5f64),
666        ];
667        for (num_bits, null_probability) in &test_cases {
668            do_test!(*num_bits, UInt32Type, *null_probability);
669            do_test!(*num_bits, UInt64Type, *null_probability);
670        }
671        let test_cases = vec![
672            (48u64, 0.0f64),
673            (48u64, 0.5f64),
674            (64u64, 0.0f64),
675            (64u64, 0.5f64),
676        ];
677        for (num_bits, null_probability) in &test_cases {
678            do_test!(*num_bits, UInt64Type, *null_probability);
679        }
680
681        // test that it returns None for datatypes that don't support bitpacking
682        let arr = Float64Array::from_iter_values(vec![0.1, 0.2, 0.3]);
683        let result = bitpack_params(&arr);
684        assert!(result.is_none());
685    }
686
687    #[test]
688    fn test_num_compressed_bits_signed_types() {
689        let values = Int32Array::from(vec![1, 2, -7]);
690        let arr = values;
691        let result = bitpack_params(&arr);
692        assert!(result.is_some());
693        let result = result.unwrap();
694        assert_eq!(4, result.num_bits);
695        assert!(result.signed);
696
697        // check that it doesn't add a sign bit if it doesn't need to
698        let values = Int32Array::from(vec![1, 2, 7]);
699        let arr = values;
700        let result = bitpack_params(&arr);
701        assert!(result.is_some());
702        let result = result.unwrap();
703        assert_eq!(3, result.num_bits);
704        assert!(!result.signed);
705    }
706
707    #[test]
708    fn test_rows_in_buffer() {
709        let test_cases = vec![
710            (5usize, 5u64, 0u8, None, 8u64),
711            (2, 3, 0, Some(5), 4),
712            (2, 3, 7, Some(6), 2),
713        ];
714
715        for (
716            buffer_len,
717            bits_per_value,
718            buffer_start_bit_offset,
719            buffer_end_bit_offset,
720            expected,
721        ) in test_cases
722        {
723            let result = rows_in_buffer(
724                buffer_len,
725                bits_per_value,
726                buffer_start_bit_offset,
727                buffer_end_bit_offset,
728            );
729            assert_eq!(expected, result);
730        }
731    }
732
733    #[test]
734    fn test_compute_start_offset() {
735        let result = compute_start_offset(0, 5, 5, 0, None);
736        assert_eq!(
737            StartOffset::SkipSome(BufferStartOffset {
738                index: 0,
739                bit_offset: 0
740            }),
741            result
742        );
743
744        let result = compute_start_offset(10, 5, 5, 0, None);
745        assert_eq!(StartOffset::SkipFull(8), result);
746    }
747
748    #[test_log::test(test)]
749    fn test_will_bitpack_allowed_types_when_possible() {
750        let test_cases: Vec<(DataType, ArrayRef, u64)> = vec![
751            (
752                DataType::UInt8,
753                Arc::new(UInt8Array::from_iter_values(vec![0, 1, 2, 3, 4, 5])),
754                3, // bits per value
755            ),
756            (
757                DataType::UInt16,
758                Arc::new(UInt16Array::from_iter_values(vec![0, 1, 2, 3, 4, 5 << 8])),
759                11,
760            ),
761            (
762                DataType::UInt32,
763                Arc::new(UInt32Array::from_iter_values(vec![0, 1, 2, 3, 4, 5 << 16])),
764                19,
765            ),
766            (
767                DataType::UInt64,
768                Arc::new(UInt64Array::from_iter_values(vec![0, 1, 2, 3, 4, 5 << 32])),
769                35,
770            ),
771            (
772                DataType::Int8,
773                Arc::new(Int8Array::from_iter_values(vec![0, 2, 3, 4, -5])),
774                4,
775            ),
776            (
777                // check it will not pack with signed bit if all values of signed type are positive
778                DataType::Int8,
779                Arc::new(Int8Array::from_iter_values(vec![0, 2, 3, 4, 5])),
780                3,
781            ),
782            (
783                DataType::Int16,
784                Arc::new(Int16Array::from_iter_values(vec![0, 1, 2, 3, -4, 5 << 8])),
785                12,
786            ),
787            (
788                DataType::Int32,
789                Arc::new(Int32Array::from_iter_values(vec![0, 1, 2, 3, 4, -5 << 16])),
790                20,
791            ),
792            (
793                DataType::Int64,
794                Arc::new(Int64Array::from_iter_values(vec![
795                    0,
796                    1,
797                    2,
798                    -3,
799                    -4,
800                    -5 << 32,
801                ])),
802                36,
803            ),
804        ];
805
806        for (data_type, arr, bits_per_value) in test_cases {
807            let mut buffed_index = 1;
808            let params = bitpack_params(arr.as_ref()).unwrap();
809            let encoder = BitpackedArrayEncoder {
810                num_bits: params.num_bits,
811                signed_type: params.signed,
812            };
813            let data = DataBlock::from_array(arr);
814            let result = encoder.encode(data, &data_type, &mut buffed_index).unwrap();
815
816            let data = result.data.as_fixed_width().unwrap();
817            assert_eq!(bits_per_value, data.bits_per_value);
818
819            let array_encoding = result.encoding.array_encoding.unwrap();
820
821            match array_encoding {
822                pb::array_encoding::ArrayEncoding::Bitpacked(bitpacked) => {
823                    assert_eq!(bits_per_value, bitpacked.compressed_bits_per_value);
824                    assert_eq!(
825                        (data_type.byte_width() * 8) as u64,
826                        bitpacked.uncompressed_bits_per_value
827                    );
828                }
829                _ => {
830                    panic!("Array did not use bitpacking encoding")
831                }
832            }
833        }
834
835        // check it will otherwise use flat encoding
836        let test_cases: Vec<(DataType, ArrayRef)> = vec![
837            // it should use flat encoding for datatypes that don't support bitpacking
838            (
839                DataType::Float32,
840                Arc::new(Float32Array::from_iter_values(vec![0.1, 0.2, 0.3])),
841            ),
842            // it should still use flat encoding if bitpacked encoding would be packed
843            // into the full byte range
844            (
845                DataType::UInt8,
846                Arc::new(UInt8Array::from_iter_values(vec![0, 1, 2, 3, 4, 250])),
847            ),
848            (
849                DataType::UInt16,
850                Arc::new(UInt16Array::from_iter_values(vec![0, 1, 2, 3, 4, 250 << 8])),
851            ),
852            (
853                DataType::UInt32,
854                Arc::new(UInt32Array::from_iter_values(vec![
855                    0,
856                    1,
857                    2,
858                    3,
859                    4,
860                    250 << 24,
861                ])),
862            ),
863            (
864                DataType::UInt64,
865                Arc::new(UInt64Array::from_iter_values(vec![
866                    0,
867                    1,
868                    2,
869                    3,
870                    4,
871                    250 << 56,
872                ])),
873            ),
874            (
875                DataType::Int8,
876                Arc::new(Int8Array::from_iter_values(vec![-100])),
877            ),
878            (
879                DataType::Int16,
880                Arc::new(Int16Array::from_iter_values(vec![-100 << 8])),
881            ),
882            (
883                DataType::Int32,
884                Arc::new(Int32Array::from_iter_values(vec![-100 << 24])),
885            ),
886            (
887                DataType::Int64,
888                Arc::new(Int64Array::from_iter_values(vec![-100 << 56])),
889            ),
890        ];
891
892        for (data_type, arr) in test_cases {
893            if let Some(params) = bitpack_params(arr.as_ref()) {
894                assert_eq!(params.num_bits, data_type.byte_width() as u64 * 8);
895            }
896        }
897    }
898
899    struct DistributionArrayGeneratorProvider<
900        DataType,
901        Dist: rand::distributions::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
902    >
903    where
904        DataType::Native: Copy + 'static,
905        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
906        DataType: ArrowPrimitiveType,
907    {
908        phantom: PhantomData<DataType>,
909        distribution: Dist,
910    }
911
912    impl<DataType, Dist> DistributionArrayGeneratorProvider<DataType, Dist>
913    where
914        Dist: rand::distributions::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
915        DataType::Native: Copy + 'static,
916        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
917        DataType: ArrowPrimitiveType,
918    {
919        fn new(dist: Dist) -> Self {
920            Self {
921                distribution: dist,
922                phantom: Default::default(),
923            }
924        }
925    }
926
927    impl<DataType, Dist> ArrayGeneratorProvider for DistributionArrayGeneratorProvider<DataType, Dist>
928    where
929        Dist: rand::distributions::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
930        DataType::Native: Copy + 'static,
931        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
932        DataType: ArrowPrimitiveType,
933    {
934        fn provide(&self) -> Box<dyn ArrayGenerator> {
935            rand_with_distribution::<DataType, Dist>(self.distribution.clone())
936        }
937
938        fn copy(&self) -> Box<dyn ArrayGeneratorProvider> {
939            Box::new(Self {
940                phantom: self.phantom,
941                distribution: self.distribution.clone(),
942            })
943        }
944    }
945
946    #[test_log::test(tokio::test)]
947    async fn test_bitpack_primitive() {
948        let bitpacked_test_cases: &Vec<(DataType, Box<dyn ArrayGeneratorProvider>)> = &vec![
949            // check less than one byte for multi-byte type
950            (
951                DataType::UInt32,
952                Box::new(
953                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
954                        Uniform::new(0, 19),
955                    ),
956                ),
957            ),
958            // // check that more than one byte for multi-byte type
959            (
960                DataType::UInt32,
961                Box::new(
962                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
963                        Uniform::new(5 << 7, 6 << 7),
964                    ),
965                ),
966            ),
967            (
968                DataType::UInt64,
969                Box::new(
970                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
971                        Uniform::new(5 << 42, 6 << 42),
972                    ),
973                ),
974            ),
975            // check less than one byte for single-byte type
976            (
977                DataType::UInt8,
978                Box::new(
979                    DistributionArrayGeneratorProvider::<UInt8Type, Uniform<u8>>::new(
980                        Uniform::new(0, 19),
981                    ),
982                ),
983            ),
984            // check less than one byte for single-byte type
985            (
986                DataType::UInt64,
987                Box::new(
988                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
989                        Uniform::new(129, 259),
990                    ),
991                ),
992            ),
993            // check byte aligned for single byte
994            (
995                DataType::UInt32,
996                Box::new(
997                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
998                        // this range should always give 8 bits
999                        Uniform::new(200, 250),
1000                    ),
1001                ),
1002            ),
1003            // check where the num_bits divides evenly into the bit length of the type
1004            (
1005                DataType::UInt64,
1006                Box::new(
1007                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1008                        Uniform::new(1, 3), // 2 bits
1009                    ),
1010                ),
1011            ),
1012            // check byte aligned for multiple bytes
1013            (
1014                DataType::UInt32,
1015                Box::new(
1016                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1017                        // this range should always always give 16 bits
1018                        Uniform::new(200 << 8, 250 << 8),
1019                    ),
1020                ),
1021            ),
1022            // check byte aligned where the num bits doesn't divide evenly into the byte length
1023            (
1024                DataType::UInt64,
1025                Box::new(
1026                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1027                        // this range should always give 24 hits
1028                        Uniform::new(200 << 16, 250 << 16),
1029                    ),
1030                ),
1031            ),
1032            // check that we can still encode an all-0 array
1033            (
1034                DataType::UInt32,
1035                Box::new(
1036                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1037                        Uniform::new(0, 1),
1038                    ),
1039                ),
1040            ),
1041            // check for signed types
1042            (
1043                DataType::Int16,
1044                Box::new(
1045                    DistributionArrayGeneratorProvider::<Int16Type, Uniform<i16>>::new(
1046                        Uniform::new(-5, 5),
1047                    ),
1048                ),
1049            ),
1050            (
1051                DataType::Int64,
1052                Box::new(
1053                    DistributionArrayGeneratorProvider::<Int64Type, Uniform<i64>>::new(
1054                        Uniform::new(-(5 << 42), 6 << 42),
1055                    ),
1056                ),
1057            ),
1058            (
1059                DataType::Int32,
1060                Box::new(
1061                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1062                        Uniform::new(-(5 << 7), 6 << 7),
1063                    ),
1064                ),
1065            ),
1066            // check signed where packed to < 1 byte for multi-byte type
1067            (
1068                DataType::Int32,
1069                Box::new(
1070                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1071                        Uniform::new(-19, 19),
1072                    ),
1073                ),
1074            ),
1075            // check signed byte aligned to single byte
1076            (
1077                DataType::Int32,
1078                Box::new(
1079                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1080                        // this range should always give 8 bits
1081                        Uniform::new(-120, 120),
1082                    ),
1083                ),
1084            ),
1085            // check signed byte aligned to multiple bytes
1086            (
1087                DataType::Int32,
1088                Box::new(
1089                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1090                        // this range should always give 16 bits
1091                        Uniform::new(-120 << 8, 120 << 8),
1092                    ),
1093                ),
1094            ),
1095            // check that it works for all positive integers even if type is signed
1096            (
1097                DataType::Int32,
1098                Box::new(
1099                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1100                        Uniform::new(10, 20),
1101                    ),
1102                ),
1103            ),
1104            // check that all 0 works for signed type
1105            (
1106                DataType::Int32,
1107                Box::new(
1108                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1109                        Uniform::new(0, 1),
1110                    ),
1111                ),
1112            ),
1113        ];
1114
1115        for (data_type, array_gen_provider) in bitpacked_test_cases {
1116            let field = Field::new("", data_type.clone(), false);
1117            check_round_trip_encoding_generated(
1118                field,
1119                array_gen_provider.copy(),
1120                LanceFileVersion::V2_1,
1121            )
1122            .await;
1123        }
1124    }
1125}