datafusion_functions/encoding/
inner.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Encoding expressions
19
20use arrow::{
21    array::{
22        Array, ArrayRef, BinaryArray, GenericByteArray, OffsetSizeTrait, StringArray,
23    },
24    datatypes::{ByteArrayType, DataType},
25};
26use arrow_buffer::{Buffer, OffsetBufferBuilder};
27use base64::{engine::general_purpose, Engine as _};
28use datafusion_common::{
29    cast::{as_generic_binary_array, as_generic_string_array},
30    not_impl_err, plan_err,
31    utils::take_function_args,
32};
33use datafusion_common::{exec_err, ScalarValue};
34use datafusion_common::{DataFusionError, Result};
35use datafusion_expr::{ColumnarValue, Documentation};
36use std::sync::Arc;
37use std::{fmt, str::FromStr};
38
39use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
40use datafusion_macros::user_doc;
41use std::any::Any;
42
43#[user_doc(
44    doc_section(label = "Binary String Functions"),
45    description = "Encode binary data into a textual representation.",
46    syntax_example = "encode(expression, format)",
47    argument(
48        name = "expression",
49        description = "Expression containing string or binary data"
50    ),
51    argument(
52        name = "format",
53        description = "Supported formats are: `base64`, `hex`"
54    ),
55    related_udf(name = "decode")
56)]
57#[derive(Debug)]
58pub struct EncodeFunc {
59    signature: Signature,
60}
61
62impl Default for EncodeFunc {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl EncodeFunc {
69    pub fn new() -> Self {
70        Self {
71            signature: Signature::user_defined(Volatility::Immutable),
72        }
73    }
74}
75
76impl ScalarUDFImpl for EncodeFunc {
77    fn as_any(&self) -> &dyn Any {
78        self
79    }
80    fn name(&self) -> &str {
81        "encode"
82    }
83
84    fn signature(&self) -> &Signature {
85        &self.signature
86    }
87
88    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
89        use DataType::*;
90
91        Ok(match arg_types[0] {
92            Utf8 => Utf8,
93            LargeUtf8 => LargeUtf8,
94            Utf8View => Utf8,
95            Binary => Utf8,
96            LargeBinary => LargeUtf8,
97            Null => Null,
98            _ => {
99                return plan_err!(
100                    "The encode function can only accept Utf8 or Binary or Null."
101                );
102            }
103        })
104    }
105
106    fn invoke_with_args(
107        &self,
108        args: datafusion_expr::ScalarFunctionArgs,
109    ) -> Result<ColumnarValue> {
110        encode(&args.args)
111    }
112
113    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
114        let [expression, format] = take_function_args(self.name(), arg_types)?;
115
116        if format != &DataType::Utf8 {
117            return Err(DataFusionError::Plan("2nd argument should be Utf8".into()));
118        }
119
120        match expression {
121            DataType::Utf8 | DataType::Utf8View | DataType::Null => {
122                Ok(vec![DataType::Utf8; 2])
123            }
124            DataType::LargeUtf8 => Ok(vec![DataType::LargeUtf8, DataType::Utf8]),
125            DataType::Binary => Ok(vec![DataType::Binary, DataType::Utf8]),
126            DataType::LargeBinary => Ok(vec![DataType::LargeBinary, DataType::Utf8]),
127            _ => plan_err!(
128                "1st argument should be Utf8 or Binary or Null, got {:?}",
129                arg_types[0]
130            ),
131        }
132    }
133
134    fn documentation(&self) -> Option<&Documentation> {
135        self.doc()
136    }
137}
138
139#[user_doc(
140    doc_section(label = "Binary String Functions"),
141    description = "Decode binary data from textual representation in string.",
142    syntax_example = "decode(expression, format)",
143    argument(
144        name = "expression",
145        description = "Expression containing encoded string data"
146    ),
147    argument(name = "format", description = "Same arguments as [encode](#encode)"),
148    related_udf(name = "encode")
149)]
150#[derive(Debug)]
151pub struct DecodeFunc {
152    signature: Signature,
153}
154
155impl Default for DecodeFunc {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl DecodeFunc {
162    pub fn new() -> Self {
163        Self {
164            signature: Signature::user_defined(Volatility::Immutable),
165        }
166    }
167}
168
169impl ScalarUDFImpl for DecodeFunc {
170    fn as_any(&self) -> &dyn Any {
171        self
172    }
173    fn name(&self) -> &str {
174        "decode"
175    }
176
177    fn signature(&self) -> &Signature {
178        &self.signature
179    }
180
181    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
182        Ok(arg_types[0].to_owned())
183    }
184
185    fn invoke_with_args(
186        &self,
187        args: datafusion_expr::ScalarFunctionArgs,
188    ) -> Result<ColumnarValue> {
189        decode(&args.args)
190    }
191
192    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
193        if arg_types.len() != 2 {
194            return plan_err!(
195                "{} expects to get 2 arguments, but got {}",
196                self.name(),
197                arg_types.len()
198            );
199        }
200
201        if arg_types[1] != DataType::Utf8 {
202            return plan_err!("2nd argument should be Utf8");
203        }
204
205        match arg_types[0] {
206            DataType::Utf8 | DataType::Utf8View | DataType::Binary | DataType::Null => {
207                Ok(vec![DataType::Binary, DataType::Utf8])
208            }
209            DataType::LargeUtf8 | DataType::LargeBinary => {
210                Ok(vec![DataType::LargeBinary, DataType::Utf8])
211            }
212            _ => plan_err!(
213                "1st argument should be Utf8 or Binary or Null, got {:?}",
214                arg_types[0]
215            ),
216        }
217    }
218
219    fn documentation(&self) -> Option<&Documentation> {
220        self.doc()
221    }
222}
223
224#[derive(Debug, Copy, Clone)]
225enum Encoding {
226    Base64,
227    Hex,
228}
229
230fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result<ColumnarValue> {
231    match value {
232        ColumnarValue::Array(a) => match a.data_type() {
233            DataType::Utf8 => encoding.encode_utf8_array::<i32>(a.as_ref()),
234            DataType::LargeUtf8 => encoding.encode_utf8_array::<i64>(a.as_ref()),
235            DataType::Utf8View => encoding.encode_utf8_array::<i32>(a.as_ref()),
236            DataType::Binary => encoding.encode_binary_array::<i32>(a.as_ref()),
237            DataType::LargeBinary => encoding.encode_binary_array::<i64>(a.as_ref()),
238            other => exec_err!(
239                "Unsupported data type {other:?} for function encode({encoding})"
240            ),
241        },
242        ColumnarValue::Scalar(scalar) => {
243            match scalar {
244                ScalarValue::Utf8(a) => {
245                    Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes())))
246                }
247                ScalarValue::LargeUtf8(a) => Ok(encoding
248                    .encode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes()))),
249                ScalarValue::Utf8View(a) => {
250                    Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes())))
251                }
252                ScalarValue::Binary(a) => Ok(
253                    encoding.encode_scalar(a.as_ref().map(|v: &Vec<u8>| v.as_slice()))
254                ),
255                ScalarValue::LargeBinary(a) => Ok(encoding
256                    .encode_large_scalar(a.as_ref().map(|v: &Vec<u8>| v.as_slice()))),
257                other => exec_err!(
258                    "Unsupported data type {other:?} for function encode({encoding})"
259                ),
260            }
261        }
262    }
263}
264
265fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result<ColumnarValue> {
266    match value {
267        ColumnarValue::Array(a) => match a.data_type() {
268            DataType::Utf8 => encoding.decode_utf8_array::<i32>(a.as_ref()),
269            DataType::LargeUtf8 => encoding.decode_utf8_array::<i64>(a.as_ref()),
270            DataType::Utf8View => encoding.decode_utf8_array::<i32>(a.as_ref()),
271            DataType::Binary => encoding.decode_binary_array::<i32>(a.as_ref()),
272            DataType::LargeBinary => encoding.decode_binary_array::<i64>(a.as_ref()),
273            other => exec_err!(
274                "Unsupported data type {other:?} for function decode({encoding})"
275            ),
276        },
277        ColumnarValue::Scalar(scalar) => {
278            match scalar {
279                ScalarValue::Utf8(a) => {
280                    encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))
281                }
282                ScalarValue::LargeUtf8(a) => encoding
283                    .decode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes())),
284                ScalarValue::Utf8View(a) => {
285                    encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))
286                }
287                ScalarValue::Binary(a) => {
288                    encoding.decode_scalar(a.as_ref().map(|v: &Vec<u8>| v.as_slice()))
289                }
290                ScalarValue::LargeBinary(a) => encoding
291                    .decode_large_scalar(a.as_ref().map(|v: &Vec<u8>| v.as_slice())),
292                other => exec_err!(
293                    "Unsupported data type {other:?} for function decode({encoding})"
294                ),
295            }
296        }
297    }
298}
299
300fn hex_encode(input: &[u8]) -> String {
301    hex::encode(input)
302}
303
304fn base64_encode(input: &[u8]) -> String {
305    general_purpose::STANDARD_NO_PAD.encode(input)
306}
307
308fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result<usize> {
309    // only write input / 2 bytes to buf
310    let out_len = input.len() / 2;
311    let buf = &mut buf[..out_len];
312    hex::decode_to_slice(input, buf).map_err(|e| {
313        DataFusionError::Internal(format!("Failed to decode from hex: {}", e))
314    })?;
315    Ok(out_len)
316}
317
318fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result<usize> {
319    general_purpose::STANDARD_NO_PAD
320        .decode_slice(input, buf)
321        .map_err(|e| {
322            DataFusionError::Internal(format!("Failed to decode from base64: {}", e))
323        })
324}
325
326macro_rules! encode_to_array {
327    ($METHOD: ident, $INPUT:expr) => {{
328        let utf8_array: StringArray = $INPUT
329            .iter()
330            .map(|x| x.map(|x| $METHOD(x.as_ref())))
331            .collect();
332        Arc::new(utf8_array)
333    }};
334}
335
336fn decode_to_array<F, T: ByteArrayType>(
337    method: F,
338    input: &GenericByteArray<T>,
339    conservative_upper_bound_size: usize,
340) -> Result<ArrayRef>
341where
342    F: Fn(&[u8], &mut [u8]) -> Result<usize>,
343{
344    let mut values = vec![0; conservative_upper_bound_size];
345    let mut offsets = OffsetBufferBuilder::new(input.len());
346    let mut total_bytes_decoded = 0;
347    for v in input {
348        if let Some(v) = v {
349            let cursor = &mut values[total_bytes_decoded..];
350            let decoded = method(v.as_ref(), cursor)?;
351            total_bytes_decoded += decoded;
352            offsets.push_length(decoded);
353        } else {
354            offsets.push_length(0);
355        }
356    }
357    // We reserved an upper bound size for the values buffer, but we only use the actual size
358    values.truncate(total_bytes_decoded);
359    let binary_array = BinaryArray::try_new(
360        offsets.finish(),
361        Buffer::from_vec(values),
362        input.nulls().cloned(),
363    )?;
364    Ok(Arc::new(binary_array))
365}
366
367impl Encoding {
368    fn encode_scalar(self, value: Option<&[u8]>) -> ColumnarValue {
369        ColumnarValue::Scalar(match self {
370            Self::Base64 => ScalarValue::Utf8(
371                value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)),
372            ),
373            Self::Hex => ScalarValue::Utf8(value.map(hex::encode)),
374        })
375    }
376
377    fn encode_large_scalar(self, value: Option<&[u8]>) -> ColumnarValue {
378        ColumnarValue::Scalar(match self {
379            Self::Base64 => ScalarValue::LargeUtf8(
380                value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)),
381            ),
382            Self::Hex => ScalarValue::LargeUtf8(value.map(hex::encode)),
383        })
384    }
385
386    fn encode_binary_array<T>(self, value: &dyn Array) -> Result<ColumnarValue>
387    where
388        T: OffsetSizeTrait,
389    {
390        let input_value = as_generic_binary_array::<T>(value)?;
391        let array: ArrayRef = match self {
392            Self::Base64 => encode_to_array!(base64_encode, input_value),
393            Self::Hex => encode_to_array!(hex_encode, input_value),
394        };
395        Ok(ColumnarValue::Array(array))
396    }
397
398    fn encode_utf8_array<T>(self, value: &dyn Array) -> Result<ColumnarValue>
399    where
400        T: OffsetSizeTrait,
401    {
402        let input_value = as_generic_string_array::<T>(value)?;
403        let array: ArrayRef = match self {
404            Self::Base64 => encode_to_array!(base64_encode, input_value),
405            Self::Hex => encode_to_array!(hex_encode, input_value),
406        };
407        Ok(ColumnarValue::Array(array))
408    }
409
410    fn decode_scalar(self, value: Option<&[u8]>) -> Result<ColumnarValue> {
411        let value = match value {
412            Some(value) => value,
413            None => return Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))),
414        };
415
416        let out = match self {
417            Self::Base64 => {
418                general_purpose::STANDARD_NO_PAD
419                    .decode(value)
420                    .map_err(|e| {
421                        DataFusionError::Internal(format!(
422                            "Failed to decode value using base64: {}",
423                            e
424                        ))
425                    })?
426            }
427            Self::Hex => hex::decode(value).map_err(|e| {
428                DataFusionError::Internal(format!(
429                    "Failed to decode value using hex: {}",
430                    e
431                ))
432            })?,
433        };
434
435        Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(out))))
436    }
437
438    fn decode_large_scalar(self, value: Option<&[u8]>) -> Result<ColumnarValue> {
439        let value = match value {
440            Some(value) => value,
441            None => return Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(None))),
442        };
443
444        let out = match self {
445            Self::Base64 => {
446                general_purpose::STANDARD_NO_PAD
447                    .decode(value)
448                    .map_err(|e| {
449                        DataFusionError::Internal(format!(
450                            "Failed to decode value using base64: {}",
451                            e
452                        ))
453                    })?
454            }
455            Self::Hex => hex::decode(value).map_err(|e| {
456                DataFusionError::Internal(format!(
457                    "Failed to decode value using hex: {}",
458                    e
459                ))
460            })?,
461        };
462
463        Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(out))))
464    }
465
466    fn decode_binary_array<T>(self, value: &dyn Array) -> Result<ColumnarValue>
467    where
468        T: OffsetSizeTrait,
469    {
470        let input_value = as_generic_binary_array::<T>(value)?;
471        let array = self.decode_byte_array(input_value)?;
472        Ok(ColumnarValue::Array(array))
473    }
474
475    fn decode_utf8_array<T>(self, value: &dyn Array) -> Result<ColumnarValue>
476    where
477        T: OffsetSizeTrait,
478    {
479        let input_value = as_generic_string_array::<T>(value)?;
480        let array = self.decode_byte_array(input_value)?;
481        Ok(ColumnarValue::Array(array))
482    }
483
484    fn decode_byte_array<T: ByteArrayType>(
485        &self,
486        input_value: &GenericByteArray<T>,
487    ) -> Result<ArrayRef> {
488        match self {
489            Self::Base64 => {
490                let upper_bound =
491                    base64::decoded_len_estimate(input_value.values().len());
492                decode_to_array(base64_decode, input_value, upper_bound)
493            }
494            Self::Hex => {
495                // Calculate the upper bound for decoded byte size
496                // For hex encoding, each pair of hex characters (2 bytes) represents 1 byte when decoded
497                // So the upper bound is half the length of the input values.
498                let upper_bound = input_value.values().len() / 2;
499                decode_to_array(hex_decode, input_value, upper_bound)
500            }
501        }
502    }
503}
504
505impl fmt::Display for Encoding {
506    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
507        write!(f, "{}", format!("{self:?}").to_lowercase())
508    }
509}
510
511impl FromStr for Encoding {
512    type Err = DataFusionError;
513    fn from_str(name: &str) -> Result<Encoding> {
514        Ok(match name {
515            "base64" => Self::Base64,
516            "hex" => Self::Hex,
517            _ => {
518                let options = [Self::Base64, Self::Hex]
519                    .iter()
520                    .map(|i| i.to_string())
521                    .collect::<Vec<_>>()
522                    .join(", ");
523                return plan_err!(
524                    "There is no built-in encoding named '{name}', currently supported encodings are: {options}"
525                );
526            }
527        })
528    }
529}
530
531/// Encodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`].
532/// Second argument is the encoding to use.
533/// Standard encodings are base64 and hex.
534fn encode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
535    let [expression, format] = take_function_args("encode", args)?;
536
537    let encoding = match format {
538        ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
539            Some(Some(method)) => method.parse::<Encoding>(),
540            _ => not_impl_err!(
541                "Second argument to encode must be non null constant string: Encode using dynamically decided method is not yet supported. Got {scalar:?}"
542            ),
543        },
544        ColumnarValue::Array(_) => not_impl_err!(
545            "Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported"
546        ),
547    }?;
548    encode_process(expression, encoding)
549}
550
551/// Decodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`].
552/// Second argument is the encoding to use.
553/// Standard encodings are base64 and hex.
554fn decode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
555    let [expression, format] = take_function_args("decode", args)?;
556
557    let encoding = match format {
558        ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
559            Some(Some(method))=> method.parse::<Encoding>(),
560            _ => not_impl_err!(
561                "Second argument to decode must be a non null constant string: Decode using dynamically decided method is not yet supported. Got {scalar:?}"
562            ),
563        },
564        ColumnarValue::Array(_) => not_impl_err!(
565            "Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported"
566        ),
567    }?;
568    decode_process(expression, encoding)
569}