datafusion_functions/string/
contains.rs1use crate::utils::make_scalar_function;
19use arrow::array::{Array, ArrayRef, AsArray};
20use arrow::compute::contains as arrow_contains;
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
23use datafusion_common::exec_err;
24use datafusion_common::DataFusionError;
25use datafusion_common::Result;
26use datafusion_expr::{
27 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28 Volatility,
29};
30use datafusion_macros::user_doc;
31use std::any::Any;
32use std::sync::Arc;
33
34#[user_doc(
35 doc_section(label = "String Functions"),
36 description = "Return true if search_str is found within string (case-sensitive).",
37 syntax_example = "contains(str, search_str)",
38 sql_example = r#"```sql
39> select contains('the quick brown fox', 'row');
40+---------------------------------------------------+
41| contains(Utf8("the quick brown fox"),Utf8("row")) |
42+---------------------------------------------------+
43| true |
44+---------------------------------------------------+
45```"#,
46 standard_argument(name = "str", prefix = "String"),
47 argument(name = "search_str", description = "The string to search for in str.")
48)]
49#[derive(Debug)]
50pub struct ContainsFunc {
51 signature: Signature,
52}
53
54impl Default for ContainsFunc {
55 fn default() -> Self {
56 ContainsFunc::new()
57 }
58}
59
60impl ContainsFunc {
61 pub fn new() -> Self {
62 Self {
63 signature: Signature::string(2, Volatility::Immutable),
64 }
65 }
66}
67
68impl ScalarUDFImpl for ContainsFunc {
69 fn as_any(&self) -> &dyn Any {
70 self
71 }
72
73 fn name(&self) -> &str {
74 "contains"
75 }
76
77 fn signature(&self) -> &Signature {
78 &self.signature
79 }
80
81 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
82 Ok(Boolean)
83 }
84
85 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86 make_scalar_function(contains, vec![])(&args.args)
87 }
88
89 fn documentation(&self) -> Option<&Documentation> {
90 self.doc()
91 }
92}
93
94pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
96 match (args[0].data_type(), args[1].data_type()) {
97 (Utf8View, Utf8View) => {
98 let mod_str = args[0].as_string_view();
99 let match_str = args[1].as_string_view();
100 let res = arrow_contains(mod_str, match_str)?;
101 Ok(Arc::new(res) as ArrayRef)
102 }
103 (Utf8, Utf8) => {
104 let mod_str = args[0].as_string::<i32>();
105 let match_str = args[1].as_string::<i32>();
106 let res = arrow_contains(mod_str, match_str)?;
107 Ok(Arc::new(res) as ArrayRef)
108 }
109 (LargeUtf8, LargeUtf8) => {
110 let mod_str = args[0].as_string::<i64>();
111 let match_str = args[1].as_string::<i64>();
112 let res = arrow_contains(mod_str, match_str)?;
113 Ok(Arc::new(res) as ArrayRef)
114 }
115 other => {
116 exec_err!("Unsupported data type {other:?} for function `contains`.")
117 }
118 }
119}
120
121#[cfg(test)]
122mod test {
123 use super::ContainsFunc;
124 use arrow::array::{BooleanArray, StringArray};
125 use arrow::datatypes::DataType;
126 use datafusion_common::ScalarValue;
127 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
128 use std::sync::Arc;
129
130 #[test]
131 fn test_contains_udf() {
132 let udf = ContainsFunc::new();
133 let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
134 Some("xxx?()"),
135 Some("yyy?()"),
136 ])));
137 let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string())));
138
139 let args = ScalarFunctionArgs {
140 args: vec![array, scalar],
141 number_rows: 2,
142 return_type: &DataType::Boolean,
143 };
144
145 let actual = udf.invoke_with_args(args).unwrap();
146 let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
147 Some(true),
148 Some(false),
149 ])));
150 assert_eq!(
151 *actual.into_array(2).unwrap(),
152 *expect.into_array(2).unwrap()
153 );
154 }
155}