lance_arrow/
bfloat16.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! bfloat16 support for Apache Arrow.
5
6use std::fmt::Formatter;
7use std::slice;
8
9use arrow_array::{
10    builder::BooleanBufferBuilder, iterator::ArrayIter, Array, ArrayAccessor, ArrayRef,
11    FixedSizeBinaryArray,
12};
13use arrow_buffer::MutableBuffer;
14use arrow_data::ArrayData;
15use arrow_schema::{ArrowError, DataType, Field as ArrowField};
16use half::bf16;
17
18use crate::FloatArray;
19
20pub const ARROW_EXT_NAME_KEY: &str = "ARROW:extension:name";
21pub const ARROW_EXT_META_KEY: &str = "ARROW:extension:metadata";
22pub const BFLOAT16_EXT_NAME: &str = "lance.bfloat16";
23
24/// Check whether the given field is a bfloat16 field.
25pub fn is_bfloat16_field(field: &ArrowField) -> bool {
26    field.data_type() == &DataType::FixedSizeBinary(2)
27        && field
28            .metadata()
29            .get(ARROW_EXT_NAME_KEY)
30            .map(|name| name == BFLOAT16_EXT_NAME)
31            .unwrap_or_default()
32}
33
34#[derive(Debug)]
35pub struct BFloat16Type {}
36
37#[derive(Clone)]
38pub struct BFloat16Array {
39    inner: FixedSizeBinaryArray,
40}
41
42impl std::fmt::Debug for BFloat16Array {
43    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44        write!(f, "BFloat16Array\n[\n")?;
45        from_arrow::print_long_array(&self.inner, f, |array, i, f| {
46            if array.is_null(i) {
47                write!(f, "null")
48            } else {
49                let binary_values = array.value(i);
50                let value =
51                    bf16::from_bits(u16::from_le_bytes([binary_values[0], binary_values[1]]));
52                write!(f, "{:?}", value)
53            }
54        })?;
55        write!(f, "]")
56    }
57}
58
59impl BFloat16Array {
60    pub fn from_iter_values(iter: impl IntoIterator<Item = bf16>) -> Self {
61        let values: Vec<bf16> = iter.into_iter().collect();
62        values.into()
63    }
64
65    pub fn iter(&self) -> BFloat16Iter {
66        BFloat16Iter::new(self)
67    }
68
69    pub fn value(&self, i: usize) -> bf16 {
70        assert!(
71            i < self.len(),
72            "Trying to access an element at index {} from a BFloat16Array of length {}",
73            i,
74            self.len()
75        );
76        // Safety:
77        // `i < self.len()
78        unsafe { self.value_unchecked(i) }
79    }
80
81    /// # Safety
82    /// Caller must ensure that `i < self.len()`
83    pub unsafe fn value_unchecked(&self, i: usize) -> bf16 {
84        let binary_value = self.inner.value_unchecked(i);
85        bf16::from_bits(u16::from_le_bytes([binary_value[0], binary_value[1]]))
86    }
87
88    pub fn into_inner(self) -> FixedSizeBinaryArray {
89        self.inner
90    }
91}
92
93impl ArrayAccessor for &BFloat16Array {
94    type Item = bf16;
95
96    fn value(&self, index: usize) -> Self::Item {
97        BFloat16Array::value(self, index)
98    }
99
100    unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
101        BFloat16Array::value_unchecked(self, index)
102    }
103}
104
105impl Array for BFloat16Array {
106    fn as_any(&self) -> &dyn std::any::Any {
107        self.inner.as_any()
108    }
109
110    fn to_data(&self) -> arrow_data::ArrayData {
111        self.inner.to_data()
112    }
113
114    fn into_data(self) -> arrow_data::ArrayData {
115        self.inner.into_data()
116    }
117
118    fn slice(&self, offset: usize, length: usize) -> ArrayRef {
119        let inner_array: &dyn Array = &self.inner;
120        inner_array.slice(offset, length)
121    }
122
123    fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> {
124        self.inner.nulls()
125    }
126
127    fn data_type(&self) -> &DataType {
128        self.inner.data_type()
129    }
130
131    fn len(&self) -> usize {
132        self.inner.len()
133    }
134
135    fn is_empty(&self) -> bool {
136        self.inner.is_empty()
137    }
138
139    fn offset(&self) -> usize {
140        self.inner.offset()
141    }
142
143    fn get_array_memory_size(&self) -> usize {
144        self.inner.get_array_memory_size()
145    }
146
147    fn get_buffer_memory_size(&self) -> usize {
148        self.inner.get_buffer_memory_size()
149    }
150}
151
152impl FromIterator<Option<bf16>> for BFloat16Array {
153    fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
154        let mut buffer = MutableBuffer::new(10);
155        // No null buffer builder :(
156        let mut nulls = BooleanBufferBuilder::new(10);
157        let mut len = 0;
158
159        for maybe_value in iter {
160            if let Some(value) = maybe_value {
161                let bytes = value.to_le_bytes();
162                buffer.extend(bytes);
163            } else {
164                buffer.extend([0u8, 0u8]);
165            }
166            nulls.append(maybe_value.is_some());
167            len += 1;
168        }
169
170        let null_buffer = nulls.finish();
171        let num_valid = null_buffer.count_set_bits();
172        let null_buffer = if num_valid == len {
173            None
174        } else {
175            Some(null_buffer.into_inner())
176        };
177
178        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
179            .len(len)
180            .add_buffer(buffer.into())
181            .null_bit_buffer(null_buffer);
182        let array_data = unsafe { array_data.build_unchecked() };
183        Self {
184            inner: FixedSizeBinaryArray::from(array_data),
185        }
186    }
187}
188
189impl FromIterator<bf16> for BFloat16Array {
190    fn from_iter<I: IntoIterator<Item = bf16>>(iter: I) -> Self {
191        Self::from_iter_values(iter)
192    }
193}
194
195impl From<Vec<bf16>> for BFloat16Array {
196    fn from(data: Vec<bf16>) -> Self {
197        let mut buffer = MutableBuffer::with_capacity(data.len() * 2);
198
199        let bytes = data.iter().flat_map(|val| {
200            let bytes = val.to_bits().to_le_bytes();
201            bytes.to_vec()
202        });
203
204        buffer.extend(bytes);
205        let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
206            .len(data.len())
207            .add_buffer(buffer.into());
208        let array_data = unsafe { array_data.build_unchecked() };
209        Self {
210            inner: FixedSizeBinaryArray::from(array_data),
211        }
212    }
213}
214
215impl TryFrom<FixedSizeBinaryArray> for BFloat16Array {
216    type Error = ArrowError;
217
218    fn try_from(value: FixedSizeBinaryArray) -> Result<Self, Self::Error> {
219        if value.value_length() == 2 {
220            Ok(Self { inner: value })
221        } else {
222            Err(ArrowError::InvalidArgumentError(
223                "FixedSizeBinaryArray must have a value length of 2".to_string(),
224            ))
225        }
226    }
227}
228
229impl PartialEq<Self> for BFloat16Array {
230    fn eq(&self, other: &Self) -> bool {
231        self.inner.eq(&other.inner)
232    }
233}
234
235type BFloat16Iter<'a> = ArrayIter<&'a BFloat16Array>;
236
237/// Methods that are lifted from arrow-rs temporarily until they are made public.
238mod from_arrow {
239    use arrow_array::Array;
240
241    /// Helper function for printing potentially long arrays.
242    pub(super) fn print_long_array<A, F>(
243        array: &A,
244        f: &mut std::fmt::Formatter,
245        print_item: F,
246    ) -> std::fmt::Result
247    where
248        A: Array,
249        F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result,
250    {
251        let head = std::cmp::min(10, array.len());
252
253        for i in 0..head {
254            if array.is_null(i) {
255                writeln!(f, "  null,")?;
256            } else {
257                write!(f, "  ")?;
258                print_item(array, i, f)?;
259                writeln!(f, ",")?;
260            }
261        }
262        if array.len() > 10 {
263            if array.len() > 20 {
264                writeln!(f, "  ...{} elements...,", array.len() - 20)?;
265            }
266
267            let tail = std::cmp::max(head, array.len() - 10);
268
269            for i in tail..array.len() {
270                if array.is_null(i) {
271                    writeln!(f, "  null,")?;
272                } else {
273                    write!(f, "  ")?;
274                    print_item(array, i, f)?;
275                    writeln!(f, ",")?;
276                }
277            }
278        }
279        Ok(())
280    }
281}
282
283impl FloatArray<BFloat16Type> for BFloat16Array {
284    type FloatType = BFloat16Type;
285
286    fn as_slice(&self) -> &[bf16] {
287        unsafe {
288            slice::from_raw_parts(
289                self.inner.value_data().as_ptr() as *const bf16,
290                self.inner.value_data().len() / 2,
291            )
292        }
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_basics() {
302        let values: Vec<f32> = vec![1.0, 2.0, 3.0];
303        let values: Vec<bf16> = values.iter().map(|v| bf16::from_f32(*v)).collect();
304
305        let array = BFloat16Array::from_iter_values(values.clone());
306        let array2 = BFloat16Array::from(values.clone());
307        assert_eq!(array, array2);
308        assert_eq!(array.len(), 3);
309
310        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  2.0,\n  3.0,\n]";
311        assert_eq!(expected_fmt, format!("{:?}", array));
312
313        for (expected, value) in values.iter().zip(array.iter()) {
314            assert_eq!(Some(*expected), value);
315        }
316
317        for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
318            assert_eq!(Some(*expected), value);
319        }
320    }
321
322    #[test]
323    fn test_nulls() {
324        let values: Vec<Option<bf16>> =
325            vec![Some(bf16::from_f32(1.0)), None, Some(bf16::from_f32(3.0))];
326        let array = BFloat16Array::from_iter(values.clone());
327        assert_eq!(array.len(), 3);
328        assert_eq!(array.null_count(), 1);
329
330        let expected_fmt = "BFloat16Array\n[\n  1.0,\n  null,\n  3.0,\n]";
331        assert_eq!(expected_fmt, format!("{:?}", array));
332
333        for (expected, value) in values.iter().zip(array.iter()) {
334            assert_eq!(*expected, value);
335        }
336    }
337}