polars_row/
decode.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
use arrow::bitmap::{Bitmap, BitmapBuilder};
use arrow::buffer::Buffer;
use arrow::datatypes::ArrowDataType;
use arrow::offset::OffsetsBuffer;

use self::encode::fixed_size;
use self::row::{RowEncodingCategoricalContext, RowEncodingOptions};
use self::variable::utf8::decode_str;
use super::*;
use crate::fixed::{boolean, decimal, numeric, packed_u32};
use crate::variable::{binary, no_order, utf8};

/// Decode `rows` into a arrow format
/// # Safety
/// This will not do any bound checks. Caller must ensure the `rows` are valid
/// encodings.
pub unsafe fn decode_rows_from_binary<'a>(
    arr: &'a BinaryArray<i64>,
    opts: &[RowEncodingOptions],
    dicts: &[Option<RowEncodingContext>],
    dtypes: &[ArrowDataType],
    rows: &mut Vec<&'a [u8]>,
) -> Vec<ArrayRef> {
    assert_eq!(arr.null_count(), 0);
    rows.clear();
    rows.extend(arr.values_iter());
    decode_rows(rows, opts, dicts, dtypes)
}

/// Decode `rows` into a arrow format
/// # Safety
/// This will not do any bound checks. Caller must ensure the `rows` are valid
/// encodings.
pub unsafe fn decode_rows(
    // the rows will be updated while the data is decoded
    rows: &mut [&[u8]],
    opts: &[RowEncodingOptions],
    dicts: &[Option<RowEncodingContext>],
    dtypes: &[ArrowDataType],
) -> Vec<ArrayRef> {
    assert_eq!(opts.len(), dtypes.len());
    assert_eq!(dicts.len(), dtypes.len());

    dtypes
        .iter()
        .zip(opts)
        .zip(dicts)
        .map(|((dtype, opt), dict)| decode(rows, *opt, dict.as_ref(), dtype))
        .collect()
}

unsafe fn decode_validity(rows: &mut [&[u8]], opt: RowEncodingOptions) -> Option<Bitmap> {
    // 2 loop system to avoid the overhead of allocating the bitmap if all the elements are valid.

    let null_sentinel = opt.null_sentinel();
    let first_null = (0..rows.len()).find(|&i| {
        let v;
        (v, rows[i]) = rows[i].split_at_unchecked(1);
        v[0] == null_sentinel
    });

    // No nulls just return None
    let first_null = first_null?;

    let mut bm = BitmapBuilder::new();
    bm.reserve(rows.len());
    bm.extend_constant(first_null, true);
    bm.push(false);
    bm.extend_trusted_len_iter(rows[first_null + 1..].iter_mut().map(|row| {
        let v;
        (v, *row) = row.split_at_unchecked(1);
        v[0] != null_sentinel
    }));
    bm.into_opt_validity()
}

// We inline this in an attempt to avoid the dispatch cost.
#[inline(always)]
fn dtype_and_data_to_encoded_item_len(
    dtype: &ArrowDataType,
    data: &[u8],
    opt: RowEncodingOptions,
    dict: Option<&RowEncodingContext>,
) -> usize {
    // Fast path: if the size is fixed, we can just divide.
    if let Some(size) = fixed_size(dtype, dict) {
        return size;
    }

    use ArrowDataType as D;
    match dtype {
        D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View
            if opt.contains(RowEncodingOptions::NO_ORDER) =>
        unsafe { no_order::len_from_buffer(data, opt) },
        D::Binary | D::LargeBinary | D::BinaryView => unsafe {
            binary::encoded_item_len(data, opt)
        },
        D::Utf8 | D::LargeUtf8 | D::Utf8View => unsafe { utf8::len_from_buffer(data, opt) },

        D::List(list_field) | D::LargeList(list_field) => {
            let mut data = data;
            let mut item_len = 0;

            let list_continuation_token = opt.list_continuation_token();

            while data[0] == list_continuation_token {
                data = &data[1..];
                let len = dtype_and_data_to_encoded_item_len(list_field.dtype(), data, opt, dict);
                data = &data[len..];
                item_len += 1 + len;
            }
            1 + item_len
        },

        D::FixedSizeBinary(_) => todo!(),
        D::FixedSizeList(fsl_field, width) => {
            let mut data = &data[1..];
            let mut item_len = 1; // validity byte

            for _ in 0..*width {
                let len = dtype_and_data_to_encoded_item_len(fsl_field.dtype(), data, opt, dict);
                data = &data[len..];
                item_len += len;
            }
            item_len
        },
        D::Struct(struct_fields) => {
            let mut data = &data[1..];
            let mut item_len = 1; // validity byte

            for struct_field in struct_fields {
                let len = dtype_and_data_to_encoded_item_len(struct_field.dtype(), data, opt, dict);
                data = &data[len..];
                item_len += len;
            }
            item_len
        },

        D::Union(_) => todo!(),
        D::Map(_, _) => todo!(),
        D::Decimal256(_, _) => todo!(),
        D::Extension(_) => todo!(),
        D::Unknown => todo!(),

        _ => unreachable!(),
    }
}

fn rows_for_fixed_size_list<'a>(
    dtype: &ArrowDataType,
    opt: RowEncodingOptions,
    dict: Option<&RowEncodingContext>,
    width: usize,
    rows: &mut [&'a [u8]],
    nested_rows: &mut Vec<&'a [u8]>,
) {
    nested_rows.clear();
    nested_rows.reserve(rows.len() * width);

    // Fast path: if the size is fixed, we can just divide.
    if let Some(size) = fixed_size(dtype, dict) {
        for row in rows.iter_mut() {
            for i in 0..width {
                nested_rows.push(&row[(i * size)..][..size]);
            }
            *row = &row[size * width..];
        }
        return;
    }

    // @TODO: This is quite slow since we need to dispatch for possibly every nested type
    for row in rows.iter_mut() {
        for _ in 0..width {
            let length = dtype_and_data_to_encoded_item_len(dtype, row, opt, dict);
            let v;
            (v, *row) = row.split_at(length);
            nested_rows.push(v);
        }
    }
}

