datafusion_functions/string/
levenshtein.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_int_type};
25use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
26use datafusion_common::utils::datafusion_strsim;
27use datafusion_common::{exec_err, utils::take_function_args, Result};
28use datafusion_expr::{ColumnarValue, Documentation};
29use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
30use datafusion_macros::user_doc;
31
32#[user_doc(
33 doc_section(label = "String Functions"),
34 description = "Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings.",
35 syntax_example = "levenshtein(str1, str2)",
36 sql_example = r#"```sql
37> select levenshtein('kitten', 'sitting');
38+---------------------------------------------+
39| levenshtein(Utf8("kitten"),Utf8("sitting")) |
40+---------------------------------------------+
41| 3 |
42+---------------------------------------------+
43```"#,
44 argument(
45 name = "str1",
46 description = "String expression to compute Levenshtein distance with str2."
47 ),
48 argument(
49 name = "str2",
50 description = "String expression to compute Levenshtein distance with str1."
51 )
52)]
53#[derive(Debug)]
54pub struct LevenshteinFunc {
55 signature: Signature,
56}
57
58impl Default for LevenshteinFunc {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl LevenshteinFunc {
65 pub fn new() -> Self {
66 Self {
67 signature: Signature::string(2, Volatility::Immutable),
68 }
69 }
70}
71
72impl ScalarUDFImpl for LevenshteinFunc {
73 fn as_any(&self) -> &dyn Any {
74 self
75 }
76
77 fn name(&self) -> &str {
78 "levenshtein"
79 }
80
81 fn signature(&self) -> &Signature {
82 &self.signature
83 }
84
85 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
86 utf8_to_int_type(&arg_types[0], "levenshtein")
87 }
88
89 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90 match args.args[0].data_type() {
91 DataType::Utf8View | DataType::Utf8 => {
92 make_scalar_function(levenshtein::<i32>, vec![])(&args.args)
93 }
94 DataType::LargeUtf8 => {
95 make_scalar_function(levenshtein::<i64>, vec![])(&args.args)
96 }
97 other => {
98 exec_err!("Unsupported data type {other:?} for function levenshtein")
99 }
100 }
101 }
102
103 fn documentation(&self) -> Option<&Documentation> {
104 self.doc()
105 }
106}
107
108pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
111 let [str1, str2] = take_function_args("levenshtein", args)?;
112
113 match str1.data_type() {
114 DataType::Utf8View => {
115 let str1_array = as_string_view_array(&str1)?;
116 let str2_array = as_string_view_array(&str2)?;
117 let result = str1_array
118 .iter()
119 .zip(str2_array.iter())
120 .map(|(string1, string2)| match (string1, string2) {
121 (Some(string1), Some(string2)) => {
122 Some(datafusion_strsim::levenshtein(string1, string2) as i32)
123 }
124 _ => None,
125 })
126 .collect::<Int32Array>();
127 Ok(Arc::new(result) as ArrayRef)
128 }
129 DataType::Utf8 => {
130 let str1_array = as_generic_string_array::<T>(&str1)?;
131 let str2_array = as_generic_string_array::<T>(&str2)?;
132 let result = str1_array
133 .iter()
134 .zip(str2_array.iter())
135 .map(|(string1, string2)| match (string1, string2) {
136 (Some(string1), Some(string2)) => {
137 Some(datafusion_strsim::levenshtein(string1, string2) as i32)
138 }
139 _ => None,
140 })
141 .collect::<Int32Array>();
142 Ok(Arc::new(result) as ArrayRef)
143 }
144 DataType::LargeUtf8 => {
145 let str1_array = as_generic_string_array::<T>(&str1)?;
146 let str2_array = as_generic_string_array::<T>(&str2)?;
147 let result = str1_array
148 .iter()
149 .zip(str2_array.iter())
150 .map(|(string1, string2)| match (string1, string2) {
151 (Some(string1), Some(string2)) => {
152 Some(datafusion_strsim::levenshtein(string1, string2) as i64)
153 }
154 _ => None,
155 })
156 .collect::<Int64Array>();
157 Ok(Arc::new(result) as ArrayRef)
158 }
159 other => {
160 exec_err!(
161 "levenshtein was called with {other} datatype arguments. It requires Utf8View, Utf8 or LargeUtf8."
162 )
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use arrow::array::StringArray;
170
171 use datafusion_common::cast::as_int32_array;
172
173 use super::*;
174
175 #[test]
176 fn to_levenshtein() -> Result<()> {
177 let string1_array =
178 Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"]));
179 let string2_array =
180 Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"]));
181 let res = levenshtein::<i32>(&[string1_array, string2_array]).unwrap();
182 let result =
183 as_int32_array(&res).expect("failed to initialized function levenshtein");
184 let expected = Int32Array::from(vec![2, 3, 2, 3]);
185 assert_eq!(&expected, result);
186
187 Ok(())
188 }
189}