datafusion_functions/string/
contains.rs1use crate::utils::make_scalar_function;
19use arrow::array::{Array, ArrayRef, AsArray};
20use arrow::compute::contains as arrow_contains;
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
23use datafusion_common::types::logical_string;
24use datafusion_common::{exec_err, DataFusionError, Result};
25use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
26use datafusion_expr::{
27 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28 TypeSignatureClass, Volatility,
29};
30use datafusion_macros::user_doc;
31use std::any::Any;
32use std::sync::Arc;
33
34#[user_doc(
35 doc_section(label = "String Functions"),
36 description = "Return true if search_str is found within string (case-sensitive).",
37 syntax_example = "contains(str, search_str)",
38 sql_example = r#"```sql
39> select contains('the quick brown fox', 'row');
40+---------------------------------------------------+
41| contains(Utf8("the quick brown fox"),Utf8("row")) |
42+---------------------------------------------------+
43| true |
44+---------------------------------------------------+
45```"#,
46 standard_argument(name = "str", prefix = "String"),
47 argument(name = "search_str", description = "The string to search for in str.")
48)]
49#[derive(Debug)]
50pub struct ContainsFunc {
51 signature: Signature,
52}
53
54impl Default for ContainsFunc {
55 fn default() -> Self {
56 ContainsFunc::new()
57 }
58}
59
60impl ContainsFunc {
61 pub fn new() -> Self {
62 Self {
63 signature: Signature::coercible(
64 vec![
65 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
66 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
67 ],
68 Volatility::Immutable,
69 ),
70 }
71 }
72}
73
74impl ScalarUDFImpl for ContainsFunc {
75 fn as_any(&self) -> &dyn Any {
76 self
77 }
78
79 fn name(&self) -> &str {
80 "contains"
81 }
82
83 fn signature(&self) -> &Signature {
84 &self.signature
85 }
86
87 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
88 Ok(Boolean)
89 }
90
91 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92 make_scalar_function(contains, vec![])(&args.args)
93 }
94
95 fn documentation(&self) -> Option<&Documentation> {
96 self.doc()
97 }
98}
99
100fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
102 if let Some(coercion_data_type) =
103 string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
104 binary_to_string_coercion(args[0].data_type(), args[1].data_type())
105 })
106 {
107 let arg0 = if args[0].data_type() == &coercion_data_type {
108 Arc::clone(&args[0])
109 } else {
110 arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
111 };
112 let arg1 = if args[1].data_type() == &coercion_data_type {
113 Arc::clone(&args[1])
114 } else {
115 arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
116 };
117
118 match coercion_data_type {
119 Utf8View => {
120 let mod_str = arg0.as_string_view();
121 let match_str = arg1.as_string_view();
122 let res = arrow_contains(mod_str, match_str)?;
123 Ok(Arc::new(res) as ArrayRef)
124 }
125 Utf8 => {
126 let mod_str = arg0.as_string::<i32>();
127 let match_str = arg1.as_string::<i32>();
128 let res = arrow_contains(mod_str, match_str)?;
129 Ok(Arc::new(res) as ArrayRef)
130 }
131 LargeUtf8 => {
132 let mod_str = arg0.as_string::<i64>();
133 let match_str = arg1.as_string::<i64>();
134 let res = arrow_contains(mod_str, match_str)?;
135 Ok(Arc::new(res) as ArrayRef)
136 }
137 other => {
138 exec_err!("Unsupported data type {other:?} for function `contains`.")
139 }
140 }
141 } else {
142 exec_err!(
143 "Unsupported data type {:?}, {:?} for function `contains`.",
144 args[0].data_type(),
145 args[1].data_type()
146 )
147 }
148}
149
150#[cfg(test)]
151mod test {
152 use super::ContainsFunc;
153 use arrow::array::{BooleanArray, StringArray};
154 use arrow::datatypes::DataType;
155 use datafusion_common::ScalarValue;
156 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
157 use std::sync::Arc;
158
159 #[test]
160 fn test_contains_udf() {
161 let udf = ContainsFunc::new();
162 let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
163 Some("xxx?()"),
164 Some("yyy?()"),
165 ])));
166 let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string())));
167
168 let args = ScalarFunctionArgs {
169 args: vec![array, scalar],
170 number_rows: 2,
171 return_type: &DataType::Boolean,
172 };
173
174 let actual = udf.invoke_with_args(args).unwrap();
175 let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
176 Some(true),
177 Some(false),
178 ])));
179 assert_eq!(
180 *actual.into_array(2).unwrap(),
181 *expect.into_array(2).unwrap()
182 );
183 }
184}