datafusion_functions/string/
common.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Common utilities for implementing string functions
19
20use std::fmt::{Display, Formatter};
21use std::sync::Arc;
22
23use crate::strings::make_and_append_view;
24use arrow::array::{
25    new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder,
26    NullBufferBuilder, OffsetSizeTrait, StringBuilder, StringViewArray,
27};
28use arrow::buffer::{Buffer, ScalarBuffer};
29use arrow::datatypes::DataType;
30use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
31use datafusion_common::Result;
32use datafusion_common::{exec_err, ScalarValue};
33use datafusion_expr::ColumnarValue;
34
35pub(crate) enum TrimType {
36    Left,
37    Right,
38    Both,
39}
40
41impl Display for TrimType {
42    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
43        match self {
44            TrimType::Left => write!(f, "ltrim"),
45            TrimType::Right => write!(f, "rtrim"),
46            TrimType::Both => write!(f, "btrim"),
47        }
48    }
49}
50
51pub(crate) fn general_trim<T: OffsetSizeTrait>(
52    args: &[ArrayRef],
53    trim_type: TrimType,
54    use_string_view: bool,
55) -> Result<ArrayRef> {
56    let func = match trim_type {
57        TrimType::Left => |input, pattern: &str| {
58            let pattern = pattern.chars().collect::<Vec<char>>();
59            let ltrimmed_str =
60                str::trim_start_matches::<&[char]>(input, pattern.as_ref());
61            // `ltrimmed_str` is actually `input`[start_offset..],
62            // so `start_offset` = len(`input`) - len(`ltrimmed_str`)
63            let start_offset = input.len() - ltrimmed_str.len();
64
65            (ltrimmed_str, start_offset as u32)
66        },
67        TrimType::Right => |input, pattern: &str| {
68            let pattern = pattern.chars().collect::<Vec<char>>();
69            let rtrimmed_str = str::trim_end_matches::<&[char]>(input, pattern.as_ref());
70
71            // `ltrimmed_str` is actually `input`[0..new_len], so `start_offset` is 0
72            (rtrimmed_str, 0)
73        },
74        TrimType::Both => |input, pattern: &str| {
75            let pattern = pattern.chars().collect::<Vec<char>>();
76            let ltrimmed_str =
77                str::trim_start_matches::<&[char]>(input, pattern.as_ref());
78            // `btrimmed_str` can be got by rtrim(ltrim(`input`)),
79            // so its `start_offset` should be same as ltrim situation above
80            let start_offset = input.len() - ltrimmed_str.len();
81            let btrimmed_str =
82                str::trim_end_matches::<&[char]>(ltrimmed_str, pattern.as_ref());
83
84            (btrimmed_str, start_offset as u32)
85        },
86    };
87
88    if use_string_view {
89        string_view_trim(func, args)
90    } else {
91        string_trim::<T>(func, args)
92    }
93}
94
95/// Applies the trim function to the given string view array(s)
96/// and returns a new string view array with the trimmed values.
97///
98/// # `trim_func`: The function to apply to each string view.
99///
100/// ## Arguments
101/// - The original string
102/// - the pattern to trim
103///
104/// ## Returns
105///  - trimmed str (must be a substring of the first argument)
106///  - start offset, needed in `string_view_trim`
107///
108/// ## Examples
109///
110/// For `ltrim`:
111/// - `fn("  abc", " ") -> ("abc", 2)`
112/// - `fn("abd", " ") -> ("abd", 0)`
113///
114/// For `btrim`:
115/// - `fn("  abc  ", " ") -> ("abc", 2)`
116/// - `fn("abd", " ") -> ("abd", 0)`
117// removing 'a will cause compiler complaining lifetime of `func`
118fn string_view_trim<'a>(
119    trim_func: fn(&'a str, &'a str) -> (&'a str, u32),
120    args: &'a [ArrayRef],
121) -> Result<ArrayRef> {
122    let string_view_array = as_string_view_array(&args[0])?;
123    let mut views_buf = Vec::with_capacity(string_view_array.len());
124    let mut null_builder = NullBufferBuilder::new(string_view_array.len());
125
126    match args.len() {
127        1 => {
128            let array_iter = string_view_array.iter();
129            let views_iter = string_view_array.views().iter();
130            for (src_str_opt, raw_view) in array_iter.zip(views_iter) {
131                trim_and_append_str(
132                    src_str_opt,
133                    Some(" "),
134                    trim_func,
135                    &mut views_buf,
136                    &mut null_builder,
137                    raw_view,
138                );
139            }
140        }
141        2 => {
142            let characters_array = as_string_view_array(&args[1])?;
143
144            if characters_array.len() == 1 {
145                // Only one `trim characters` exist
146                if characters_array.is_null(0) {
147                    return Ok(new_null_array(
148                        // The schema is expecting utf8 as null
149                        &DataType::Utf8View,
150                        string_view_array.len(),
151                    ));
152                }
153
154                let characters = characters_array.value(0);
155                let array_iter = string_view_array.iter();
156                let views_iter = string_view_array.views().iter();
157                for (src_str_opt, raw_view) in array_iter.zip(views_iter) {
158                    trim_and_append_str(
159                        src_str_opt,
160                        Some(characters),
161                        trim_func,
162                        &mut views_buf,
163                        &mut null_builder,
164                        raw_view,
165                    );
166                }
167            } else {
168                // A specific `trim characters` for a row in the string view array
169                let characters_iter = characters_array.iter();
170                let array_iter = string_view_array.iter();
171                let views_iter = string_view_array.views().iter();
172                for ((src_str_opt, raw_view), characters_opt) in
173                    array_iter.zip(views_iter).zip(characters_iter)
174                {
175                    trim_and_append_str(
176                        src_str_opt,
177                        characters_opt,
178                        trim_func,
179                        &mut views_buf,
180                        &mut null_builder,
181                        raw_view,
182                    );
183                }
184            }
185        }
186        other => {
187            return exec_err!(
188            "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
189            );
190        }
191    }
192
193    let views_buf = ScalarBuffer::from(views_buf);
194    let nulls_buf = null_builder.finish();
195
196    // Safety:
197    // (1) The blocks of the given views are all provided
198    // (2) Each of the range `view.offset+start..end` of view in views_buf is within
199    // the bounds of each of the blocks
200    unsafe {
201        let array = StringViewArray::new_unchecked(
202            views_buf,
203            string_view_array.data_buffers().to_vec(),
204            nulls_buf,
205        );
206        Ok(Arc::new(array) as ArrayRef)
207    }
208}
209
210/// Trims the given string and appends the trimmed string to the views buffer
211/// and the null buffer.
212///
213/// Calls `trim_func` on the string value in `original_view`, for non_null
214/// values and appends the updated view to the views buffer / null_builder.
215///
216/// Arguments
217/// - `src_str_opt`: The original string value (represented by the view)
218/// - `trim_characters_opt`: The characters to trim from the string
219/// - `trim_func`: The function to apply to the string (see [`string_view_trim`] for details)
220/// - `views_buf`: The buffer to append the updated views to
221/// - `null_builder`: The buffer to append the null values to
222/// - `original_view`: The original view value (that contains src_str_opt)
223fn trim_and_append_str<'a>(
224    src_str_opt: Option<&'a str>,
225    trim_characters_opt: Option<&'a str>,
226    trim_func: fn(&'a str, &'a str) -> (&'a str, u32),
227    views_buf: &mut Vec<u128>,
228    null_builder: &mut NullBufferBuilder,
229    original_view: &u128,
230) {
231    if let (Some(src_str), Some(characters)) = (src_str_opt, trim_characters_opt) {
232        let (trim_str, start_offset) = trim_func(src_str, characters);
233        make_and_append_view(
234            views_buf,
235            null_builder,
236            original_view,
237            trim_str,
238            start_offset,
239        );
240    } else {
241        null_builder.append_null();
242        views_buf.push(0);
243    }
244}
245
246/// Applies the trim function to the given string array(s)
247/// and returns a new string array with the trimmed values.
248///
249/// See [`string_view_trim`] for details on `func`
250fn string_trim<'a, T: OffsetSizeTrait>(
251    func: fn(&'a str, &'a str) -> (&'a str, u32),
252    args: &'a [ArrayRef],
253) -> Result<ArrayRef> {
254    let string_array = as_generic_string_array::<T>(&args[0])?;
255
256    match args.len() {
257        1 => {
258            let result = string_array
259                .iter()
260                .map(|string| string.map(|string: &str| func(string, " ").0))
261                .collect::<GenericStringArray<T>>();
262
263            Ok(Arc::new(result) as ArrayRef)
264        }
265        2 => {
266            let characters_array = as_generic_string_array::<T>(&args[1])?;
267
268            if characters_array.len() == 1 {
269                if characters_array.is_null(0) {
270                    return Ok(new_null_array(
271                        string_array.data_type(),
272                        string_array.len(),
273                    ));
274                }
275
276                let characters = characters_array.value(0);
277                let result = string_array
278                    .iter()
279                    .map(|item| item.map(|string| func(string, characters).0))
280                    .collect::<GenericStringArray<T>>();
281                return Ok(Arc::new(result) as ArrayRef);
282            }
283
284            let result = string_array
285                .iter()
286                .zip(characters_array.iter())
287                .map(|(string, characters)| match (string, characters) {
288                    (Some(string), Some(characters)) => Some(func(string, characters).0),
289                    _ => None,
290                })
291                .collect::<GenericStringArray<T>>();
292
293            Ok(Arc::new(result) as ArrayRef)
294        }
295        other => {
296            exec_err!(
297            "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
298            )
299        }
300    }
301}
302
303pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
304    case_conversion(args, |string| string.to_lowercase(), name)
305}
306
307pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
308    case_conversion(args, |string| string.to_uppercase(), name)
309}
310
311fn case_conversion<'a, F>(
312    args: &'a [ColumnarValue],
313    op: F,
314    name: &str,
315) -> Result<ColumnarValue>
316where
317    F: Fn(&'a str) -> String,
318{
319    match &args[0] {
320        ColumnarValue::Array(array) => match array.data_type() {
321            DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
322                array, op,
323            )?)),
324            DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
325                i64,
326                _,
327            >(array, op)?)),
328            DataType::Utf8View => {
329                let string_array = as_string_view_array(array)?;
330                let mut string_builder = StringBuilder::with_capacity(
331                    string_array.len(),
332                    string_array.get_array_memory_size(),
333                );
334
335                for str in string_array.iter() {
336                    if let Some(str) = str {
337                        string_builder.append_value(op(str));
338                    } else {
339                        string_builder.append_null();
340                    }
341                }
342
343                Ok(ColumnarValue::Array(Arc::new(string_builder.finish())))
344            }
345            other => exec_err!("Unsupported data type {other:?} for function {name}"),
346        },
347        ColumnarValue::Scalar(scalar) => match scalar {
348            ScalarValue::Utf8(a) => {
349                let result = a.as_ref().map(|x| op(x));
350                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
351            }
352            ScalarValue::LargeUtf8(a) => {
353                let result = a.as_ref().map(|x| op(x));
354                Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
355            }
356            ScalarValue::Utf8View(a) => {
357                let result = a.as_ref().map(|x| op(x));
358                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
359            }
360            other => exec_err!("Unsupported data type {other:?} for function {name}"),
361        },
362    }
363}
364
365fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
366where
367    O: OffsetSizeTrait,
368    F: Fn(&'a str) -> String,
369{
370    const PRE_ALLOC_BYTES: usize = 8;
371
372    let string_array = as_generic_string_array::<O>(array)?;
373    let value_data = string_array.value_data();
374
375    // All values are ASCII.
376    if value_data.is_ascii() {
377        return case_conversion_ascii_array::<O, _>(string_array, op);
378    }
379
380    // Values contain non-ASCII.
381    let item_len = string_array.len();
382    let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
383    let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);
384
385    if string_array.null_count() == 0 {
386        let iter =
387            (0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) })));
388        builder.extend(iter);
389    } else {
390        let iter = string_array.iter().map(|string| string.map(&op));
391        builder.extend(iter);
392    }
393    Ok(Arc::new(builder.finish()))
394}
395
396/// All values of string_array are ASCII, and when converting case, there is no changes in the byte
397/// array length. Therefore, the StringArray can be treated as a complete ASCII string for
398/// case conversion, and we can reuse the offsets buffer and the nulls buffer.
399fn case_conversion_ascii_array<'a, O, F>(
400    string_array: &'a GenericStringArray<O>,
401    op: F,
402) -> Result<ArrayRef>
403where
404    O: OffsetSizeTrait,
405    F: Fn(&'a str) -> String,
406{
407    let value_data = string_array.value_data();
408    // SAFETY: all items stored in value_data satisfy UTF8.
409    // ref: impl ByteArrayNativeType for str {...}
410    let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };
411
412    // conversion
413    let converted_values = op(str_values);
414    assert_eq!(converted_values.len(), str_values.len());
415    let bytes = converted_values.into_bytes();
416
417    // build result
418    let values = Buffer::from_vec(bytes);
419    let offsets = string_array.offsets().clone();
420    let nulls = string_array.nulls().cloned();
421    // SAFETY: offsets and nulls are consistent with the input array.
422    Ok(Arc::new(unsafe {
423        GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
424    }))
425}