datafusion_functions/string/
starts_with.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::ArrayRef;
22use arrow::datatypes::DataType;
23use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
24use datafusion_expr::type_coercion::binary::{
25    binary_to_string_coercion, string_coercion,
26};
27
28use crate::utils::make_scalar_function;
29use datafusion_common::types::logical_string;
30use datafusion_common::{internal_err, Result, ScalarValue};
31use datafusion_expr::{
32    cast, Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
33    ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
34};
35use datafusion_macros::user_doc;
36
37/// Returns true if string starts with prefix.
38/// starts_with('alphabet', 'alph') = 't'
39fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
40    if let Some(coercion_data_type) =
41        string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
42            binary_to_string_coercion(args[0].data_type(), args[1].data_type())
43        })
44    {
45        let arg0 = if args[0].data_type() == &coercion_data_type {
46            Arc::clone(&args[0])
47        } else {
48            arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
49        };
50        let arg1 = if args[1].data_type() == &coercion_data_type {
51            Arc::clone(&args[1])
52        } else {
53            arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
54        };
55        let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?;
56        Ok(Arc::new(result) as ArrayRef)
57    } else {
58        internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")
59    }
60}
61
62#[user_doc(
63    doc_section(label = "String Functions"),
64    description = "Tests if a string starts with a substring.",
65    syntax_example = "starts_with(str, substr)",
66    sql_example = r#"```sql
67> select starts_with('datafusion','data');
68+----------------------------------------------+
69| starts_with(Utf8("datafusion"),Utf8("data")) |
70+----------------------------------------------+
71| true                                         |
72+----------------------------------------------+
73```"#,
74    standard_argument(name = "str", prefix = "String"),
75    argument(name = "substr", description = "Substring to test for.")
76)]
77#[derive(Debug)]
78pub struct StartsWithFunc {
79    signature: Signature,
80}
81
82impl Default for StartsWithFunc {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl StartsWithFunc {
89    pub fn new() -> Self {
90        Self {
91            signature: Signature::coercible(
92                vec![
93                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
94                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
95                ],
96                Volatility::Immutable,
97            ),
98        }
99    }
100}
101
102impl ScalarUDFImpl for StartsWithFunc {
103    fn as_any(&self) -> &dyn Any {
104        self
105    }
106
107    fn name(&self) -> &str {
108        "starts_with"
109    }
110
111    fn signature(&self) -> &Signature {
112        &self.signature
113    }
114
115    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
116        Ok(DataType::Boolean)
117    }
118
119    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
120        match args.args[0].data_type() {
121            DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => {
122                make_scalar_function(starts_with, vec![])(&args.args)
123            }
124            _ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?,
125        }
126    }
127
128    fn simplify(
129        &self,
130        args: Vec<Expr>,
131        info: &dyn SimplifyInfo,
132    ) -> Result<ExprSimplifyResult> {
133        if let Expr::Literal(scalar_value) = &args[1] {
134            // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping
135            // Example: starts_with(col, 'ja%') -> col LIKE 'ja\%%'
136            //   1. 'ja%'         (input pattern)
137            //   2. 'ja\%'        (escape special char '%')
138            //   3. 'ja\%%'       (add suffix for starts_with)
139            let like_expr = match scalar_value {
140                ScalarValue::Utf8(Some(pattern))
141                | ScalarValue::LargeUtf8(Some(pattern))
142                | ScalarValue::Utf8View(Some(pattern)) => {
143                    let escaped_pattern = pattern.replace("%", "\\%");
144                    let like_pattern = format!("{}%", escaped_pattern);
145                    Expr::Literal(ScalarValue::Utf8(Some(like_pattern)))
146                }
147                _ => return Ok(ExprSimplifyResult::Original(args)),
148            };
149
150            let expr_data_type = info.get_data_type(&args[0])?;
151            let pattern_data_type = info.get_data_type(&like_expr)?;
152
153            if let Some(coercion_data_type) =
154                string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
155                    binary_to_string_coercion(&expr_data_type, &pattern_data_type)
156                })
157            {
158                let expr = if expr_data_type == coercion_data_type {
159                    args[0].clone()
160                } else {
161                    cast(args[0].clone(), coercion_data_type.clone())
162                };
163
164                let pattern = if pattern_data_type == coercion_data_type {
165                    like_expr
166                } else {
167                    cast(like_expr, coercion_data_type)
168                };
169
170                return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
171                    negated: false,
172                    expr: Box::new(expr),
173                    pattern: Box::new(pattern),
174                    escape_char: None,
175                    case_insensitive: false,
176                })));
177            }
178        }
179
180        Ok(ExprSimplifyResult::Original(args))
181    }
182
183    fn documentation(&self) -> Option<&Documentation> {
184        self.doc()
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use crate::utils::test::test_function;
191    use arrow::array::{Array, BooleanArray};
192    use arrow::datatypes::DataType::Boolean;
193    use datafusion_common::{Result, ScalarValue};
194    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
195
196    use super::*;
197
198    #[test]
199    fn test_functions() -> Result<()> {
200        // Generate test cases for starts_with
201        let test_cases = vec![
202            (Some("alphabet"), Some("alph"), Some(true)),
203            (Some("alphabet"), Some("bet"), Some(false)),
204            (
205                Some("somewhat large string"),
206                Some("somewhat large"),
207                Some(true),
208            ),
209            (Some("somewhat large string"), Some("large"), Some(false)),
210        ]
211        .into_iter()
212        .flat_map(|(a, b, c)| {
213            let utf_8_args = vec![
214                ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
215                ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
216            ];
217
218            let large_utf_8_args = vec![
219                ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
220                ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
221            ];
222
223            let utf_8_view_args = vec![
224                ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
225                ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
226            ];
227
228            vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
229        });
230
231        for (args, expected) in test_cases {
232            test_function!(
233                StartsWithFunc::new(),
234                args,
235                Ok(expected),
236                bool,
237                Boolean,
238                BooleanArray
239            );
240        }
241
242        Ok(())
243    }
244}