1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_int_type};
22use arrow::array::{
23 ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26use datafusion_common::types::logical_string;
27use datafusion_common::{exec_err, internal_err, Result};
28use datafusion_expr::{
29 Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
30 Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35 doc_section(label = "String Functions"),
36 description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.",
37 syntax_example = "strpos(str, substr)",
38 alternative_syntax = "position(substr in origstr)",
39 sql_example = r#"```sql
40> select strpos('datafusion', 'fus');
41+----------------------------------------+
42| strpos(Utf8("datafusion"),Utf8("fus")) |
43+----------------------------------------+
44| 5 |
45+----------------------------------------+
46```"#,
47 standard_argument(name = "str", prefix = "String"),
48 argument(name = "substr", description = "Substring expression to search for.")
49)]
50#[derive(Debug)]
51pub struct StrposFunc {
52 signature: Signature,
53 aliases: Vec<String>,
54}
55
56impl Default for StrposFunc {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl StrposFunc {
63 pub fn new() -> Self {
64 Self {
65 signature: Signature::coercible(
66 vec![
67 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
68 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
69 ],
70 Volatility::Immutable,
71 ),
72 aliases: vec![String::from("instr"), String::from("position")],
73 }
74 }
75}
76
77impl ScalarUDFImpl for StrposFunc {
78 fn as_any(&self) -> &dyn Any {
79 self
80 }
81
82 fn name(&self) -> &str {
83 "strpos"
84 }
85
86 fn signature(&self) -> &Signature {
87 &self.signature
88 }
89
90 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
91 internal_err!("return_type_from_args should be used instead")
92 }
93
94 fn return_type_from_args(
95 &self,
96 args: datafusion_expr::ReturnTypeArgs,
97 ) -> Result<datafusion_expr::ReturnInfo> {
98 utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| {
99 datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x))
100 })
101 }
102
103 fn invoke_with_args(
104 &self,
105 args: datafusion_expr::ScalarFunctionArgs,
106 ) -> Result<ColumnarValue> {
107 make_scalar_function(strpos, vec![])(&args.args)
108 }
109
110 fn aliases(&self) -> &[String] {
111 &self.aliases
112 }
113
114 fn documentation(&self) -> Option<&Documentation> {
115 self.doc()
116 }
117}
118
119fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
120 match (args[0].data_type(), args[1].data_type()) {
121 (DataType::Utf8, DataType::Utf8) => {
122 let string_array = args[0].as_string::<i32>();
123 let substring_array = args[1].as_string::<i32>();
124 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
125 }
126 (DataType::Utf8, DataType::Utf8View) => {
127 let string_array = args[0].as_string::<i32>();
128 let substring_array = args[1].as_string_view();
129 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
130 }
131 (DataType::Utf8, DataType::LargeUtf8) => {
132 let string_array = args[0].as_string::<i32>();
133 let substring_array = args[1].as_string::<i64>();
134 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
135 }
136 (DataType::LargeUtf8, DataType::Utf8) => {
137 let string_array = args[0].as_string::<i64>();
138 let substring_array = args[1].as_string::<i32>();
139 calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
140 }
141 (DataType::LargeUtf8, DataType::Utf8View) => {
142 let string_array = args[0].as_string::<i64>();
143 let substring_array = args[1].as_string_view();
144 calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
145 }
146 (DataType::LargeUtf8, DataType::LargeUtf8) => {
147 let string_array = args[0].as_string::<i64>();
148 let substring_array = args[1].as_string::<i64>();
149 calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
150 }
151 (DataType::Utf8View, DataType::Utf8View) => {
152 let string_array = args[0].as_string_view();
153 let substring_array = args[1].as_string_view();
154 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
155 }
156 (DataType::Utf8View, DataType::Utf8) => {
157 let string_array = args[0].as_string_view();
158 let substring_array = args[1].as_string::<i32>();
159 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
160 }
161 (DataType::Utf8View, DataType::LargeUtf8) => {
162 let string_array = args[0].as_string_view();
163 let substring_array = args[1].as_string::<i64>();
164 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
165 }
166
167 other => {
168 exec_err!("Unsupported data type combination {other:?} for function strpos")
169 }
170 }
171}
172
173fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
177 string_array: V1,
178 substring_array: V2,
179) -> Result<ArrayRef>
180where
181 V1: StringArrayType<'a, Item = &'a str>,
182 V2: StringArrayType<'a, Item = &'a str>,
183{
184 let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
185 let string_iter = string_array.iter();
186 let substring_iter = substring_array.iter();
187
188 let result = string_iter
189 .zip(substring_iter)
190 .map(|(string, substring)| match (string, substring) {
191 (Some(string), Some(substring)) => {
192 if ascii_only {
195 if substring.is_empty() {
197 T::Native::from_usize(1)
198 } else {
199 T::Native::from_usize(
200 string
201 .as_bytes()
202 .windows(substring.len())
203 .position(|w| w == substring.as_bytes())
204 .map(|x| x + 1)
205 .unwrap_or(0),
206 )
207 }
208 } else {
209 T::Native::from_usize(
212 string
213 .find(substring)
214 .map(|x| string[..x].chars().count() + 1)
215 .unwrap_or(0),
216 )
217 }
218 }
219 _ => None,
220 })
221 .collect::<PrimitiveArray<T>>();
222
223 Ok(Arc::new(result) as ArrayRef)
224}
225
226#[cfg(test)]
227mod tests {
228 use arrow::array::{Array, Int32Array, Int64Array};
229 use arrow::datatypes::DataType::{Int32, Int64};
230
231 use arrow::datatypes::DataType;
232 use datafusion_common::{Result, ScalarValue};
233 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
234
235 use crate::unicode::strpos::StrposFunc;
236 use crate::utils::test::test_function;
237
238 macro_rules! test_strpos {
239 ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
240 test_function!(
241 StrposFunc::new(),
242 vec![
243 ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
244 ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
245 ],
246 Ok(Some($result)),
247 $t3,
248 $t4,
249 $t5
250 )
251 };
252 }
253
254 #[test]
255 fn test_strpos_functions() {
256 test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
258 test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
259 test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
260 test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
261 test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
262 test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
263 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);
264
265 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
267 test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
268 test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
269 test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
270 test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
271 test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
272 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
273
274 test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
276 test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
277 test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
278 test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
279 test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
280 test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
281 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);
282
283 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
285 test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
286 test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
287 test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
288 test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
289 test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
290 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);
291
292 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
294 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
295 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
296 test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
297 test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
298 test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
299 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);
300
301 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
303 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
304 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
305 test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
306 test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
307 test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
308 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);
309
310 test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
312 test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
313 test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
314 test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
315 test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
316 test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
317 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
318 }
319
320 #[test]
321 fn nullable_return_type() {
322 fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
323 let strpos = StrposFunc::new();
324 let args = datafusion_expr::ReturnTypeArgs {
325 arg_types: &[DataType::Utf8, DataType::Utf8],
326 nullables: &[string_array_nullable, substring_nullable],
327 scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
328 };
329
330 let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts();
331
332 nullable
333 }
334
335 assert!(!get_nullable(false, false));
336
337 assert!(get_nullable(true, false));
339 assert!(get_nullable(false, true));
340 assert!(get_nullable(true, true));
341 }
342}