datafusion_functions/unicode/
rpad.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::{make_scalar_function, utf8_to_str_type};
19use arrow::array::{
20    ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
21    OffsetSizeTrait, StringArrayType, StringViewArray,
22};
23use arrow::datatypes::DataType;
24use datafusion_common::cast::as_int64_array;
25use datafusion_common::DataFusionError;
26use datafusion_common::{exec_err, Result};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::{
29    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
30};
31use datafusion_macros::user_doc;
32use std::any::Any;
33use std::fmt::Write;
34use std::sync::Arc;
35use unicode_segmentation::UnicodeSegmentation;
36use DataType::{LargeUtf8, Utf8, Utf8View};
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Pads the right side of a string with another string to a specified string length.",
41    syntax_example = "rpad(str, n[, padding_str])",
42    sql_example = r#"```sql
43>  select rpad('datafusion', 20, '_-');
44+-----------------------------------------------+
45| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) |
46+-----------------------------------------------+
47| datafusion_-_-_-_-_-                          |
48+-----------------------------------------------+
49```"#,
50    standard_argument(name = "str", prefix = "String"),
51    argument(name = "n", description = "String length to pad to."),
52    argument(
53        name = "padding_str",
54        description = "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._"
55    ),
56    related_udf(name = "lpad")
57)]
58#[derive(Debug)]
59pub struct RPadFunc {
60    signature: Signature,
61}
62
63impl Default for RPadFunc {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl RPadFunc {
70    pub fn new() -> Self {
71        use DataType::*;
72        Self {
73            signature: Signature::one_of(
74                vec![
75                    Exact(vec![Utf8View, Int64]),
76                    Exact(vec![Utf8View, Int64, Utf8View]),
77                    Exact(vec![Utf8View, Int64, Utf8]),
78                    Exact(vec![Utf8View, Int64, LargeUtf8]),
79                    Exact(vec![Utf8, Int64]),
80                    Exact(vec![Utf8, Int64, Utf8View]),
81                    Exact(vec![Utf8, Int64, Utf8]),
82                    Exact(vec![Utf8, Int64, LargeUtf8]),
83                    Exact(vec![LargeUtf8, Int64]),
84                    Exact(vec![LargeUtf8, Int64, Utf8View]),
85                    Exact(vec![LargeUtf8, Int64, Utf8]),
86                    Exact(vec![LargeUtf8, Int64, LargeUtf8]),
87                ],
88                Volatility::Immutable,
89            ),
90        }
91    }
92}
93
94impl ScalarUDFImpl for RPadFunc {
95    fn as_any(&self) -> &dyn Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "rpad"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108        utf8_to_str_type(&arg_types[0], "rpad")
109    }
110
111    fn invoke_with_args(
112        &self,
113        args: datafusion_expr::ScalarFunctionArgs,
114    ) -> Result<ColumnarValue> {
115        let args = &args.args;
116        match (
117            args.len(),
118            args[0].data_type(),
119            args.get(2).map(|arg| arg.data_type()),
120        ) {
121            (2, Utf8 | Utf8View, _) => {
122                make_scalar_function(rpad::<i32, i32>, vec![])(args)
123            }
124            (2, LargeUtf8, _) => make_scalar_function(rpad::<i64, i64>, vec![])(args),
125            (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => {
126                make_scalar_function(rpad::<i32, i32>, vec![])(args)
127            }
128            (3, LargeUtf8, Some(LargeUtf8)) => {
129                make_scalar_function(rpad::<i64, i64>, vec![])(args)
130            }
131            (3, Utf8 | Utf8View, Some(LargeUtf8)) => {
132                make_scalar_function(rpad::<i32, i64>, vec![])(args)
133            }
134            (3, LargeUtf8, Some(Utf8 | Utf8View)) => {
135                make_scalar_function(rpad::<i64, i32>, vec![])(args)
136            }
137            (_, _, _) => {
138                exec_err!("Unsupported combination of data types for function rpad")
139            }
140        }
141    }
142
143    fn documentation(&self) -> Option<&Documentation> {
144        self.doc()
145    }
146}
147
148pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
149    args: &[ArrayRef],
150) -> Result<ArrayRef> {
151    if args.len() < 2 || args.len() > 3 {
152        return exec_err!(
153            "rpad was called with {} arguments. It requires 2 or 3 arguments.",
154            args.len()
155        );
156    }
157
158    let length_array = as_int64_array(&args[1])?;
159    match (
160        args.len(),
161        args[0].data_type(),
162        args.get(2).map(|arg| arg.data_type()),
163    ) {
164        (2, Utf8View, _) => {
165            rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
166                args[0].as_string_view(),
167                length_array,
168                None,
169            )
170        }
171        (3, Utf8View, Some(Utf8View)) => {
172            rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
173                args[0].as_string_view(),
174                length_array,
175                Some(args[2].as_string_view()),
176            )
177        }
178        (3, Utf8View, Some(Utf8 | LargeUtf8)) => {
179            rpad_impl::<&StringViewArray, &GenericStringArray<FillArrayLen>, StringArrayLen>(
180                args[0].as_string_view(),
181                length_array,
182                Some(args[2].as_string::<FillArrayLen>()),
183            )
184        }
185        (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::<
186            &GenericStringArray<StringArrayLen>,
187            &StringViewArray,
188            StringArrayLen,
189        >(
190            args[0].as_string::<StringArrayLen>(),
191            length_array,
192            Some(args[2].as_string_view()),
193        ),
194        (_, _, _) => rpad_impl::<
195            &GenericStringArray<StringArrayLen>,
196            &GenericStringArray<FillArrayLen>,
197            StringArrayLen,
198        >(
199            args[0].as_string::<StringArrayLen>(),
200            length_array,
201            args.get(2).map(|arg| arg.as_string::<FillArrayLen>()),
202        ),
203    }
204}
205
206/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
207/// rpad('hi', 5, 'xy') = 'hixyx'
208pub fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>(
209    string_array: StringArrType,
210    length_array: &Int64Array,
211    fill_array: Option<FillArrType>,
212) -> Result<ArrayRef>
213where
214    StringArrType: StringArrayType<'a>,
215    FillArrType: StringArrayType<'a>,
216    StringArrayLen: OffsetSizeTrait,
217{
218    let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
219
220    match fill_array {
221        None => {
222            string_array.iter().zip(length_array.iter()).try_for_each(
223                |(string, length)| -> Result<(), DataFusionError> {
224                    match (string, length) {
225                        (Some(string), Some(length)) => {
226                            if length > i32::MAX as i64 {
227                                return exec_err!(
228                                    "rpad requested length {} too large",
229                                    length
230                                );
231                            }
232                            let length = if length < 0 { 0 } else { length as usize };
233                            if length == 0 {
234                                builder.append_value("");
235                            } else {
236                                let graphemes =
237                                    string.graphemes(true).collect::<Vec<&str>>();
238                                if length < graphemes.len() {
239                                    builder.append_value(graphemes[..length].concat());
240                                } else {
241                                    builder.write_str(string)?;
242                                    builder.write_str(
243                                        &" ".repeat(length - graphemes.len()),
244                                    )?;
245                                    builder.append_value("");
246                                }
247                            }
248                        }
249                        _ => builder.append_null(),
250                    }
251                    Ok(())
252                },
253            )?;
254        }
255        Some(fill_array) => {
256            string_array
257                .iter()
258                .zip(length_array.iter())
259                .zip(fill_array.iter())
260                .try_for_each(
261                    |((string, length), fill)| -> Result<(), DataFusionError> {
262                        match (string, length, fill) {
263                            (Some(string), Some(length), Some(fill)) => {
264                                if length > i32::MAX as i64 {
265                                    return exec_err!(
266                                        "rpad requested length {} too large",
267                                        length
268                                    );
269                                }
270                                let length = if length < 0 { 0 } else { length as usize };
271                                let graphemes =
272                                    string.graphemes(true).collect::<Vec<&str>>();
273
274                                if length < graphemes.len() {
275                                    builder.append_value(graphemes[..length].concat());
276                                } else if fill.is_empty() {
277                                    builder.append_value(string);
278                                } else {
279                                    builder.write_str(string)?;
280                                    fill.chars()
281                                        .cycle()
282                                        .take(length - graphemes.len())
283                                        .for_each(|ch| builder.write_char(ch).unwrap());
284                                    builder.append_value("");
285                                }
286                            }
287                            _ => builder.append_null(),
288                        }
289                        Ok(())
290                    },
291                )?;
292        }
293    }
294
295    Ok(Arc::new(builder.finish()) as ArrayRef)
296}
297
298#[cfg(test)]
299mod tests {
300    use arrow::array::{Array, StringArray};
301    use arrow::datatypes::DataType::Utf8;
302
303    use datafusion_common::{Result, ScalarValue};
304    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
305
306    use crate::unicode::rpad::RPadFunc;
307    use crate::utils::test::test_function;
308
309    #[test]
310    fn test_functions() -> Result<()> {
311        test_function!(
312            RPadFunc::new(),
313            vec![
314                ColumnarValue::Scalar(ScalarValue::from("josé")),
315                ColumnarValue::Scalar(ScalarValue::from(5i64)),
316            ],
317            Ok(Some("josé ")),
318            &str,
319            Utf8,
320            StringArray
321        );
322        test_function!(
323            RPadFunc::new(),
324            vec![
325                ColumnarValue::Scalar(ScalarValue::from("hi")),
326                ColumnarValue::Scalar(ScalarValue::from(5i64)),
327            ],
328            Ok(Some("hi   ")),
329            &str,
330            Utf8,
331            StringArray
332        );
333        test_function!(
334            RPadFunc::new(),
335            vec![
336                ColumnarValue::Scalar(ScalarValue::from("hi")),
337                ColumnarValue::Scalar(ScalarValue::from(0i64)),
338            ],
339            Ok(Some("")),
340            &str,
341            Utf8,
342            StringArray
343        );
344        test_function!(
345            RPadFunc::new(),
346            vec![
347                ColumnarValue::Scalar(ScalarValue::from("hi")),
348                ColumnarValue::Scalar(ScalarValue::Int64(None)),
349            ],
350            Ok(None),
351            &str,
352            Utf8,
353            StringArray
354        );
355        test_function!(
356            RPadFunc::new(),
357            vec![
358                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
359                ColumnarValue::Scalar(ScalarValue::from(5i64)),
360            ],
361            Ok(None),
362            &str,
363            Utf8,
364            StringArray
365        );
366        test_function!(
367            RPadFunc::new(),
368            vec![
369                ColumnarValue::Scalar(ScalarValue::from("hi")),
370                ColumnarValue::Scalar(ScalarValue::from(5i64)),
371                ColumnarValue::Scalar(ScalarValue::from("xy")),
372            ],
373            Ok(Some("hixyx")),
374            &str,
375            Utf8,
376            StringArray
377        );
378        test_function!(
379            RPadFunc::new(),
380            vec![
381                ColumnarValue::Scalar(ScalarValue::from("hi")),
382                ColumnarValue::Scalar(ScalarValue::from(21i64)),
383                ColumnarValue::Scalar(ScalarValue::from("abcdef")),
384            ],
385            Ok(Some("hiabcdefabcdefabcdefa")),
386            &str,
387            Utf8,
388            StringArray
389        );
390        test_function!(
391            RPadFunc::new(),
392            vec![
393                ColumnarValue::Scalar(ScalarValue::from("hi")),
394                ColumnarValue::Scalar(ScalarValue::from(5i64)),
395                ColumnarValue::Scalar(ScalarValue::from(" ")),
396            ],
397            Ok(Some("hi   ")),
398            &str,
399            Utf8,
400            StringArray
401        );
402        test_function!(
403            RPadFunc::new(),
404            vec![
405                ColumnarValue::Scalar(ScalarValue::from("hi")),
406                ColumnarValue::Scalar(ScalarValue::from(5i64)),
407                ColumnarValue::Scalar(ScalarValue::from("")),
408            ],
409            Ok(Some("hi")),
410            &str,
411            Utf8,
412            StringArray
413        );
414        test_function!(
415            RPadFunc::new(),
416            vec![
417                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
418                ColumnarValue::Scalar(ScalarValue::from(5i64)),
419                ColumnarValue::Scalar(ScalarValue::from("xy")),
420            ],
421            Ok(None),
422            &str,
423            Utf8,
424            StringArray
425        );
426        test_function!(
427            RPadFunc::new(),
428            vec![
429                ColumnarValue::Scalar(ScalarValue::from("hi")),
430                ColumnarValue::Scalar(ScalarValue::Int64(None)),
431                ColumnarValue::Scalar(ScalarValue::from("xy")),
432            ],
433            Ok(None),
434            &str,
435            Utf8,
436            StringArray
437        );
438        test_function!(
439            RPadFunc::new(),
440            vec![
441                ColumnarValue::Scalar(ScalarValue::from("hi")),
442                ColumnarValue::Scalar(ScalarValue::from(5i64)),
443                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
444            ],
445            Ok(None),
446            &str,
447            Utf8,
448            StringArray
449        );
450        test_function!(
451            RPadFunc::new(),
452            vec![
453                ColumnarValue::Scalar(ScalarValue::from("josé")),
454                ColumnarValue::Scalar(ScalarValue::from(10i64)),
455                ColumnarValue::Scalar(ScalarValue::from("xy")),
456            ],
457            Ok(Some("joséxyxyxy")),
458            &str,
459            Utf8,
460            StringArray
461        );
462        test_function!(
463            RPadFunc::new(),
464            vec![
465                ColumnarValue::Scalar(ScalarValue::from("josé")),
466                ColumnarValue::Scalar(ScalarValue::from(10i64)),
467                ColumnarValue::Scalar(ScalarValue::from("éñ")),
468            ],
469            Ok(Some("josééñéñéñ")),
470            &str,
471            Utf8,
472            StringArray
473        );
474        #[cfg(not(feature = "unicode_expressions"))]
475        test_function!(
476            RPadFunc::new(),
477            &[
478                ColumnarValue::Scalar(ScalarValue::from("josé")),
479                ColumnarValue::Scalar(ScalarValue::from(5i64)),
480            ],
481            internal_err!(
482                "function rpad requires compilation with feature flag: unicode_expressions."
483            ),
484            &str,
485            Utf8,
486            StringArray
487        );
488
489        Ok(())
490    }
491}