lance_io/encodings/
dictionary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Dictionary encoding.
5//!
6
7use std::fmt;
8use std::sync::Arc;
9
10use arrow_array::cast::{as_dictionary_array, as_primitive_array};
11use arrow_array::types::{
12    ArrowDictionaryKeyType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
13    UInt64Type, UInt8Type,
14};
15use arrow_array::{Array, ArrayRef, DictionaryArray, PrimitiveArray, UInt32Array};
16use arrow_schema::DataType;
17use async_trait::async_trait;
18use snafu::location;
19
20use crate::{
21    traits::{Reader, Writer},
22    ReadBatchParams,
23};
24use lance_core::{Error, Result};
25
26use super::plain::PlainEncoder;
27use super::AsyncIndex;
28use crate::encodings::plain::PlainDecoder;
29use crate::encodings::{Decoder, Encoder};
30
31/// Encoder for Dictionary encoding.
32pub struct DictionaryEncoder<'a> {
33    writer: &'a mut dyn Writer,
34    key_type: &'a DataType,
35}
36
37impl<'a> DictionaryEncoder<'a> {
38    pub fn new(writer: &'a mut dyn Writer, key_type: &'a DataType) -> Self {
39        Self { writer, key_type }
40    }
41
42    async fn write_typed_array<T: ArrowDictionaryKeyType>(
43        &mut self,
44        arrs: &[&dyn Array],
45    ) -> Result<usize> {
46        assert!(!arrs.is_empty());
47        let data_type = arrs[0].data_type();
48        let pos = self.writer.tell().await?;
49        let mut plain_encoder = PlainEncoder::new(self.writer, data_type);
50
51        let keys = arrs
52            .iter()
53            .map(|a| {
54                let dict_arr = as_dictionary_array::<T>(*a);
55                dict_arr.keys() as &dyn Array
56            })
57            .collect::<Vec<_>>();
58
59        plain_encoder.encode(keys.as_slice()).await?;
60        Ok(pos)
61    }
62}
63
64#[async_trait]
65impl Encoder for DictionaryEncoder<'_> {
66    async fn encode(&mut self, array: &[&dyn Array]) -> Result<usize> {
67        use DataType::*;
68
69        match self.key_type {
70            UInt8 => self.write_typed_array::<UInt8Type>(array).await,
71            UInt16 => self.write_typed_array::<UInt16Type>(array).await,
72            UInt32 => self.write_typed_array::<UInt32Type>(array).await,
73            UInt64 => self.write_typed_array::<UInt64Type>(array).await,
74            Int8 => self.write_typed_array::<Int8Type>(array).await,
75            Int16 => self.write_typed_array::<Int16Type>(array).await,
76            Int32 => self.write_typed_array::<Int32Type>(array).await,
77            Int64 => self.write_typed_array::<Int64Type>(array).await,
78            _ => Err(Error::Schema {
79                message: format!(
80                    "DictionaryEncoder: unsupported key type: {:?}",
81                    self.key_type
82                ),
83                location: location!(),
84            }),
85        }
86    }
87}
88
89/// Decoder for Dictionary encoding.
90pub struct DictionaryDecoder<'a> {
91    reader: &'a dyn Reader,
92    /// The start position of the key array in the file.
93    position: usize,
94    /// Number of the rows in this batch.
95    length: usize,
96    /// The dictionary data type
97    data_type: &'a DataType,
98    /// Value array,
99    value_arr: ArrayRef,
100}
101
102impl fmt::Debug for DictionaryDecoder<'_> {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        f.debug_struct("DictionaryDecoder")
105            .field("position", &self.position)
106            .field("length", &self.length)
107            .field("data_type", &self.data_type)
108            .field("value_arr", &self.value_arr)
109            .finish()
110    }
111}
112
113impl<'a> DictionaryDecoder<'a> {
114    pub fn new(
115        reader: &'a dyn Reader,
116        position: usize,
117        length: usize,
118        data_type: &'a DataType,
119        value_arr: ArrayRef,
120    ) -> Self {
121        assert!(matches!(data_type, DataType::Dictionary(_, _)));
122        Self {
123            reader,
124            position,
125            length,
126            data_type,
127            value_arr,
128        }
129    }
130
131    async fn decode_impl(&self, params: impl Into<ReadBatchParams>) -> Result<ArrayRef> {
132        let index_type = if let DataType::Dictionary(key_type, _) = &self.data_type {
133            assert!(key_type.as_ref().is_dictionary_key_type());
134            key_type.as_ref()
135        } else {
136            return Err(Error::Arrow {
137                message: format!("Not a dictionary type: {}", self.data_type),
138                location: location!(),
139            });
140        };
141
142        let decoder = PlainDecoder::new(self.reader, index_type, self.position, self.length)?;
143        let keys = decoder.get(params.into()).await?;
144
145        match index_type {
146            DataType::Int8 => self.make_dict_array::<Int8Type>(keys).await,
147            DataType::Int16 => self.make_dict_array::<Int16Type>(keys).await,
148            DataType::Int32 => self.make_dict_array::<Int32Type>(keys).await,
149            DataType::Int64 => self.make_dict_array::<Int64Type>(keys).await,
150            DataType::UInt8 => self.make_dict_array::<UInt8Type>(keys).await,
151            DataType::UInt16 => self.make_dict_array::<UInt16Type>(keys).await,
152            DataType::UInt32 => self.make_dict_array::<UInt32Type>(keys).await,
153            DataType::UInt64 => self.make_dict_array::<UInt64Type>(keys).await,
154            _ => Err(Error::Arrow {
155                message: format!("Dictionary encoding does not support index type: {index_type}",),
156                location: location!(),
157            }),
158        }
159    }
160
161    async fn make_dict_array<T: ArrowDictionaryKeyType + Sync + Send>(
162        &self,
163        index_array: ArrayRef,
164    ) -> Result<ArrayRef> {
165        let keys: PrimitiveArray<T> = as_primitive_array(index_array.as_ref()).clone();
166        Ok(Arc::new(DictionaryArray::try_new(
167            keys,
168            self.value_arr.clone(),
169        )?))
170    }
171}
172
173#[async_trait]
174impl Decoder for DictionaryDecoder<'_> {
175    async fn decode(&self) -> Result<ArrayRef> {
176        self.decode_impl(..).await
177    }
178
179    async fn take(&self, indices: &UInt32Array) -> Result<ArrayRef> {
180        self.decode_impl(indices.clone()).await
181    }
182}
183
184#[async_trait]
185impl AsyncIndex<usize> for DictionaryDecoder<'_> {
186    type Output = Result<ArrayRef>;
187
188    async fn get(&self, _index: usize) -> Self::Output {
189        Err(Error::NotSupported {
190            source: "DictionaryDecoder does not support get()"
191                .to_string()
192                .into(),
193            location: location!(),
194        })
195    }
196}
197
198#[async_trait]
199impl AsyncIndex<ReadBatchParams> for DictionaryDecoder<'_> {
200    type Output = Result<ArrayRef>;
201
202    async fn get(&self, params: ReadBatchParams) -> Self::Output {
203        self.decode_impl(params.clone()).await
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    use crate::local::LocalObjectReader;
212    use arrow_array::StringArray;
213    use arrow_buffer::ArrowNativeType;
214    use tokio::io::AsyncWriteExt;
215
216    async fn test_dict_decoder_for_type<T: ArrowDictionaryKeyType>() {
217        let value_array: StringArray = vec![Some("a"), Some("b"), Some("c"), Some("d")]
218            .into_iter()
219            .collect();
220        let value_array_ref = Arc::new(value_array) as ArrayRef;
221
222        let keys1: PrimitiveArray<T> = vec![T::Native::from_usize(0), T::Native::from_usize(1)]
223            .into_iter()
224            .collect();
225        let arr1: DictionaryArray<T> =
226            DictionaryArray::try_new(keys1, value_array_ref.clone()).unwrap();
227
228        let keys2: PrimitiveArray<T> = vec![T::Native::from_usize(1), T::Native::from_usize(3)]
229            .into_iter()
230            .collect();
231        let arr2: DictionaryArray<T> =
232            DictionaryArray::try_new(keys2, value_array_ref.clone()).unwrap();
233
234        let keys1_ref = arr1.keys() as &dyn Array;
235        let keys2_ref = arr2.keys() as &dyn Array;
236        let arrs: Vec<&dyn Array> = vec![keys1_ref, keys2_ref];
237
238        let temp_dir = tempfile::tempdir().unwrap();
239        let path = temp_dir.path().join("foo");
240
241        let pos;
242        {
243            let mut object_writer = tokio::fs::File::create(&path).await.unwrap();
244            let mut encoder = PlainEncoder::new(&mut object_writer, arr1.keys().data_type());
245            pos = encoder.encode(arrs.as_slice()).await.unwrap();
246            object_writer.shutdown().await.unwrap();
247        }
248
249        let reader = LocalObjectReader::open_local_path(&path, 2048, None)
250            .await
251            .unwrap();
252        let decoder = DictionaryDecoder::new(
253            reader.as_ref(),
254            pos,
255            arr1.len() + arr2.len(),
256            arr1.data_type(),
257            value_array_ref.clone(),
258        );
259
260        let decoded_data = decoder.decode().await.unwrap();
261        let expected_data: DictionaryArray<T> = vec!["a", "b", "b", "d"].into_iter().collect();
262        assert_eq!(
263            &expected_data,
264            decoded_data
265                .as_any()
266                .downcast_ref::<DictionaryArray<T>>()
267                .unwrap()
268        );
269    }
270
271    #[tokio::test]
272    async fn test_dict_decoder() {
273        test_dict_decoder_for_type::<Int8Type>().await;
274        test_dict_decoder_for_type::<Int16Type>().await;
275        test_dict_decoder_for_type::<Int32Type>().await;
276        test_dict_decoder_for_type::<Int64Type>().await;
277
278        test_dict_decoder_for_type::<UInt8Type>().await;
279        test_dict_decoder_for_type::<UInt16Type>().await;
280        test_dict_decoder_for_type::<UInt32Type>().await;
281        test_dict_decoder_for_type::<UInt64Type>().await;
282    }
283}