arrow_buffer/util/
bit_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utils for working with bits
19
20/// Returns the nearest number that is `>=` than `num` and is a multiple of 64
21#[inline]
22pub fn round_upto_multiple_of_64(num: usize) -> usize {
23    round_upto_power_of_2(num, 64)
24}
25
26/// Returns the nearest multiple of `factor` that is `>=` than `num`. Here `factor` must
27/// be a power of 2.
28pub fn round_upto_power_of_2(num: usize, factor: usize) -> usize {
29    debug_assert!(factor > 0 && factor.is_power_of_two());
30    num.checked_add(factor - 1)
31        .expect("failed to round to next highest power of 2")
32        & !(factor - 1)
33}
34
35/// Returns whether bit at position `i` in `data` is set or not
36#[inline]
37pub fn get_bit(data: &[u8], i: usize) -> bool {
38    data[i / 8] & (1 << (i % 8)) != 0
39}
40
41/// Returns whether bit at position `i` in `data` is set or not.
42///
43/// # Safety
44///
45/// Note this doesn't do any bound checking, for performance reason. The caller is
46/// responsible to guarantee that `i` is within bounds.
47#[inline]
48pub unsafe fn get_bit_raw(data: *const u8, i: usize) -> bool {
49    (*data.add(i / 8) & (1 << (i % 8))) != 0
50}
51
52/// Sets bit at position `i` for `data` to 1
53#[inline]
54pub fn set_bit(data: &mut [u8], i: usize) {
55    data[i / 8] |= 1 << (i % 8);
56}
57
58/// Sets bit at position `i` for `data`
59///
60/// # Safety
61///
62/// Note this doesn't do any bound checking, for performance reason. The caller is
63/// responsible to guarantee that `i` is within bounds.
64#[inline]
65pub unsafe fn set_bit_raw(data: *mut u8, i: usize) {
66    *data.add(i / 8) |= 1 << (i % 8);
67}
68
69/// Sets bit at position `i` for `data` to 0
70#[inline]
71pub fn unset_bit(data: &mut [u8], i: usize) {
72    data[i / 8] &= !(1 << (i % 8));
73}
74
75/// Sets bit at position `i` for `data` to 0
76///
77/// # Safety
78///
79/// Note this doesn't do any bound checking, for performance reason. The caller is
80/// responsible to guarantee that `i` is within bounds.
81#[inline]
82pub unsafe fn unset_bit_raw(data: *mut u8, i: usize) {
83    *data.add(i / 8) &= !(1 << (i % 8));
84}
85
86/// Returns the ceil of `value`/`divisor`
87#[inline]
88pub fn ceil(value: usize, divisor: usize) -> usize {
89    // Rewrite as `value.div_ceil(&divisor)` after
90    // https://github.com/rust-lang/rust/issues/88581 is merged.
91    value / divisor + (0 != value % divisor) as usize
92}
93
94#[cfg(test)]
95mod tests {
96    use std::collections::HashSet;
97
98    use super::*;
99    use rand::rngs::StdRng;
100    use rand::{Rng, SeedableRng};
101
102    #[test]
103    fn test_round_upto_multiple_of_64() {
104        assert_eq!(0, round_upto_multiple_of_64(0));
105        assert_eq!(64, round_upto_multiple_of_64(1));
106        assert_eq!(64, round_upto_multiple_of_64(63));
107        assert_eq!(64, round_upto_multiple_of_64(64));
108        assert_eq!(128, round_upto_multiple_of_64(65));
109        assert_eq!(192, round_upto_multiple_of_64(129));
110    }
111
112    #[test]
113    #[should_panic(expected = "failed to round to next highest power of 2")]
114    fn test_round_upto_panic() {
115        let _ = round_upto_power_of_2(usize::MAX, 2);
116    }
117
118    #[test]
119    fn test_get_bit() {
120        // 00001101
121        assert!(get_bit(&[0b00001101], 0));
122        assert!(!get_bit(&[0b00001101], 1));
123        assert!(get_bit(&[0b00001101], 2));
124        assert!(get_bit(&[0b00001101], 3));
125
126        // 01001001 01010010
127        assert!(get_bit(&[0b01001001, 0b01010010], 0));
128        assert!(!get_bit(&[0b01001001, 0b01010010], 1));
129        assert!(!get_bit(&[0b01001001, 0b01010010], 2));
130        assert!(get_bit(&[0b01001001, 0b01010010], 3));
131        assert!(!get_bit(&[0b01001001, 0b01010010], 4));
132        assert!(!get_bit(&[0b01001001, 0b01010010], 5));
133        assert!(get_bit(&[0b01001001, 0b01010010], 6));
134        assert!(!get_bit(&[0b01001001, 0b01010010], 7));
135        assert!(!get_bit(&[0b01001001, 0b01010010], 8));
136        assert!(get_bit(&[0b01001001, 0b01010010], 9));
137        assert!(!get_bit(&[0b01001001, 0b01010010], 10));
138        assert!(!get_bit(&[0b01001001, 0b01010010], 11));
139        assert!(get_bit(&[0b01001001, 0b01010010], 12));
140        assert!(!get_bit(&[0b01001001, 0b01010010], 13));
141        assert!(get_bit(&[0b01001001, 0b01010010], 14));
142        assert!(!get_bit(&[0b01001001, 0b01010010], 15));
143    }
144
145    pub fn seedable_rng() -> StdRng {
146        StdRng::seed_from_u64(42)
147    }
148
149    #[test]
150    fn test_get_bit_raw() {
151        const NUM_BYTE: usize = 10;
152        let mut buf = [0; NUM_BYTE];
153        let mut expected = vec![];
154        let mut rng = seedable_rng();
155        for i in 0..8 * NUM_BYTE {
156            let b = rng.gen_bool(0.5);
157            expected.push(b);
158            if b {
159                set_bit(&mut buf[..], i)
160            }
161        }
162
163        let raw_ptr = buf.as_ptr();
164        for (i, b) in expected.iter().enumerate() {
165            unsafe {
166                assert_eq!(*b, get_bit_raw(raw_ptr, i));
167            }
168        }
169    }
170
171    #[test]
172    fn test_set_bit() {
173        let mut b = [0b00000010];
174        set_bit(&mut b, 0);
175        assert_eq!([0b00000011], b);
176        set_bit(&mut b, 1);
177        assert_eq!([0b00000011], b);
178        set_bit(&mut b, 7);
179        assert_eq!([0b10000011], b);
180    }
181
182    #[test]
183    fn test_unset_bit() {
184        let mut b = [0b11111101];
185        unset_bit(&mut b, 0);
186        assert_eq!([0b11111100], b);
187        unset_bit(&mut b, 1);
188        assert_eq!([0b11111100], b);
189        unset_bit(&mut b, 7);
190        assert_eq!([0b01111100], b);
191    }
192
193    #[test]
194    fn test_set_bit_raw() {
195        const NUM_BYTE: usize = 10;
196        let mut buf = vec![0; NUM_BYTE];
197        let mut expected = vec![];
198        let mut rng = seedable_rng();
199        for i in 0..8 * NUM_BYTE {
200            let b = rng.gen_bool(0.5);
201            expected.push(b);
202            if b {
203                unsafe {
204                    set_bit_raw(buf.as_mut_ptr(), i);
205                }
206            }
207        }
208
209        let raw_ptr = buf.as_ptr();
210        for (i, b) in expected.iter().enumerate() {
211            unsafe {
212                assert_eq!(*b, get_bit_raw(raw_ptr, i));
213            }
214        }
215    }
216
217    #[test]
218    fn test_unset_bit_raw() {
219        const NUM_BYTE: usize = 10;
220        let mut buf = vec![255; NUM_BYTE];
221        let mut expected = vec![];
222        let mut rng = seedable_rng();
223        for i in 0..8 * NUM_BYTE {
224            let b = rng.gen_bool(0.5);
225            expected.push(b);
226            if !b {
227                unsafe {
228                    unset_bit_raw(buf.as_mut_ptr(), i);
229                }
230            }
231        }
232
233        let raw_ptr = buf.as_ptr();
234        for (i, b) in expected.iter().enumerate() {
235            unsafe {
236                assert_eq!(*b, get_bit_raw(raw_ptr, i));
237            }
238        }
239    }
240
241    #[test]
242    fn test_get_set_bit_roundtrip() {
243        const NUM_BYTES: usize = 10;
244        const NUM_SETS: usize = 10;
245
246        let mut buffer: [u8; NUM_BYTES * 8] = [0; NUM_BYTES * 8];
247        let mut v = HashSet::new();
248        let mut rng = seedable_rng();
249        for _ in 0..NUM_SETS {
250            let offset = rng.gen_range(0..8 * NUM_BYTES);
251            v.insert(offset);
252            set_bit(&mut buffer[..], offset);
253        }
254        for i in 0..NUM_BYTES * 8 {
255            assert_eq!(v.contains(&i), get_bit(&buffer[..], i));
256        }
257    }
258
259    #[test]
260    fn test_ceil() {
261        assert_eq!(ceil(0, 1), 0);
262        assert_eq!(ceil(1, 1), 1);
263        assert_eq!(ceil(1, 2), 1);
264        assert_eq!(ceil(1, 8), 1);
265        assert_eq!(ceil(7, 8), 1);
266        assert_eq!(ceil(8, 8), 1);
267        assert_eq!(ceil(9, 8), 2);
268        assert_eq!(ceil(9, 9), 1);
269        assert_eq!(ceil(10000000000, 10), 1000000000);
270        assert_eq!(ceil(10, 10000000000), 1);
271        assert_eq!(ceil(10000000000, 1000000000), 10);
272    }
273}