1use std::sync::Arc;
5use std::vec;
6
7use arrow_array::builder::{ArrayBuilder, StringBuilder};
8use arrow_array::cast::AsArray;
9use arrow_array::types::UInt8Type;
10use arrow_array::{
11 make_array, new_null_array, Array, ArrayRef, DictionaryArray, StringArray, UInt8Array,
12};
13use arrow_schema::DataType;
14use futures::{future::BoxFuture, FutureExt};
15use lance_arrow::DataTypeExt;
16use lance_core::{Error, Result};
17use snafu::location;
18use std::collections::HashMap;
19
20use crate::buffer::LanceBuffer;
21use crate::data::{
22 BlockInfo, DataBlock, DictionaryDataBlock, FixedWidthDataBlock, NullableDataBlock,
23 VariableWidthBlock,
24};
25use crate::decoder::LogicalPageDecoder;
26use crate::encodings::logical::primitive::PrimitiveFieldDecoder;
27use crate::format::ProtobufUtils;
28use crate::{
29 decoder::{PageScheduler, PrimitivePageDecoder},
30 encoder::{ArrayEncoder, EncodedArray},
31 EncodingsIo,
32};
33
34#[derive(Debug)]
35pub struct DictionaryPageScheduler {
36 indices_scheduler: Arc<dyn PageScheduler>,
37 items_scheduler: Arc<dyn PageScheduler>,
38 num_dictionary_items: u32,
40 should_decode_dict: bool,
43}
44
45impl DictionaryPageScheduler {
46 pub fn new(
47 indices_scheduler: Arc<dyn PageScheduler>,
48 items_scheduler: Arc<dyn PageScheduler>,
49 num_dictionary_items: u32,
50 should_decode_dict: bool,
51 ) -> Self {
52 Self {
53 indices_scheduler,
54 items_scheduler,
55 num_dictionary_items,
56 should_decode_dict,
57 }
58 }
59}
60
61impl PageScheduler for DictionaryPageScheduler {
62 fn schedule_ranges(
63 &self,
64 ranges: &[std::ops::Range<u64>],
65 scheduler: &Arc<dyn EncodingsIo>,
66 top_level_row: u64,
67 ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
68 let indices_page_decoder =
77 self.indices_scheduler
78 .schedule_ranges(ranges, scheduler, top_level_row);
79
80 let items_range = 0..(self.num_dictionary_items as u64);
82 let items_page_decoder = self.items_scheduler.schedule_ranges(
83 std::slice::from_ref(&items_range),
84 scheduler,
85 top_level_row,
86 );
87
88 let copy_size = self.num_dictionary_items as u64;
89
90 if self.should_decode_dict {
91 tokio::spawn(async move {
92 let items_decoder: Arc<dyn PrimitivePageDecoder> =
93 Arc::from(items_page_decoder.await?);
94
95 let mut primitive_wrapper = PrimitiveFieldDecoder::new_from_data(
96 items_decoder.clone(),
97 DataType::Utf8,
98 copy_size,
99 false,
100 );
101
102 let drained_task = primitive_wrapper.drain(copy_size)?;
104 let items_decode_task = drained_task.task;
105 let decoded_dict = items_decode_task.decode()?;
106
107 let indices_decoder: Box<dyn PrimitivePageDecoder> = indices_page_decoder.await?;
108
109 Ok(Box::new(DictionaryPageDecoder {
110 decoded_dict,
111 indices_decoder,
112 }) as Box<dyn PrimitivePageDecoder>)
113 })
114 .map(|join_handle| join_handle.unwrap())
115 .boxed()
116 } else {
117 let num_dictionary_items = self.num_dictionary_items;
118 tokio::spawn(async move {
119 let items_decoder: Arc<dyn PrimitivePageDecoder> =
120 Arc::from(items_page_decoder.await?);
121
122 let decoded_dict = items_decoder
123 .decode(0, num_dictionary_items as u64)?
124 .borrow_and_clone();
125
126 let indices_decoder = indices_page_decoder.await?;
127
128 Ok(Box::new(DirectDictionaryPageDecoder {
129 decoded_dict,
130 indices_decoder,
131 }) as Box<dyn PrimitivePageDecoder>)
132 })
133 .map(|join_handle| join_handle.unwrap())
134 .boxed()
135 }
136 }
137}
138
139struct DirectDictionaryPageDecoder {
140 decoded_dict: DataBlock,
141 indices_decoder: Box<dyn PrimitivePageDecoder>,
142}
143
144impl PrimitivePageDecoder for DirectDictionaryPageDecoder {
145 fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
146 let indices = self
147 .indices_decoder
148 .decode(rows_to_skip, num_rows)?
149 .as_fixed_width()
150 .unwrap();
151 let dict = self.decoded_dict.try_clone()?;
152 Ok(DataBlock::Dictionary(DictionaryDataBlock {
153 indices,
154 dictionary: Box::new(dict),
155 }))
156 }
157}
158
159struct DictionaryPageDecoder {
160 decoded_dict: Arc<dyn Array>,
161 indices_decoder: Box<dyn PrimitivePageDecoder>,
162}
163
164impl PrimitivePageDecoder for DictionaryPageDecoder {
165 fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
166 let indices_data = self.indices_decoder.decode(rows_to_skip, num_rows)?;
168
169 let indices_array = make_array(indices_data.into_arrow(DataType::UInt8, false)?);
170 let indices_array = indices_array.as_primitive::<UInt8Type>();
171
172 let dictionary = self.decoded_dict.clone();
173
174 let adjusted_indices: UInt8Array = indices_array
175 .iter()
176 .map(|x| match x {
177 Some(0) => None,
178 Some(x) => Some(x - 1),
179 None => None,
180 })
181 .collect();
182
183 let dict_array =
185 DictionaryArray::<UInt8Type>::try_new(adjusted_indices, dictionary).unwrap();
186 let string_array = arrow_cast::cast(&dict_array, &DataType::Utf8).unwrap();
187 let string_array = string_array.as_any().downcast_ref::<StringArray>().unwrap();
188
189 let null_buffer = string_array.nulls().map(|n| n.buffer().clone());
190 let offsets_buffer = string_array.offsets().inner().inner().clone();
191 let bytes_buffer = string_array.values().clone();
192
193 let string_data = DataBlock::VariableWidth(VariableWidthBlock {
194 bits_per_offset: 32,
195 data: LanceBuffer::from(bytes_buffer),
196 offsets: LanceBuffer::from(offsets_buffer),
197 num_values: num_rows,
198 block_info: BlockInfo::new(),
199 });
200 if let Some(nulls) = null_buffer {
201 Ok(DataBlock::Nullable(NullableDataBlock {
202 data: Box::new(string_data),
203 nulls: LanceBuffer::from(nulls),
204 block_info: BlockInfo::new(),
205 }))
206 } else {
207 Ok(string_data)
208 }
209 }
210}
211
212#[derive(Debug)]
215pub struct AlreadyDictionaryEncoder {
216 indices_encoder: Box<dyn ArrayEncoder>,
217 items_encoder: Box<dyn ArrayEncoder>,
218}
219
220impl AlreadyDictionaryEncoder {
221 pub fn new(
222 indices_encoder: Box<dyn ArrayEncoder>,
223 items_encoder: Box<dyn ArrayEncoder>,
224 ) -> Self {
225 Self {
226 indices_encoder,
227 items_encoder,
228 }
229 }
230}
231
232impl ArrayEncoder for AlreadyDictionaryEncoder {
233 fn encode(
234 &self,
235 data: DataBlock,
236 data_type: &DataType,
237 buffer_index: &mut u32,
238 ) -> Result<EncodedArray> {
239 let DataType::Dictionary(key_type, value_type) = data_type else {
240 panic!("Expected dictionary type");
241 };
242
243 let dict_data = match data {
244 DataBlock::Dictionary(dict_data) => dict_data,
245 DataBlock::AllNull(all_null) => {
246 let indices = UInt8Array::from(vec![0; all_null.num_values as usize]);
248 let indices = arrow_cast::cast(&indices, key_type.as_ref()).unwrap();
249 let indices = indices.into_data();
250 let values = new_null_array(value_type, 1);
251 DictionaryDataBlock {
252 indices: FixedWidthDataBlock {
253 bits_per_value: key_type.byte_width() as u64 * 8,
254 data: LanceBuffer::Borrowed(indices.buffers()[0].clone()),
255 num_values: all_null.num_values,
256 block_info: BlockInfo::new(),
257 },
258 dictionary: Box::new(DataBlock::from_array(values)),
259 }
260 }
261 _ => panic!("Expected dictionary data"),
262 };
263 let num_dictionary_items = dict_data.dictionary.num_values() as u32;
264
265 let encoded_indices = self.indices_encoder.encode(
266 DataBlock::FixedWidth(dict_data.indices),
267 key_type,
268 buffer_index,
269 )?;
270 let encoded_items =
271 self.items_encoder
272 .encode(*dict_data.dictionary, value_type, buffer_index)?;
273
274 let encoded = DataBlock::Dictionary(DictionaryDataBlock {
275 dictionary: Box::new(encoded_items.data),
276 indices: encoded_indices.data.as_fixed_width().unwrap(),
277 });
278
279 let encoding = ProtobufUtils::dict_encoding(
280 encoded_indices.encoding,
281 encoded_items.encoding,
282 num_dictionary_items,
283 );
284
285 Ok(EncodedArray {
286 data: encoded,
287 encoding,
288 })
289 }
290}
291
292#[derive(Debug)]
293pub struct DictionaryEncoder {
294 indices_encoder: Box<dyn ArrayEncoder>,
295 items_encoder: Box<dyn ArrayEncoder>,
296}
297
298impl DictionaryEncoder {
299 pub fn new(
300 indices_encoder: Box<dyn ArrayEncoder>,
301 items_encoder: Box<dyn ArrayEncoder>,
302 ) -> Self {
303 Self {
304 indices_encoder,
305 items_encoder,
306 }
307 }
308}
309
310fn encode_dict_indices_and_items(string_array: &StringArray) -> (ArrayRef, ArrayRef) {
311 let mut arr_hashmap: HashMap<&str, u8> = HashMap::new();
312 let mut curr_dict_index = 1;
315 let total_capacity = string_array.len();
316
317 let mut dict_indices = Vec::with_capacity(total_capacity);
318 let mut dict_builder = StringBuilder::new();
319
320 for i in 0..string_array.len() {
321 if !string_array.is_valid(i) {
322 dict_indices.push(0);
324 continue;
325 }
326
327 let st = string_array.value(i);
328
329 let hashmap_entry = *arr_hashmap.entry(st).or_insert(curr_dict_index);
330 dict_indices.push(hashmap_entry);
331
332 if hashmap_entry == curr_dict_index {
335 dict_builder.append_value(st);
336 curr_dict_index += 1;
337 }
338 }
339
340 let array_dict_indices = Arc::new(UInt8Array::from(dict_indices)) as ArrayRef;
341
342 if dict_builder.is_empty() {
347 dict_builder.append_option(Option::<&str>::None);
348 }
349
350 let dict_elements = dict_builder.finish();
351 let array_dict_elements = arrow_cast::cast(&dict_elements, &DataType::Utf8).unwrap();
352
353 (array_dict_indices, array_dict_elements)
354}
355
356impl ArrayEncoder for DictionaryEncoder {
357 fn encode(
358 &self,
359 data: DataBlock,
360 data_type: &DataType,
361 buffer_index: &mut u32,
362 ) -> Result<EncodedArray> {
363 if !matches!(data_type, DataType::Utf8) {
364 return Err(Error::InvalidInput {
365 source: format!(
366 "DictionaryEncoder only supports string arrays but got {}",
367 data_type
368 )
369 .into(),
370 location: location!(),
371 });
372 }
373 let str_data = make_array(data.into_arrow(DataType::Utf8, false)?);
375
376 let (index_array, items_array) = encode_dict_indices_and_items(str_data.as_string());
377 let dict_size = items_array.len() as u32;
378 let index_data = DataBlock::from(index_array);
379 let items_data = DataBlock::from(items_array);
380
381 let encoded_indices =
382 self.indices_encoder
383 .encode(index_data, &DataType::UInt8, buffer_index)?;
384
385 let encoded_items = self
386 .items_encoder
387 .encode(items_data, &DataType::Utf8, buffer_index)?;
388
389 let encoded_data = DataBlock::Dictionary(DictionaryDataBlock {
390 indices: encoded_indices.data.as_fixed_width().unwrap(),
391 dictionary: Box::new(encoded_items.data),
392 });
393
394 let encoding = ProtobufUtils::dict_encoding(
395 encoded_indices.encoding,
396 encoded_items.encoding,
397 dict_size,
398 );
399
400 Ok(EncodedArray {
401 data: encoded_data,
402 encoding,
403 })
404 }
405}
406
407#[cfg(test)]
408pub mod tests {
409
410 use arrow_array::{
411 builder::{LargeStringBuilder, StringBuilder},
412 ArrayRef, StringArray, UInt8Array,
413 };
414 use arrow_schema::{DataType, Field};
415 use std::{collections::HashMap, sync::Arc, vec};
416
417 use crate::{
418 testing::{check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases},
419 version::LanceFileVersion,
420 };
421
422 use super::encode_dict_indices_and_items;
423
424 #[test]
428 fn test_encode_dict_nulls() {
429 let string_array = Arc::new(StringArray::from(vec![
431 None,
432 Some("foo"),
433 Some("bar"),
434 Some("bar"),
435 None,
436 Some("foo"),
437 None,
438 None,
439 ]));
440 let (dict_indices, dict_items) = encode_dict_indices_and_items(&string_array);
441
442 let expected_indices = Arc::new(UInt8Array::from(vec![0, 1, 2, 2, 0, 1, 0, 0])) as ArrayRef;
443 let expected_items = Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef;
444 assert_eq!(&dict_indices, &expected_indices);
445 assert_eq!(&dict_items, &expected_items);
446 }
447
448 #[test_log::test(tokio::test)]
449 async fn test_utf8() {
450 let field = Field::new("", DataType::Utf8, false);
451 check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
452 }
453
454 #[test_log::test(tokio::test)]
455 async fn test_binary() {
456 let field = Field::new("", DataType::Binary, false);
457 check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
458 }
459
460 #[test_log::test(tokio::test)]
461 async fn test_large_binary() {
462 let field = Field::new("", DataType::LargeBinary, true);
463 check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
464 }
465
466 #[test_log::test(tokio::test)]
467 async fn test_large_utf8() {
468 let field = Field::new("", DataType::LargeUtf8, true);
469 check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
470 }
471
472 #[test_log::test(tokio::test)]
473 async fn test_simple_utf8() {
474 let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
475
476 let test_cases = TestCases::default()
477 .with_range(0..2)
478 .with_range(0..3)
479 .with_range(1..3)
480 .with_indices(vec![1, 3]);
481 check_round_trip_encoding_of_data(
482 vec![Arc::new(string_array)],
483 &test_cases,
484 HashMap::new(),
485 )
486 .await;
487 }
488
489 #[test_log::test(tokio::test)]
490 async fn test_sliced_utf8() {
491 let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
492 let string_array = string_array.slice(1, 3);
493
494 let test_cases = TestCases::default()
495 .with_range(0..1)
496 .with_range(0..2)
497 .with_range(1..2);
498 check_round_trip_encoding_of_data(
499 vec![Arc::new(string_array)],
500 &test_cases,
501 HashMap::new(),
502 )
503 .await;
504 }
505
506 #[test_log::test(tokio::test)]
507 async fn test_empty_strings() {
508 let values = [Some("abc"), Some(""), None];
511 for order in [[0, 1, 2], [1, 0, 2], [2, 0, 1]] {
513 let mut string_builder = StringBuilder::new();
514 for idx in order {
515 string_builder.append_option(values[idx]);
516 }
517 let string_array = Arc::new(string_builder.finish());
518 let test_cases = TestCases::default()
519 .with_indices(vec![1])
520 .with_indices(vec![0])
521 .with_indices(vec![2]);
522 check_round_trip_encoding_of_data(
523 vec![string_array.clone()],
524 &test_cases,
525 HashMap::new(),
526 )
527 .await;
528 let test_cases = test_cases.with_batch_size(1);
529 check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new())
530 .await;
531 }
532
533 let string_array = Arc::new(StringArray::from(vec![Some(""), None, Some("")]));
538
539 let test_cases = TestCases::default().with_range(0..2).with_indices(vec![1]);
540 check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases, HashMap::new())
541 .await;
542 let test_cases = test_cases.with_batch_size(1);
543 check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new()).await;
544 }
545
546 #[test_log::test(tokio::test)]
547 #[ignore] async fn test_jumbo_string() {
549 let mut string_builder = LargeStringBuilder::new();
553 let giant_string = String::from_iter((0..(1024 * 1024)).map(|_| '0'));
555 for _ in 0..5000 {
556 string_builder.append_option(Some(&giant_string));
557 }
558 let giant_array = Arc::new(string_builder.finish()) as ArrayRef;
559 let arrs = vec![giant_array];
560
561 let test_cases = TestCases::default().without_validation();
563 check_round_trip_encoding_of_data(arrs, &test_cases, HashMap::new()).await;
564 }
565
566 #[test_log::test(tokio::test)]
569 async fn test_random_dictionary_input() {
570 let dict_field = Field::new(
571 "",
572 DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
573 false,
574 );
575 check_round_trip_encoding_random(dict_field, LanceFileVersion::V2_0).await;
576 }
577}