1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 new_null_array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
23 OffsetSizeTrait, PrimitiveArray,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26
27use crate::utils::utf8_to_int_type;
28use datafusion_common::{
29 exec_err, internal_err, utils::take_function_args, Result, ScalarValue,
30};
31use datafusion_expr::TypeSignature::Exact;
32use datafusion_expr::{
33 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34 Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
41 syntax_example = "find_in_set(str, strlist)",
42 sql_example = r#"```sql
43> select find_in_set('b', 'a,b,c,d');
44+----------------------------------------+
45| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
46+----------------------------------------+
47| 2 |
48+----------------------------------------+
49```"#,
50 argument(name = "str", description = "String expression to find in strlist."),
51 argument(
52 name = "strlist",
53 description = "A string list is a string composed of substrings separated by , characters."
54 )
55)]
56#[derive(Debug)]
57pub struct FindInSetFunc {
58 signature: Signature,
59}
60
61impl Default for FindInSetFunc {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl FindInSetFunc {
68 pub fn new() -> Self {
69 use DataType::*;
70 Self {
71 signature: Signature::one_of(
72 vec![
73 Exact(vec![Utf8View, Utf8View]),
74 Exact(vec![Utf8, Utf8]),
75 Exact(vec![LargeUtf8, LargeUtf8]),
76 ],
77 Volatility::Immutable,
78 ),
79 }
80 }
81}
82
83impl ScalarUDFImpl for FindInSetFunc {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 fn name(&self) -> &str {
89 "find_in_set"
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95
96 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97 utf8_to_int_type(&arg_types[0], "find_in_set")
98 }
99
100 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101 let ScalarFunctionArgs { args, .. } = args;
102
103 let [string, str_list] = take_function_args(self.name(), args)?;
104
105 match (string, str_list) {
106 (
108 ColumnarValue::Scalar(
109 ScalarValue::Utf8View(string)
110 | ScalarValue::Utf8(string)
111 | ScalarValue::LargeUtf8(string),
112 ),
113 ColumnarValue::Scalar(
114 ScalarValue::Utf8View(str_list)
115 | ScalarValue::Utf8(str_list)
116 | ScalarValue::LargeUtf8(str_list),
117 ),
118 ) => {
119 let res = match (string, str_list) {
120 (Some(string), Some(str_list)) => {
121 let position = str_list
122 .split(',')
123 .position(|s| s == string)
124 .map_or(0, |idx| idx + 1);
125
126 Some(position as i32)
127 }
128 _ => None,
129 };
130 Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
131 }
132
133 (
135 ColumnarValue::Array(str_array),
136 ColumnarValue::Scalar(
137 ScalarValue::Utf8View(str_list_literal)
138 | ScalarValue::Utf8(str_list_literal)
139 | ScalarValue::LargeUtf8(str_list_literal),
140 ),
141 ) => {
142 let result_array = match str_list_literal {
143 None => new_null_array(str_array.data_type(), str_array.len()),
145 Some(str_list_literal) => {
146 let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
147 let result = match str_array.data_type() {
148 DataType::Utf8 => {
149 let string_array = str_array.as_string::<i32>();
150 find_in_set_right_literal::<Int32Type, _>(
151 string_array,
152 str_list,
153 )
154 }
155 DataType::LargeUtf8 => {
156 let string_array = str_array.as_string::<i64>();
157 find_in_set_right_literal::<Int64Type, _>(
158 string_array,
159 str_list,
160 )
161 }
162 DataType::Utf8View => {
163 let string_array = str_array.as_string_view();
164 find_in_set_right_literal::<Int32Type, _>(
165 string_array,
166 str_list,
167 )
168 }
169 other => {
170 exec_err!("Unsupported data type {other:?} for function find_in_set")
171 }
172 };
173 Arc::new(result?)
174 }
175 };
176 Ok(ColumnarValue::Array(result_array))
177 }
178
179 (
181 ColumnarValue::Scalar(
182 ScalarValue::Utf8View(string_literal)
183 | ScalarValue::Utf8(string_literal)
184 | ScalarValue::LargeUtf8(string_literal),
185 ),
186 ColumnarValue::Array(str_list_array),
187 ) => {
188 let res = match string_literal {
189 None => {
191 new_null_array(str_list_array.data_type(), str_list_array.len())
192 }
193 Some(string) => {
194 let result = match str_list_array.data_type() {
195 DataType::Utf8 => {
196 let str_list = str_list_array.as_string::<i32>();
197 find_in_set_left_literal::<Int32Type, _>(string, str_list)
198 }
199 DataType::LargeUtf8 => {
200 let str_list = str_list_array.as_string::<i64>();
201 find_in_set_left_literal::<Int64Type, _>(string, str_list)
202 }
203 DataType::Utf8View => {
204 let str_list = str_list_array.as_string_view();
205 find_in_set_left_literal::<Int32Type, _>(string, str_list)
206 }
207 other => {
208 exec_err!("Unsupported data type {other:?} for function find_in_set")
209 }
210 };
211 Arc::new(result?)
212 }
213 };
214 Ok(ColumnarValue::Array(res))
215 }
216
217 (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
219 let res = find_in_set(base_array, exp_array)?;
220
221 Ok(ColumnarValue::Array(res))
222 }
223 _ => {
224 internal_err!("Invalid argument types for `find_in_set` function")
225 }
226 }
227 }
228
229 fn documentation(&self) -> Option<&Documentation> {
230 self.doc()
231 }
232}
233
234fn find_in_set(str: ArrayRef, str_list: ArrayRef) -> Result<ArrayRef> {
238 match str.data_type() {
239 DataType::Utf8 => {
240 let string_array = str.as_string::<i32>();
241 let str_list_array = str_list.as_string::<i32>();
242 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
243 }
244 DataType::LargeUtf8 => {
245 let string_array = str.as_string::<i64>();
246 let str_list_array = str_list.as_string::<i64>();
247 find_in_set_general::<Int64Type, _>(string_array, str_list_array)
248 }
249 DataType::Utf8View => {
250 let string_array = str.as_string_view();
251 let str_list_array = str_list.as_string_view();
252 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
253 }
254 other => {
255 exec_err!("Unsupported data type {other:?} for function find_in_set")
256 }
257 }
258}
259
260pub fn find_in_set_general<'a, T, V>(
261 string_array: V,
262 str_list_array: V,
263) -> Result<ArrayRef>
264where
265 T: ArrowPrimitiveType,
266 T::Native: OffsetSizeTrait,
267 V: ArrayAccessor<Item = &'a str>,
268{
269 let string_iter = ArrayIter::new(string_array);
270 let str_list_iter = ArrayIter::new(str_list_array);
271
272 let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
273
274 string_iter
275 .zip(str_list_iter)
276 .for_each(
277 |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
278 (Some(string), Some(str_list)) => {
279 let position = str_list
280 .split(',')
281 .position(|s| s == string)
282 .map_or(0, |idx| idx + 1);
283 builder.append_value(T::Native::from_usize(position).unwrap());
284 }
285 _ => builder.append_null(),
286 },
287 );
288
289 Ok(Arc::new(builder.finish()) as ArrayRef)
290}
291
292fn find_in_set_left_literal<'a, T, V>(
293 string: String,
294 str_list_array: V,
295) -> Result<ArrayRef>
296where
297 T: ArrowPrimitiveType,
298 T::Native: OffsetSizeTrait,
299 V: ArrayAccessor<Item = &'a str>,
300{
301 let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
302
303 let str_list_iter = ArrayIter::new(str_list_array);
304
305 str_list_iter.for_each(|str_list_opt| match str_list_opt {
306 Some(str_list) => {
307 let position = str_list
308 .split(',')
309 .position(|s| s == string)
310 .map_or(0, |idx| idx + 1);
311 builder.append_value(T::Native::from_usize(position).unwrap());
312 }
313 None => builder.append_null(),
314 });
315
316 Ok(Arc::new(builder.finish()) as ArrayRef)
317}
318
319fn find_in_set_right_literal<'a, T, V>(
320 string_array: V,
321 str_list: Vec<&str>,
322) -> Result<ArrayRef>
323where
324 T: ArrowPrimitiveType,
325 T::Native: OffsetSizeTrait,
326 V: ArrayAccessor<Item = &'a str>,
327{
328 let mut builder = PrimitiveArray::<T>::builder(string_array.len());
329
330 let string_iter = ArrayIter::new(string_array);
331
332 string_iter.for_each(|string_opt| match string_opt {
333 Some(string) => {
334 let position = str_list
335 .iter()
336 .position(|s| *s == string)
337 .map_or(0, |idx| idx + 1);
338 builder.append_value(T::Native::from_usize(position).unwrap());
339 }
340 None => builder.append_null(),
341 });
342
343 Ok(Arc::new(builder.finish()) as ArrayRef)
344}
345
346#[cfg(test)]
347mod tests {
348 use crate::unicode::find_in_set::FindInSetFunc;
349 use crate::utils::test::test_function;
350 use arrow::array::{Array, Int32Array, StringArray};
351 use arrow::datatypes::DataType::Int32;
352 use datafusion_common::{Result, ScalarValue};
353 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
354 use std::sync::Arc;
355
356 #[test]
357 fn test_functions() -> Result<()> {
358 test_function!(
359 FindInSetFunc::new(),
360 vec![
361 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
362 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
363 ],
364 Ok(Some(1)),
365 i32,
366 Int32,
367 Int32Array
368 );
369 test_function!(
370 FindInSetFunc::new(),
371 vec![
372 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
373 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
374 "a,Д,🔥"
375 )))),
376 ],
377 Ok(Some(3)),
378 i32,
379 Int32,
380 Int32Array
381 );
382 test_function!(
383 FindInSetFunc::new(),
384 vec![
385 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
386 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
387 ],
388 Ok(Some(0)),
389 i32,
390 Int32,
391 Int32Array
392 );
393 test_function!(
394 FindInSetFunc::new(),
395 vec![
396 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
397 "Apache Software Foundation"
398 )))),
399 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
400 "Github,Apache Software Foundation,DataFusion"
401 )))),
402 ],
403 Ok(Some(2)),
404 i32,
405 Int32,
406 Int32Array
407 );
408 test_function!(
409 FindInSetFunc::new(),
410 vec![
411 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
412 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
413 ],
414 Ok(Some(0)),
415 i32,
416 Int32,
417 Int32Array
418 );
419 test_function!(
420 FindInSetFunc::new(),
421 vec![
422 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
423 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
424 ],
425 Ok(Some(0)),
426 i32,
427 Int32,
428 Int32Array
429 );
430 test_function!(
431 FindInSetFunc::new(),
432 vec![
433 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
434 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
435 ],
436 Ok(None),
437 i32,
438 Int32,
439 Int32Array
440 );
441 test_function!(
442 FindInSetFunc::new(),
443 vec![
444 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
445 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
446 ],
447 Ok(None),
448 i32,
449 Int32,
450 Int32Array
451 );
452
453 Ok(())
454 }
455
456 macro_rules! test_find_in_set {
457 ($test_name:ident, $args:expr, $expected:expr) => {
458 #[test]
459 fn $test_name() -> Result<()> {
460 let fis = crate::unicode::find_in_set();
461
462 let args = $args;
463 let expected = $expected;
464
465 let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
466 let cardinality = args
467 .iter()
468 .fold(Option::<usize>::None, |acc, arg| match arg {
469 ColumnarValue::Scalar(_) => acc,
470 ColumnarValue::Array(a) => Some(a.len()),
471 })
472 .unwrap_or(1);
473 let return_type = fis.return_type(&type_array)?;
474 let result = fis.invoke_with_args(ScalarFunctionArgs {
475 args,
476 number_rows: cardinality,
477 return_type: &return_type,
478 });
479 assert!(result.is_ok());
480
481 let result = result?
482 .to_array(cardinality)
483 .expect("Failed to convert to array");
484 let result = result
485 .as_any()
486 .downcast_ref::<Int32Array>()
487 .expect("Failed to convert to type");
488 assert_eq!(*result, expected);
489
490 Ok(())
491 }
492 };
493 }
494
495 test_find_in_set!(
496 test_find_in_set_with_scalar_args,
497 vec![
498 ColumnarValue::Array(Arc::new(StringArray::from(vec![
499 "", "a", "b", "c", "d"
500 ]))),
501 ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
502 ],
503 Int32Array::from(vec![0, 0, 1, 2, 3])
504 );
505 test_find_in_set!(
506 test_find_in_set_with_scalar_args_2,
507 vec![
508 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
509 "ApacheSoftware".to_string()
510 ))),
511 ColumnarValue::Array(Arc::new(StringArray::from(vec![
512 "a,b,c",
513 "ApacheSoftware,Github,DataFusion",
514 ""
515 ]))),
516 ],
517 Int32Array::from(vec![0, 1, 0])
518 );
519 test_find_in_set!(
520 test_find_in_set_with_scalar_args_3,
521 vec![
522 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
523 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
524 ],
525 Int32Array::from(vec![None::<i32>; 3])
526 );
527 test_find_in_set!(
528 test_find_in_set_with_scalar_args_4,
529 vec![
530 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
531 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
532 ],
533 Int32Array::from(vec![None::<i32>; 3])
534 );
535}