datafusion_functions/string/
levenshtein.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_int_type};
25use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
26use datafusion_common::utils::datafusion_strsim;
27use datafusion_common::{exec_err, utils::take_function_args, Result};
28use datafusion_expr::{ColumnarValue, Documentation};
29use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
30use datafusion_macros::user_doc;
31
32#[user_doc(
33    doc_section(label = "String Functions"),
34    description = "Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings.",
35    syntax_example = "levenshtein(str1, str2)",
36    sql_example = r#"```sql
37> select levenshtein('kitten', 'sitting');
38+---------------------------------------------+
39| levenshtein(Utf8("kitten"),Utf8("sitting")) |
40+---------------------------------------------+
41| 3                                           |
42+---------------------------------------------+
43```"#,
44    argument(
45        name = "str1",
46        description = "String expression to compute Levenshtein distance with str2."
47    ),
48    argument(
49        name = "str2",
50        description = "String expression to compute Levenshtein distance with str1."
51    )
52)]
53#[derive(Debug)]
54pub struct LevenshteinFunc {
55    signature: Signature,
56}
57
58impl Default for LevenshteinFunc {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl LevenshteinFunc {
65    pub fn new() -> Self {
66        Self {
67            signature: Signature::string(2, Volatility::Immutable),
68        }
69    }
70}
71
72impl ScalarUDFImpl for LevenshteinFunc {
73    fn as_any(&self) -> &dyn Any {
74        self
75    }
76
77    fn name(&self) -> &str {
78        "levenshtein"
79    }
80
81    fn signature(&self) -> &Signature {
82        &self.signature
83    }
84
85    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
86        utf8_to_int_type(&arg_types[0], "levenshtein")
87    }
88
89    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90        match args.args[0].data_type() {
91            DataType::Utf8View | DataType::Utf8 => {
92                make_scalar_function(levenshtein::<i32>, vec![])(&args.args)
93            }
94            DataType::LargeUtf8 => {
95                make_scalar_function(levenshtein::<i64>, vec![])(&args.args)
96            }
97            other => {
98                exec_err!("Unsupported data type {other:?} for function levenshtein")
99            }
100        }
101    }
102
103    fn documentation(&self) -> Option<&Documentation> {
104        self.doc()
105    }
106}
107
108///Returns the Levenshtein distance between the two given strings.
109/// LEVENSHTEIN('kitten', 'sitting') = 3
110pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
111    let [str1, str2] = take_function_args("levenshtein", args)?;
112
113    match str1.data_type() {
114        DataType::Utf8View => {
115            let str1_array = as_string_view_array(&str1)?;
116            let str2_array = as_string_view_array(&str2)?;
117            let result = str1_array
118                .iter()
119                .zip(str2_array.iter())
120                .map(|(string1, string2)| match (string1, string2) {
121                    (Some(string1), Some(string2)) => {
122                        Some(datafusion_strsim::levenshtein(string1, string2) as i32)
123                    }
124                    _ => None,
125                })
126                .collect::<Int32Array>();
127            Ok(Arc::new(result) as ArrayRef)
128        }
129        DataType::Utf8 => {
130            let str1_array = as_generic_string_array::<T>(&str1)?;
131            let str2_array = as_generic_string_array::<T>(&str2)?;
132            let result = str1_array
133                .iter()
134                .zip(str2_array.iter())
135                .map(|(string1, string2)| match (string1, string2) {
136                    (Some(string1), Some(string2)) => {
137                        Some(datafusion_strsim::levenshtein(string1, string2) as i32)
138                    }
139                    _ => None,
140                })
141                .collect::<Int32Array>();
142            Ok(Arc::new(result) as ArrayRef)
143        }
144        DataType::LargeUtf8 => {
145            let str1_array = as_generic_string_array::<T>(&str1)?;
146            let str2_array = as_generic_string_array::<T>(&str2)?;
147            let result = str1_array
148                .iter()
149                .zip(str2_array.iter())
150                .map(|(string1, string2)| match (string1, string2) {
151                    (Some(string1), Some(string2)) => {
152                        Some(datafusion_strsim::levenshtein(string1, string2) as i64)
153                    }
154                    _ => None,
155                })
156                .collect::<Int64Array>();
157            Ok(Arc::new(result) as ArrayRef)
158        }
159        other => {
160            exec_err!(
161                "levenshtein was called with {other} datatype arguments. It requires Utf8View, Utf8 or LargeUtf8."
162            )
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use arrow::array::StringArray;
170
171    use datafusion_common::cast::as_int32_array;
172
173    use super::*;
174
175    #[test]
176    fn to_levenshtein() -> Result<()> {
177        let string1_array =
178            Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"]));
179        let string2_array =
180            Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"]));
181        let res = levenshtein::<i32>(&[string1_array, string2_array]).unwrap();
182        let result =
183            as_int32_array(&res).expect("failed to initialized function levenshtein");
184        let expected = Int32Array::from(vec![2, 3, 2, 3]);
185        assert_eq!(&expected, result);
186
187        Ok(())
188    }
189}