1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_int_type};
22use arrow::array::{
23 ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26use datafusion_common::{exec_err, internal_err, Result};
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
29};
30use datafusion_macros::user_doc;
31
32#[user_doc(
33 doc_section(label = "String Functions"),
34 description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.",
35 syntax_example = "strpos(str, substr)",
36 alternative_syntax = "position(substr in origstr)",
37 sql_example = r#"```sql
38> select strpos('datafusion', 'fus');
39+----------------------------------------+
40| strpos(Utf8("datafusion"),Utf8("fus")) |
41+----------------------------------------+
42| 5 |
43+----------------------------------------+
44```"#,
45 standard_argument(name = "str", prefix = "String"),
46 argument(name = "substr", description = "Substring expression to search for.")
47)]
48#[derive(Debug)]
49pub struct StrposFunc {
50 signature: Signature,
51 aliases: Vec<String>,
52}
53
54impl Default for StrposFunc {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl StrposFunc {
61 pub fn new() -> Self {
62 Self {
63 signature: Signature::string(2, Volatility::Immutable),
64 aliases: vec![String::from("instr"), String::from("position")],
65 }
66 }
67}
68
69impl ScalarUDFImpl for StrposFunc {
70 fn as_any(&self) -> &dyn Any {
71 self
72 }
73
74 fn name(&self) -> &str {
75 "strpos"
76 }
77
78 fn signature(&self) -> &Signature {
79 &self.signature
80 }
81
82 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
83 internal_err!("return_type_from_args should be used instead")
84 }
85
86 fn return_type_from_args(
87 &self,
88 args: datafusion_expr::ReturnTypeArgs,
89 ) -> Result<datafusion_expr::ReturnInfo> {
90 utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| {
91 datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x))
92 })
93 }
94
95 fn invoke_with_args(
96 &self,
97 args: datafusion_expr::ScalarFunctionArgs,
98 ) -> Result<ColumnarValue> {
99 make_scalar_function(strpos, vec![])(&args.args)
100 }
101
102 fn aliases(&self) -> &[String] {
103 &self.aliases
104 }
105
106 fn documentation(&self) -> Option<&Documentation> {
107 self.doc()
108 }
109}
110
111fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
112 match (args[0].data_type(), args[1].data_type()) {
113 (DataType::Utf8, DataType::Utf8) => {
114 let string_array = args[0].as_string::<i32>();
115 let substring_array = args[1].as_string::<i32>();
116 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
117 }
118 (DataType::Utf8, DataType::LargeUtf8) => {
119 let string_array = args[0].as_string::<i32>();
120 let substring_array = args[1].as_string::<i64>();
121 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
122 }
123 (DataType::LargeUtf8, DataType::Utf8) => {
124 let string_array = args[0].as_string::<i64>();
125 let substring_array = args[1].as_string::<i32>();
126 calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
127 }
128 (DataType::LargeUtf8, DataType::LargeUtf8) => {
129 let string_array = args[0].as_string::<i64>();
130 let substring_array = args[1].as_string::<i64>();
131 calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
132 }
133 (DataType::Utf8View, DataType::Utf8View) => {
134 let string_array = args[0].as_string_view();
135 let substring_array = args[1].as_string_view();
136 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
137 }
138 (DataType::Utf8View, DataType::Utf8) => {
139 let string_array = args[0].as_string_view();
140 let substring_array = args[1].as_string::<i32>();
141 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
142 }
143 (DataType::Utf8View, DataType::LargeUtf8) => {
144 let string_array = args[0].as_string_view();
145 let substring_array = args[1].as_string::<i64>();
146 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
147 }
148
149 other => {
150 exec_err!("Unsupported data type combination {other:?} for function strpos")
151 }
152 }
153}
154
155fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
159 string_array: V1,
160 substring_array: V2,
161) -> Result<ArrayRef>
162where
163 V1: StringArrayType<'a, Item = &'a str>,
164 V2: StringArrayType<'a, Item = &'a str>,
165{
166 let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
167 let string_iter = string_array.iter();
168 let substring_iter = substring_array.iter();
169
170 let result = string_iter
171 .zip(substring_iter)
172 .map(|(string, substring)| match (string, substring) {
173 (Some(string), Some(substring)) => {
174 if ascii_only {
177 if substring.is_empty() {
179 T::Native::from_usize(1)
180 } else {
181 T::Native::from_usize(
182 string
183 .as_bytes()
184 .windows(substring.len())
185 .position(|w| w == substring.as_bytes())
186 .map(|x| x + 1)
187 .unwrap_or(0),
188 )
189 }
190 } else {
191 T::Native::from_usize(
194 string
195 .find(substring)
196 .map(|x| string[..x].chars().count() + 1)
197 .unwrap_or(0),
198 )
199 }
200 }
201 _ => None,
202 })
203 .collect::<PrimitiveArray<T>>();
204
205 Ok(Arc::new(result) as ArrayRef)
206}
207
208#[cfg(test)]
209mod tests {
210 use arrow::array::{Array, Int32Array, Int64Array};
211 use arrow::datatypes::DataType::{Int32, Int64};
212
213 use arrow::datatypes::DataType;
214 use datafusion_common::{Result, ScalarValue};
215 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
216
217 use crate::unicode::strpos::StrposFunc;
218 use crate::utils::test::test_function;
219
220 macro_rules! test_strpos {
221 ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
222 test_function!(
223 StrposFunc::new(),
224 vec![
225 ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
226 ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
227 ],
228 Ok(Some($result)),
229 $t3,
230 $t4,
231 $t5
232 )
233 };
234 }
235
236 #[test]
237 fn test_strpos_functions() {
238 test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
240 test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
241 test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
242 test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
243 test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
244 test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
245 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);
246
247 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
249 test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
250 test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
251 test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
252 test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
253 test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
254 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
255
256 test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
258 test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
259 test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
260 test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
261 test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
262 test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
263 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);
264
265 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
267 test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
268 test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
269 test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
270 test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
271 test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
272 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);
273
274 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
276 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
277 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
278 test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
279 test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
280 test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
281 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);
282
283 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
285 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
286 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
287 test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
288 test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
289 test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
290 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);
291
292 test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
294 test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
295 test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
296 test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
297 test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
298 test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
299 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
300 }
301
302 #[test]
303 fn nullable_return_type() {
304 fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
305 let strpos = StrposFunc::new();
306 let args = datafusion_expr::ReturnTypeArgs {
307 arg_types: &[DataType::Utf8, DataType::Utf8],
308 nullables: &[string_array_nullable, substring_nullable],
309 scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
310 };
311
312 let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts();
313
314 nullable
315 }
316
317 assert!(!get_nullable(false, false));
318
319 assert!(get_nullable(true, false));
321 assert!(get_nullable(false, true));
322 assert!(get_nullable(true, true));
323 }
324}