datafusion_functions/string/
repeat.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_str_type};
22use arrow::array::{
23 ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
24 OffsetSizeTrait, StringArrayType, StringViewArray,
25};
26use arrow::datatypes::DataType;
27use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::types::{logical_int64, logical_string, NativeType};
30use datafusion_common::{exec_err, DataFusionError, Result};
31use datafusion_expr::{ColumnarValue, Documentation, Volatility};
32use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
33use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Returns a string with an input string repeated a specified number.",
39 syntax_example = "repeat(str, n)",
40 sql_example = r#"```sql
41> select repeat('data', 3);
42+-------------------------------+
43| repeat(Utf8("data"),Int64(3)) |
44+-------------------------------+
45| datadatadata |
46+-------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 argument(
50 name = "n",
51 description = "Number of times to repeat the input string."
52 )
53)]
54#[derive(Debug)]
55pub struct RepeatFunc {
56 signature: Signature,
57}
58
59impl Default for RepeatFunc {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl RepeatFunc {
66 pub fn new() -> Self {
67 Self {
68 signature: Signature::coercible(
69 vec![
70 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71 Coercion::new_implicit(
73 TypeSignatureClass::Native(logical_int64()),
74 vec![TypeSignatureClass::Integer],
75 NativeType::Int64,
76 ),
77 ],
78 Volatility::Immutable,
79 ),
80 }
81 }
82}
83
84impl ScalarUDFImpl for RepeatFunc {
85 fn as_any(&self) -> &dyn Any {
86 self
87 }
88
89 fn name(&self) -> &str {
90 "repeat"
91 }
92
93 fn signature(&self) -> &Signature {
94 &self.signature
95 }
96
97 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
98 utf8_to_str_type(&arg_types[0], "repeat")
99 }
100
101 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
102 make_scalar_function(repeat, vec![])(&args.args)
103 }
104
105 fn documentation(&self) -> Option<&Documentation> {
106 self.doc()
107 }
108}
109
110fn repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
113 let number_array = as_int64_array(&args[1])?;
114 match args[0].data_type() {
115 Utf8View => {
116 let string_view_array = args[0].as_string_view();
117 repeat_impl::<i32, &StringViewArray>(
118 string_view_array,
119 number_array,
120 i32::MAX as usize,
121 )
122 }
123 Utf8 => {
124 let string_array = args[0].as_string::<i32>();
125 repeat_impl::<i32, &GenericStringArray<i32>>(
126 string_array,
127 number_array,
128 i32::MAX as usize,
129 )
130 }
131 LargeUtf8 => {
132 let string_array = args[0].as_string::<i64>();
133 repeat_impl::<i64, &GenericStringArray<i64>>(
134 string_array,
135 number_array,
136 i64::MAX as usize,
137 )
138 }
139 other => exec_err!(
140 "Unsupported data type {other:?} for function repeat. \
141 Expected Utf8, Utf8View or LargeUtf8."
142 ),
143 }
144}
145
146fn repeat_impl<'a, T, S>(
147 string_array: S,
148 number_array: &Int64Array,
149 max_str_len: usize,
150) -> Result<ArrayRef>
151where
152 T: OffsetSizeTrait,
153 S: StringArrayType<'a>,
154{
155 let mut total_capacity = 0;
156 string_array.iter().zip(number_array.iter()).try_for_each(
157 |(string, number)| -> Result<(), DataFusionError> {
158 match (string, number) {
159 (Some(string), Some(number)) if number >= 0 => {
160 let item_capacity = string.len() * number as usize;
161 if item_capacity > max_str_len {
162 return exec_err!(
163 "string size overflow on repeat, max size is {}, but got {}",
164 max_str_len,
165 number as usize * string.len()
166 );
167 }
168 total_capacity += item_capacity;
169 }
170 _ => (),
171 }
172 Ok(())
173 },
174 )?;
175
176 let mut builder =
177 GenericStringBuilder::<T>::with_capacity(string_array.len(), total_capacity);
178
179 string_array.iter().zip(number_array.iter()).try_for_each(
180 |(string, number)| -> Result<(), DataFusionError> {
181 match (string, number) {
182 (Some(string), Some(number)) if number >= 0 => {
183 builder.append_value(string.repeat(number as usize));
184 }
185 (Some(_), Some(_)) => builder.append_value(""),
186 _ => builder.append_null(),
187 }
188 Ok(())
189 },
190 )?;
191 let array = builder.finish();
192
193 Ok(Arc::new(array) as ArrayRef)
194}
195
196#[cfg(test)]
197mod tests {
198 use arrow::array::{Array, StringArray};
199 use arrow::datatypes::DataType::Utf8;
200
201 use datafusion_common::ScalarValue;
202 use datafusion_common::{exec_err, Result};
203 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
204
205 use crate::string::repeat::RepeatFunc;
206 use crate::utils::test::test_function;
207
208 #[test]
209 fn test_functions() -> Result<()> {
210 test_function!(
211 RepeatFunc::new(),
212 vec![
213 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
214 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
215 ],
216 Ok(Some("PgPgPgPg")),
217 &str,
218 Utf8,
219 StringArray
220 );
221 test_function!(
222 RepeatFunc::new(),
223 vec![
224 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
225 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
226 ],
227 Ok(None),
228 &str,
229 Utf8,
230 StringArray
231 );
232 test_function!(
233 RepeatFunc::new(),
234 vec![
235 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
236 ColumnarValue::Scalar(ScalarValue::Int64(None)),
237 ],
238 Ok(None),
239 &str,
240 Utf8,
241 StringArray
242 );
243
244 test_function!(
245 RepeatFunc::new(),
246 vec![
247 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
248 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
249 ],
250 Ok(Some("PgPgPgPg")),
251 &str,
252 Utf8,
253 StringArray
254 );
255 test_function!(
256 RepeatFunc::new(),
257 vec![
258 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
259 ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
260 ],
261 Ok(None),
262 &str,
263 Utf8,
264 StringArray
265 );
266 test_function!(
267 RepeatFunc::new(),
268 vec![
269 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
270 ColumnarValue::Scalar(ScalarValue::Int64(None)),
271 ],
272 Ok(None),
273 &str,
274 Utf8,
275 StringArray
276 );
277 test_function!(
278 RepeatFunc::new(),
279 vec![
280 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
281 ColumnarValue::Scalar(ScalarValue::Int64(Some(1073741824))),
282 ],
283 exec_err!(
284 "string size overflow on repeat, max size is {}, but got {}",
285 i32::MAX,
286 2usize * 1073741824
287 ),
288 &str,
289 Utf8,
290 StringArray
291 );
292
293 Ok(())
294 }
295}