1use crate::builder::{ArrayBuilder, PrimitiveBuilder};
19use crate::types::ArrowDictionaryKeyType;
20use crate::{
21 Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, PrimitiveArray, TypedDictionaryArray,
22};
23use arrow_buffer::{ArrowNativeType, ToByteSlice};
24use arrow_schema::{ArrowError, DataType};
25use std::any::Any;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29#[derive(Debug)]
33struct Value<T>(T);
34
35impl<T: ToByteSlice> std::hash::Hash for Value<T> {
36 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
37 self.0.to_byte_slice().hash(state)
38 }
39}
40
41impl<T: ToByteSlice> PartialEq for Value<T> {
42 fn eq(&self, other: &Self) -> bool {
43 self.0.to_byte_slice().eq(other.0.to_byte_slice())
44 }
45}
46
47impl<T: ToByteSlice> Eq for Value<T> {}
48
49#[derive(Debug)]
82pub struct PrimitiveDictionaryBuilder<K, V>
83where
84 K: ArrowPrimitiveType,
85 V: ArrowPrimitiveType,
86{
87 keys_builder: PrimitiveBuilder<K>,
88 values_builder: PrimitiveBuilder<V>,
89 map: HashMap<Value<V::Native>, usize>,
90}
91
92impl<K, V> Default for PrimitiveDictionaryBuilder<K, V>
93where
94 K: ArrowPrimitiveType,
95 V: ArrowPrimitiveType,
96{
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102impl<K, V> PrimitiveDictionaryBuilder<K, V>
103where
104 K: ArrowPrimitiveType,
105 V: ArrowPrimitiveType,
106{
107 pub fn new() -> Self {
109 Self {
110 keys_builder: PrimitiveBuilder::new(),
111 values_builder: PrimitiveBuilder::new(),
112 map: HashMap::new(),
113 }
114 }
115
116 pub fn new_from_empty_builders(
122 keys_builder: PrimitiveBuilder<K>,
123 values_builder: PrimitiveBuilder<V>,
124 ) -> Self {
125 assert!(
126 keys_builder.is_empty() && values_builder.is_empty(),
127 "keys and values builders must be empty"
128 );
129 let values_capacity = values_builder.capacity();
130 Self {
131 keys_builder,
132 values_builder,
133 map: HashMap::with_capacity(values_capacity),
134 }
135 }
136
137 pub unsafe fn new_from_builders(
143 keys_builder: PrimitiveBuilder<K>,
144 values_builder: PrimitiveBuilder<V>,
145 ) -> Self {
146 let keys = keys_builder.values_slice();
147 let values = values_builder.values_slice();
148 let mut map = HashMap::with_capacity(values.len());
149
150 keys.iter().zip(values.iter()).for_each(|(key, value)| {
151 map.insert(Value(*value), K::Native::to_usize(*key).unwrap());
152 });
153
154 Self {
155 keys_builder,
156 values_builder,
157 map,
158 }
159 }
160
161 pub fn with_capacity(keys_capacity: usize, values_capacity: usize) -> Self {
166 Self {
167 keys_builder: PrimitiveBuilder::with_capacity(keys_capacity),
168 values_builder: PrimitiveBuilder::with_capacity(values_capacity),
169 map: HashMap::with_capacity(values_capacity),
170 }
171 }
172}
173
174impl<K, V> ArrayBuilder for PrimitiveDictionaryBuilder<K, V>
175where
176 K: ArrowDictionaryKeyType,
177 V: ArrowPrimitiveType,
178{
179 fn as_any(&self) -> &dyn Any {
181 self
182 }
183
184 fn as_any_mut(&mut self) -> &mut dyn Any {
186 self
187 }
188
189 fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
191 self
192 }
193
194 fn len(&self) -> usize {
196 self.keys_builder.len()
197 }
198
199 fn finish(&mut self) -> ArrayRef {
201 Arc::new(self.finish())
202 }
203
204 fn finish_cloned(&self) -> ArrayRef {
206 Arc::new(self.finish_cloned())
207 }
208}
209
210impl<K, V> PrimitiveDictionaryBuilder<K, V>
211where
212 K: ArrowDictionaryKeyType,
213 V: ArrowPrimitiveType,
214{
215 #[inline]
216 fn get_or_insert_key(&mut self, value: V::Native) -> Result<K::Native, ArrowError> {
217 match self.map.get(&Value(value)) {
218 Some(&key) => {
219 Ok(K::Native::from_usize(key).ok_or(ArrowError::DictionaryKeyOverflowError)?)
220 }
221 None => {
222 let key = self.values_builder.len();
223 self.values_builder.append_value(value);
224 self.map.insert(Value(value), key);
225 Ok(K::Native::from_usize(key).ok_or(ArrowError::DictionaryKeyOverflowError)?)
226 }
227 }
228 }
229
230 #[inline]
234 pub fn append(&mut self, value: V::Native) -> Result<K::Native, ArrowError> {
235 let key = self.get_or_insert_key(value)?;
236 self.keys_builder.append_value(key);
237 Ok(key)
238 }
239
240 pub fn append_n(&mut self, value: V::Native, count: usize) -> Result<K::Native, ArrowError> {
245 let key = self.get_or_insert_key(value)?;
246 self.keys_builder.append_value_n(key, count);
247 Ok(key)
248 }
249
250 #[inline]
256 pub fn append_value(&mut self, value: V::Native) {
257 self.append(value).expect("dictionary key overflow");
258 }
259
260 pub fn append_values(&mut self, value: V::Native, count: usize) {
267 self.append_n(value, count)
268 .expect("dictionary key overflow");
269 }
270
271 #[inline]
273 pub fn append_null(&mut self) {
274 self.keys_builder.append_null()
275 }
276
277 #[inline]
279 pub fn append_nulls(&mut self, n: usize) {
280 self.keys_builder.append_nulls(n)
281 }
282
283 #[inline]
289 pub fn append_option(&mut self, value: Option<V::Native>) {
290 match value {
291 None => self.append_null(),
292 Some(v) => self.append_value(v),
293 };
294 }
295
296 pub fn append_options(&mut self, value: Option<V::Native>, count: usize) {
303 match value {
304 None => self.keys_builder.append_nulls(count),
305 Some(v) => self.append_values(v, count),
306 };
307 }
308
309 pub fn extend_dictionary(
317 &mut self,
318 dictionary: &TypedDictionaryArray<K, PrimitiveArray<V>>,
319 ) -> Result<(), ArrowError> {
320 let values = dictionary.values();
321
322 let v_len = values.len();
323 let k_len = dictionary.keys().len();
324 if v_len == 0 && k_len == 0 {
325 return Ok(());
326 }
327
328 if v_len == 0 {
330 self.append_nulls(k_len);
331 return Ok(());
332 }
333
334 if k_len == 0 {
335 return Err(ArrowError::InvalidArgumentError(
336 "Dictionary keys should not be empty when values are not empty".to_string(),
337 ));
338 }
339
340 let mapped_values = values
342 .iter()
343 .map(|dict_value| {
345 dict_value
346 .map(|dict_value| self.get_or_insert_key(dict_value))
347 .transpose()
348 })
349 .collect::<Result<Vec<_>, _>>()?;
350
351 dictionary.keys().iter().for_each(|key| match key {
353 None => self.append_null(),
354 Some(original_dict_index) => {
355 let index = original_dict_index.as_usize().min(v_len - 1);
356 match mapped_values[index] {
357 None => self.append_null(),
358 Some(mapped_value) => self.keys_builder.append_value(mapped_value),
359 }
360 }
361 });
362
363 Ok(())
364 }
365
366 pub fn finish(&mut self) -> DictionaryArray<K> {
368 self.map.clear();
369 let values = self.values_builder.finish();
370 let keys = self.keys_builder.finish();
371
372 let data_type =
373 DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone()));
374
375 let builder = keys
376 .into_data()
377 .into_builder()
378 .data_type(data_type)
379 .child_data(vec![values.into_data()]);
380
381 DictionaryArray::from(unsafe { builder.build_unchecked() })
382 }
383
384 pub fn finish_cloned(&self) -> DictionaryArray<K> {
386 let values = self.values_builder.finish_cloned();
387 let keys = self.keys_builder.finish_cloned();
388
389 let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE));
390
391 let builder = keys
392 .into_data()
393 .into_builder()
394 .data_type(data_type)
395 .child_data(vec![values.into_data()]);
396
397 DictionaryArray::from(unsafe { builder.build_unchecked() })
398 }
399
400 pub fn values_slice(&self) -> &[V::Native] {
402 self.values_builder.values_slice()
403 }
404
405 pub fn values_slice_mut(&mut self) -> &mut [V::Native] {
407 self.values_builder.values_slice_mut()
408 }
409
410 pub fn validity_slice(&self) -> Option<&[u8]> {
412 self.keys_builder.validity_slice()
413 }
414}
415
416impl<K: ArrowDictionaryKeyType, P: ArrowPrimitiveType> Extend<Option<P::Native>>
417 for PrimitiveDictionaryBuilder<K, P>
418{
419 #[inline]
420 fn extend<T: IntoIterator<Item = Option<P::Native>>>(&mut self, iter: T) {
421 for v in iter {
422 self.append_option(v)
423 }
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 use crate::array::{Int32Array, UInt32Array, UInt8Array};
432 use crate::builder::Decimal128Builder;
433 use crate::cast::AsArray;
434 use crate::types::{Decimal128Type, Int32Type, UInt32Type, UInt8Type};
435
436 #[test]
437 fn test_primitive_dictionary_builder() {
438 let mut builder = PrimitiveDictionaryBuilder::<UInt8Type, UInt32Type>::with_capacity(3, 2);
439 builder.append(12345678).unwrap();
440 builder.append_null();
441 builder.append(22345678).unwrap();
442 let array = builder.finish();
443
444 assert_eq!(
445 array.keys(),
446 &UInt8Array::from(vec![Some(0), None, Some(1)])
447 );
448
449 let av = array.values();
451 let ava: &UInt32Array = av.as_any().downcast_ref::<UInt32Array>().unwrap();
452 let avs: &[u32] = ava.values();
453
454 assert!(!array.is_null(0));
455 assert!(array.is_null(1));
456 assert!(!array.is_null(2));
457
458 assert_eq!(avs, &[12345678, 22345678]);
459 }
460
461 #[test]
462 fn test_extend() {
463 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
464 builder.extend([1, 2, 3, 1, 2, 3, 1, 2, 3].into_iter().map(Some));
465 builder.extend([4, 5, 1, 3, 1].into_iter().map(Some));
466 let dict = builder.finish();
467 assert_eq!(
468 dict.keys().values(),
469 &[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 0, 2, 0]
470 );
471 assert_eq!(dict.values().len(), 5);
472 }
473
474 #[test]
475 #[should_panic(expected = "DictionaryKeyOverflowError")]
476 fn test_primitive_dictionary_overflow() {
477 let mut builder =
478 PrimitiveDictionaryBuilder::<UInt8Type, UInt32Type>::with_capacity(257, 257);
479 for i in 0..256 {
481 builder.append(i + 1000).unwrap();
482 }
483 builder.append(1257).unwrap();
485 }
486
487 #[test]
488 fn test_primitive_dictionary_with_builders() {
489 let keys_builder = PrimitiveBuilder::<Int32Type>::new();
490 let values_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2));
491 let mut builder =
492 PrimitiveDictionaryBuilder::<Int32Type, Decimal128Type>::new_from_empty_builders(
493 keys_builder,
494 values_builder,
495 );
496 let dict_array = builder.finish();
497 assert_eq!(dict_array.value_type(), DataType::Decimal128(1, 2));
498 assert_eq!(
499 dict_array.data_type(),
500 &DataType::Dictionary(
501 Box::new(DataType::Int32),
502 Box::new(DataType::Decimal128(1, 2)),
503 )
504 );
505 }
506
507 #[test]
508 fn test_extend_dictionary() {
509 let some_dict = {
510 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
511 builder.extend([1, 2, 3, 1, 2, 3, 1, 2, 3].into_iter().map(Some));
512 builder.extend([None::<i32>]);
513 builder.extend([4, 5, 1, 3, 1].into_iter().map(Some));
514 builder.append_null();
515 builder.finish()
516 };
517
518 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
519 builder.extend([6, 6, 7, 6, 5].into_iter().map(Some));
520 builder
521 .extend_dictionary(&some_dict.downcast_dict().unwrap())
522 .unwrap();
523 let dict = builder.finish();
524
525 assert_eq!(dict.values().len(), 7);
526
527 let values = dict
528 .downcast_dict::<Int32Array>()
529 .unwrap()
530 .into_iter()
531 .collect::<Vec<_>>();
532
533 assert_eq!(
534 values,
535 [
536 Some(6),
537 Some(6),
538 Some(7),
539 Some(6),
540 Some(5),
541 Some(1),
542 Some(2),
543 Some(3),
544 Some(1),
545 Some(2),
546 Some(3),
547 Some(1),
548 Some(2),
549 Some(3),
550 None,
551 Some(4),
552 Some(5),
553 Some(1),
554 Some(3),
555 Some(1),
556 None
557 ]
558 );
559 }
560
561 #[test]
562 fn test_extend_dictionary_with_null_in_mapped_value() {
563 let some_dict = {
564 let mut values_builder = PrimitiveBuilder::<Int32Type>::new();
565 let mut keys_builder = PrimitiveBuilder::<Int32Type>::new();
566
567 values_builder.append_null();
569 keys_builder.append_value(0);
570 values_builder.append_value(42);
571 keys_builder.append_value(1);
572
573 let values = values_builder.finish();
574 let keys = keys_builder.finish();
575
576 let data_type = DataType::Dictionary(
577 Box::new(Int32Type::DATA_TYPE),
578 Box::new(values.data_type().clone()),
579 );
580
581 let builder = keys
582 .into_data()
583 .into_builder()
584 .data_type(data_type)
585 .child_data(vec![values.into_data()]);
586
587 DictionaryArray::from(unsafe { builder.build_unchecked() })
588 };
589
590 let some_dict_values = some_dict.values().as_primitive::<Int32Type>();
591 assert_eq!(
592 some_dict_values.into_iter().collect::<Vec<_>>(),
593 &[None, Some(42)]
594 );
595
596 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
597 builder
598 .extend_dictionary(&some_dict.downcast_dict().unwrap())
599 .unwrap();
600 let dict = builder.finish();
601
602 assert_eq!(dict.values().len(), 1);
603
604 let values = dict
605 .downcast_dict::<Int32Array>()
606 .unwrap()
607 .into_iter()
608 .collect::<Vec<_>>();
609
610 assert_eq!(values, [None, Some(42)]);
611 }
612
613 #[test]
614 fn test_extend_all_null_dictionary() {
615 let some_dict = {
616 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
617 builder.append_nulls(2);
618 builder.finish()
619 };
620
621 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
622 builder
623 .extend_dictionary(&some_dict.downcast_dict().unwrap())
624 .unwrap();
625 let dict = builder.finish();
626
627 assert_eq!(dict.values().len(), 0);
628
629 let values = dict
630 .downcast_dict::<Int32Array>()
631 .unwrap()
632 .into_iter()
633 .collect::<Vec<_>>();
634
635 assert_eq!(values, [None, None]);
636 }
637
638 #[test]
639 fn creating_dictionary_from_builders_should_use_values_capacity_for_the_map() {
640 let builder = PrimitiveDictionaryBuilder::<Int32Type, crate::types::TimestampMicrosecondType>::new_from_empty_builders(
641 PrimitiveBuilder::with_capacity(1).with_data_type(DataType::Int32),
642 PrimitiveBuilder::with_capacity(2).with_data_type(DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, Some("+08:00".into()))),
643 );
644
645 assert!(
646 builder.map.capacity() >= builder.values_builder.capacity(),
647 "map capacity {} should be at least the values capacity {}",
648 builder.map.capacity(),
649 builder.values_builder.capacity()
650 )
651 }
652}