arrow_cast/
base64.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Functions for converting data in [`GenericBinaryArray`] such as [`StringArray`] to/from base64 encoded strings
19//!
20//! [`StringArray`]: arrow_array::StringArray
21
22use arrow_array::{Array, GenericBinaryArray, GenericStringArray, OffsetSizeTrait};
23use arrow_buffer::{Buffer, OffsetBuffer};
24use arrow_schema::ArrowError;
25use base64::encoded_len;
26use base64::engine::Config;
27
28pub use base64::prelude::*;
29
30/// Bas64 encode each element of `array` with the provided [`Engine`]
31pub fn b64_encode<E: Engine, O: OffsetSizeTrait>(
32    engine: &E,
33    array: &GenericBinaryArray<O>,
34) -> GenericStringArray<O> {
35    let lengths = array.offsets().windows(2).map(|w| {
36        let len = w[1].as_usize() - w[0].as_usize();
37        encoded_len(len, engine.config().encode_padding()).unwrap()
38    });
39    let offsets = OffsetBuffer::<O>::from_lengths(lengths);
40    let buffer_len = offsets.last().unwrap().as_usize();
41    let mut buffer = vec![0_u8; buffer_len];
42    let mut offset = 0;
43
44    for i in 0..array.len() {
45        let len = engine
46            .encode_slice(array.value(i), &mut buffer[offset..])
47            .unwrap();
48        offset += len;
49    }
50    assert_eq!(offset, buffer_len);
51
52    // Safety: Base64 is valid UTF-8
53    unsafe {
54        GenericStringArray::new_unchecked(offsets, Buffer::from_vec(buffer), array.nulls().cloned())
55    }
56}
57
58/// Base64 decode each element of `array` with the provided [`Engine`]
59pub fn b64_decode<E: Engine, O: OffsetSizeTrait>(
60    engine: &E,
61    array: &GenericBinaryArray<O>,
62) -> Result<GenericBinaryArray<O>, ArrowError> {
63    let estimated_len = array.values().len(); // This is an overestimate
64    let mut buffer = vec![0; estimated_len];
65
66    let mut offsets = Vec::with_capacity(array.len() + 1);
67    offsets.push(O::usize_as(0));
68    let mut offset = 0;
69
70    for v in array.iter() {
71        if let Some(v) = v {
72            let len = engine.decode_slice(v, &mut buffer[offset..]).unwrap();
73            // This cannot overflow as `len` is less than `v.len()` and `a` is valid
74            offset += len;
75        }
76        offsets.push(O::usize_as(offset));
77    }
78
79    // Safety: offsets monotonically increasing by construction
80    let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) };
81
82    Ok(GenericBinaryArray::new(
83        offsets,
84        Buffer::from_vec(buffer),
85        array.nulls().cloned(),
86    ))
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use arrow_array::BinaryArray;
93    use rand::{thread_rng, Rng};
94
95    fn test_engine<E: Engine>(e: &E, a: &BinaryArray) {
96        let encoded = b64_encode(e, a);
97        encoded.to_data().validate_full().unwrap();
98
99        let to_decode = encoded.into();
100        let decoded = b64_decode(e, &to_decode).unwrap();
101        decoded.to_data().validate_full().unwrap();
102
103        assert_eq!(&decoded, a);
104    }
105
106    #[test]
107    fn test_b64() {
108        let mut rng = thread_rng();
109        let len = rng.gen_range(1024..1050);
110        let data: BinaryArray = (0..len)
111            .map(|_| {
112                let len = rng.gen_range(0..16);
113                Some((0..len).map(|_| rng.gen()).collect::<Vec<u8>>())
114            })
115            .collect();
116
117        test_engine(&BASE64_STANDARD, &data);
118        test_engine(&BASE64_STANDARD_NO_PAD, &data);
119    }
120}