datafusion_functions/unicode/
strpos.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_int_type};
22use arrow::array::{
23    ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26use datafusion_common::{exec_err, internal_err, Result};
27use datafusion_expr::{
28    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
29};
30use datafusion_macros::user_doc;
31
32#[user_doc(
33    doc_section(label = "String Functions"),
34    description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.",
35    syntax_example = "strpos(str, substr)",
36    alternative_syntax = "position(substr in origstr)",
37    sql_example = r#"```sql
38> select strpos('datafusion', 'fus');
39+----------------------------------------+
40| strpos(Utf8("datafusion"),Utf8("fus")) |
41+----------------------------------------+
42| 5                                      |
43+----------------------------------------+ 
44```"#,
45    standard_argument(name = "str", prefix = "String"),
46    argument(name = "substr", description = "Substring expression to search for.")
47)]
48#[derive(Debug)]
49pub struct StrposFunc {
50    signature: Signature,
51    aliases: Vec<String>,
52}
53
54impl Default for StrposFunc {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl StrposFunc {
61    pub fn new() -> Self {
62        Self {
63            signature: Signature::string(2, Volatility::Immutable),
64            aliases: vec![String::from("instr"), String::from("position")],
65        }
66    }
67}
68
69impl ScalarUDFImpl for StrposFunc {
70    fn as_any(&self) -> &dyn Any {
71        self
72    }
73
74    fn name(&self) -> &str {
75        "strpos"
76    }
77
78    fn signature(&self) -> &Signature {
79        &self.signature
80    }
81
82    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
83        internal_err!("return_type_from_args should be used instead")
84    }
85
86    fn return_type_from_args(
87        &self,
88        args: datafusion_expr::ReturnTypeArgs,
89    ) -> Result<datafusion_expr::ReturnInfo> {
90        utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| {
91            datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x))
92        })
93    }
94
95    fn invoke_with_args(
96        &self,
97        args: datafusion_expr::ScalarFunctionArgs,
98    ) -> Result<ColumnarValue> {
99        make_scalar_function(strpos, vec![])(&args.args)
100    }
101
102    fn aliases(&self) -> &[String] {
103        &self.aliases
104    }
105
106    fn documentation(&self) -> Option<&Documentation> {
107        self.doc()
108    }
109}
110
111fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
112    match (args[0].data_type(), args[1].data_type()) {
113        (DataType::Utf8, DataType::Utf8) => {
114            let string_array = args[0].as_string::<i32>();
115            let substring_array = args[1].as_string::<i32>();
116            calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
117        }
118        (DataType::Utf8, DataType::LargeUtf8) => {
119            let string_array = args[0].as_string::<i32>();
120            let substring_array = args[1].as_string::<i64>();
121            calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
122        }
123        (DataType::LargeUtf8, DataType::Utf8) => {
124            let string_array = args[0].as_string::<i64>();
125            let substring_array = args[1].as_string::<i32>();
126            calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
127        }
128        (DataType::LargeUtf8, DataType::LargeUtf8) => {
129            let string_array = args[0].as_string::<i64>();
130            let substring_array = args[1].as_string::<i64>();
131            calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
132        }
133        (DataType::Utf8View, DataType::Utf8View) => {
134            let string_array = args[0].as_string_view();
135            let substring_array = args[1].as_string_view();
136            calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
137        }
138        (DataType::Utf8View, DataType::Utf8) => {
139            let string_array = args[0].as_string_view();
140            let substring_array = args[1].as_string::<i32>();
141            calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
142        }
143        (DataType::Utf8View, DataType::LargeUtf8) => {
144            let string_array = args[0].as_string_view();
145            let substring_array = args[1].as_string::<i64>();
146            calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
147        }
148
149        other => {
150            exec_err!("Unsupported data type combination {other:?} for function strpos")
151        }
152    }
153}
154
155/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)
156/// strpos('high', 'ig') = 2
157/// The implementation uses UTF-8 code points as characters
158fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
159    string_array: V1,
160    substring_array: V2,
161) -> Result<ArrayRef>
162where
163    V1: StringArrayType<'a, Item = &'a str>,
164    V2: StringArrayType<'a, Item = &'a str>,
165{
166    let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
167    let string_iter = string_array.iter();
168    let substring_iter = substring_array.iter();
169
170    let result = string_iter
171        .zip(substring_iter)
172        .map(|(string, substring)| match (string, substring) {
173            (Some(string), Some(substring)) => {
174                // If only ASCII characters are present, we can use the slide window method to find
175                // the sub vector in the main vector. This is faster than string.find() method.
176                if ascii_only {
177                    // If the substring is empty, the result is 1.
178                    if substring.is_empty() {
179                        T::Native::from_usize(1)
180                    } else {
181                        T::Native::from_usize(
182                            string
183                                .as_bytes()
184                                .windows(substring.len())
185                                .position(|w| w == substring.as_bytes())
186                                .map(|x| x + 1)
187                                .unwrap_or(0),
188                        )
189                    }
190                } else {
191                    // The `find` method returns the byte index of the substring.
192                    // We count the number of chars up to that byte index.
193                    T::Native::from_usize(
194                        string
195                            .find(substring)
196                            .map(|x| string[..x].chars().count() + 1)
197                            .unwrap_or(0),
198                    )
199                }
200            }
201            _ => None,
202        })
203        .collect::<PrimitiveArray<T>>();
204
205    Ok(Arc::new(result) as ArrayRef)
206}
207
208#[cfg(test)]
209mod tests {
210    use arrow::array::{Array, Int32Array, Int64Array};
211    use arrow::datatypes::DataType::{Int32, Int64};
212
213    use arrow::datatypes::DataType;
214    use datafusion_common::{Result, ScalarValue};
215    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
216
217    use crate::unicode::strpos::StrposFunc;
218    use crate::utils::test::test_function;
219
220    macro_rules! test_strpos {
221        ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
222            test_function!(
223                StrposFunc::new(),
224                vec![
225                    ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
226                    ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
227                ],
228                Ok(Some($result)),
229                $t3,
230                $t4,
231                $t5
232            )
233        };
234    }
235
236    #[test]
237    fn test_strpos_functions() {
238        // Utf8 and Utf8 combinations
239        test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
240        test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
241        test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
242        test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
243        test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
244        test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
245        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);
246
247        // LargeUtf8 and LargeUtf8 combinations
248        test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
249        test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
250        test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
251        test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
252        test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
253        test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
254        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
255
256        // Utf8 and LargeUtf8 combinations
257        test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
258        test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
259        test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
260        test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
261        test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
262        test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
263        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);
264
265        // LargeUtf8 and Utf8 combinations
266        test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
267        test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
268        test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
269        test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
270        test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
271        test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
272        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);
273
274        // Utf8View and Utf8View combinations
275        test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
276        test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
277        test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
278        test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
279        test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
280        test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
281        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);
282
283        // Utf8View and Utf8 combinations
284        test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
285        test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
286        test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
287        test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
288        test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
289        test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
290        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);
291
292        // Utf8View and LargeUtf8 combinations
293        test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
294        test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
295        test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
296        test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
297        test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
298        test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
299        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
300    }
301
302    #[test]
303    fn nullable_return_type() {
304        fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
305            let strpos = StrposFunc::new();
306            let args = datafusion_expr::ReturnTypeArgs {
307                arg_types: &[DataType::Utf8, DataType::Utf8],
308                nullables: &[string_array_nullable, substring_nullable],
309                scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
310            };
311
312            let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts();
313
314            nullable
315        }
316
317        assert!(!get_nullable(false, false));
318
319        // If any of the arguments is nullable, the result is nullable
320        assert!(get_nullable(true, false));
321        assert!(get_nullable(false, true));
322        assert!(get_nullable(true, true));
323    }
324}