datafusion_functions/string/
replace.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_str_type};
25use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
26use datafusion_common::types::logical_string;
27use datafusion_common::{exec_err, Result};
28use datafusion_expr::type_coercion::binary::{
29 binary_to_string_coercion, string_coercion,
30};
31use datafusion_expr::{
32 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33 TypeSignatureClass, Volatility,
34};
35use datafusion_macros::user_doc;
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Replaces all occurrences of a specified substring in a string with a new substring.",
39 syntax_example = "replace(str, substr, replacement)",
40 sql_example = r#"```sql
41> select replace('ABabbaBA', 'ab', 'cd');
42+-------------------------------------------------+
43| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) |
44+-------------------------------------------------+
45| ABcdbaBA |
46+-------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 standard_argument(
50 name = "substr",
51 prefix = "Substring expression to replace in the input string. Substring"
52 ),
53 standard_argument(name = "replacement", prefix = "Replacement substring")
54)]
55#[derive(Debug)]
56pub struct ReplaceFunc {
57 signature: Signature,
58}
59
60impl Default for ReplaceFunc {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl ReplaceFunc {
67 pub fn new() -> Self {
68 Self {
69 signature: Signature::coercible(
70 vec![
71 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74 ],
75 Volatility::Immutable,
76 ),
77 }
78 }
79}
80
81impl ScalarUDFImpl for ReplaceFunc {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn name(&self) -> &str {
87 "replace"
88 }
89
90 fn signature(&self) -> &Signature {
91 &self.signature
92 }
93
94 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95 if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1])
96 .and_then(|dt| string_coercion(&dt, &arg_types[2]))
97 .or_else(|| {
98 binary_to_string_coercion(&arg_types[0], &arg_types[1])
99 .and_then(|dt| binary_to_string_coercion(&dt, &arg_types[2]))
100 })
101 {
102 utf8_to_str_type(&coercion_data_type, "replace")
103 } else {
104 exec_err!("Unsupported data types for replace. Expected Utf8, LargeUtf8 or Utf8View")
105 }
106 }
107
108 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109 let data_types = args
110 .args
111 .iter()
112 .map(|arg| arg.data_type())
113 .collect::<Vec<_>>();
114
115 if let Some(coercion_type) = string_coercion(&data_types[0], &data_types[1])
116 .and_then(|dt| string_coercion(&dt, &data_types[2]))
117 .or_else(|| {
118 binary_to_string_coercion(&data_types[0], &data_types[1])
119 .and_then(|dt| binary_to_string_coercion(&dt, &data_types[2]))
120 })
121 {
122 let mut converted_args = Vec::with_capacity(args.args.len());
123 for arg in &args.args {
124 if arg.data_type() == coercion_type {
125 converted_args.push(arg.clone());
126 } else {
127 let converted = arg.cast_to(&coercion_type, None)?;
128 converted_args.push(converted);
129 }
130 }
131
132 match coercion_type {
133 DataType::Utf8 => {
134 make_scalar_function(replace::<i32>, vec![])(&converted_args)
135 }
136 DataType::LargeUtf8 => {
137 make_scalar_function(replace::<i64>, vec![])(&converted_args)
138 }
139 DataType::Utf8View => {
140 make_scalar_function(replace_view, vec![])(&converted_args)
141 }
142 other => exec_err!(
143 "Unsupported coercion data type {other:?} for function replace"
144 ),
145 }
146 } else {
147 exec_err!(
148 "Unsupported data type {:?}, {:?}, {:?} for function replace.",
149 data_types[0],
150 data_types[1],
151 data_types[2]
152 )
153 }
154 }
155
156 fn documentation(&self) -> Option<&Documentation> {
157 self.doc()
158 }
159}
160
161fn replace_view(args: &[ArrayRef]) -> Result<ArrayRef> {
162 let string_array = as_string_view_array(&args[0])?;
163 let from_array = as_string_view_array(&args[1])?;
164 let to_array = as_string_view_array(&args[2])?;
165
166 let result = string_array
167 .iter()
168 .zip(from_array.iter())
169 .zip(to_array.iter())
170 .map(|((string, from), to)| match (string, from, to) {
171 (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)),
172 _ => None,
173 })
174 .collect::<StringArray>();
175
176 Ok(Arc::new(result) as ArrayRef)
177}
178
179fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
182 let string_array = as_generic_string_array::<T>(&args[0])?;
183 let from_array = as_generic_string_array::<T>(&args[1])?;
184 let to_array = as_generic_string_array::<T>(&args[2])?;
185
186 let result = string_array
187 .iter()
188 .zip(from_array.iter())
189 .zip(to_array.iter())
190 .map(|((string, from), to)| match (string, from, to) {
191 (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)),
192 _ => None,
193 })
194 .collect::<GenericStringArray<T>>();
195
196 Ok(Arc::new(result) as ArrayRef)
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::utils::test::test_function;
203 use arrow::array::Array;
204 use arrow::array::LargeStringArray;
205 use arrow::array::StringArray;
206 use arrow::datatypes::DataType::{LargeUtf8, Utf8};
207 use datafusion_common::ScalarValue;
208 #[test]
209 fn test_functions() -> Result<()> {
210 test_function!(
211 ReplaceFunc::new(),
212 vec![
213 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))),
214 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))),
215 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))),
216 ],
217 Ok(Some("aacccdqcccc")),
218 &str,
219 Utf8,
220 StringArray
221 );
222
223 test_function!(
224 ReplaceFunc::new(),
225 vec![
226 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
227 "aabbb"
228 )))),
229 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))),
230 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))),
231 ],
232 Ok(Some("aacc")),
233 &str,
234 LargeUtf8,
235 LargeStringArray
236 );
237
238 test_function!(
239 ReplaceFunc::new(),
240 vec![
241 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
242 "aabbbcw"
243 )))),
244 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))),
245 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))),
246 ],
247 Ok(Some("aaccbcw")),
248 &str,
249 Utf8,
250 StringArray
251 );
252
253 Ok(())
254 }
255}