1use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType};
19use arrow::datatypes::{DataType, Int64Type};
20use arrow::datatypes::{
21 DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
22};
23use arrow::error::ArrowError;
24use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
25use datafusion_expr::{
26 ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact,
27 TypeSignature::Uniform, Volatility,
28};
29use datafusion_macros::user_doc;
30use itertools::izip;
31use regex::Regex;
32use std::collections::hash_map::Entry;
33use std::collections::HashMap;
34use std::sync::Arc;
35
36#[user_doc(
37 doc_section(label = "Regular Expression Functions"),
38 description = "Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.",
39 syntax_example = "regexp_count(str, regexp[, start, flags])",
40 sql_example = r#"```sql
41> select regexp_count('abcAbAbc', 'abc', 2, 'i');
42+---------------------------------------------------------------+
43| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) |
44+---------------------------------------------------------------+
45| 1 |
46+---------------------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 standard_argument(name = "regexp", prefix = "Regular"),
50 argument(
51 name = "start",
52 description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function."
53 ),
54 argument(
55 name = "flags",
56 description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
57 - **i**: case-insensitive: letters match both upper and lower case
58 - **m**: multi-line mode: ^ and $ match begin/end of line
59 - **s**: allow . to match \n
60 - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
61 - **U**: swap the meaning of x* and x*?"#
62 )
63)]
64#[derive(Debug)]
65pub struct RegexpCountFunc {
66 signature: Signature,
67}
68
69impl Default for RegexpCountFunc {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl RegexpCountFunc {
76 pub fn new() -> Self {
77 Self {
78 signature: Signature::one_of(
79 vec![
80 Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
81 Exact(vec![Utf8View, Utf8View, Int64]),
82 Exact(vec![LargeUtf8, LargeUtf8, Int64]),
83 Exact(vec![Utf8, Utf8, Int64]),
84 Exact(vec![Utf8View, Utf8View, Int64, Utf8View]),
85 Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]),
86 Exact(vec![Utf8, Utf8, Int64, Utf8]),
87 ],
88 Volatility::Immutable,
89 ),
90 }
91 }
92}
93
94impl ScalarUDFImpl for RegexpCountFunc {
95 fn as_any(&self) -> &dyn std::any::Any {
96 self
97 }
98
99 fn name(&self) -> &str {
100 "regexp_count"
101 }
102
103 fn signature(&self) -> &Signature {
104 &self.signature
105 }
106
107 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
108 Ok(Int64)
109 }
110
111 fn invoke_with_args(
112 &self,
113 args: datafusion_expr::ScalarFunctionArgs,
114 ) -> Result<ColumnarValue> {
115 let args = &args.args;
116
117 let len = args
118 .iter()
119 .fold(Option::<usize>::None, |acc, arg| match arg {
120 ColumnarValue::Scalar(_) => acc,
121 ColumnarValue::Array(a) => Some(a.len()),
122 });
123
124 let is_scalar = len.is_none();
125 let inferred_length = len.unwrap_or(1);
126 let args = args
127 .iter()
128 .map(|arg| arg.to_array(inferred_length))
129 .collect::<Result<Vec<_>>>()?;
130
131 let result = regexp_count_func(&args);
132 if is_scalar {
133 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
135 result.map(ColumnarValue::Scalar)
136 } else {
137 result.map(ColumnarValue::Array)
138 }
139 }
140
141 fn documentation(&self) -> Option<&Documentation> {
142 self.doc()
143 }
144}
145
146pub fn regexp_count_func(args: &[ArrayRef]) -> Result<ArrayRef> {
147 let args_len = args.len();
148 if !(2..=4).contains(&args_len) {
149 return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4.");
150 }
151
152 let values = &args[0];
153 match values.data_type() {
154 Utf8 | LargeUtf8 | Utf8View => (),
155 other => {
156 return internal_err!(
157 "Unsupported data type {other:?} for function regexp_count"
158 );
159 }
160 }
161
162 regexp_count(
163 values,
164 &args[1],
165 if args_len > 2 { Some(&args[2]) } else { None },
166 if args_len > 3 { Some(&args[3]) } else { None },
167 )
168 .map_err(|e| e.into())
169}
170
171pub fn regexp_count(
187 values: &dyn Array,
188 regex_array: &dyn Datum,
189 start_array: Option<&dyn Datum>,
190 flags_array: Option<&dyn Datum>,
191) -> Result<ArrayRef, ArrowError> {
192 let (regex_array, is_regex_scalar) = regex_array.get();
193 let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| {
194 let (start, is_start_scalar) = start.get();
195 (Some(start), is_start_scalar)
196 });
197 let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| {
198 let (flags, is_flags_scalar) = flags.get();
199 (Some(flags), is_flags_scalar)
200 });
201
202 match (values.data_type(), regex_array.data_type(), flags_array) {
203 (Utf8, Utf8, None) => regexp_count_inner(
204 values.as_string::<i32>(),
205 regex_array.as_string::<i32>(),
206 is_regex_scalar,
207 start_array.map(|start| start.as_primitive::<Int64Type>()),
208 is_start_scalar,
209 None,
210 is_flags_scalar,
211 ),
212 (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner(
213 values.as_string::<i32>(),
214 regex_array.as_string::<i32>(),
215 is_regex_scalar,
216 start_array.map(|start| start.as_primitive::<Int64Type>()),
217 is_start_scalar,
218 Some(flags_array.as_string::<i32>()),
219 is_flags_scalar,
220 ),
221 (LargeUtf8, LargeUtf8, None) => regexp_count_inner(
222 values.as_string::<i64>(),
223 regex_array.as_string::<i64>(),
224 is_regex_scalar,
225 start_array.map(|start| start.as_primitive::<Int64Type>()),
226 is_start_scalar,
227 None,
228 is_flags_scalar,
229 ),
230 (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner(
231 values.as_string::<i64>(),
232 regex_array.as_string::<i64>(),
233 is_regex_scalar,
234 start_array.map(|start| start.as_primitive::<Int64Type>()),
235 is_start_scalar,
236 Some(flags_array.as_string::<i64>()),
237 is_flags_scalar,
238 ),
239 (Utf8View, Utf8View, None) => regexp_count_inner(
240 values.as_string_view(),
241 regex_array.as_string_view(),
242 is_regex_scalar,
243 start_array.map(|start| start.as_primitive::<Int64Type>()),
244 is_start_scalar,
245 None,
246 is_flags_scalar,
247 ),
248 (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner(
249 values.as_string_view(),
250 regex_array.as_string_view(),
251 is_regex_scalar,
252 start_array.map(|start| start.as_primitive::<Int64Type>()),
253 is_start_scalar,
254 Some(flags_array.as_string_view()),
255 is_flags_scalar,
256 ),
257 _ => Err(ArrowError::ComputeError(
258 "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
259 )),
260 }
261}
262
263pub fn regexp_count_inner<'a, S>(
264 values: S,
265 regex_array: S,
266 is_regex_scalar: bool,
267 start_array: Option<&Int64Array>,
268 is_start_scalar: bool,
269 flags_array: Option<S>,
270 is_flags_scalar: bool,
271) -> Result<ArrayRef, ArrowError>
272where
273 S: StringArrayType<'a>,
274{
275 let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 {
276 (Some(regex_array.value(0)), true)
277 } else {
278 (None, false)
279 };
280
281 let (start_array, start_scalar, is_start_scalar) =
282 if let Some(start_array) = start_array {
283 if is_start_scalar || start_array.len() == 1 {
284 (None, Some(start_array.value(0)), true)
285 } else {
286 (Some(start_array), None, false)
287 }
288 } else {
289 (None, Some(1), true)
290 };
291
292 let (flags_array, flags_scalar, is_flags_scalar) =
293 if let Some(flags_array) = flags_array {
294 if is_flags_scalar || flags_array.len() == 1 {
295 (None, Some(flags_array.value(0)), true)
296 } else {
297 (Some(flags_array), None, false)
298 }
299 } else {
300 (None, None, true)
301 };
302
303 let mut regex_cache = HashMap::new();
304
305 match (is_regex_scalar, is_start_scalar, is_flags_scalar) {
306 (true, true, true) => {
307 let regex = match regex_scalar {
308 None | Some("") => {
309 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])))
310 }
311 Some(regex) => regex,
312 };
313
314 let pattern = compile_regex(regex, flags_scalar)?;
315
316 Ok(Arc::new(
317 values
318 .iter()
319 .map(|value| count_matches(value, &pattern, start_scalar))
320 .collect::<Result<Int64Array, ArrowError>>()?,
321 ))
322 }
323 (true, true, false) => {
324 let regex = match regex_scalar {
325 None | Some("") => {
326 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])))
327 }
328 Some(regex) => regex,
329 };
330
331 let flags_array = flags_array.unwrap();
332 if values.len() != flags_array.len() {
333 return Err(ArrowError::ComputeError(format!(
334 "flags_array must be the same length as values array; got {} and {}",
335 flags_array.len(),
336 values.len(),
337 )));
338 }
339
340 Ok(Arc::new(
341 values
342 .iter()
343 .zip(flags_array.iter())
344 .map(|(value, flags)| {
345 let pattern =
346 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
347 count_matches(value, pattern, start_scalar)
348 })
349 .collect::<Result<Int64Array, ArrowError>>()?,
350 ))
351 }
352 (true, false, true) => {
353 let regex = match regex_scalar {
354 None | Some("") => {
355 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])))
356 }
357 Some(regex) => regex,
358 };
359
360 let pattern = compile_regex(regex, flags_scalar)?;
361
362 let start_array = start_array.unwrap();
363
364 Ok(Arc::new(
365 values
366 .iter()
367 .zip(start_array.iter())
368 .map(|(value, start)| count_matches(value, &pattern, start))
369 .collect::<Result<Int64Array, ArrowError>>()?,
370 ))
371 }
372 (true, false, false) => {
373 let regex = match regex_scalar {
374 None | Some("") => {
375 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])))
376 }
377 Some(regex) => regex,
378 };
379
380 let flags_array = flags_array.unwrap();
381 if values.len() != flags_array.len() {
382 return Err(ArrowError::ComputeError(format!(
383 "flags_array must be the same length as values array; got {} and {}",
384 flags_array.len(),
385 values.len(),
386 )));
387 }
388
389 Ok(Arc::new(
390 izip!(
391 values.iter(),
392 start_array.unwrap().iter(),
393 flags_array.iter()
394 )
395 .map(|(value, start, flags)| {
396 let pattern =
397 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
398
399 count_matches(value, pattern, start)
400 })
401 .collect::<Result<Int64Array, ArrowError>>()?,
402 ))
403 }
404 (false, true, true) => {
405 if values.len() != regex_array.len() {
406 return Err(ArrowError::ComputeError(format!(
407 "regex_array must be the same length as values array; got {} and {}",
408 regex_array.len(),
409 values.len(),
410 )));
411 }
412
413 Ok(Arc::new(
414 values
415 .iter()
416 .zip(regex_array.iter())
417 .map(|(value, regex)| {
418 let regex = match regex {
419 None | Some("") => return Ok(0),
420 Some(regex) => regex,
421 };
422
423 let pattern = compile_and_cache_regex(
424 regex,
425 flags_scalar,
426 &mut regex_cache,
427 )?;
428 count_matches(value, pattern, start_scalar)
429 })
430 .collect::<Result<Int64Array, ArrowError>>()?,
431 ))
432 }
433 (false, true, false) => {
434 if values.len() != regex_array.len() {
435 return Err(ArrowError::ComputeError(format!(
436 "regex_array must be the same length as values array; got {} and {}",
437 regex_array.len(),
438 values.len(),
439 )));
440 }
441
442 let flags_array = flags_array.unwrap();
443 if values.len() != flags_array.len() {
444 return Err(ArrowError::ComputeError(format!(
445 "flags_array must be the same length as values array; got {} and {}",
446 flags_array.len(),
447 values.len(),
448 )));
449 }
450
451 Ok(Arc::new(
452 izip!(values.iter(), regex_array.iter(), flags_array.iter())
453 .map(|(value, regex, flags)| {
454 let regex = match regex {
455 None | Some("") => return Ok(0),
456 Some(regex) => regex,
457 };
458
459 let pattern =
460 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
461
462 count_matches(value, pattern, start_scalar)
463 })
464 .collect::<Result<Int64Array, ArrowError>>()?,
465 ))
466 }
467 (false, false, true) => {
468 if values.len() != regex_array.len() {
469 return Err(ArrowError::ComputeError(format!(
470 "regex_array must be the same length as values array; got {} and {}",
471 regex_array.len(),
472 values.len(),
473 )));
474 }
475
476 let start_array = start_array.unwrap();
477 if values.len() != start_array.len() {
478 return Err(ArrowError::ComputeError(format!(
479 "start_array must be the same length as values array; got {} and {}",
480 start_array.len(),
481 values.len(),
482 )));
483 }
484
485 Ok(Arc::new(
486 izip!(values.iter(), regex_array.iter(), start_array.iter())
487 .map(|(value, regex, start)| {
488 let regex = match regex {
489 None | Some("") => return Ok(0),
490 Some(regex) => regex,
491 };
492
493 let pattern = compile_and_cache_regex(
494 regex,
495 flags_scalar,
496 &mut regex_cache,
497 )?;
498 count_matches(value, pattern, start)
499 })
500 .collect::<Result<Int64Array, ArrowError>>()?,
501 ))
502 }
503 (false, false, false) => {
504 if values.len() != regex_array.len() {
505 return Err(ArrowError::ComputeError(format!(
506 "regex_array must be the same length as values array; got {} and {}",
507 regex_array.len(),
508 values.len(),
509 )));
510 }
511
512 let start_array = start_array.unwrap();
513 if values.len() != start_array.len() {
514 return Err(ArrowError::ComputeError(format!(
515 "start_array must be the same length as values array; got {} and {}",
516 start_array.len(),
517 values.len(),
518 )));
519 }
520
521 let flags_array = flags_array.unwrap();
522 if values.len() != flags_array.len() {
523 return Err(ArrowError::ComputeError(format!(
524 "flags_array must be the same length as values array; got {} and {}",
525 flags_array.len(),
526 values.len(),
527 )));
528 }
529
530 Ok(Arc::new(
531 izip!(
532 values.iter(),
533 regex_array.iter(),
534 start_array.iter(),
535 flags_array.iter()
536 )
537 .map(|(value, regex, start, flags)| {
538 let regex = match regex {
539 None | Some("") => return Ok(0),
540 Some(regex) => regex,
541 };
542
543 let pattern =
544 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
545 count_matches(value, pattern, start)
546 })
547 .collect::<Result<Int64Array, ArrowError>>()?,
548 ))
549 }
550 }
551}
552
553fn compile_and_cache_regex<'strings, 'cache>(
554 regex: &'strings str,
555 flags: Option<&'strings str>,
556 regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
557) -> Result<&'cache Regex, ArrowError>
558where
559 'strings: 'cache,
560{
561 let result = match regex_cache.entry((regex, flags)) {
562 Entry::Occupied(occupied_entry) => occupied_entry.into_mut(),
563 Entry::Vacant(vacant_entry) => {
564 let compiled = compile_regex(regex, flags)?;
565 vacant_entry.insert(compiled)
566 }
567 };
568 Ok(result)
569}
570
571fn compile_regex(regex: &str, flags: Option<&str>) -> Result<Regex, ArrowError> {
572 let pattern = match flags {
573 None | Some("") => regex.to_string(),
574 Some(flags) => {
575 if flags.contains("g") {
576 return Err(ArrowError::ComputeError(
577 "regexp_count() does not support global flag".to_string(),
578 ));
579 }
580 format!("(?{}){}", flags, regex)
581 }
582 };
583
584 Regex::new(&pattern).map_err(|_| {
585 ArrowError::ComputeError(format!(
586 "Regular expression did not compile: {}",
587 pattern
588 ))
589 })
590}
591
592fn count_matches(
593 value: Option<&str>,
594 pattern: &Regex,
595 start: Option<i64>,
596) -> Result<i64, ArrowError> {
597 let value = match value {
598 None | Some("") => return Ok(0),
599 Some(value) => value,
600 };
601
602 if let Some(start) = start {
603 if start < 1 {
604 return Err(ArrowError::ComputeError(
605 "regexp_count() requires start to be 1 based".to_string(),
606 ));
607 }
608
609 let find_slice = value.chars().skip(start as usize - 1).collect::<String>();
610 let count = pattern.find_iter(find_slice.as_str()).count();
611 Ok(count as i64)
612 } else {
613 let count = pattern.find_iter(value).count();
614 Ok(count as i64)
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621 use arrow::array::{GenericStringArray, StringViewArray};
622 use datafusion_expr::ScalarFunctionArgs;
623
624 #[test]
625 fn test_regexp_count() {
626 test_case_sensitive_regexp_count_scalar();
627 test_case_sensitive_regexp_count_scalar_start();
628 test_case_insensitive_regexp_count_scalar_flags();
629 test_case_sensitive_regexp_count_start_scalar_complex();
630
631 test_case_sensitive_regexp_count_array::<GenericStringArray<i32>>();
632 test_case_sensitive_regexp_count_array::<GenericStringArray<i64>>();
633 test_case_sensitive_regexp_count_array::<StringViewArray>();
634
635 test_case_sensitive_regexp_count_array_start::<GenericStringArray<i32>>();
636 test_case_sensitive_regexp_count_array_start::<GenericStringArray<i64>>();
637 test_case_sensitive_regexp_count_array_start::<StringViewArray>();
638
639 test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i32>>();
640 test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i64>>();
641 test_case_insensitive_regexp_count_array_flags::<StringViewArray>();
642
643 test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i32>>();
644 test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i64>>();
645 test_case_sensitive_regexp_count_array_complex::<StringViewArray>();
646
647 test_case_regexp_count_cache_check::<GenericStringArray<i32>>();
648 }
649
650 fn test_case_sensitive_regexp_count_scalar() {
651 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
652 let regex = "abc";
653 let expected: Vec<i64> = vec![0, 1, 2, 1, 3];
654
655 values.iter().enumerate().for_each(|(pos, &v)| {
656 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
658 let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
659 let expected = expected.get(pos).cloned();
660 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
661 args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)],
662 number_rows: 2,
663 return_type: &Int64,
664 });
665 match re {
666 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
667 assert_eq!(v, expected, "regexp_count scalar test failed");
668 }
669 _ => panic!("Unexpected result"),
670 }
671
672 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
674 let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
675 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
676 args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)],
677 number_rows: 2,
678 return_type: &Int64,
679 });
680 match re {
681 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
682 assert_eq!(v, expected, "regexp_count scalar test failed");
683 }
684 _ => panic!("Unexpected result"),
685 }
686
687 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
689 let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
690 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
691 args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)],
692 number_rows: 2,
693 return_type: &Int64,
694 });
695 match re {
696 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
697 assert_eq!(v, expected, "regexp_count scalar test failed");
698 }
699 _ => panic!("Unexpected result"),
700 }
701 });
702 }
703
704 fn test_case_sensitive_regexp_count_scalar_start() {
705 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
706 let regex = "abc";
707 let start = 2;
708 let expected: Vec<i64> = vec![0, 1, 1, 0, 2];
709
710 values.iter().enumerate().for_each(|(pos, &v)| {
711 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
713 let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
714 let start_sv = ScalarValue::Int64(Some(start));
715 let expected = expected.get(pos).cloned();
716 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
717 args: vec![
718 ColumnarValue::Scalar(v_sv),
719 ColumnarValue::Scalar(regex_sv),
720 ColumnarValue::Scalar(start_sv.clone()),
721 ],
722 number_rows: 3,
723 return_type: &Int64,
724 });
725 match re {
726 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
727 assert_eq!(v, expected, "regexp_count scalar test failed");
728 }
729 _ => panic!("Unexpected result"),
730 }
731
732 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
734 let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
735 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
736 args: vec![
737 ColumnarValue::Scalar(v_sv),
738 ColumnarValue::Scalar(regex_sv),
739 ColumnarValue::Scalar(start_sv.clone()),
740 ],
741 number_rows: 3,
742 return_type: &Int64,
743 });
744 match re {
745 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
746 assert_eq!(v, expected, "regexp_count scalar test failed");
747 }
748 _ => panic!("Unexpected result"),
749 }
750
751 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
753 let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
754 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
755 args: vec![
756 ColumnarValue::Scalar(v_sv),
757 ColumnarValue::Scalar(regex_sv),
758 ColumnarValue::Scalar(start_sv.clone()),
759 ],
760 number_rows: 3,
761 return_type: &Int64,
762 });
763 match re {
764 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
765 assert_eq!(v, expected, "regexp_count scalar test failed");
766 }
767 _ => panic!("Unexpected result"),
768 }
769 });
770 }
771
772 fn test_case_insensitive_regexp_count_scalar_flags() {
773 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
774 let regex = "abc";
775 let start = 1;
776 let flags = "i";
777 let expected: Vec<i64> = vec![0, 1, 2, 2, 3];
778
779 values.iter().enumerate().for_each(|(pos, &v)| {
780 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
782 let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
783 let start_sv = ScalarValue::Int64(Some(start));
784 let flags_sv = ScalarValue::Utf8(Some(flags.to_string()));
785 let expected = expected.get(pos).cloned();
786 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
787 args: vec![
788 ColumnarValue::Scalar(v_sv),
789 ColumnarValue::Scalar(regex_sv),
790 ColumnarValue::Scalar(start_sv.clone()),
791 ColumnarValue::Scalar(flags_sv.clone()),
792 ],
793 number_rows: 4,
794 return_type: &Int64,
795 });
796 match re {
797 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
798 assert_eq!(v, expected, "regexp_count scalar test failed");
799 }
800 _ => panic!("Unexpected result"),
801 }
802
803 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
805 let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
806 let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string()));
807 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
808 args: vec![
809 ColumnarValue::Scalar(v_sv),
810 ColumnarValue::Scalar(regex_sv),
811 ColumnarValue::Scalar(start_sv.clone()),
812 ColumnarValue::Scalar(flags_sv.clone()),
813 ],
814 number_rows: 4,
815 return_type: &Int64,
816 });
817 match re {
818 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
819 assert_eq!(v, expected, "regexp_count scalar test failed");
820 }
821 _ => panic!("Unexpected result"),
822 }
823
824 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
826 let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
827 let flags_sv = ScalarValue::Utf8View(Some(flags.to_string()));
828 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
829 args: vec![
830 ColumnarValue::Scalar(v_sv),
831 ColumnarValue::Scalar(regex_sv),
832 ColumnarValue::Scalar(start_sv.clone()),
833 ColumnarValue::Scalar(flags_sv.clone()),
834 ],
835 number_rows: 4,
836 return_type: &Int64,
837 });
838 match re {
839 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
840 assert_eq!(v, expected, "regexp_count scalar test failed");
841 }
842 _ => panic!("Unexpected result"),
843 }
844 });
845 }
846
847 fn test_case_sensitive_regexp_count_array<A>()
848 where
849 A: From<Vec<&'static str>> + Array + 'static,
850 {
851 let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]);
852 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
853
854 let expected = Int64Array::from(vec![0, 1, 2, 2, 2]);
855
856 let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
857 assert_eq!(re.as_ref(), &expected);
858 }
859
860 fn test_case_sensitive_regexp_count_array_start<A>()
861 where
862 A: From<Vec<&'static str>> + Array + 'static,
863 {
864 let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
865 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
866 let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
867
868 let expected = Int64Array::from(vec![0, 0, 1, 1, 0]);
869
870 let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
871 .unwrap();
872 assert_eq!(re.as_ref(), &expected);
873 }
874
875 fn test_case_insensitive_regexp_count_array_flags<A>()
876 where
877 A: From<Vec<&'static str>> + Array + 'static,
878 {
879 let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
880 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
881 let start = Int64Array::from(vec![1]);
882 let flags = A::from(vec!["", "i", "", "", "i"]);
883
884 let expected = Int64Array::from(vec![0, 1, 2, 2, 3]);
885
886 let re = regexp_count_func(&[
887 Arc::new(values),
888 Arc::new(regex),
889 Arc::new(start),
890 Arc::new(flags),
891 ])
892 .unwrap();
893 assert_eq!(re.as_ref(), &expected);
894 }
895
896 fn test_case_sensitive_regexp_count_start_scalar_complex() {
897 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
898 let regex = ["", "abc", "a", "bc", "ab"];
899 let start = 5;
900 let flags = ["", "i", "", "", "i"];
901 let expected: Vec<i64> = vec![0, 0, 0, 1, 1];
902
903 values.iter().enumerate().for_each(|(pos, &v)| {
904 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
906 let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| s.to_string()));
907 let start_sv = ScalarValue::Int64(Some(start));
908 let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string()));
909 let expected = expected.get(pos).cloned();
910 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
911 args: vec![
912 ColumnarValue::Scalar(v_sv),
913 ColumnarValue::Scalar(regex_sv),
914 ColumnarValue::Scalar(start_sv.clone()),
915 ColumnarValue::Scalar(flags_sv.clone()),
916 ],
917 number_rows: 4,
918 return_type: &Int64,
919 });
920 match re {
921 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
922 assert_eq!(v, expected, "regexp_count scalar test failed");
923 }
924 _ => panic!("Unexpected result"),
925 }
926
927 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
929 let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string()));
930 let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string()));
931 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
932 args: vec![
933 ColumnarValue::Scalar(v_sv),
934 ColumnarValue::Scalar(regex_sv),
935 ColumnarValue::Scalar(start_sv.clone()),
936 ColumnarValue::Scalar(flags_sv.clone()),
937 ],
938 number_rows: 4,
939 return_type: &Int64,
940 });
941 match re {
942 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
943 assert_eq!(v, expected, "regexp_count scalar test failed");
944 }
945 _ => panic!("Unexpected result"),
946 }
947
948 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
950 let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string()));
951 let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string()));
952 let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
953 args: vec![
954 ColumnarValue::Scalar(v_sv),
955 ColumnarValue::Scalar(regex_sv),
956 ColumnarValue::Scalar(start_sv.clone()),
957 ColumnarValue::Scalar(flags_sv.clone()),
958 ],
959 number_rows: 4,
960 return_type: &Int64,
961 });
962 match re {
963 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
964 assert_eq!(v, expected, "regexp_count scalar test failed");
965 }
966 _ => panic!("Unexpected result"),
967 }
968 });
969 }
970
971 fn test_case_sensitive_regexp_count_array_complex<A>()
972 where
973 A: From<Vec<&'static str>> + Array + 'static,
974 {
975 let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
976 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
977 let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
978 let flags = A::from(vec!["", "i", "", "", "i"]);
979
980 let expected = Int64Array::from(vec![0, 1, 1, 1, 1]);
981
982 let re = regexp_count_func(&[
983 Arc::new(values),
984 Arc::new(regex),
985 Arc::new(start),
986 Arc::new(flags),
987 ])
988 .unwrap();
989 assert_eq!(re.as_ref(), &expected);
990 }
991
992 fn test_case_regexp_count_cache_check<A>()
993 where
994 A: From<Vec<&'static str>> + Array + 'static,
995 {
996 let values = A::from(vec!["aaa", "Aaa", "aaa"]);
997 let regex = A::from(vec!["aaa", "aaa", "aaa"]);
998 let start = Int64Array::from(vec![1, 1, 1]);
999 let flags = A::from(vec!["", "i", ""]);
1000
1001 let expected = Int64Array::from(vec![1, 1, 1]);
1002
1003 let re = regexp_count_func(&[
1004 Arc::new(values),
1005 Arc::new(regex),
1006 Arc::new(start),
1007 Arc::new(flags),
1008 ])
1009 .unwrap();
1010 assert_eq!(re.as_ref(), &expected);
1011 }
1012}