1use std::fmt::Formatter;
7use std::slice;
8
9use arrow_array::{
10 builder::BooleanBufferBuilder, iterator::ArrayIter, Array, ArrayAccessor, ArrayRef,
11 FixedSizeBinaryArray,
12};
13use arrow_buffer::MutableBuffer;
14use arrow_data::ArrayData;
15use arrow_schema::{ArrowError, DataType, Field as ArrowField};
16use half::bf16;
17
18use crate::FloatArray;
19
20pub const ARROW_EXT_NAME_KEY: &str = "ARROW:extension:name";
21pub const ARROW_EXT_META_KEY: &str = "ARROW:extension:metadata";
22pub const BFLOAT16_EXT_NAME: &str = "lance.bfloat16";
23
24pub fn is_bfloat16_field(field: &ArrowField) -> bool {
26 field.data_type() == &DataType::FixedSizeBinary(2)
27 && field
28 .metadata()
29 .get(ARROW_EXT_NAME_KEY)
30 .map(|name| name == BFLOAT16_EXT_NAME)
31 .unwrap_or_default()
32}
33
34#[derive(Debug)]
35pub struct BFloat16Type {}
36
37#[derive(Clone)]
38pub struct BFloat16Array {
39 inner: FixedSizeBinaryArray,
40}
41
42impl std::fmt::Debug for BFloat16Array {
43 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44 write!(f, "BFloat16Array\n[\n")?;
45 from_arrow::print_long_array(&self.inner, f, |array, i, f| {
46 if array.is_null(i) {
47 write!(f, "null")
48 } else {
49 let binary_values = array.value(i);
50 let value =
51 bf16::from_bits(u16::from_le_bytes([binary_values[0], binary_values[1]]));
52 write!(f, "{:?}", value)
53 }
54 })?;
55 write!(f, "]")
56 }
57}
58
59impl BFloat16Array {
60 pub fn from_iter_values(iter: impl IntoIterator<Item = bf16>) -> Self {
61 let values: Vec<bf16> = iter.into_iter().collect();
62 values.into()
63 }
64
65 pub fn iter(&self) -> BFloat16Iter {
66 BFloat16Iter::new(self)
67 }
68
69 pub fn value(&self, i: usize) -> bf16 {
70 assert!(
71 i < self.len(),
72 "Trying to access an element at index {} from a BFloat16Array of length {}",
73 i,
74 self.len()
75 );
76 unsafe { self.value_unchecked(i) }
79 }
80
81 pub unsafe fn value_unchecked(&self, i: usize) -> bf16 {
84 let binary_value = self.inner.value_unchecked(i);
85 bf16::from_bits(u16::from_le_bytes([binary_value[0], binary_value[1]]))
86 }
87
88 pub fn into_inner(self) -> FixedSizeBinaryArray {
89 self.inner
90 }
91}
92
93impl ArrayAccessor for &BFloat16Array {
94 type Item = bf16;
95
96 fn value(&self, index: usize) -> Self::Item {
97 BFloat16Array::value(self, index)
98 }
99
100 unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
101 BFloat16Array::value_unchecked(self, index)
102 }
103}
104
105impl Array for BFloat16Array {
106 fn as_any(&self) -> &dyn std::any::Any {
107 self.inner.as_any()
108 }
109
110 fn to_data(&self) -> arrow_data::ArrayData {
111 self.inner.to_data()
112 }
113
114 fn into_data(self) -> arrow_data::ArrayData {
115 self.inner.into_data()
116 }
117
118 fn slice(&self, offset: usize, length: usize) -> ArrayRef {
119 let inner_array: &dyn Array = &self.inner;
120 inner_array.slice(offset, length)
121 }
122
123 fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> {
124 self.inner.nulls()
125 }
126
127 fn data_type(&self) -> &DataType {
128 self.inner.data_type()
129 }
130
131 fn len(&self) -> usize {
132 self.inner.len()
133 }
134
135 fn is_empty(&self) -> bool {
136 self.inner.is_empty()
137 }
138
139 fn offset(&self) -> usize {
140 self.inner.offset()
141 }
142
143 fn get_array_memory_size(&self) -> usize {
144 self.inner.get_array_memory_size()
145 }
146
147 fn get_buffer_memory_size(&self) -> usize {
148 self.inner.get_buffer_memory_size()
149 }
150}
151
152impl FromIterator<Option<bf16>> for BFloat16Array {
153 fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
154 let mut buffer = MutableBuffer::new(10);
155 let mut nulls = BooleanBufferBuilder::new(10);
157 let mut len = 0;
158
159 for maybe_value in iter {
160 if let Some(value) = maybe_value {
161 let bytes = value.to_le_bytes();
162 buffer.extend(bytes);
163 } else {
164 buffer.extend([0u8, 0u8]);
165 }
166 nulls.append(maybe_value.is_some());
167 len += 1;
168 }
169
170 let null_buffer = nulls.finish();
171 let num_valid = null_buffer.count_set_bits();
172 let null_buffer = if num_valid == len {
173 None
174 } else {
175 Some(null_buffer.into_inner())
176 };
177
178 let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
179 .len(len)
180 .add_buffer(buffer.into())
181 .null_bit_buffer(null_buffer);
182 let array_data = unsafe { array_data.build_unchecked() };
183 Self {
184 inner: FixedSizeBinaryArray::from(array_data),
185 }
186 }
187}
188
189impl FromIterator<bf16> for BFloat16Array {
190 fn from_iter<I: IntoIterator<Item = bf16>>(iter: I) -> Self {
191 Self::from_iter_values(iter)
192 }
193}
194
195impl From<Vec<bf16>> for BFloat16Array {
196 fn from(data: Vec<bf16>) -> Self {
197 let mut buffer = MutableBuffer::with_capacity(data.len() * 2);
198
199 let bytes = data.iter().flat_map(|val| {
200 let bytes = val.to_bits().to_le_bytes();
201 bytes.to_vec()
202 });
203
204 buffer.extend(bytes);
205 let array_data = ArrayData::builder(DataType::FixedSizeBinary(2))
206 .len(data.len())
207 .add_buffer(buffer.into());
208 let array_data = unsafe { array_data.build_unchecked() };
209 Self {
210 inner: FixedSizeBinaryArray::from(array_data),
211 }
212 }
213}
214
215impl TryFrom<FixedSizeBinaryArray> for BFloat16Array {
216 type Error = ArrowError;
217
218 fn try_from(value: FixedSizeBinaryArray) -> Result<Self, Self::Error> {
219 if value.value_length() == 2 {
220 Ok(Self { inner: value })
221 } else {
222 Err(ArrowError::InvalidArgumentError(
223 "FixedSizeBinaryArray must have a value length of 2".to_string(),
224 ))
225 }
226 }
227}
228
229impl PartialEq<Self> for BFloat16Array {
230 fn eq(&self, other: &Self) -> bool {
231 self.inner.eq(&other.inner)
232 }
233}
234
235type BFloat16Iter<'a> = ArrayIter<&'a BFloat16Array>;
236
237mod from_arrow {
239 use arrow_array::Array;
240
241 pub(super) fn print_long_array<A, F>(
243 array: &A,
244 f: &mut std::fmt::Formatter,
245 print_item: F,
246 ) -> std::fmt::Result
247 where
248 A: Array,
249 F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result,
250 {
251 let head = std::cmp::min(10, array.len());
252
253 for i in 0..head {
254 if array.is_null(i) {
255 writeln!(f, " null,")?;
256 } else {
257 write!(f, " ")?;
258 print_item(array, i, f)?;
259 writeln!(f, ",")?;
260 }
261 }
262 if array.len() > 10 {
263 if array.len() > 20 {
264 writeln!(f, " ...{} elements...,", array.len() - 20)?;
265 }
266
267 let tail = std::cmp::max(head, array.len() - 10);
268
269 for i in tail..array.len() {
270 if array.is_null(i) {
271 writeln!(f, " null,")?;
272 } else {
273 write!(f, " ")?;
274 print_item(array, i, f)?;
275 writeln!(f, ",")?;
276 }
277 }
278 }
279 Ok(())
280 }
281}
282
283impl FloatArray<BFloat16Type> for BFloat16Array {
284 type FloatType = BFloat16Type;
285
286 fn as_slice(&self) -> &[bf16] {
287 unsafe {
288 slice::from_raw_parts(
289 self.inner.value_data().as_ptr() as *const bf16,
290 self.inner.value_data().len() / 2,
291 )
292 }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_basics() {
302 let values: Vec<f32> = vec![1.0, 2.0, 3.0];
303 let values: Vec<bf16> = values.iter().map(|v| bf16::from_f32(*v)).collect();
304
305 let array = BFloat16Array::from_iter_values(values.clone());
306 let array2 = BFloat16Array::from(values.clone());
307 assert_eq!(array, array2);
308 assert_eq!(array.len(), 3);
309
310 let expected_fmt = "BFloat16Array\n[\n 1.0,\n 2.0,\n 3.0,\n]";
311 assert_eq!(expected_fmt, format!("{:?}", array));
312
313 for (expected, value) in values.iter().zip(array.iter()) {
314 assert_eq!(Some(*expected), value);
315 }
316
317 for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
318 assert_eq!(Some(*expected), value);
319 }
320 }
321
322 #[test]
323 fn test_nulls() {
324 let values: Vec<Option<bf16>> =
325 vec![Some(bf16::from_f32(1.0)), None, Some(bf16::from_f32(3.0))];
326 let array = BFloat16Array::from_iter(values.clone());
327 assert_eq!(array.len(), 3);
328 assert_eq!(array.null_count(), 1);
329
330 let expected_fmt = "BFloat16Array\n[\n 1.0,\n null,\n 3.0,\n]";
331 assert_eq!(expected_fmt, format!("{:?}", array));
332
333 for (expected, value) in values.iter().zip(array.iter()) {
334 assert_eq!(*expected, value);
335 }
336 }
337}