lance_encoding/encodings/physical/
dictionary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5use std::vec;
6
7use arrow_array::builder::{ArrayBuilder, StringBuilder};
8use arrow_array::cast::AsArray;
9use arrow_array::types::UInt8Type;
10use arrow_array::{
11    make_array, new_null_array, Array, ArrayRef, DictionaryArray, StringArray, UInt8Array,
12};
13use arrow_schema::DataType;
14use futures::{future::BoxFuture, FutureExt};
15use lance_arrow::DataTypeExt;
16use lance_core::{Error, Result};
17use snafu::location;
18use std::collections::HashMap;
19
20use crate::buffer::LanceBuffer;
21use crate::data::{
22    BlockInfo, DataBlock, DictionaryDataBlock, FixedWidthDataBlock, NullableDataBlock,
23    VariableWidthBlock,
24};
25use crate::decoder::LogicalPageDecoder;
26use crate::encodings::logical::primitive::PrimitiveFieldDecoder;
27use crate::format::ProtobufUtils;
28use crate::{
29    decoder::{PageScheduler, PrimitivePageDecoder},
30    encoder::{ArrayEncoder, EncodedArray},
31    EncodingsIo,
32};
33
34#[derive(Debug)]
35pub struct DictionaryPageScheduler {
36    indices_scheduler: Arc<dyn PageScheduler>,
37    items_scheduler: Arc<dyn PageScheduler>,
38    // The number of items in the dictionary
39    num_dictionary_items: u32,
40    // If true, decode the dictionary items.  If false, leave them dictionary encoded (e.g. the
41    // output type is probably a dictionary type)
42    should_decode_dict: bool,
43}
44
45impl DictionaryPageScheduler {
46    pub fn new(
47        indices_scheduler: Arc<dyn PageScheduler>,
48        items_scheduler: Arc<dyn PageScheduler>,
49        num_dictionary_items: u32,
50        should_decode_dict: bool,
51    ) -> Self {
52        Self {
53            indices_scheduler,
54            items_scheduler,
55            num_dictionary_items,
56            should_decode_dict,
57        }
58    }
59}
60
61impl PageScheduler for DictionaryPageScheduler {
62    fn schedule_ranges(
63        &self,
64        ranges: &[std::ops::Range<u64>],
65        scheduler: &Arc<dyn EncodingsIo>,
66        top_level_row: u64,
67    ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
68        // We want to decode indices and items
69        // e.g. indices [0, 1, 2, 0, 1, 0]
70        // items (dictionary) ["abcd", "hello", "apple"]
71        // This will map to ["abcd", "hello", "apple", "abcd", "hello", "abcd"]
72        // We decode all the items during scheduling itself
73        // These are used to rebuild the string later
74
75        // Schedule indices for decoding
76        let indices_page_decoder =
77            self.indices_scheduler
78                .schedule_ranges(ranges, scheduler, top_level_row);
79
80        // Schedule items for decoding
81        let items_range = 0..(self.num_dictionary_items as u64);
82        let items_page_decoder = self.items_scheduler.schedule_ranges(
83            std::slice::from_ref(&items_range),
84            scheduler,
85            top_level_row,
86        );
87
88        let copy_size = self.num_dictionary_items as u64;
89
90        if self.should_decode_dict {
91            tokio::spawn(async move {
92                let items_decoder: Arc<dyn PrimitivePageDecoder> =
93                    Arc::from(items_page_decoder.await?);
94
95                let mut primitive_wrapper = PrimitiveFieldDecoder::new_from_data(
96                    items_decoder.clone(),
97                    DataType::Utf8,
98                    copy_size,
99                    false,
100                );
101
102                // Decode all items
103                let drained_task = primitive_wrapper.drain(copy_size)?;
104                let items_decode_task = drained_task.task;
105                let decoded_dict = items_decode_task.decode()?;
106
107                let indices_decoder: Box<dyn PrimitivePageDecoder> = indices_page_decoder.await?;
108
109                Ok(Box::new(DictionaryPageDecoder {
110                    decoded_dict,
111                    indices_decoder,
112                }) as Box<dyn PrimitivePageDecoder>)
113            })
114            .map(|join_handle| join_handle.unwrap())
115            .boxed()
116        } else {
117            let num_dictionary_items = self.num_dictionary_items;
118            tokio::spawn(async move {
119                let items_decoder: Arc<dyn PrimitivePageDecoder> =
120                    Arc::from(items_page_decoder.await?);
121
122                let decoded_dict = items_decoder
123                    .decode(0, num_dictionary_items as u64)?
124                    .borrow_and_clone();
125
126                let indices_decoder = indices_page_decoder.await?;
127
128                Ok(Box::new(DirectDictionaryPageDecoder {
129                    decoded_dict,
130                    indices_decoder,
131                }) as Box<dyn PrimitivePageDecoder>)
132            })
133            .map(|join_handle| join_handle.unwrap())
134            .boxed()
135        }
136    }
137}
138
139struct DirectDictionaryPageDecoder {
140    decoded_dict: DataBlock,
141    indices_decoder: Box<dyn PrimitivePageDecoder>,
142}
143
144impl PrimitivePageDecoder for DirectDictionaryPageDecoder {
145    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
146        let indices = self
147            .indices_decoder
148            .decode(rows_to_skip, num_rows)?
149            .as_fixed_width()
150            .unwrap();
151        let dict = self.decoded_dict.try_clone()?;
152        Ok(DataBlock::Dictionary(DictionaryDataBlock {
153            indices,
154            dictionary: Box::new(dict),
155        }))
156    }
157}
158
159struct DictionaryPageDecoder {
160    decoded_dict: Arc<dyn Array>,
161    indices_decoder: Box<dyn PrimitivePageDecoder>,
162}
163
164impl PrimitivePageDecoder for DictionaryPageDecoder {
165    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
166        // Decode the indices
167        let indices_data = self.indices_decoder.decode(rows_to_skip, num_rows)?;
168
169        let indices_array = make_array(indices_data.into_arrow(DataType::UInt8, false)?);
170        let indices_array = indices_array.as_primitive::<UInt8Type>();
171
172        let dictionary = self.decoded_dict.clone();
173
174        let adjusted_indices: UInt8Array = indices_array
175            .iter()
176            .map(|x| match x {
177                Some(0) => None,
178                Some(x) => Some(x - 1),
179                None => None,
180            })
181            .collect();
182
183        // Build dictionary array using indices and items
184        let dict_array =
185            DictionaryArray::<UInt8Type>::try_new(adjusted_indices, dictionary).unwrap();
186        let string_array = arrow_cast::cast(&dict_array, &DataType::Utf8).unwrap();
187        let string_array = string_array.as_any().downcast_ref::<StringArray>().unwrap();
188
189        let null_buffer = string_array.nulls().map(|n| n.buffer().clone());
190        let offsets_buffer = string_array.offsets().inner().inner().clone();
191        let bytes_buffer = string_array.values().clone();
192
193        let string_data = DataBlock::VariableWidth(VariableWidthBlock {
194            bits_per_offset: 32,
195            data: LanceBuffer::from(bytes_buffer),
196            offsets: LanceBuffer::from(offsets_buffer),
197            num_values: num_rows,
198            block_info: BlockInfo::new(),
199        });
200        if let Some(nulls) = null_buffer {
201            Ok(DataBlock::Nullable(NullableDataBlock {
202                data: Box::new(string_data),
203                nulls: LanceBuffer::from(nulls),
204                block_info: BlockInfo::new(),
205            }))
206        } else {
207            Ok(string_data)
208        }
209    }
210}
211
212/// An encoder for data that is already dictionary encoded.  Stores the
213/// data as a dictionary encoding.
214#[derive(Debug)]
215pub struct AlreadyDictionaryEncoder {
216    indices_encoder: Box<dyn ArrayEncoder>,
217    items_encoder: Box<dyn ArrayEncoder>,
218}
219
220impl AlreadyDictionaryEncoder {
221    pub fn new(
222        indices_encoder: Box<dyn ArrayEncoder>,
223        items_encoder: Box<dyn ArrayEncoder>,
224    ) -> Self {
225        Self {
226            indices_encoder,
227            items_encoder,
228        }
229    }
230}
231
232impl ArrayEncoder for AlreadyDictionaryEncoder {
233    fn encode(
234        &self,
235        data: DataBlock,
236        data_type: &DataType,
237        buffer_index: &mut u32,
238    ) -> Result<EncodedArray> {
239        let DataType::Dictionary(key_type, value_type) = data_type else {
240            panic!("Expected dictionary type");
241        };
242
243        let dict_data = match data {
244            DataBlock::Dictionary(dict_data) => dict_data,
245            DataBlock::AllNull(all_null) => {
246                // In 2.1 this won't happen, kind of annoying to materialize a bunch of nulls
247                let indices = UInt8Array::from(vec![0; all_null.num_values as usize]);
248                let indices = arrow_cast::cast(&indices, key_type.as_ref()).unwrap();
249                let indices = indices.into_data();
250                let values = new_null_array(value_type, 1);
251                DictionaryDataBlock {
252                    indices: FixedWidthDataBlock {
253                        bits_per_value: key_type.byte_width() as u64 * 8,
254                        data: LanceBuffer::Borrowed(indices.buffers()[0].clone()),
255                        num_values: all_null.num_values,
256                        block_info: BlockInfo::new(),
257                    },
258                    dictionary: Box::new(DataBlock::from_array(values)),
259                }
260            }
261            _ => panic!("Expected dictionary data"),
262        };
263        let num_dictionary_items = dict_data.dictionary.num_values() as u32;
264
265        let encoded_indices = self.indices_encoder.encode(
266            DataBlock::FixedWidth(dict_data.indices),
267            key_type,
268            buffer_index,
269        )?;
270        let encoded_items =
271            self.items_encoder
272                .encode(*dict_data.dictionary, value_type, buffer_index)?;
273
274        let encoded = DataBlock::Dictionary(DictionaryDataBlock {
275            dictionary: Box::new(encoded_items.data),
276            indices: encoded_indices.data.as_fixed_width().unwrap(),
277        });
278
279        let encoding = ProtobufUtils::dict_encoding(
280            encoded_indices.encoding,
281            encoded_items.encoding,
282            num_dictionary_items,
283        );
284
285        Ok(EncodedArray {
286            data: encoded,
287            encoding,
288        })
289    }
290}
291
292#[derive(Debug)]
293pub struct DictionaryEncoder {
294    indices_encoder: Box<dyn ArrayEncoder>,
295    items_encoder: Box<dyn ArrayEncoder>,
296}
297
298impl DictionaryEncoder {
299    pub fn new(
300        indices_encoder: Box<dyn ArrayEncoder>,
301        items_encoder: Box<dyn ArrayEncoder>,
302    ) -> Self {
303        Self {
304            indices_encoder,
305            items_encoder,
306        }
307    }
308}
309
310fn encode_dict_indices_and_items(string_array: &StringArray) -> (ArrayRef, ArrayRef) {
311    let mut arr_hashmap: HashMap<&str, u8> = HashMap::new();
312    // We start with a dict index of 1 because the value 0 is reserved for nulls
313    // The dict indices are adjusted by subtracting 1 later during decode
314    let mut curr_dict_index = 1;
315    let total_capacity = string_array.len();
316
317    let mut dict_indices = Vec::with_capacity(total_capacity);
318    let mut dict_builder = StringBuilder::new();
319
320    for i in 0..string_array.len() {
321        if !string_array.is_valid(i) {
322            // null value
323            dict_indices.push(0);
324            continue;
325        }
326
327        let st = string_array.value(i);
328
329        let hashmap_entry = *arr_hashmap.entry(st).or_insert(curr_dict_index);
330        dict_indices.push(hashmap_entry);
331
332        // if item didn't exist in the hashmap, add it to the dictionary
333        // and increment the dictionary index
334        if hashmap_entry == curr_dict_index {
335            dict_builder.append_value(st);
336            curr_dict_index += 1;
337        }
338    }
339
340    let array_dict_indices = Arc::new(UInt8Array::from(dict_indices)) as ArrayRef;
341
342    // If there is an empty dictionary:
343    // Either there is an array of nulls or an empty array altogether
344    // In this case create the dictionary with a single null element
345    // Because decoding [] is not currently supported by the binary decoder
346    if dict_builder.is_empty() {
347        dict_builder.append_option(Option::<&str>::None);
348    }
349
350    let dict_elements = dict_builder.finish();
351    let array_dict_elements = arrow_cast::cast(&dict_elements, &DataType::Utf8).unwrap();
352
353    (array_dict_indices, array_dict_elements)
354}
355
356impl ArrayEncoder for DictionaryEncoder {
357    fn encode(
358        &self,
359        data: DataBlock,
360        data_type: &DataType,
361        buffer_index: &mut u32,
362    ) -> Result<EncodedArray> {
363        if !matches!(data_type, DataType::Utf8) {
364            return Err(Error::InvalidInput {
365                source: format!(
366                    "DictionaryEncoder only supports string arrays but got {}",
367                    data_type
368                )
369                .into(),
370                location: location!(),
371            });
372        }
373        // We only support string arrays for now
374        let str_data = make_array(data.into_arrow(DataType::Utf8, false)?);
375
376        let (index_array, items_array) = encode_dict_indices_and_items(str_data.as_string());
377        let dict_size = items_array.len() as u32;
378        let index_data = DataBlock::from(index_array);
379        let items_data = DataBlock::from(items_array);
380
381        let encoded_indices =
382            self.indices_encoder
383                .encode(index_data, &DataType::UInt8, buffer_index)?;
384
385        let encoded_items = self
386            .items_encoder
387            .encode(items_data, &DataType::Utf8, buffer_index)?;
388
389        let encoded_data = DataBlock::Dictionary(DictionaryDataBlock {
390            indices: encoded_indices.data.as_fixed_width().unwrap(),
391            dictionary: Box::new(encoded_items.data),
392        });
393
394        let encoding = ProtobufUtils::dict_encoding(
395            encoded_indices.encoding,
396            encoded_items.encoding,
397            dict_size,
398        );
399
400        Ok(EncodedArray {
401            data: encoded_data,
402            encoding,
403        })
404    }
405}
406
407#[cfg(test)]
408pub mod tests {
409
410    use arrow_array::{
411        builder::{LargeStringBuilder, StringBuilder},
412        ArrayRef, StringArray, UInt8Array,
413    };
414    use arrow_schema::{DataType, Field};
415    use std::{collections::HashMap, sync::Arc, vec};
416
417    use crate::{
418        testing::{check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases},
419        version::LanceFileVersion,
420    };
421
422    use super::encode_dict_indices_and_items;
423
424    // These tests cover the case where we opportunistically convert some (or all) pages of
425    // a string column into dictionaries (and decode on read)
426
427    #[test]
428    fn test_encode_dict_nulls() {
429        // Null entries in string arrays should be adjusted
430        let string_array = Arc::new(StringArray::from(vec![
431            None,
432            Some("foo"),
433            Some("bar"),
434            Some("bar"),
435            None,
436            Some("foo"),
437            None,
438            None,
439        ]));
440        let (dict_indices, dict_items) = encode_dict_indices_and_items(&string_array);
441
442        let expected_indices = Arc::new(UInt8Array::from(vec![0, 1, 2, 2, 0, 1, 0, 0])) as ArrayRef;
443        let expected_items = Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef;
444        assert_eq!(&dict_indices, &expected_indices);
445        assert_eq!(&dict_items, &expected_items);
446    }
447
448    #[test_log::test(tokio::test)]
449    async fn test_utf8() {
450        let field = Field::new("", DataType::Utf8, false);
451        check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
452    }
453
454    #[test_log::test(tokio::test)]
455    async fn test_binary() {
456        let field = Field::new("", DataType::Binary, false);
457        check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
458    }
459
460    #[test_log::test(tokio::test)]
461    async fn test_large_binary() {
462        let field = Field::new("", DataType::LargeBinary, true);
463        check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
464    }
465
466    #[test_log::test(tokio::test)]
467    async fn test_large_utf8() {
468        let field = Field::new("", DataType::LargeUtf8, true);
469        check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
470    }
471
472    #[test_log::test(tokio::test)]
473    async fn test_simple_utf8() {
474        let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
475
476        let test_cases = TestCases::default()
477            .with_range(0..2)
478            .with_range(0..3)
479            .with_range(1..3)
480            .with_indices(vec![1, 3]);
481        check_round_trip_encoding_of_data(
482            vec![Arc::new(string_array)],
483            &test_cases,
484            HashMap::new(),
485        )
486        .await;
487    }
488
489    #[test_log::test(tokio::test)]
490    async fn test_sliced_utf8() {
491        let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
492        let string_array = string_array.slice(1, 3);
493
494        let test_cases = TestCases::default()
495            .with_range(0..1)
496            .with_range(0..2)
497            .with_range(1..2);
498        check_round_trip_encoding_of_data(
499            vec![Arc::new(string_array)],
500            &test_cases,
501            HashMap::new(),
502        )
503        .await;
504    }
505
506    #[test_log::test(tokio::test)]
507    async fn test_empty_strings() {
508        // Scenario 1: Some strings are empty
509
510        let values = [Some("abc"), Some(""), None];
511        // Test empty list at beginning, middle, and end
512        for order in [[0, 1, 2], [1, 0, 2], [2, 0, 1]] {
513            let mut string_builder = StringBuilder::new();
514            for idx in order {
515                string_builder.append_option(values[idx]);
516            }
517            let string_array = Arc::new(string_builder.finish());
518            let test_cases = TestCases::default()
519                .with_indices(vec![1])
520                .with_indices(vec![0])
521                .with_indices(vec![2]);
522            check_round_trip_encoding_of_data(
523                vec![string_array.clone()],
524                &test_cases,
525                HashMap::new(),
526            )
527            .await;
528            let test_cases = test_cases.with_batch_size(1);
529            check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new())
530                .await;
531        }
532
533        // Scenario 2: All strings are empty
534
535        // When encoding an array of empty strings there are no bytes to encode
536        // which is strange and we want to ensure we handle it
537        let string_array = Arc::new(StringArray::from(vec![Some(""), None, Some("")]));
538
539        let test_cases = TestCases::default().with_range(0..2).with_indices(vec![1]);
540        check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases, HashMap::new())
541            .await;
542        let test_cases = test_cases.with_batch_size(1);
543        check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new()).await;
544    }
545
546    #[test_log::test(tokio::test)]
547    #[ignore] // This test is quite slow in debug mode
548    async fn test_jumbo_string() {
549        // This is an overflow test.  We have a list of lists where each list
550        // has 1Mi items.  We encode 5000 of these lists and so we have over 4Gi in the
551        // offsets range
552        let mut string_builder = LargeStringBuilder::new();
553        // a 1 MiB string
554        let giant_string = String::from_iter((0..(1024 * 1024)).map(|_| '0'));
555        for _ in 0..5000 {
556            string_builder.append_option(Some(&giant_string));
557        }
558        let giant_array = Arc::new(string_builder.finish()) as ArrayRef;
559        let arrs = vec![giant_array];
560
561        // // We can't validate because our validation relies on concatenating all input arrays
562        let test_cases = TestCases::default().without_validation();
563        check_round_trip_encoding_of_data(arrs, &test_cases, HashMap::new()).await;
564    }
565
566    // These tests cover the case where the input is already dictionary encoded
567
568    #[test_log::test(tokio::test)]
569    async fn test_random_dictionary_input() {
570        let dict_field = Field::new(
571            "",
572            DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
573            false,
574        );
575        check_round_trip_encoding_random(dict_field, LanceFileVersion::V2_0).await;
576    }
577}