lance_datafusion/
logical_expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extends logical expression.
5
6use std::sync::Arc;
7
8use arrow_schema::DataType;
9
10use crate::expr::safe_coerce_scalar;
11use datafusion::logical_expr::{expr::ScalarFunction, BinaryExpr, Operator};
12use datafusion::logical_expr::{Between, ScalarUDF, ScalarUDFImpl};
13use datafusion::prelude::*;
14use datafusion::scalar::ScalarValue;
15use datafusion_functions::core::getfield::GetFieldFunc;
16use lance_arrow::DataTypeExt;
17
18use lance_core::datatypes::Schema;
19use lance_core::{Error, Result};
20use snafu::location;
21/// Resolve a Value
22fn resolve_value(expr: &Expr, data_type: &DataType) -> Result<Expr> {
23    match expr {
24        Expr::Literal(scalar_value) => {
25            Ok(Expr::Literal(safe_coerce_scalar(scalar_value, data_type).ok_or_else(|| Error::invalid_input(
26                format!("Received literal {expr} and could not convert to literal of type '{data_type:?}'"),
27                location!(),
28            ))?))
29        }
30        _ => Err(Error::invalid_input(
31            format!("Expected a literal of type '{data_type:?}' but received: {expr}"),
32            location!(),
33        )),
34    }
35}
36
37/// A simple helper function that interprets an Expr as a string scalar
38/// or returns None if it is not.
39pub fn get_as_string_scalar_opt(expr: &Expr) -> Option<&str> {
40    match expr {
41        Expr::Literal(ScalarValue::Utf8(Some(s))) => Some(s),
42        _ => None,
43    }
44}
45
46/// Given a Expr::Column or Expr::GetIndexedField, get the data type of referenced
47/// field in the schema.
48///
49/// If the column is not found in the schema, return None. If the expression is
50/// not a field reference, also returns None.
51pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option<DataType> {
52    let mut field_path = Vec::new();
53    let mut current_expr = expr;
54    // We are looping from outer-most reference to inner-most.
55    loop {
56        match current_expr {
57            Expr::Column(c) => {
58                field_path.push(c.name.as_str());
59                break;
60            }
61            Expr::ScalarFunction(udf) => {
62                if udf.name() == GetFieldFunc::default().name() {
63                    let name = get_as_string_scalar_opt(&udf.args[1])?;
64                    field_path.push(name);
65                    current_expr = &udf.args[0];
66                } else {
67                    return None;
68                }
69            }
70            _ => return None,
71        }
72    }
73
74    let mut path_iter = field_path.iter().rev();
75    let mut field = schema.field(path_iter.next()?)?;
76    for name in path_iter {
77        if field.data_type().is_struct() {
78            field = field.children.iter().find(|f| &f.name == name)?;
79        } else {
80            return None;
81        }
82    }
83    Some(field.data_type())
84}
85
86/// Resolve logical expression `expr`.
87///
88/// Parameters
89///
90/// - *expr*: a datafusion logical expression
91/// - *schema*: lance schema.
92pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
93    match expr {
94        Expr::Between(Between {
95            expr: inner_expr,
96            low,
97            high,
98            negated,
99        }) => {
100            if let Some(inner_expr_type) = resolve_column_type(inner_expr.as_ref(), schema) {
101                Ok(Expr::Between(Between {
102                    expr: inner_expr.clone(),
103                    low: Box::new(coerce_expr(low.as_ref(), &inner_expr_type)?),
104                    high: Box::new(coerce_expr(high.as_ref(), &inner_expr_type)?),
105                    negated: *negated,
106                }))
107            } else {
108                Ok(expr.clone())
109            }
110        }
111        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
112            if matches!(op, Operator::And | Operator::Or) {
113                Ok(Expr::BinaryExpr(BinaryExpr {
114                    left: Box::new(resolve_expr(left.as_ref(), schema)?),
115                    op: *op,
116                    right: Box::new(resolve_expr(right.as_ref(), schema)?),
117                }))
118            } else if let Some(left_type) = resolve_column_type(left.as_ref(), schema) {
119                match right.as_ref() {
120                    Expr::Literal(_) => Ok(Expr::BinaryExpr(BinaryExpr {
121                        left: left.clone(),
122                        op: *op,
123                        right: Box::new(resolve_value(right.as_ref(), &left_type)?),
124                    })),
125                    // For cases complex expressions (not just literals) on right hand side like x = 1 + 1 + -2*2
126                    Expr::BinaryExpr(r) => Ok(Expr::BinaryExpr(BinaryExpr {
127                        left: left.clone(),
128                        op: *op,
129                        right: Box::new(Expr::BinaryExpr(BinaryExpr {
130                            left: coerce_expr(&r.left, &left_type).map(Box::new)?,
131                            op: r.op,
132                            right: coerce_expr(&r.right, &left_type).map(Box::new)?,
133                        })),
134                    })),
135                    _ => Ok(expr.clone()),
136                }
137            } else if let Some(right_type) = resolve_column_type(right.as_ref(), schema) {
138                match left.as_ref() {
139                    Expr::Literal(_) => Ok(Expr::BinaryExpr(BinaryExpr {
140                        left: Box::new(resolve_value(left.as_ref(), &right_type)?),
141                        op: *op,
142                        right: right.clone(),
143                    })),
144                    _ => Ok(expr.clone()),
145                }
146            } else {
147                Ok(expr.clone())
148            }
149        }
150        Expr::InList(in_list) => {
151            if matches!(in_list.expr.as_ref(), Expr::Column(_)) {
152                if let Some(resolved_type) = resolve_column_type(in_list.expr.as_ref(), schema) {
153                    let resolved_values = in_list
154                        .list
155                        .iter()
156                        .map(|val| coerce_expr(val, &resolved_type))
157                        .collect::<Result<Vec<_>>>()?;
158                    Ok(Expr::in_list(
159                        in_list.expr.as_ref().clone(),
160                        resolved_values,
161                        in_list.negated,
162                    ))
163                } else {
164                    Ok(expr.clone())
165                }
166            } else {
167                Ok(expr.clone())
168            }
169        }
170        _ => {
171            // Passthrough
172            Ok(expr.clone())
173        }
174    }
175}
176
177/// Coerce expression of literals to column type.
178///
179/// Parameters
180///
181/// - *expr*: a datafusion logical expression
182/// - *dtype*: a lance data type
183pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result<Expr> {
184    match expr {
185        Expr::BinaryExpr(BinaryExpr { left, op, right }) => Ok(Expr::BinaryExpr(BinaryExpr {
186            left: Box::new(coerce_expr(left, dtype)?),
187            op: *op,
188            right: Box::new(coerce_expr(right, dtype)?),
189        })),
190        Expr::Literal(l) => Ok(resolve_value(&Expr::Literal(l.clone()), dtype)?),
191        _ => Ok(expr.clone()),
192    }
193}
194
195/// Coerce logical expression for filters to boolean.
196///
197/// Parameters
198///
199/// - *expr*: a datafusion logical expression
200pub fn coerce_filter_type_to_boolean(expr: Expr) -> Result<Expr> {
201    match &expr {
202        // TODO: consider making this dispatch more generic, i.e. fun.output_type -> coerce
203        // instead of hardcoding coerce method for each function
204        Expr::ScalarFunction(ScalarFunction { func, .. }) => {
205            if func.name() == "regexp_match" {
206                Ok(Expr::IsNotNull(Box::new(expr)))
207            } else {
208                Ok(expr)
209            }
210        }
211        _ => Ok(expr),
212    }
213}
214
215// As part of the DF 37 release there are now two different ways to
216// represent a nested field access in `Expr`.  The old way is to use
217// `Expr::field` which returns a `GetStructField` and the new way is
218// to use `Expr::ScalarFunction` with a `GetFieldFunc` UDF.
219//
220// Currently, the old path leads to bugs in DF.  This is probably a
221// bug and will probably be fixed in a future version.  In the meantime
222// we need to make sure we are always using the new way to avoid this
223// bug.  This trait adds field_newstyle which lets us easily create
224// logical `Expr` that use the new style.
225pub trait ExprExt {
226    // Helper function to replace Expr::field in DF 37 since DF
227    // confuses itself with the GetStructField returned by Expr::field
228    fn field_newstyle(&self, name: &str) -> Expr;
229}
230
231impl ExprExt for Expr {
232    fn field_newstyle(&self, name: &str) -> Expr {
233        Self::ScalarFunction(ScalarFunction {
234            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
235            args: vec![
236                self.clone(),
237                Self::Literal(ScalarValue::Utf8(Some(name.to_string()))),
238            ],
239        })
240    }
241}
242
243#[cfg(test)]
244pub mod tests {
245    use std::sync::Arc;
246
247    use super::*;
248
249    use arrow_schema::{Field, Schema as ArrowSchema};
250    use datafusion_functions::core::expr_ext::FieldAccessor;
251
252    #[test]
253    fn test_resolve_large_utf8() {
254        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::LargeUtf8, false)]);
255        let expr = Expr::BinaryExpr(BinaryExpr {
256            left: Box::new(Expr::Column("a".to_string().into())),
257            op: Operator::Eq,
258            right: Box::new(Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))),
259        });
260
261        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
262        match resolved {
263            Expr::BinaryExpr(be) => {
264                assert_eq!(
265                    be.right.as_ref(),
266                    &Expr::Literal(ScalarValue::LargeUtf8(Some("a".to_string())))
267                )
268            }
269            _ => unreachable!("Expected BinaryExpr"),
270        };
271    }
272
273    #[test]
274    fn test_resolve_binary_expr_on_right() {
275        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float64, false)]);
276        let expr = Expr::BinaryExpr(BinaryExpr {
277            left: Box::new(Expr::Column("a".to_string().into())),
278            op: Operator::Eq,
279            right: Box::new(Expr::BinaryExpr(BinaryExpr {
280                left: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)))),
281                op: Operator::Minus,
282                right: Box::new(Expr::Literal(ScalarValue::Int64(Some(-1)))),
283            })),
284        });
285        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
286
287        match resolved {
288            Expr::BinaryExpr(be) => match be.right.as_ref() {
289                Expr::BinaryExpr(r_be) => {
290                    assert_eq!(
291                        r_be.left.as_ref(),
292                        &Expr::Literal(ScalarValue::Float64(Some(2.0)))
293                    );
294                    assert_eq!(
295                        r_be.right.as_ref(),
296                        &Expr::Literal(ScalarValue::Float64(Some(-1.0)))
297                    );
298                }
299                _ => panic!("Expected BinaryExpr"),
300            },
301            _ => panic!("Expected BinaryExpr"),
302        }
303    }
304
305    #[test]
306    fn test_resolve_in_expr() {
307        // Type coercion should apply for `A IN (0)` or `A NOT IN (0)`
308        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, false)]);
309        let expr = Expr::in_list(
310            Expr::Column("a".to_string().into()),
311            vec![Expr::Literal(ScalarValue::Float64(Some(0.0)))],
312            false,
313        );
314        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
315        let expected = Expr::in_list(
316            Expr::Column("a".to_string().into()),
317            vec![Expr::Literal(ScalarValue::Float32(Some(0.0)))],
318            false,
319        );
320        assert_eq!(resolved, expected);
321
322        let expr = Expr::in_list(
323            Expr::Column("a".to_string().into()),
324            vec![Expr::Literal(ScalarValue::Float64(Some(0.0)))],
325            true,
326        );
327        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
328        let expected = Expr::in_list(
329            Expr::Column("a".to_string().into()),
330            vec![Expr::Literal(ScalarValue::Float32(Some(0.0)))],
331            true,
332        );
333        assert_eq!(resolved, expected);
334    }
335
336    #[test]
337    fn test_resolve_column_type() {
338        let schema = Arc::new(ArrowSchema::new(vec![
339            Field::new("int", DataType::Int32, true),
340            Field::new(
341                "st",
342                DataType::Struct(
343                    vec![
344                        Field::new("str", DataType::Utf8, true),
345                        Field::new(
346                            "st",
347                            DataType::Struct(
348                                vec![Field::new("float", DataType::Float64, true)].into(),
349                            ),
350                            true,
351                        ),
352                    ]
353                    .into(),
354                ),
355                true,
356            ),
357        ]));
358        let schema = Schema::try_from(schema.as_ref()).unwrap();
359
360        assert_eq!(
361            resolve_column_type(&col("int"), &schema),
362            Some(DataType::Int32)
363        );
364        assert_eq!(
365            resolve_column_type(&col("st").field("str"), &schema),
366            Some(DataType::Utf8)
367        );
368        assert_eq!(
369            resolve_column_type(&col("st").field("st").field("float"), &schema),
370            Some(DataType::Float64)
371        );
372
373        assert_eq!(resolve_column_type(&col("x"), &schema), None);
374        assert_eq!(resolve_column_type(&col("str"), &schema), None);
375        assert_eq!(resolve_column_type(&col("float"), &schema), None);
376        assert_eq!(
377            resolve_column_type(&col("st").field("str").eq(lit("x")), &schema),
378            None
379        );
380    }
381}