unsafe fn decode_lexical_cat(
    rows: &mut [&[u8]],
    opt: RowEncodingOptions,
    _values: &RowEncodingCategoricalContext,
) -> PrimitiveArray<u32> {
    let mut s = numeric::decode_primitive::<u32>(rows, opt);
    numeric::decode_primitive::<u32>(rows, opt).with_validity(s.take_validity())
}

unsafe fn decode(
    rows: &mut [&[u8]],
    opt: RowEncodingOptions,
    dict: Option<&RowEncodingContext>,
    dtype: &ArrowDataType,
) -> ArrayRef {
    use ArrowDataType as D;
    match dtype {
        D::Null => NullArray::new(D::Null, rows.len()).to_boxed(),
        D::Boolean => boolean::decode_bool(rows, opt).to_boxed(),
        D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View
            if opt.contains(RowEncodingOptions::NO_ORDER) =>
        {
            let array = no_order::decode_variable_no_order(rows, opt);

            if matches!(dtype, D::Utf8 | D::LargeUtf8 | D::Utf8View) {
                unsafe { array.to_utf8view_unchecked() }.to_boxed()
            } else {
                array.to_boxed()
            }
        },
        D::Binary | D::LargeBinary | D::BinaryView => binary::decode_binview(rows, opt).to_boxed(),
        D::Utf8 | D::LargeUtf8 | D::Utf8View => decode_str(rows, opt).boxed(),

        D::Struct(fields) => {
            let validity = decode_validity(rows, opt);

            let values = match dict {
                None => fields
                    .iter()
                    .map(|struct_fld| decode(rows, opt, None, struct_fld.dtype()))
                    .collect(),
                Some(RowEncodingContext::Struct(dicts)) => fields
                    .iter()
                    .zip(dicts)
                    .map(|(struct_fld, dict)| decode(rows, opt, dict.as_ref(), struct_fld.dtype()))
                    .collect(),
                _ => unreachable!(),
            };
            StructArray::new(dtype.clone(), rows.len(), values, validity).to_boxed()
        },
        D::FixedSizeList(fsl_field, width) => {
            let validity = decode_validity(rows, opt);

            // @TODO: we could consider making this into a scratchpad
            let mut nested_rows = Vec::new();
            rows_for_fixed_size_list(fsl_field.dtype(), opt, dict, *width, rows, &mut nested_rows);
            let values = decode(&mut nested_rows, opt, dict, fsl_field.dtype());

            FixedSizeListArray::new(dtype.clone(), rows.len(), values, validity).to_boxed()
        },
        D::List(list_field) | D::LargeList(list_field) => {
            let mut validity = BitmapBuilder::new();

            // @TODO: we could consider making this into a scratchpad
            let num_rows = rows.len();
            let mut nested_rows = Vec::new();
            let mut offsets = Vec::with_capacity(rows.len() + 1);
            offsets.push(0);

            let list_null_sentinel = opt.list_null_sentinel();
            let list_continuation_token = opt.list_continuation_token();
            let list_termination_token = opt.list_termination_token();

            // @TODO: make a specialized loop for fixed size list_field.dtype()
            for (i, row) in rows.iter_mut().enumerate() {
                while row[0] == list_continuation_token {
                    *row = &row[1..];
                    let len =
                        dtype_and_data_to_encoded_item_len(list_field.dtype(), row, opt, dict);
                    nested_rows.push(&row[..len]);
                    *row = &row[len..];
                }

                offsets.push(nested_rows.len() as i64);

                // @TODO: Might be better to make this a 2-loop system.
                if row[0] == list_null_sentinel {
                    *row = &row[1..];
                    validity.reserve(num_rows);
                    validity.extend_constant(i - validity.len(), true);
                    validity.push(false);
                    continue;
                }

                assert_eq!(row[0], list_termination_token);
                *row = &row[1..];
            }

            let validity = if validity.is_empty() {
                None
            } else {
                validity.extend_constant(num_rows - validity.len(), true);
                validity.into_opt_validity()
            };
            assert_eq!(offsets.len(), rows.len() + 1);

            let values = decode(&mut nested_rows, opt, dict, list_field.dtype());

            ListArray::<i64>::new(
                dtype.clone(),
                unsafe { OffsetsBuffer::new_unchecked(Buffer::from(offsets)) },
                values,
                validity,
            )
            .to_boxed()
        },

        dt => {
            if matches!(dt, D::UInt32) {
                if let Some(dict) = dict {
                    return match dict {
                        RowEncodingContext::Categorical(ctx) => {
                            if ctx.is_enum {
                                packed_u32::decode(rows, opt, ctx.needed_num_bits()).to_boxed()
                            } else if ctx.lexical_sort_idxs.is_none() {
                                numeric::decode_primitive::<u32>(rows, opt).to_boxed()
                            } else {
                                decode_lexical_cat(rows, opt, ctx).to_boxed()
                            }
                        },
                        _ => unreachable!(),
                    };
                }
            }

            if matches!(dt, D::Int128) {
                if let Some(dict) = dict {
                    return match dict {
                        RowEncodingContext::Decimal(precision) => {
                            decimal::decode(rows, opt, *precision).to_boxed()
                        },
                        _ => unreachable!(),
                    };
                }
            }

            with_match_arrow_primitive_type!(dt, |$T| {
                numeric::decode_primitive::<$T>(rows, opt).to_boxed()
            })
        },
    }
}