datafusion_functions/string/
concat.rs1use arrow::array::{as_largestring_array, Array};
19use arrow::datatypes::DataType;
20use datafusion_expr::sort_properties::ExprProperties;
21use std::any::Any;
22use std::sync::Arc;
23
24use crate::string::concat;
25use crate::strings::{
26 ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
27};
28use datafusion_common::cast::{as_string_array, as_string_view_array};
29use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Concatenates multiple strings together.",
39 syntax_example = "concat(str[, ..., str_n])",
40 sql_example = r#"```sql
41> select concat('data', 'f', 'us', 'ion');
42+-------------------------------------------------------+
43| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) |
44+-------------------------------------------------------+
45| datafusion |
46+-------------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 argument(
50 name = "str_n",
51 description = "Subsequent string expressions to concatenate."
52 ),
53 related_udf(name = "concat_ws")
54)]
55#[derive(Debug)]
56pub struct ConcatFunc {
57 signature: Signature,
58}
59
60impl Default for ConcatFunc {
61 fn default() -> Self {
62 ConcatFunc::new()
63 }
64}
65
66impl ConcatFunc {
67 pub fn new() -> Self {
68 use DataType::*;
69 Self {
70 signature: Signature::variadic(
71 vec![Utf8View, Utf8, LargeUtf8],
72 Volatility::Immutable,
73 ),
74 }
75 }
76}
77
78impl ScalarUDFImpl for ConcatFunc {
79 fn as_any(&self) -> &dyn Any {
80 self
81 }
82
83 fn name(&self) -> &str {
84 "concat"
85 }
86
87 fn signature(&self) -> &Signature {
88 &self.signature
89 }
90
91 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
92 use DataType::*;
93 let mut dt = &Utf8;
94 arg_types.iter().for_each(|data_type| {
95 if data_type == &Utf8View {
96 dt = data_type;
97 }
98 if data_type == &LargeUtf8 && dt != &Utf8View {
99 dt = data_type;
100 }
101 });
102
103 Ok(dt.to_owned())
104 }
105
106 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109 let ScalarFunctionArgs { args, .. } = args;
110
111 let mut return_datatype = DataType::Utf8;
112 args.iter().for_each(|col| {
113 if col.data_type() == DataType::Utf8View {
114 return_datatype = col.data_type();
115 }
116 if col.data_type() == DataType::LargeUtf8
117 && return_datatype != DataType::Utf8View
118 {
119 return_datatype = col.data_type();
120 }
121 });
122
123 let array_len = args
124 .iter()
125 .filter_map(|x| match x {
126 ColumnarValue::Array(array) => Some(array.len()),
127 _ => None,
128 })
129 .next();
130
131 if array_len.is_none() {
133 let mut result = String::new();
134 for arg in args {
135 let ColumnarValue::Scalar(scalar) = arg else {
136 return internal_err!("concat expected scalar value, got {arg:?}");
137 };
138
139 match scalar.try_as_str() {
140 Some(Some(v)) => result.push_str(v),
141 Some(None) => {} None => plan_err!(
143 "Concat function does not support scalar type {:?}",
144 scalar
145 )?,
146 }
147 }
148
149 return match return_datatype {
150 DataType::Utf8View => {
151 Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
152 }
153 DataType::Utf8 => {
154 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
155 }
156 DataType::LargeUtf8 => {
157 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
158 }
159 other => {
160 plan_err!("Concat function does not support datatype of {other}")
161 }
162 };
163 }
164
165 let len = array_len.unwrap();
167 let mut data_size = 0;
168 let mut columns = Vec::with_capacity(args.len());
169
170 for arg in &args {
171 match arg {
172 ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
173 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
174 | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
175 if let Some(s) = maybe_value {
176 data_size += s.len() * len;
177 columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
178 }
179 }
180 ColumnarValue::Array(array) => {
181 match array.data_type() {
182 DataType::Utf8 => {
183 let string_array = as_string_array(array)?;
184
185 data_size += string_array.values().len();
186 let column = if array.is_nullable() {
187 ColumnarValueRef::NullableArray(string_array)
188 } else {
189 ColumnarValueRef::NonNullableArray(string_array)
190 };
191 columns.push(column);
192 },
193 DataType::LargeUtf8 => {
194 let string_array = as_largestring_array(array);
195
196 data_size += string_array.values().len();
197 let column = if array.is_nullable() {
198 ColumnarValueRef::NullableLargeStringArray(string_array)
199 } else {
200 ColumnarValueRef::NonNullableLargeStringArray(string_array)
201 };
202 columns.push(column);
203 },
204 DataType::Utf8View => {
205 let string_array = as_string_view_array(array)?;
206
207 data_size += string_array.len();
208 let column = if array.is_nullable() {
209 ColumnarValueRef::NullableStringViewArray(string_array)
210 } else {
211 ColumnarValueRef::NonNullableStringViewArray(string_array)
212 };
213 columns.push(column);
214 },
215 other => {
216 return plan_err!("Input was {other} which is not a supported datatype for concat function")
217 }
218 };
219 }
220 _ => unreachable!("concat"),
221 }
222 }
223
224 match return_datatype {
225 DataType::Utf8 => {
226 let mut builder = StringArrayBuilder::with_capacity(len, data_size);
227 for i in 0..len {
228 columns
229 .iter()
230 .for_each(|column| builder.write::<true>(column, i));
231 builder.append_offset();
232 }
233
234 let string_array = builder.finish(None);
235 Ok(ColumnarValue::Array(Arc::new(string_array)))
236 }
237 DataType::Utf8View => {
238 let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
239 for i in 0..len {
240 columns
241 .iter()
242 .for_each(|column| builder.write::<true>(column, i));
243 builder.append_offset();
244 }
245
246 let string_array = builder.finish();
247 Ok(ColumnarValue::Array(Arc::new(string_array)))
248 }
249 DataType::LargeUtf8 => {
250 let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
251 for i in 0..len {
252 columns
253 .iter()
254 .for_each(|column| builder.write::<true>(column, i));
255 builder.append_offset();
256 }
257
258 let string_array = builder.finish(None);
259 Ok(ColumnarValue::Array(Arc::new(string_array)))
260 }
261 _ => unreachable!(),
262 }
263 }
264
265 fn simplify(
274 &self,
275 args: Vec<Expr>,
276 _info: &dyn SimplifyInfo,
277 ) -> Result<ExprSimplifyResult> {
278 simplify_concat(args)
279 }
280
281 fn documentation(&self) -> Option<&Documentation> {
282 self.doc()
283 }
284
285 fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
286 Ok(true)
287 }
288}
289
290pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
291 let mut new_args = Vec::with_capacity(args.len());
292 let mut contiguous_scalar = "".to_string();
293
294 let return_type = {
295 let data_types: Vec<_> = args
296 .iter()
297 .filter_map(|expr| match expr {
298 Expr::Literal(l) => Some(l.data_type()),
299 _ => None,
300 })
301 .collect();
302 ConcatFunc::new().return_type(&data_types)
303 }?;
304
305 for arg in args.clone() {
306 match arg {
307 Expr::Literal(ScalarValue::Utf8(None)) => {}
308 Expr::Literal(ScalarValue::LargeUtf8(None)) => {
309 }
310 Expr::Literal(ScalarValue::Utf8View(None)) => { }
311
312 Expr::Literal(ScalarValue::Utf8(Some(v))) => {
316 contiguous_scalar += &v;
317 }
318 Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => {
319 contiguous_scalar += &v;
320 }
321 Expr::Literal(ScalarValue::Utf8View(Some(v))) => {
322 contiguous_scalar += &v;
323 }
324
325 Expr::Literal(x) => {
326 return internal_err!(
327 "The scalar {x} should be casted to string type during the type coercion."
328 )
329 }
330 arg => {
334 if !contiguous_scalar.is_empty() {
335 match return_type {
336 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
337 DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
338 DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
339 _ => unreachable!(),
340 }
341 contiguous_scalar = "".to_string();
342 }
343 new_args.push(arg);
344 }
345 }
346 }
347
348 if !contiguous_scalar.is_empty() {
349 match return_type {
350 DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
351 DataType::LargeUtf8 => {
352 new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
353 }
354 DataType::Utf8View => {
355 new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
356 }
357 _ => unreachable!(),
358 }
359 }
360
361 if !args.eq(&new_args) {
362 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
363 ScalarFunction {
364 func: concat(),
365 args: new_args,
366 },
367 )))
368 } else {
369 Ok(ExprSimplifyResult::Original(args))
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::utils::test::test_function;
377 use arrow::array::{Array, LargeStringArray, StringViewArray};
378 use arrow::array::{ArrayRef, StringArray};
379 use DataType::*;
380
381 #[test]
382 fn test_functions() -> Result<()> {
383 test_function!(
384 ConcatFunc::new(),
385 vec![
386 ColumnarValue::Scalar(ScalarValue::from("aa")),
387 ColumnarValue::Scalar(ScalarValue::from("bb")),
388 ColumnarValue::Scalar(ScalarValue::from("cc")),
389 ],
390 Ok(Some("aabbcc")),
391 &str,
392 Utf8,
393 StringArray
394 );
395 test_function!(
396 ConcatFunc::new(),
397 vec![
398 ColumnarValue::Scalar(ScalarValue::from("aa")),
399 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
400 ColumnarValue::Scalar(ScalarValue::from("cc")),
401 ],
402 Ok(Some("aacc")),
403 &str,
404 Utf8,
405 StringArray
406 );
407 test_function!(
408 ConcatFunc::new(),
409 vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
410 Ok(Some("")),
411 &str,
412 Utf8,
413 StringArray
414 );
415 test_function!(
416 ConcatFunc::new(),
417 vec![
418 ColumnarValue::Scalar(ScalarValue::from("aa")),
419 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
420 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
421 ColumnarValue::Scalar(ScalarValue::from("cc")),
422 ],
423 Ok(Some("aacc")),
424 &str,
425 Utf8View,
426 StringViewArray
427 );
428 test_function!(
429 ConcatFunc::new(),
430 vec![
431 ColumnarValue::Scalar(ScalarValue::from("aa")),
432 ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
433 ColumnarValue::Scalar(ScalarValue::from("cc")),
434 ],
435 Ok(Some("aacc")),
436 &str,
437 LargeUtf8,
438 LargeStringArray
439 );
440 test_function!(
441 ConcatFunc::new(),
442 vec![
443 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
444 ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
445 ],
446 Ok(Some("aacc")),
447 &str,
448 Utf8View,
449 StringViewArray
450 );
451
452 Ok(())
453 }
454
455 #[test]
456 fn concat() -> Result<()> {
457 let c0 =
458 ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
459 let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
460 let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
461 Some("x"),
462 None,
463 Some("z"),
464 ])));
465 let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
466 let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
467 Some("a"),
468 None,
469 Some("b"),
470 ])));
471
472 let args = ScalarFunctionArgs {
473 args: vec![c0, c1, c2, c3, c4],
474 number_rows: 3,
475 return_type: &Utf8,
476 };
477
478 let result = ConcatFunc::new().invoke_with_args(args)?;
479 let expected =
480 Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
481 as ArrayRef;
482 match &result {
483 ColumnarValue::Array(array) => {
484 assert_eq!(&expected, array);
485 }
486 _ => panic!(),
487 }
488 Ok(())
489 }
490}