datafusion_functions/string/
starts_with.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::ArrayRef;
22use arrow::datatypes::DataType;
23use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
24use datafusion_expr::type_coercion::binary::{
25 binary_to_string_coercion, string_coercion,
26};
27
28use crate::utils::make_scalar_function;
29use datafusion_common::types::logical_string;
30use datafusion_common::{internal_err, Result, ScalarValue};
31use datafusion_expr::{
32 cast, Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
33 ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
34};
35use datafusion_macros::user_doc;
36
37fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
40 if let Some(coercion_data_type) =
41 string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
42 binary_to_string_coercion(args[0].data_type(), args[1].data_type())
43 })
44 {
45 let arg0 = if args[0].data_type() == &coercion_data_type {
46 Arc::clone(&args[0])
47 } else {
48 arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
49 };
50 let arg1 = if args[1].data_type() == &coercion_data_type {
51 Arc::clone(&args[1])
52 } else {
53 arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
54 };
55 let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?;
56 Ok(Arc::new(result) as ArrayRef)
57 } else {
58 internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")
59 }
60}
61
62#[user_doc(
63 doc_section(label = "String Functions"),
64 description = "Tests if a string starts with a substring.",
65 syntax_example = "starts_with(str, substr)",
66 sql_example = r#"```sql
67> select starts_with('datafusion','data');
68+----------------------------------------------+
69| starts_with(Utf8("datafusion"),Utf8("data")) |
70+----------------------------------------------+
71| true |
72+----------------------------------------------+
73```"#,
74 standard_argument(name = "str", prefix = "String"),
75 argument(name = "substr", description = "Substring to test for.")
76)]
77#[derive(Debug)]
78pub struct StartsWithFunc {
79 signature: Signature,
80}
81
82impl Default for StartsWithFunc {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl StartsWithFunc {
89 pub fn new() -> Self {
90 Self {
91 signature: Signature::coercible(
92 vec![
93 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
94 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
95 ],
96 Volatility::Immutable,
97 ),
98 }
99 }
100}
101
102impl ScalarUDFImpl for StartsWithFunc {
103 fn as_any(&self) -> &dyn Any {
104 self
105 }
106
107 fn name(&self) -> &str {
108 "starts_with"
109 }
110
111 fn signature(&self) -> &Signature {
112 &self.signature
113 }
114
115 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
116 Ok(DataType::Boolean)
117 }
118
119 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
120 match args.args[0].data_type() {
121 DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => {
122 make_scalar_function(starts_with, vec![])(&args.args)
123 }
124 _ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?,
125 }
126 }
127
128 fn simplify(
129 &self,
130 args: Vec<Expr>,
131 info: &dyn SimplifyInfo,
132 ) -> Result<ExprSimplifyResult> {
133 if let Expr::Literal(scalar_value) = &args[1] {
134 let like_expr = match scalar_value {
140 ScalarValue::Utf8(Some(pattern))
141 | ScalarValue::LargeUtf8(Some(pattern))
142 | ScalarValue::Utf8View(Some(pattern)) => {
143 let escaped_pattern = pattern.replace("%", "\\%");
144 let like_pattern = format!("{}%", escaped_pattern);
145 Expr::Literal(ScalarValue::Utf8(Some(like_pattern)))
146 }
147 _ => return Ok(ExprSimplifyResult::Original(args)),
148 };
149
150 let expr_data_type = info.get_data_type(&args[0])?;
151 let pattern_data_type = info.get_data_type(&like_expr)?;
152
153 if let Some(coercion_data_type) =
154 string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
155 binary_to_string_coercion(&expr_data_type, &pattern_data_type)
156 })
157 {
158 let expr = if expr_data_type == coercion_data_type {
159 args[0].clone()
160 } else {
161 cast(args[0].clone(), coercion_data_type.clone())
162 };
163
164 let pattern = if pattern_data_type == coercion_data_type {
165 like_expr
166 } else {
167 cast(like_expr, coercion_data_type)
168 };
169
170 return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
171 negated: false,
172 expr: Box::new(expr),
173 pattern: Box::new(pattern),
174 escape_char: None,
175 case_insensitive: false,
176 })));
177 }
178 }
179
180 Ok(ExprSimplifyResult::Original(args))
181 }
182
183 fn documentation(&self) -> Option<&Documentation> {
184 self.doc()
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use crate::utils::test::test_function;
191 use arrow::array::{Array, BooleanArray};
192 use arrow::datatypes::DataType::Boolean;
193 use datafusion_common::{Result, ScalarValue};
194 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
195
196 use super::*;
197
198 #[test]
199 fn test_functions() -> Result<()> {
200 let test_cases = vec![
202 (Some("alphabet"), Some("alph"), Some(true)),
203 (Some("alphabet"), Some("bet"), Some(false)),
204 (
205 Some("somewhat large string"),
206 Some("somewhat large"),
207 Some(true),
208 ),
209 (Some("somewhat large string"), Some("large"), Some(false)),
210 ]
211 .into_iter()
212 .flat_map(|(a, b, c)| {
213 let utf_8_args = vec![
214 ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
215 ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
216 ];
217
218 let large_utf_8_args = vec![
219 ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
220 ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
221 ];
222
223 let utf_8_view_args = vec![
224 ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
225 ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
226 ];
227
228 vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
229 });
230
231 for (args, expected) in test_cases {
232 test_function!(
233 StartsWithFunc::new(),
234 args,
235 Ok(expected),
236 bool,
237 Boolean,
238 BooleanArray
239 );
240 }
241
242 Ok(())
243 }
244}