datafusion_functions/unicode/
substrindex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
23    PrimitiveArray, StringBuilder,
24};
25use arrow::datatypes::{DataType, Int32Type, Int64Type};
26
27use crate::utils::{make_scalar_function, utf8_to_str_type};
28use datafusion_common::{exec_err, utils::take_function_args, Result};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36    doc_section(label = "String Functions"),
37    description = r#"Returns the substring from str before count occurrences of the delimiter delim.
38If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
39If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#,
40    syntax_example = "substr_index(str, delim, count)",
41    sql_example = r#"```sql
42> select substr_index('www.apache.org', '.', 1);
43+---------------------------------------------------------+
44| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) |
45+---------------------------------------------------------+
46| www                                                     |
47+---------------------------------------------------------+
48> select substr_index('www.apache.org', '.', -1);
49+----------------------------------------------------------+
50| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) |
51+----------------------------------------------------------+
52| org                                                      |
53+----------------------------------------------------------+
54```"#,
55    standard_argument(name = "str", prefix = "String"),
56    argument(
57        name = "delim",
58        description = "The string to find in str to split str."
59    ),
60    argument(
61        name = "count",
62        description = "The number of times to search for the delimiter. Can be either a positive or negative number."
63    )
64)]
65#[derive(Debug)]
66pub struct SubstrIndexFunc {
67    signature: Signature,
68    aliases: Vec<String>,
69}
70
71impl Default for SubstrIndexFunc {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl SubstrIndexFunc {
78    pub fn new() -> Self {
79        use DataType::*;
80        Self {
81            signature: Signature::one_of(
82                vec![
83                    Exact(vec![Utf8View, Utf8View, Int64]),
84                    Exact(vec![Utf8, Utf8, Int64]),
85                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
86                ],
87                Volatility::Immutable,
88            ),
89            aliases: vec![String::from("substring_index")],
90        }
91    }
92}
93
94impl ScalarUDFImpl for SubstrIndexFunc {
95    fn as_any(&self) -> &dyn Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "substr_index"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108        utf8_to_str_type(&arg_types[0], "substr_index")
109    }
110
111    fn invoke_with_args(
112        &self,
113        args: datafusion_expr::ScalarFunctionArgs,
114    ) -> Result<ColumnarValue> {
115        make_scalar_function(substr_index, vec![])(&args.args)
116    }
117
118    fn aliases(&self) -> &[String] {
119        &self.aliases
120    }
121
122    fn documentation(&self) -> Option<&Documentation> {
123        self.doc()
124    }
125}
126
127/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
128/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www
129/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
130/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
131/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
132fn substr_index(args: &[ArrayRef]) -> Result<ArrayRef> {
133    let [str, delim, count] = take_function_args("substr_index", args)?;
134
135    match str.data_type() {
136        DataType::Utf8 => {
137            let string_array = str.as_string::<i32>();
138            let delimiter_array = delim.as_string::<i32>();
139            let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
140            substr_index_general::<Int32Type, _, _>(
141                string_array,
142                delimiter_array,
143                count_array,
144            )
145        }
146        DataType::LargeUtf8 => {
147            let string_array = str.as_string::<i64>();
148            let delimiter_array = delim.as_string::<i64>();
149            let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
150            substr_index_general::<Int64Type, _, _>(
151                string_array,
152                delimiter_array,
153                count_array,
154            )
155        }
156        DataType::Utf8View => {
157            let string_array = str.as_string_view();
158            let delimiter_array = delim.as_string_view();
159            let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
160            substr_index_general::<Int32Type, _, _>(
161                string_array,
162                delimiter_array,
163                count_array,
164            )
165        }
166        other => {
167            exec_err!("Unsupported data type {other:?} for function substr_index")
168        }
169    }
170}
171
172pub fn substr_index_general<
173    'a,
174    T: ArrowPrimitiveType,
175    V: ArrayAccessor<Item = &'a str>,
176    P: ArrayAccessor<Item = i64>,
177>(
178    string_array: V,
179    delimiter_array: V,
180    count_array: P,
181) -> Result<ArrayRef>
182where
183    T::Native: OffsetSizeTrait,
184{
185    let mut builder = StringBuilder::new();
186    let string_iter = ArrayIter::new(string_array);
187    let delimiter_array_iter = ArrayIter::new(delimiter_array);
188    let count_array_iter = ArrayIter::new(count_array);
189    string_iter
190        .zip(delimiter_array_iter)
191        .zip(count_array_iter)
192        .for_each(|((string, delimiter), n)| match (string, delimiter, n) {
193            (Some(string), Some(delimiter), Some(n)) => {
194                // In MySQL, these cases will return an empty string.
195                if n == 0 || string.is_empty() || delimiter.is_empty() {
196                    builder.append_value("");
197                    return;
198                }
199
200                let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
201                let length = if n > 0 {
202                    let split = string.split(delimiter);
203                    split
204                        .take(occurrences)
205                        .map(|s| s.len() + delimiter.len())
206                        .sum::<usize>()
207                        - delimiter.len()
208                } else {
209                    let split = string.rsplit(delimiter);
210                    split
211                        .take(occurrences)
212                        .map(|s| s.len() + delimiter.len())
213                        .sum::<usize>()
214                        - delimiter.len()
215                };
216                if n > 0 {
217                    match string.get(..length) {
218                        Some(substring) => builder.append_value(substring),
219                        None => builder.append_null(),
220                    }
221                } else {
222                    match string.get(string.len().saturating_sub(length)..) {
223                        Some(substring) => builder.append_value(substring),
224                        None => builder.append_null(),
225                    }
226                }
227            }
228            _ => builder.append_null(),
229        });
230
231    Ok(Arc::new(builder.finish()) as ArrayRef)
232}
233
234#[cfg(test)]
235mod tests {
236    use arrow::array::{Array, StringArray};
237    use arrow::datatypes::DataType::Utf8;
238
239    use datafusion_common::{Result, ScalarValue};
240    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
241
242    use crate::unicode::substrindex::SubstrIndexFunc;
243    use crate::utils::test::test_function;
244
245    #[test]
246    fn test_functions() -> Result<()> {
247        test_function!(
248            SubstrIndexFunc::new(),
249            vec![
250                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
251                ColumnarValue::Scalar(ScalarValue::from(".")),
252                ColumnarValue::Scalar(ScalarValue::from(1i64)),
253            ],
254            Ok(Some("www")),
255            &str,
256            Utf8,
257            StringArray
258        );
259        test_function!(
260            SubstrIndexFunc::new(),
261            vec![
262                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
263                ColumnarValue::Scalar(ScalarValue::from(".")),
264                ColumnarValue::Scalar(ScalarValue::from(2i64)),
265            ],
266            Ok(Some("www.apache")),
267            &str,
268            Utf8,
269            StringArray
270        );
271        test_function!(
272            SubstrIndexFunc::new(),
273            vec![
274                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
275                ColumnarValue::Scalar(ScalarValue::from(".")),
276                ColumnarValue::Scalar(ScalarValue::from(-2i64)),
277            ],
278            Ok(Some("apache.org")),
279            &str,
280            Utf8,
281            StringArray
282        );
283        test_function!(
284            SubstrIndexFunc::new(),
285            vec![
286                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
287                ColumnarValue::Scalar(ScalarValue::from(".")),
288                ColumnarValue::Scalar(ScalarValue::from(-1i64)),
289            ],
290            Ok(Some("org")),
291            &str,
292            Utf8,
293            StringArray
294        );
295        test_function!(
296            SubstrIndexFunc::new(),
297            vec![
298                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
299                ColumnarValue::Scalar(ScalarValue::from(".")),
300                ColumnarValue::Scalar(ScalarValue::from(0i64)),
301            ],
302            Ok(Some("")),
303            &str,
304            Utf8,
305            StringArray
306        );
307        test_function!(
308            SubstrIndexFunc::new(),
309            vec![
310                ColumnarValue::Scalar(ScalarValue::from("")),
311                ColumnarValue::Scalar(ScalarValue::from(".")),
312                ColumnarValue::Scalar(ScalarValue::from(1i64)),
313            ],
314            Ok(Some("")),
315            &str,
316            Utf8,
317            StringArray
318        );
319        test_function!(
320            SubstrIndexFunc::new(),
321            vec![
322                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
323                ColumnarValue::Scalar(ScalarValue::from("")),
324                ColumnarValue::Scalar(ScalarValue::from(1i64)),
325            ],
326            Ok(Some("")),
327            &str,
328            Utf8,
329            StringArray
330        );
331
332        Ok(())
333    }
334}