datafusion_functions/unicode/
substrindex.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
23 PrimitiveArray, StringBuilder,
24};
25use arrow::datatypes::{DataType, Int32Type, Int64Type};
26
27use crate::utils::{make_scalar_function, utf8_to_str_type};
28use datafusion_common::{exec_err, utils::take_function_args, Result};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36 doc_section(label = "String Functions"),
37 description = r#"Returns the substring from str before count occurrences of the delimiter delim.
38If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
39If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#,
40 syntax_example = "substr_index(str, delim, count)",
41 sql_example = r#"```sql
42> select substr_index('www.apache.org', '.', 1);
43+---------------------------------------------------------+
44| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) |
45+---------------------------------------------------------+
46| www |
47+---------------------------------------------------------+
48> select substr_index('www.apache.org', '.', -1);
49+----------------------------------------------------------+
50| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) |
51+----------------------------------------------------------+
52| org |
53+----------------------------------------------------------+
54```"#,
55 standard_argument(name = "str", prefix = "String"),
56 argument(
57 name = "delim",
58 description = "The string to find in str to split str."
59 ),
60 argument(
61 name = "count",
62 description = "The number of times to search for the delimiter. Can be either a positive or negative number."
63 )
64)]
65#[derive(Debug)]
66pub struct SubstrIndexFunc {
67 signature: Signature,
68 aliases: Vec<String>,
69}
70
71impl Default for SubstrIndexFunc {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl SubstrIndexFunc {
78 pub fn new() -> Self {
79 use DataType::*;
80 Self {
81 signature: Signature::one_of(
82 vec![
83 Exact(vec![Utf8View, Utf8View, Int64]),
84 Exact(vec![Utf8, Utf8, Int64]),
85 Exact(vec![LargeUtf8, LargeUtf8, Int64]),
86 ],
87 Volatility::Immutable,
88 ),
89 aliases: vec![String::from("substring_index")],
90 }
91 }
92}
93
94impl ScalarUDFImpl for SubstrIndexFunc {
95 fn as_any(&self) -> &dyn Any {
96 self
97 }
98
99 fn name(&self) -> &str {
100 "substr_index"
101 }
102
103 fn signature(&self) -> &Signature {
104 &self.signature
105 }
106
107 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108 utf8_to_str_type(&arg_types[0], "substr_index")
109 }
110
111 fn invoke_with_args(
112 &self,
113 args: datafusion_expr::ScalarFunctionArgs,
114 ) -> Result<ColumnarValue> {
115 make_scalar_function(substr_index, vec![])(&args.args)
116 }
117
118 fn aliases(&self) -> &[String] {
119 &self.aliases
120 }
121
122 fn documentation(&self) -> Option<&Documentation> {
123 self.doc()
124 }
125}
126
127fn substr_index(args: &[ArrayRef]) -> Result<ArrayRef> {
133 let [str, delim, count] = take_function_args("substr_index", args)?;
134
135 match str.data_type() {
136 DataType::Utf8 => {
137 let string_array = str.as_string::<i32>();
138 let delimiter_array = delim.as_string::<i32>();
139 let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
140 substr_index_general::<Int32Type, _, _>(
141 string_array,
142 delimiter_array,
143 count_array,
144 )
145 }
146 DataType::LargeUtf8 => {
147 let string_array = str.as_string::<i64>();
148 let delimiter_array = delim.as_string::<i64>();
149 let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
150 substr_index_general::<Int64Type, _, _>(
151 string_array,
152 delimiter_array,
153 count_array,
154 )
155 }
156 DataType::Utf8View => {
157 let string_array = str.as_string_view();
158 let delimiter_array = delim.as_string_view();
159 let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
160 substr_index_general::<Int32Type, _, _>(
161 string_array,
162 delimiter_array,
163 count_array,
164 )
165 }
166 other => {
167 exec_err!("Unsupported data type {other:?} for function substr_index")
168 }
169 }
170}
171
172pub fn substr_index_general<
173 'a,
174 T: ArrowPrimitiveType,
175 V: ArrayAccessor<Item = &'a str>,
176 P: ArrayAccessor<Item = i64>,
177>(
178 string_array: V,
179 delimiter_array: V,
180 count_array: P,
181) -> Result<ArrayRef>
182where
183 T::Native: OffsetSizeTrait,
184{
185 let mut builder = StringBuilder::new();
186 let string_iter = ArrayIter::new(string_array);
187 let delimiter_array_iter = ArrayIter::new(delimiter_array);
188 let count_array_iter = ArrayIter::new(count_array);
189 string_iter
190 .zip(delimiter_array_iter)
191 .zip(count_array_iter)
192 .for_each(|((string, delimiter), n)| match (string, delimiter, n) {
193 (Some(string), Some(delimiter), Some(n)) => {
194 if n == 0 || string.is_empty() || delimiter.is_empty() {
196 builder.append_value("");
197 return;
198 }
199
200 let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
201 let length = if n > 0 {
202 let split = string.split(delimiter);
203 split
204 .take(occurrences)
205 .map(|s| s.len() + delimiter.len())
206 .sum::<usize>()
207 - delimiter.len()
208 } else {
209 let split = string.rsplit(delimiter);
210 split
211 .take(occurrences)
212 .map(|s| s.len() + delimiter.len())
213 .sum::<usize>()
214 - delimiter.len()
215 };
216 if n > 0 {
217 match string.get(..length) {
218 Some(substring) => builder.append_value(substring),
219 None => builder.append_null(),
220 }
221 } else {
222 match string.get(string.len().saturating_sub(length)..) {
223 Some(substring) => builder.append_value(substring),
224 None => builder.append_null(),
225 }
226 }
227 }
228 _ => builder.append_null(),
229 });
230
231 Ok(Arc::new(builder.finish()) as ArrayRef)
232}
233
234#[cfg(test)]
235mod tests {
236 use arrow::array::{Array, StringArray};
237 use arrow::datatypes::DataType::Utf8;
238
239 use datafusion_common::{Result, ScalarValue};
240 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
241
242 use crate::unicode::substrindex::SubstrIndexFunc;
243 use crate::utils::test::test_function;
244
245 #[test]
246 fn test_functions() -> Result<()> {
247 test_function!(
248 SubstrIndexFunc::new(),
249 vec![
250 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
251 ColumnarValue::Scalar(ScalarValue::from(".")),
252 ColumnarValue::Scalar(ScalarValue::from(1i64)),
253 ],
254 Ok(Some("www")),
255 &str,
256 Utf8,
257 StringArray
258 );
259 test_function!(
260 SubstrIndexFunc::new(),
261 vec![
262 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
263 ColumnarValue::Scalar(ScalarValue::from(".")),
264 ColumnarValue::Scalar(ScalarValue::from(2i64)),
265 ],
266 Ok(Some("www.apache")),
267 &str,
268 Utf8,
269 StringArray
270 );
271 test_function!(
272 SubstrIndexFunc::new(),
273 vec![
274 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
275 ColumnarValue::Scalar(ScalarValue::from(".")),
276 ColumnarValue::Scalar(ScalarValue::from(-2i64)),
277 ],
278 Ok(Some("apache.org")),
279 &str,
280 Utf8,
281 StringArray
282 );
283 test_function!(
284 SubstrIndexFunc::new(),
285 vec![
286 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
287 ColumnarValue::Scalar(ScalarValue::from(".")),
288 ColumnarValue::Scalar(ScalarValue::from(-1i64)),
289 ],
290 Ok(Some("org")),
291 &str,
292 Utf8,
293 StringArray
294 );
295 test_function!(
296 SubstrIndexFunc::new(),
297 vec![
298 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
299 ColumnarValue::Scalar(ScalarValue::from(".")),
300 ColumnarValue::Scalar(ScalarValue::from(0i64)),
301 ],
302 Ok(Some("")),
303 &str,
304 Utf8,
305 StringArray
306 );
307 test_function!(
308 SubstrIndexFunc::new(),
309 vec![
310 ColumnarValue::Scalar(ScalarValue::from("")),
311 ColumnarValue::Scalar(ScalarValue::from(".")),
312 ColumnarValue::Scalar(ScalarValue::from(1i64)),
313 ],
314 Ok(Some("")),
315 &str,
316 Utf8,
317 StringArray
318 );
319 test_function!(
320 SubstrIndexFunc::new(),
321 vec![
322 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
323 ColumnarValue::Scalar(ScalarValue::from("")),
324 ColumnarValue::Scalar(ScalarValue::from(1i64)),
325 ],
326 Ok(Some("")),
327 &str,
328 Utf8,
329 StringArray
330 );
331
332 Ok(())
333 }
334}