datafusion_functions/string/
replace.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_str_type};
25use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
26use datafusion_common::{exec_err, Result};
27use datafusion_expr::{ColumnarValue, Documentation, Volatility};
28use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
29use datafusion_macros::user_doc;
30#[user_doc(
31 doc_section(label = "String Functions"),
32 description = "Replaces all occurrences of a specified substring in a string with a new substring.",
33 syntax_example = "replace(str, substr, replacement)",
34 sql_example = r#"```sql
35> select replace('ABabbaBA', 'ab', 'cd');
36+-------------------------------------------------+
37| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) |
38+-------------------------------------------------+
39| ABcdbaBA |
40+-------------------------------------------------+
41```"#,
42 standard_argument(name = "str", prefix = "String"),
43 standard_argument(
44 name = "substr",
45 prefix = "Substring expression to replace in the input string. Substring"
46 ),
47 standard_argument(name = "replacement", prefix = "Replacement substring")
48)]
49#[derive(Debug)]
50pub struct ReplaceFunc {
51 signature: Signature,
52}
53
54impl Default for ReplaceFunc {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl ReplaceFunc {
61 pub fn new() -> Self {
62 Self {
63 signature: Signature::string(3, Volatility::Immutable),
64 }
65 }
66}
67
68impl ScalarUDFImpl for ReplaceFunc {
69 fn as_any(&self) -> &dyn Any {
70 self
71 }
72
73 fn name(&self) -> &str {
74 "replace"
75 }
76
77 fn signature(&self) -> &Signature {
78 &self.signature
79 }
80
81 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
82 utf8_to_str_type(&arg_types[0], "replace")
83 }
84
85 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86 match args.args[0].data_type() {
87 DataType::Utf8 => make_scalar_function(replace::<i32>, vec![])(&args.args),
88 DataType::LargeUtf8 => {
89 make_scalar_function(replace::<i64>, vec![])(&args.args)
90 }
91 DataType::Utf8View => make_scalar_function(replace_view, vec![])(&args.args),
92 other => {
93 exec_err!("Unsupported data type {other:?} for function replace")
94 }
95 }
96 }
97
98 fn documentation(&self) -> Option<&Documentation> {
99 self.doc()
100 }
101}
102
103fn replace_view(args: &[ArrayRef]) -> Result<ArrayRef> {
104 let string_array = as_string_view_array(&args[0])?;
105 let from_array = as_string_view_array(&args[1])?;
106 let to_array = as_string_view_array(&args[2])?;
107
108 let result = string_array
109 .iter()
110 .zip(from_array.iter())
111 .zip(to_array.iter())
112 .map(|((string, from), to)| match (string, from, to) {
113 (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)),
114 _ => None,
115 })
116 .collect::<StringArray>();
117
118 Ok(Arc::new(result) as ArrayRef)
119}
120fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
123 let string_array = as_generic_string_array::<T>(&args[0])?;
124 let from_array = as_generic_string_array::<T>(&args[1])?;
125 let to_array = as_generic_string_array::<T>(&args[2])?;
126
127 let result = string_array
128 .iter()
129 .zip(from_array.iter())
130 .zip(to_array.iter())
131 .map(|((string, from), to)| match (string, from, to) {
132 (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)),
133 _ => None,
134 })
135 .collect::<GenericStringArray<T>>();
136
137 Ok(Arc::new(result) as ArrayRef)
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::utils::test::test_function;
144 use arrow::array::Array;
145 use arrow::array::LargeStringArray;
146 use arrow::array::StringArray;
147 use arrow::datatypes::DataType::{LargeUtf8, Utf8};
148 use datafusion_common::ScalarValue;
149 #[test]
150 fn test_functions() -> Result<()> {
151 test_function!(
152 ReplaceFunc::new(),
153 vec![
154 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))),
155 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))),
156 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))),
157 ],
158 Ok(Some("aacccdqcccc")),
159 &str,
160 Utf8,
161 StringArray
162 );
163
164 test_function!(
165 ReplaceFunc::new(),
166 vec![
167 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
168 "aabbb"
169 )))),
170 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))),
171 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))),
172 ],
173 Ok(Some("aacc")),
174 &str,
175 LargeUtf8,
176 LargeStringArray
177 );
178
179 test_function!(
180 ReplaceFunc::new(),
181 vec![
182 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
183 "aabbbcw"
184 )))),
185 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))),
186 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))),
187 ],
188 Ok(Some("aaccbcw")),
189 &str,
190 Utf8,
191 StringArray
192 );
193
194 Ok(())
195 }
196}