1use std::any::Any;
19use std::sync::Arc;
20
21use crate::strings::make_and_append_view;
22use crate::utils::make_scalar_function;
23use arrow::array::{
24 Array, ArrayIter, ArrayRef, AsArray, Int64Array, NullBufferBuilder, StringArrayType,
25 StringViewArray, StringViewBuilder,
26};
27use arrow::buffer::ScalarBuffer;
28use arrow::datatypes::DataType;
29use datafusion_common::cast::as_int64_array;
30use datafusion_common::{exec_err, plan_err, Result};
31use datafusion_expr::{
32 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Extracts a substring of a specified number of characters from a specific starting position in a string.",
39 syntax_example = "substr(str, start_pos[, length])",
40 alternative_syntax = "substring(str from start_pos for length)",
41 sql_example = r#"```sql
42> select substr('datafusion', 5, 3);
43+----------------------------------------------+
44| substr(Utf8("datafusion"),Int64(5),Int64(3)) |
45+----------------------------------------------+
46| fus |
47+----------------------------------------------+
48```"#,
49 standard_argument(name = "str", prefix = "String"),
50 argument(
51 name = "start_pos",
52 description = "Character position to start the substring at. The first character in the string has a position of 1."
53 ),
54 argument(
55 name = "length",
56 description = "Number of characters to extract. If not specified, returns the rest of the string after the start position."
57 )
58)]
59#[derive(Debug)]
60pub struct SubstrFunc {
61 signature: Signature,
62 aliases: Vec<String>,
63}
64
65impl Default for SubstrFunc {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl SubstrFunc {
72 pub fn new() -> Self {
73 Self {
74 signature: Signature::user_defined(Volatility::Immutable),
75 aliases: vec![String::from("substring")],
76 }
77 }
78}
79
80impl ScalarUDFImpl for SubstrFunc {
81 fn as_any(&self) -> &dyn Any {
82 self
83 }
84
85 fn name(&self) -> &str {
86 "substr"
87 }
88
89 fn signature(&self) -> &Signature {
90 &self.signature
91 }
92
93 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
95 Ok(DataType::Utf8View)
96 }
97
98 fn invoke_with_args(
99 &self,
100 args: datafusion_expr::ScalarFunctionArgs,
101 ) -> Result<ColumnarValue> {
102 make_scalar_function(substr, vec![])(&args.args)
103 }
104
105 fn aliases(&self) -> &[String] {
106 &self.aliases
107 }
108
109 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
110 if arg_types.len() < 2 || arg_types.len() > 3 {
111 return plan_err!(
112 "The {} function requires 2 or 3 arguments, but got {}.",
113 self.name(),
114 arg_types.len()
115 );
116 }
117 let first_data_type = match &arg_types[0] {
118 DataType::Null => Ok(DataType::Utf8),
119 DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(arg_types[0].clone()),
120 DataType::Dictionary(key_type, value_type) => {
121 if key_type.is_integer() {
122 match value_type.as_ref() {
123 DataType::Null => Ok(DataType::Utf8),
124 DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(*value_type.clone()),
125 _ => plan_err!(
126 "The first argument of the {} function can only be a string, but got {:?}.",
127 self.name(),
128 arg_types[0]
129 ),
130 }
131 } else {
132 plan_err!(
133 "The first argument of the {} function can only be a string, but got {:?}.",
134 self.name(),
135 arg_types[0]
136 )
137 }
138 }
139 _ => plan_err!(
140 "The first argument of the {} function can only be a string, but got {:?}.",
141 self.name(),
142 arg_types[0]
143 )
144 }?;
145
146 if ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[1]) {
147 return plan_err!(
148 "The second argument of the {} function can only be an integer, but got {:?}.",
149 self.name(),
150 arg_types[1]
151 );
152 }
153
154 if arg_types.len() == 3
155 && ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[2])
156 {
157 return plan_err!(
158 "The third argument of the {} function can only be an integer, but got {:?}.",
159 self.name(),
160 arg_types[2]
161 );
162 }
163
164 if arg_types.len() == 2 {
165 Ok(vec![first_data_type.to_owned(), DataType::Int64])
166 } else {
167 Ok(vec![
168 first_data_type.to_owned(),
169 DataType::Int64,
170 DataType::Int64,
171 ])
172 }
173 }
174
175 fn documentation(&self) -> Option<&Documentation> {
176 self.doc()
177 }
178}
179
180pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
185 match args[0].data_type() {
186 DataType::Utf8 => {
187 let string_array = args[0].as_string::<i32>();
188 string_substr::<_>(string_array, &args[1..])
189 }
190 DataType::LargeUtf8 => {
191 let string_array = args[0].as_string::<i64>();
192 string_substr::<_>(string_array, &args[1..])
193 }
194 DataType::Utf8View => {
195 let string_array = args[0].as_string_view();
196 string_view_substr(string_array, &args[1..])
197 }
198 other => exec_err!(
199 "Unsupported data type {other:?} for function substr,\
200 expected Utf8View, Utf8 or LargeUtf8."
201 ),
202 }
203}
204
205fn get_true_start_end(
220 input: &str,
221 start: i64,
222 count: Option<u64>,
223 is_input_ascii_only: bool,
224) -> (usize, usize) {
225 let start = start.checked_sub(1).unwrap_or(start);
226
227 let end = match count {
228 Some(count) => start + count as i64,
229 None => input.len() as i64,
230 };
231 let count_to_end = count.is_some();
232
233 let start = start.clamp(0, input.len() as i64) as usize;
234 let end = end.clamp(0, input.len() as i64) as usize;
235 let count = end - start;
236
237 if is_input_ascii_only {
239 return (start, end);
240 }
241
242 let (mut st, mut ed) = (input.len(), input.len());
246 let mut start_counting = false;
247 let mut cnt = 0;
248 for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
249 if char_cnt == start {
250 st = byte_cnt;
251 if count_to_end {
252 start_counting = true;
253 } else {
254 break;
255 }
256 }
257 if start_counting {
258 if cnt == count {
259 ed = byte_cnt;
260 break;
261 }
262 cnt += 1;
263 }
264 }
265 (st, ed)
266}
267
268fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
279 string_array: &V,
280 start: &Int64Array,
281 count: Option<&Int64Array>,
282) -> bool {
283 let is_short_prefix = match count {
284 Some(count) => {
285 let short_prefix_threshold = 32.0;
286 let n_sample = 10;
287
288 let avg_prefix_len = start
291 .iter()
292 .zip(count.iter())
293 .take(n_sample)
294 .map(|(start, count)| {
295 let start = start.unwrap_or(0);
296 let count = count.unwrap_or(0);
297 start + count
299 })
300 .sum::<i64>();
301
302 avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold
303 }
304 None => false,
305 };
306
307 if is_short_prefix {
308 false
310 } else {
311 string_array.is_ascii()
312 }
313}
314
315fn string_view_substr(
318 string_view_array: &StringViewArray,
319 args: &[ArrayRef],
320) -> Result<ArrayRef> {
321 let mut views_buf = Vec::with_capacity(string_view_array.len());
322 let mut null_builder = NullBufferBuilder::new(string_view_array.len());
323
324 let start_array = as_int64_array(&args[0])?;
325 let count_array_opt = if args.len() == 2 {
326 Some(as_int64_array(&args[1])?)
327 } else {
328 None
329 };
330
331 let enable_ascii_fast_path =
332 enable_ascii_fast_path(&string_view_array, start_array, count_array_opt);
333
334 match args.len() {
337 1 => {
338 for ((str_opt, raw_view), start_opt) in string_view_array
339 .iter()
340 .zip(string_view_array.views().iter())
341 .zip(start_array.iter())
342 {
343 if let (Some(str), Some(start)) = (str_opt, start_opt) {
344 let (start, end) =
345 get_true_start_end(str, start, None, enable_ascii_fast_path);
346 let substr = &str[start..end];
347
348 make_and_append_view(
349 &mut views_buf,
350 &mut null_builder,
351 raw_view,
352 substr,
353 start as u32,
354 );
355 } else {
356 null_builder.append_null();
357 views_buf.push(0);
358 }
359 }
360 }
361 2 => {
362 let count_array = count_array_opt.unwrap();
363 for (((str_opt, raw_view), start_opt), count_opt) in string_view_array
364 .iter()
365 .zip(string_view_array.views().iter())
366 .zip(start_array.iter())
367 .zip(count_array.iter())
368 {
369 if let (Some(str), Some(start), Some(count)) =
370 (str_opt, start_opt, count_opt)
371 {
372 if count < 0 {
373 return exec_err!(
374 "negative substring length not allowed: substr(<str>, {start}, {count})"
375 );
376 } else {
377 if start == i64::MIN {
378 return exec_err!(
379 "negative overflow when calculating skip value"
380 );
381 }
382 let (start, end) = get_true_start_end(
383 str,
384 start,
385 Some(count as u64),
386 enable_ascii_fast_path,
387 );
388 let substr = &str[start..end];
389
390 make_and_append_view(
391 &mut views_buf,
392 &mut null_builder,
393 raw_view,
394 substr,
395 start as u32,
396 );
397 }
398 } else {
399 null_builder.append_null();
400 views_buf.push(0);
401 }
402 }
403 }
404 other => {
405 return exec_err!(
406 "substr was called with {other} arguments. It requires 2 or 3."
407 )
408 }
409 }
410
411 let views_buf = ScalarBuffer::from(views_buf);
412 let nulls_buf = null_builder.finish();
413
414 unsafe {
419 let array = StringViewArray::new_unchecked(
420 views_buf,
421 string_view_array.data_buffers().to_vec(),
422 nulls_buf,
423 );
424 Ok(Arc::new(array) as ArrayRef)
425 }
426}
427
428fn string_substr<'a, V>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
429where
430 V: StringArrayType<'a>,
431{
432 let start_array = as_int64_array(&args[0])?;
433 let count_array_opt = if args.len() == 2 {
434 Some(as_int64_array(&args[1])?)
435 } else {
436 None
437 };
438
439 let enable_ascii_fast_path =
440 enable_ascii_fast_path(&string_array, start_array, count_array_opt);
441
442 match args.len() {
443 1 => {
444 let iter = ArrayIter::new(string_array);
445 let mut result_builder = StringViewBuilder::new();
446 for (string, start) in iter.zip(start_array.iter()) {
447 match (string, start) {
448 (Some(string), Some(start)) => {
449 let (start, end) = get_true_start_end(
450 string,
451 start,
452 None,
453 enable_ascii_fast_path,
454 ); let substr = &string[start..end];
456 result_builder.append_value(substr);
457 }
458 _ => {
459 result_builder.append_null();
460 }
461 }
462 }
463 Ok(Arc::new(result_builder.finish()) as ArrayRef)
464 }
465 2 => {
466 let iter = ArrayIter::new(string_array);
467 let count_array = count_array_opt.unwrap();
468 let mut result_builder = StringViewBuilder::new();
469
470 for ((string, start), count) in
471 iter.zip(start_array.iter()).zip(count_array.iter())
472 {
473 match (string, start, count) {
474 (Some(string), Some(start), Some(count)) => {
475 if count < 0 {
476 return exec_err!(
477 "negative substring length not allowed: substr(<str>, {start}, {count})"
478 );
479 } else {
480 if start == i64::MIN {
481 return exec_err!(
482 "negative overflow when calculating skip value"
483 );
484 }
485 let (start, end) = get_true_start_end(
486 string,
487 start,
488 Some(count as u64),
489 enable_ascii_fast_path,
490 ); let substr = &string[start..end];
492 result_builder.append_value(substr);
493 }
494 }
495 _ => {
496 result_builder.append_null();
497 }
498 }
499 }
500 Ok(Arc::new(result_builder.finish()) as ArrayRef)
501 }
502 other => {
503 exec_err!("substr was called with {other} arguments. It requires 2 or 3.")
504 }
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use arrow::array::{Array, StringViewArray};
511 use arrow::datatypes::DataType::Utf8View;
512
513 use datafusion_common::{exec_err, Result, ScalarValue};
514 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
515
516 use crate::unicode::substr::SubstrFunc;
517 use crate::utils::test::test_function;
518
519 #[test]
520 fn test_functions() -> Result<()> {
521 test_function!(
522 SubstrFunc::new(),
523 vec![
524 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
525 ColumnarValue::Scalar(ScalarValue::from(1i64)),
526 ],
527 Ok(None),
528 &str,
529 Utf8View,
530 StringViewArray
531 );
532 test_function!(
533 SubstrFunc::new(),
534 vec![
535 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
536 "alphabet"
537 )))),
538 ColumnarValue::Scalar(ScalarValue::from(0i64)),
539 ],
540 Ok(Some("alphabet")),
541 &str,
542 Utf8View,
543 StringViewArray
544 );
545 test_function!(
546 SubstrFunc::new(),
547 vec![
548 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
549 "this és longer than 12B"
550 )))),
551 ColumnarValue::Scalar(ScalarValue::from(5i64)),
552 ColumnarValue::Scalar(ScalarValue::from(2i64)),
553 ],
554 Ok(Some(" é")),
555 &str,
556 Utf8View,
557 StringViewArray
558 );
559 test_function!(
560 SubstrFunc::new(),
561 vec![
562 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
563 "this is longer than 12B"
564 )))),
565 ColumnarValue::Scalar(ScalarValue::from(5i64)),
566 ],
567 Ok(Some(" is longer than 12B")),
568 &str,
569 Utf8View,
570 StringViewArray
571 );
572 test_function!(
573 SubstrFunc::new(),
574 vec![
575 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
576 "joséésoj"
577 )))),
578 ColumnarValue::Scalar(ScalarValue::from(5i64)),
579 ],
580 Ok(Some("ésoj")),
581 &str,
582 Utf8View,
583 StringViewArray
584 );
585 test_function!(
586 SubstrFunc::new(),
587 vec![
588 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
589 "alphabet"
590 )))),
591 ColumnarValue::Scalar(ScalarValue::from(3i64)),
592 ColumnarValue::Scalar(ScalarValue::from(2i64)),
593 ],
594 Ok(Some("ph")),
595 &str,
596 Utf8View,
597 StringViewArray
598 );
599 test_function!(
600 SubstrFunc::new(),
601 vec![
602 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
603 "alphabet"
604 )))),
605 ColumnarValue::Scalar(ScalarValue::from(3i64)),
606 ColumnarValue::Scalar(ScalarValue::from(20i64)),
607 ],
608 Ok(Some("phabet")),
609 &str,
610 Utf8View,
611 StringViewArray
612 );
613 test_function!(
614 SubstrFunc::new(),
615 vec![
616 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
617 ColumnarValue::Scalar(ScalarValue::from(0i64)),
618 ],
619 Ok(Some("alphabet")),
620 &str,
621 Utf8View,
622 StringViewArray
623 );
624 test_function!(
625 SubstrFunc::new(),
626 vec![
627 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
628 ColumnarValue::Scalar(ScalarValue::from(5i64)),
629 ],
630 Ok(Some("ésoj")),
631 &str,
632 Utf8View,
633 StringViewArray
634 );
635 test_function!(
636 SubstrFunc::new(),
637 vec![
638 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
639 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
640 ],
641 Ok(Some("joséésoj")),
642 &str,
643 Utf8View,
644 StringViewArray
645 );
646 test_function!(
647 SubstrFunc::new(),
648 vec![
649 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
650 ColumnarValue::Scalar(ScalarValue::from(1i64)),
651 ],
652 Ok(Some("alphabet")),
653 &str,
654 Utf8View,
655 StringViewArray
656 );
657 test_function!(
658 SubstrFunc::new(),
659 vec![
660 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
661 ColumnarValue::Scalar(ScalarValue::from(2i64)),
662 ],
663 Ok(Some("lphabet")),
664 &str,
665 Utf8View,
666 StringViewArray
667 );
668 test_function!(
669 SubstrFunc::new(),
670 vec![
671 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
672 ColumnarValue::Scalar(ScalarValue::from(3i64)),
673 ],
674 Ok(Some("phabet")),
675 &str,
676 Utf8View,
677 StringViewArray
678 );
679 test_function!(
680 SubstrFunc::new(),
681 vec![
682 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
683 ColumnarValue::Scalar(ScalarValue::from(-3i64)),
684 ],
685 Ok(Some("alphabet")),
686 &str,
687 Utf8View,
688 StringViewArray
689 );
690 test_function!(
691 SubstrFunc::new(),
692 vec![
693 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
694 ColumnarValue::Scalar(ScalarValue::from(30i64)),
695 ],
696 Ok(Some("")),
697 &str,
698 Utf8View,
699 StringViewArray
700 );
701 test_function!(
702 SubstrFunc::new(),
703 vec![
704 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
705 ColumnarValue::Scalar(ScalarValue::Int64(None)),
706 ],
707 Ok(None),
708 &str,
709 Utf8View,
710 StringViewArray
711 );
712 test_function!(
713 SubstrFunc::new(),
714 vec![
715 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
716 ColumnarValue::Scalar(ScalarValue::from(3i64)),
717 ColumnarValue::Scalar(ScalarValue::from(2i64)),
718 ],
719 Ok(Some("ph")),
720 &str,
721 Utf8View,
722 StringViewArray
723 );
724 test_function!(
725 SubstrFunc::new(),
726 vec![
727 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
728 ColumnarValue::Scalar(ScalarValue::from(3i64)),
729 ColumnarValue::Scalar(ScalarValue::from(20i64)),
730 ],
731 Ok(Some("phabet")),
732 &str,
733 Utf8View,
734 StringViewArray
735 );
736 test_function!(
737 SubstrFunc::new(),
738 vec![
739 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
740 ColumnarValue::Scalar(ScalarValue::from(0i64)),
741 ColumnarValue::Scalar(ScalarValue::from(5i64)),
742 ],
743 Ok(Some("alph")),
744 &str,
745 Utf8View,
746 StringViewArray
747 );
748 test_function!(
750 SubstrFunc::new(),
751 vec![
752 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
753 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
754 ColumnarValue::Scalar(ScalarValue::from(10i64)),
755 ],
756 Ok(Some("alph")),
757 &str,
758 Utf8View,
759 StringViewArray
760 );
761 test_function!(
763 SubstrFunc::new(),
764 vec![
765 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
766 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
767 ColumnarValue::Scalar(ScalarValue::from(4i64)),
768 ],
769 Ok(Some("")),
770 &str,
771 Utf8View,
772 StringViewArray
773 );
774 test_function!(
776 SubstrFunc::new(),
777 vec![
778 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
779 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
780 ColumnarValue::Scalar(ScalarValue::from(5i64)),
781 ],
782 Ok(Some("")),
783 &str,
784 Utf8View,
785 StringViewArray
786 );
787 test_function!(
788 SubstrFunc::new(),
789 vec![
790 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
791 ColumnarValue::Scalar(ScalarValue::Int64(None)),
792 ColumnarValue::Scalar(ScalarValue::from(20i64)),
793 ],
794 Ok(None),
795 &str,
796 Utf8View,
797 StringViewArray
798 );
799 test_function!(
800 SubstrFunc::new(),
801 vec![
802 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
803 ColumnarValue::Scalar(ScalarValue::from(3i64)),
804 ColumnarValue::Scalar(ScalarValue::Int64(None)),
805 ],
806 Ok(None),
807 &str,
808 Utf8View,
809 StringViewArray
810 );
811 test_function!(
812 SubstrFunc::new(),
813 vec![
814 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
815 ColumnarValue::Scalar(ScalarValue::from(1i64)),
816 ColumnarValue::Scalar(ScalarValue::from(-1i64)),
817 ],
818 exec_err!("negative substring length not allowed: substr(<str>, 1, -1)"),
819 &str,
820 Utf8View,
821 StringViewArray
822 );
823 test_function!(
824 SubstrFunc::new(),
825 vec![
826 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
827 ColumnarValue::Scalar(ScalarValue::from(5i64)),
828 ColumnarValue::Scalar(ScalarValue::from(2i64)),
829 ],
830 Ok(Some("és")),
831 &str,
832 Utf8View,
833 StringViewArray
834 );
835 #[cfg(not(feature = "unicode_expressions"))]
836 test_function!(
837 SubstrFunc::new(),
838 &[
839 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
840 ColumnarValue::Scalar(ScalarValue::from(0i64)),
841 ],
842 internal_err!(
843 "function substr requires compilation with feature flag: unicode_expressions."
844 ),
845 &str,
846 Utf8View,
847 StringViewArray
848 );
849 test_function!(
850 SubstrFunc::new(),
851 vec![
852 ColumnarValue::Scalar(ScalarValue::from("abc")),
853 ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)),
854 ],
855 Ok(Some("abc")),
856 &str,
857 Utf8View,
858 StringViewArray
859 );
860 test_function!(
861 SubstrFunc::new(),
862 vec![
863 ColumnarValue::Scalar(ScalarValue::from("overflow")),
864 ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)),
865 ColumnarValue::Scalar(ScalarValue::from(1i64)),
866 ],
867 exec_err!("negative overflow when calculating skip value"),
868 &str,
869 Utf8View,
870 StringViewArray
871 );
872
873 Ok(())
874 }
875}