1use std::fmt;
8use std::sync::Arc;
9
10use arrow_array::cast::{as_dictionary_array, as_primitive_array};
11use arrow_array::types::{
12 ArrowDictionaryKeyType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
13 UInt64Type, UInt8Type,
14};
15use arrow_array::{Array, ArrayRef, DictionaryArray, PrimitiveArray, UInt32Array};
16use arrow_schema::DataType;
17use async_trait::async_trait;
18use snafu::location;
19
20use crate::{
21 traits::{Reader, Writer},
22 ReadBatchParams,
23};
24use lance_core::{Error, Result};
25
26use super::plain::PlainEncoder;
27use super::AsyncIndex;
28use crate::encodings::plain::PlainDecoder;
29use crate::encodings::{Decoder, Encoder};
30
31pub struct DictionaryEncoder<'a> {
33 writer: &'a mut dyn Writer,
34 key_type: &'a DataType,
35}
36
37impl<'a> DictionaryEncoder<'a> {
38 pub fn new(writer: &'a mut dyn Writer, key_type: &'a DataType) -> Self {
39 Self { writer, key_type }
40 }
41
42 async fn write_typed_array<T: ArrowDictionaryKeyType>(
43 &mut self,
44 arrs: &[&dyn Array],
45 ) -> Result<usize> {
46 assert!(!arrs.is_empty());
47 let data_type = arrs[0].data_type();
48 let pos = self.writer.tell().await?;
49 let mut plain_encoder = PlainEncoder::new(self.writer, data_type);
50
51 let keys = arrs
52 .iter()
53 .map(|a| {
54 let dict_arr = as_dictionary_array::<T>(*a);
55 dict_arr.keys() as &dyn Array
56 })
57 .collect::<Vec<_>>();
58
59 plain_encoder.encode(keys.as_slice()).await?;
60 Ok(pos)
61 }
62}
63
64#[async_trait]
65impl Encoder for DictionaryEncoder<'_> {
66 async fn encode(&mut self, array: &[&dyn Array]) -> Result<usize> {
67 use DataType::*;
68
69 match self.key_type {
70 UInt8 => self.write_typed_array::<UInt8Type>(array).await,
71 UInt16 => self.write_typed_array::<UInt16Type>(array).await,
72 UInt32 => self.write_typed_array::<UInt32Type>(array).await,
73 UInt64 => self.write_typed_array::<UInt64Type>(array).await,
74 Int8 => self.write_typed_array::<Int8Type>(array).await,
75 Int16 => self.write_typed_array::<Int16Type>(array).await,
76 Int32 => self.write_typed_array::<Int32Type>(array).await,
77 Int64 => self.write_typed_array::<Int64Type>(array).await,
78 _ => Err(Error::Schema {
79 message: format!(
80 "DictionaryEncoder: unsupported key type: {:?}",
81 self.key_type
82 ),
83 location: location!(),
84 }),
85 }
86 }
87}
88
89pub struct DictionaryDecoder<'a> {
91 reader: &'a dyn Reader,
92 position: usize,
94 length: usize,
96 data_type: &'a DataType,
98 value_arr: ArrayRef,
100}
101
102impl fmt::Debug for DictionaryDecoder<'_> {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 f.debug_struct("DictionaryDecoder")
105 .field("position", &self.position)
106 .field("length", &self.length)
107 .field("data_type", &self.data_type)
108 .field("value_arr", &self.value_arr)
109 .finish()
110 }
111}
112
113impl<'a> DictionaryDecoder<'a> {
114 pub fn new(
115 reader: &'a dyn Reader,
116 position: usize,
117 length: usize,
118 data_type: &'a DataType,
119 value_arr: ArrayRef,
120 ) -> Self {
121 assert!(matches!(data_type, DataType::Dictionary(_, _)));
122 Self {
123 reader,
124 position,
125 length,
126 data_type,
127 value_arr,
128 }
129 }
130
131 async fn decode_impl(&self, params: impl Into<ReadBatchParams>) -> Result<ArrayRef> {
132 let index_type = if let DataType::Dictionary(key_type, _) = &self.data_type {
133 assert!(key_type.as_ref().is_dictionary_key_type());
134 key_type.as_ref()
135 } else {
136 return Err(Error::Arrow {
137 message: format!("Not a dictionary type: {}", self.data_type),
138 location: location!(),
139 });
140 };
141
142 let decoder = PlainDecoder::new(self.reader, index_type, self.position, self.length)?;
143 let keys = decoder.get(params.into()).await?;
144
145 match index_type {
146 DataType::Int8 => self.make_dict_array::<Int8Type>(keys).await,
147 DataType::Int16 => self.make_dict_array::<Int16Type>(keys).await,
148 DataType::Int32 => self.make_dict_array::<Int32Type>(keys).await,
149 DataType::Int64 => self.make_dict_array::<Int64Type>(keys).await,
150 DataType::UInt8 => self.make_dict_array::<UInt8Type>(keys).await,
151 DataType::UInt16 => self.make_dict_array::<UInt16Type>(keys).await,
152 DataType::UInt32 => self.make_dict_array::<UInt32Type>(keys).await,
153 DataType::UInt64 => self.make_dict_array::<UInt64Type>(keys).await,
154 _ => Err(Error::Arrow {
155 message: format!("Dictionary encoding does not support index type: {index_type}",),
156 location: location!(),
157 }),
158 }
159 }
160
161 async fn make_dict_array<T: ArrowDictionaryKeyType + Sync + Send>(
162 &self,
163 index_array: ArrayRef,
164 ) -> Result<ArrayRef> {
165 let keys: PrimitiveArray<T> = as_primitive_array(index_array.as_ref()).clone();
166 Ok(Arc::new(DictionaryArray::try_new(
167 keys,
168 self.value_arr.clone(),
169 )?))
170 }
171}
172
173#[async_trait]
174impl Decoder for DictionaryDecoder<'_> {
175 async fn decode(&self) -> Result<ArrayRef> {
176 self.decode_impl(..).await
177 }
178
179 async fn take(&self, indices: &UInt32Array) -> Result<ArrayRef> {
180 self.decode_impl(indices.clone()).await
181 }
182}
183
184#[async_trait]
185impl AsyncIndex<usize> for DictionaryDecoder<'_> {
186 type Output = Result<ArrayRef>;
187
188 async fn get(&self, _index: usize) -> Self::Output {
189 Err(Error::NotSupported {
190 source: "DictionaryDecoder does not support get()"
191 .to_string()
192 .into(),
193 location: location!(),
194 })
195 }
196}
197
198#[async_trait]
199impl AsyncIndex<ReadBatchParams> for DictionaryDecoder<'_> {
200 type Output = Result<ArrayRef>;
201
202 async fn get(&self, params: ReadBatchParams) -> Self::Output {
203 self.decode_impl(params.clone()).await
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 use crate::local::LocalObjectReader;
212 use arrow_array::StringArray;
213 use arrow_buffer::ArrowNativeType;
214 use tokio::io::AsyncWriteExt;
215
216 async fn test_dict_decoder_for_type<T: ArrowDictionaryKeyType>() {
217 let value_array: StringArray = vec![Some("a"), Some("b"), Some("c"), Some("d")]
218 .into_iter()
219 .collect();
220 let value_array_ref = Arc::new(value_array) as ArrayRef;
221
222 let keys1: PrimitiveArray<T> = vec![T::Native::from_usize(0), T::Native::from_usize(1)]
223 .into_iter()
224 .collect();
225 let arr1: DictionaryArray<T> =
226 DictionaryArray::try_new(keys1, value_array_ref.clone()).unwrap();
227
228 let keys2: PrimitiveArray<T> = vec![T::Native::from_usize(1), T::Native::from_usize(3)]
229 .into_iter()
230 .collect();
231 let arr2: DictionaryArray<T> =
232 DictionaryArray::try_new(keys2, value_array_ref.clone()).unwrap();
233
234 let keys1_ref = arr1.keys() as &dyn Array;
235 let keys2_ref = arr2.keys() as &dyn Array;
236 let arrs: Vec<&dyn Array> = vec![keys1_ref, keys2_ref];
237
238 let temp_dir = tempfile::tempdir().unwrap();
239 let path = temp_dir.path().join("foo");
240
241 let pos;
242 {
243 let mut object_writer = tokio::fs::File::create(&path).await.unwrap();
244 let mut encoder = PlainEncoder::new(&mut object_writer, arr1.keys().data_type());
245 pos = encoder.encode(arrs.as_slice()).await.unwrap();
246 object_writer.shutdown().await.unwrap();
247 }
248
249 let reader = LocalObjectReader::open_local_path(&path, 2048, None)
250 .await
251 .unwrap();
252 let decoder = DictionaryDecoder::new(
253 reader.as_ref(),
254 pos,
255 arr1.len() + arr2.len(),
256 arr1.data_type(),
257 value_array_ref.clone(),
258 );
259
260 let decoded_data = decoder.decode().await.unwrap();
261 let expected_data: DictionaryArray<T> = vec!["a", "b", "b", "d"].into_iter().collect();
262 assert_eq!(
263 &expected_data,
264 decoded_data
265 .as_any()
266 .downcast_ref::<DictionaryArray<T>>()
267 .unwrap()
268 );
269 }
270
271 #[tokio::test]
272 async fn test_dict_decoder() {
273 test_dict_decoder_for_type::<Int8Type>().await;
274 test_dict_decoder_for_type::<Int16Type>().await;
275 test_dict_decoder_for_type::<Int32Type>().await;
276 test_dict_decoder_for_type::<Int64Type>().await;
277
278 test_dict_decoder_for_type::<UInt8Type>().await;
279 test_dict_decoder_for_type::<UInt16Type>().await;
280 test_dict_decoder_for_type::<UInt32Type>().await;
281 test_dict_decoder_for_type::<UInt64Type>().await;
282 }
283}