datafusion_functions/string/
repeat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_str_type};
22use arrow::array::{
23    ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
24    OffsetSizeTrait, StringArrayType, StringViewArray,
25};
26use arrow::datatypes::DataType;
27use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::types::{logical_int64, logical_string, NativeType};
30use datafusion_common::{exec_err, DataFusionError, Result};
31use datafusion_expr::{ColumnarValue, Documentation, Volatility};
32use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
33use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Returns a string with an input string repeated a specified number.",
39    syntax_example = "repeat(str, n)",
40    sql_example = r#"```sql
41> select repeat('data', 3);
42+-------------------------------+
43| repeat(Utf8("data"),Int64(3)) |
44+-------------------------------+
45| datadatadata                  |
46+-------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    argument(
50        name = "n",
51        description = "Number of times to repeat the input string."
52    )
53)]
54#[derive(Debug)]
55pub struct RepeatFunc {
56    signature: Signature,
57}
58
59impl Default for RepeatFunc {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl RepeatFunc {
66    pub fn new() -> Self {
67        Self {
68            signature: Signature::coercible(
69                vec![
70                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71                    // Accept all integer types but cast them to i64
72                    Coercion::new_implicit(
73                        TypeSignatureClass::Native(logical_int64()),
74                        vec![TypeSignatureClass::Integer],
75                        NativeType::Int64,
76                    ),
77                ],
78                Volatility::Immutable,
79            ),
80        }
81    }
82}
83
84impl ScalarUDFImpl for RepeatFunc {
85    fn as_any(&self) -> &dyn Any {
86        self
87    }
88
89    fn name(&self) -> &str {
90        "repeat"
91    }
92
93    fn signature(&self) -> &Signature {
94        &self.signature
95    }
96
97    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
98        utf8_to_str_type(&arg_types[0], "repeat")
99    }
100
101    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
102        make_scalar_function(repeat, vec![])(&args.args)
103    }
104
105    fn documentation(&self) -> Option<&Documentation> {
106        self.doc()
107    }
108}
109
110/// Repeats string the specified number of times.
111/// repeat('Pg', 4) = 'PgPgPgPg'
112fn repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
113    let number_array = as_int64_array(&args[1])?;
114    match args[0].data_type() {
115        Utf8View => {
116            let string_view_array = args[0].as_string_view();
117            repeat_impl::<i32, &StringViewArray>(
118                string_view_array,
119                number_array,
120                i32::MAX as usize,
121            )
122        }
123        Utf8 => {
124            let string_array = args[0].as_string::<i32>();
125            repeat_impl::<i32, &GenericStringArray<i32>>(
126                string_array,
127                number_array,
128                i32::MAX as usize,
129            )
130        }
131        LargeUtf8 => {
132            let string_array = args[0].as_string::<i64>();
133            repeat_impl::<i64, &GenericStringArray<i64>>(
134                string_array,
135                number_array,
136                i64::MAX as usize,
137            )
138        }
139        other => exec_err!(
140            "Unsupported data type {other:?} for function repeat. \
141        Expected Utf8, Utf8View or LargeUtf8."
142        ),
143    }
144}
145
146fn repeat_impl<'a, T, S>(
147    string_array: S,
148    number_array: &Int64Array,
149    max_str_len: usize,
150) -> Result<ArrayRef>
151where
152    T: OffsetSizeTrait,
153    S: StringArrayType<'a>,
154{
155    let mut total_capacity = 0;
156    string_array.iter().zip(number_array.iter()).try_for_each(
157        |(string, number)| -> Result<(), DataFusionError> {
158            match (string, number) {
159                (Some(string), Some(number)) if number >= 0 => {
160                    let item_capacity = string.len() * number as usize;
161                    if item_capacity > max_str_len {
162                        return exec_err!(
163                            "string size overflow on repeat, max size is {}, but got {}",
164                            max_str_len,
165                            number as usize * string.len()
166                        );
167                    }
168                    total_capacity += item_capacity;
169                }
170                _ => (),
171            }
172            Ok(())
173        },
174    )?;
175
176    let mut builder =
177        GenericStringBuilder::<T>::with_capacity(string_array.len(), total_capacity);
178
179    string_array.iter().zip(number_array.iter()).try_for_each(
180        |(string, number)| -> Result<(), DataFusionError> {
181            match (string, number) {
182                (Some(string), Some(number)) if number >= 0 => {
183                    builder.append_value(string.repeat(number as usize));
184                }
185                (Some(_), Some(_)) => builder.append_value(""),
186                _ => builder.append_null(),
187            }
188            Ok(())
189        },
190    )?;
191    let array = builder.finish();
192
193    Ok(Arc::new(array) as ArrayRef)
194}
195
196#[cfg(test)]
197mod tests {
198    use arrow::array::{Array, StringArray};
199    use arrow::datatypes::DataType::Utf8;
200
201    use datafusion_common::ScalarValue;
202    use datafusion_common::{exec_err, Result};
203    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
204
205    use crate::string::repeat::RepeatFunc;
206    use crate::utils::test::test_function;
207
208    #[test]
209    fn test_functions() -> Result<()> {
210        test_function!(
211            RepeatFunc::new(),
212            vec![
213                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
214                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
215            ],
216            Ok(Some("PgPgPgPg")),
217            &str,
218            Utf8,
219            StringArray
220        );
221        test_function!(
222            RepeatFunc::new(),
223            vec![
224                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
225                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
226            ],
227            Ok(None),
228            &str,
229            Utf8,
230            StringArray
231        );
232        test_function!(
233            RepeatFunc::new(),
234            vec![
235                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
236                ColumnarValue::Scalar(ScalarValue::Int64(None)),
237            ],
238            Ok(None),
239            &str,
240            Utf8,
241            StringArray
242        );
243
244        test_function!(
245            RepeatFunc::new(),
246            vec![
247                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
248                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
249            ],
250            Ok(Some("PgPgPgPg")),
251            &str,
252            Utf8,
253            StringArray
254        );
255        test_function!(
256            RepeatFunc::new(),
257            vec![
258                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
259                ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
260            ],
261            Ok(None),
262            &str,
263            Utf8,
264            StringArray
265        );
266        test_function!(
267            RepeatFunc::new(),
268            vec![
269                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
270                ColumnarValue::Scalar(ScalarValue::Int64(None)),
271            ],
272            Ok(None),
273            &str,
274            Utf8,
275            StringArray
276        );
277        test_function!(
278            RepeatFunc::new(),
279            vec![
280                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
281                ColumnarValue::Scalar(ScalarValue::Int64(Some(1073741824))),
282            ],
283            exec_err!(
284                "string size overflow on repeat, max size is {}, but got {}",
285                i32::MAX,
286                2usize * 1073741824
287            ),
288            &str,
289            Utf8,
290            StringArray
291        );
292
293        Ok(())
294    }
295}