datafusion_functions/unicode/
lpad.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::Write;
20use std::sync::Arc;
21
22use arrow::array::{
23    Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
24    OffsetSizeTrait, StringArrayType, StringViewArray,
25};
26use arrow::datatypes::DataType;
27use unicode_segmentation::UnicodeSegmentation;
28use DataType::{LargeUtf8, Utf8, Utf8View};
29
30use crate::utils::{make_scalar_function, utf8_to_str_type};
31use datafusion_common::cast::as_int64_array;
32use datafusion_common::{exec_err, Result};
33use datafusion_expr::TypeSignature::Exact;
34use datafusion_expr::{
35    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
36};
37use datafusion_macros::user_doc;
38
39#[user_doc(
40    doc_section(label = "String Functions"),
41    description = "Pads the left side of a string with another string to a specified string length.",
42    syntax_example = "lpad(str, n[, padding_str])",
43    sql_example = r#"```sql
44> select lpad('Dolly', 10, 'hello');
45+---------------------------------------------+
46| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) |
47+---------------------------------------------+
48| helloDolly                                  |
49+---------------------------------------------+
50```"#,
51    standard_argument(name = "str", prefix = "String"),
52    argument(name = "n", description = "String length to pad to."),
53    argument(
54        name = "padding_str",
55        description = "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._"
56    ),
57    related_udf(name = "rpad")
58)]
59#[derive(Debug)]
60pub struct LPadFunc {
61    signature: Signature,
62}
63
64impl Default for LPadFunc {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl LPadFunc {
71    pub fn new() -> Self {
72        use DataType::*;
73        Self {
74            signature: Signature::one_of(
75                vec![
76                    Exact(vec![Utf8View, Int64]),
77                    Exact(vec![Utf8View, Int64, Utf8View]),
78                    Exact(vec![Utf8View, Int64, Utf8]),
79                    Exact(vec![Utf8View, Int64, LargeUtf8]),
80                    Exact(vec![Utf8, Int64]),
81                    Exact(vec![Utf8, Int64, Utf8View]),
82                    Exact(vec![Utf8, Int64, Utf8]),
83                    Exact(vec![Utf8, Int64, LargeUtf8]),
84                    Exact(vec![LargeUtf8, Int64]),
85                    Exact(vec![LargeUtf8, Int64, Utf8View]),
86                    Exact(vec![LargeUtf8, Int64, Utf8]),
87                    Exact(vec![LargeUtf8, Int64, LargeUtf8]),
88                ],
89                Volatility::Immutable,
90            ),
91        }
92    }
93}
94
95impl ScalarUDFImpl for LPadFunc {
96    fn as_any(&self) -> &dyn Any {
97        self
98    }
99
100    fn name(&self) -> &str {
101        "lpad"
102    }
103
104    fn signature(&self) -> &Signature {
105        &self.signature
106    }
107
108    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
109        utf8_to_str_type(&arg_types[0], "lpad")
110    }
111
112    fn invoke_with_args(
113        &self,
114        args: datafusion_expr::ScalarFunctionArgs,
115    ) -> Result<ColumnarValue> {
116        let args = &args.args;
117        match args[0].data_type() {
118            Utf8 | Utf8View => make_scalar_function(lpad::<i32>, vec![])(args),
119            LargeUtf8 => make_scalar_function(lpad::<i64>, vec![])(args),
120            other => exec_err!("Unsupported data type {other:?} for function lpad"),
121        }
122    }
123
124    fn documentation(&self) -> Option<&Documentation> {
125        self.doc()
126    }
127}
128
129/// Extends the string to length 'length' by prepending the characters fill (a space by default).
130/// If the string is already longer than length then it is truncated (on the right).
131/// lpad('hi', 5, 'xy') = 'xyxhi'
132pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
133    if args.len() <= 1 || args.len() > 3 {
134        return exec_err!(
135            "lpad was called with {} arguments. It requires at least 2 and at most 3.",
136            args.len()
137        );
138    }
139
140    let length_array = as_int64_array(&args[1])?;
141
142    match (args.len(), args[0].data_type()) {
143        (2, Utf8View) => lpad_impl::<&StringViewArray, &GenericStringArray<i32>, T>(
144            args[0].as_string_view(),
145            length_array,
146            None,
147        ),
148        (2, Utf8 | LargeUtf8) => lpad_impl::<
149            &GenericStringArray<T>,
150            &GenericStringArray<T>,
151            T,
152        >(args[0].as_string::<T>(), length_array, None),
153        (3, Utf8View) => lpad_with_replace::<&StringViewArray, T>(
154            args[0].as_string_view(),
155            length_array,
156            &args[2],
157        ),
158        (3, Utf8 | LargeUtf8) => lpad_with_replace::<&GenericStringArray<T>, T>(
159            args[0].as_string::<T>(),
160            length_array,
161            &args[2],
162        ),
163        (_, _) => unreachable!("lpad"),
164    }
165}
166
167fn lpad_with_replace<'a, V, T: OffsetSizeTrait>(
168    string_array: V,
169    length_array: &Int64Array,
170    fill_array: &'a ArrayRef,
171) -> Result<ArrayRef>
172where
173    V: StringArrayType<'a>,
174{
175    match fill_array.data_type() {
176        Utf8View => lpad_impl::<V, &StringViewArray, T>(
177            string_array,
178            length_array,
179            Some(fill_array.as_string_view()),
180        ),
181        LargeUtf8 => lpad_impl::<V, &GenericStringArray<i64>, T>(
182            string_array,
183            length_array,
184            Some(fill_array.as_string::<i64>()),
185        ),
186        Utf8 => lpad_impl::<V, &GenericStringArray<i32>, T>(
187            string_array,
188            length_array,
189            Some(fill_array.as_string::<i32>()),
190        ),
191        other => {
192            exec_err!("Unsupported data type {other:?} for function lpad")
193        }
194    }
195}
196
197fn lpad_impl<'a, V, V2, T>(
198    string_array: V,
199    length_array: &Int64Array,
200    fill_array: Option<V2>,
201) -> Result<ArrayRef>
202where
203    V: StringArrayType<'a>,
204    V2: StringArrayType<'a>,
205    T: OffsetSizeTrait,
206{
207    let array = if fill_array.is_none() {
208        let mut builder: GenericStringBuilder<T> = GenericStringBuilder::new();
209
210        for (string, length) in string_array.iter().zip(length_array.iter()) {
211            if let (Some(string), Some(length)) = (string, length) {
212                if length > i32::MAX as i64 {
213                    return exec_err!("lpad requested length {length} too large");
214                }
215
216                let length = if length < 0 { 0 } else { length as usize };
217                if length == 0 {
218                    builder.append_value("");
219                    continue;
220                }
221
222                let graphemes = string.graphemes(true).collect::<Vec<&str>>();
223                if length < graphemes.len() {
224                    builder.append_value(graphemes[..length].concat());
225                } else {
226                    builder.write_str(" ".repeat(length - graphemes.len()).as_str())?;
227                    builder.write_str(string)?;
228                    builder.append_value("");
229                }
230            } else {
231                builder.append_null();
232            }
233        }
234
235        builder.finish()
236    } else {
237        let mut builder: GenericStringBuilder<T> = GenericStringBuilder::new();
238
239        for ((string, length), fill) in string_array
240            .iter()
241            .zip(length_array.iter())
242            .zip(fill_array.unwrap().iter())
243        {
244            if let (Some(string), Some(length), Some(fill)) = (string, length, fill) {
245                if length > i32::MAX as i64 {
246                    return exec_err!("lpad requested length {length} too large");
247                }
248
249                let length = if length < 0 { 0 } else { length as usize };
250                if length == 0 {
251                    builder.append_value("");
252                    continue;
253                }
254
255                let graphemes = string.graphemes(true).collect::<Vec<&str>>();
256                let fill_chars = fill.chars().collect::<Vec<char>>();
257
258                if length < graphemes.len() {
259                    builder.append_value(graphemes[..length].concat());
260                } else if fill_chars.is_empty() {
261                    builder.append_value(string);
262                } else {
263                    for l in 0..length - graphemes.len() {
264                        let c = *fill_chars.get(l % fill_chars.len()).unwrap();
265                        builder.write_char(c)?;
266                    }
267                    builder.write_str(string)?;
268                    builder.append_value("");
269                }
270            } else {
271                builder.append_null();
272            }
273        }
274
275        builder.finish()
276    };
277
278    Ok(Arc::new(array) as ArrayRef)
279}
280
281#[cfg(test)]
282mod tests {
283    use crate::unicode::lpad::LPadFunc;
284    use crate::utils::test::test_function;
285
286    use arrow::array::{Array, LargeStringArray, StringArray};
287    use arrow::datatypes::DataType::{LargeUtf8, Utf8};
288
289    use datafusion_common::{Result, ScalarValue};
290    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
291
292    macro_rules! test_lpad {
293        ($INPUT:expr, $LENGTH:expr, $EXPECTED:expr) => {
294            test_function!(
295                LPadFunc::new(),
296                vec![
297                    ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)),
298                    ColumnarValue::Scalar($LENGTH)
299                ],
300                $EXPECTED,
301                &str,
302                Utf8,
303                StringArray
304            );
305
306            test_function!(
307                LPadFunc::new(),
308                vec![
309                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)),
310                    ColumnarValue::Scalar($LENGTH)
311                ],
312                $EXPECTED,
313                &str,
314                LargeUtf8,
315                LargeStringArray
316            );
317
318            test_function!(
319                LPadFunc::new(),
320                vec![
321                    ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)),
322                    ColumnarValue::Scalar($LENGTH)
323                ],
324                $EXPECTED,
325                &str,
326                Utf8,
327                StringArray
328            );
329        };
330
331        ($INPUT:expr, $LENGTH:expr, $REPLACE:expr, $EXPECTED:expr) => {
332            // utf8, utf8
333            test_function!(
334                LPadFunc::new(),
335                vec![
336                    ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)),
337                    ColumnarValue::Scalar($LENGTH),
338                    ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE))
339                ],
340                $EXPECTED,
341                &str,
342                Utf8,
343                StringArray
344            );
345            // utf8, largeutf8
346            test_function!(
347                LPadFunc::new(),
348                vec![
349                    ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)),
350                    ColumnarValue::Scalar($LENGTH),
351                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE))
352                ],
353                $EXPECTED,
354                &str,
355                Utf8,
356                StringArray
357            );
358            // utf8, utf8view
359            test_function!(
360                LPadFunc::new(),
361                vec![
362                    ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)),
363                    ColumnarValue::Scalar($LENGTH),
364                    ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE))
365                ],
366                $EXPECTED,
367                &str,
368                Utf8,
369                StringArray
370            );
371
372            // largeutf8, utf8
373            test_function!(
374                LPadFunc::new(),
375                vec![
376                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)),
377                    ColumnarValue::Scalar($LENGTH),
378                    ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE))
379                ],
380                $EXPECTED,
381                &str,
382                LargeUtf8,
383                LargeStringArray
384            );
385            // largeutf8, largeutf8
386            test_function!(
387                LPadFunc::new(),
388                vec![
389                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)),
390                    ColumnarValue::Scalar($LENGTH),
391                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE))
392                ],
393                $EXPECTED,
394                &str,
395                LargeUtf8,
396                LargeStringArray
397            );
398            // largeutf8, utf8view
399            test_function!(
400                LPadFunc::new(),
401                vec![
402                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)),
403                    ColumnarValue::Scalar($LENGTH),
404                    ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE))
405                ],
406                $EXPECTED,
407                &str,
408                LargeUtf8,
409                LargeStringArray
410            );
411
412            // utf8view, utf8
413            test_function!(
414                LPadFunc::new(),
415                vec![
416                    ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)),
417                    ColumnarValue::Scalar($LENGTH),
418                    ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE))
419                ],
420                $EXPECTED,
421                &str,
422                Utf8,
423                StringArray
424            );
425            // utf8view, largeutf8
426            test_function!(
427                LPadFunc::new(),
428                vec![
429                    ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)),
430                    ColumnarValue::Scalar($LENGTH),
431                    ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE))
432                ],
433                $EXPECTED,
434                &str,
435                Utf8,
436                StringArray
437            );
438            // utf8view, utf8view
439            test_function!(
440                LPadFunc::new(),
441                vec![
442                    ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)),
443                    ColumnarValue::Scalar($LENGTH),
444                    ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE))
445                ],
446                $EXPECTED,
447                &str,
448                Utf8,
449                StringArray
450            );
451        };
452    }
453
454    #[test]
455    fn test_functions() -> Result<()> {
456        test_lpad!(
457            Some("josé".into()),
458            ScalarValue::Int64(Some(5i64)),
459            Ok(Some(" josé"))
460        );
461        test_lpad!(
462            Some("hi".into()),
463            ScalarValue::Int64(Some(5i64)),
464            Ok(Some("   hi"))
465        );
466        test_lpad!(
467            Some("hi".into()),
468            ScalarValue::Int64(Some(0i64)),
469            Ok(Some(""))
470        );
471        test_lpad!(Some("hi".into()), ScalarValue::Int64(None), Ok(None));
472        test_lpad!(None, ScalarValue::Int64(Some(5i64)), Ok(None));
473        test_lpad!(
474            Some("hi".into()),
475            ScalarValue::Int64(Some(5i64)),
476            Some("xy".into()),
477            Ok(Some("xyxhi"))
478        );
479        test_lpad!(
480            Some("hi".into()),
481            ScalarValue::Int64(Some(21i64)),
482            Some("abcdef".into()),
483            Ok(Some("abcdefabcdefabcdefahi"))
484        );
485        test_lpad!(
486            Some("hi".into()),
487            ScalarValue::Int64(Some(5i64)),
488            Some(" ".into()),
489            Ok(Some("   hi"))
490        );
491        test_lpad!(
492            Some("hi".into()),
493            ScalarValue::Int64(Some(5i64)),
494            Some("".into()),
495            Ok(Some("hi"))
496        );
497        test_lpad!(
498            None,
499            ScalarValue::Int64(Some(5i64)),
500            Some("xy".into()),
501            Ok(None)
502        );
503        test_lpad!(
504            Some("hi".into()),
505            ScalarValue::Int64(None),
506            Some("xy".into()),
507            Ok(None)
508        );
509        test_lpad!(
510            Some("hi".into()),
511            ScalarValue::Int64(Some(5i64)),
512            None,
513            Ok(None)
514        );
515        test_lpad!(
516            Some("josé".into()),
517            ScalarValue::Int64(Some(10i64)),
518            Some("xy".into()),
519            Ok(Some("xyxyxyjosé"))
520        );
521        test_lpad!(
522            Some("josé".into()),
523            ScalarValue::Int64(Some(10i64)),
524            Some("éñ".into()),
525            Ok(Some("éñéñéñjosé"))
526        );
527
528        #[cfg(not(feature = "unicode_expressions"))]
529        test_lpad!(Some("josé".into()), ScalarValue::Int64(Some(5i64)), internal_err!(
530                "function lpad requires compilation with feature flag: unicode_expressions."
531        ));
532
533        Ok(())
534    }
535}