1use std::sync::Arc;
7
8use arrow_schema::DataType;
9
10use crate::expr::safe_coerce_scalar;
11use datafusion::logical_expr::{expr::ScalarFunction, BinaryExpr, Operator};
12use datafusion::logical_expr::{Between, ScalarUDF, ScalarUDFImpl};
13use datafusion::prelude::*;
14use datafusion::scalar::ScalarValue;
15use datafusion_functions::core::getfield::GetFieldFunc;
16use lance_arrow::DataTypeExt;
17
18use lance_core::datatypes::Schema;
19use lance_core::{Error, Result};
20use snafu::location;
21fn resolve_value(expr: &Expr, data_type: &DataType) -> Result<Expr> {
23 match expr {
24 Expr::Literal(scalar_value) => {
25 Ok(Expr::Literal(safe_coerce_scalar(scalar_value, data_type).ok_or_else(|| Error::invalid_input(
26 format!("Received literal {expr} and could not convert to literal of type '{data_type:?}'"),
27 location!(),
28 ))?))
29 }
30 _ => Err(Error::invalid_input(
31 format!("Expected a literal of type '{data_type:?}' but received: {expr}"),
32 location!(),
33 )),
34 }
35}
36
37pub fn get_as_string_scalar_opt(expr: &Expr) -> Option<&str> {
40 match expr {
41 Expr::Literal(ScalarValue::Utf8(Some(s))) => Some(s),
42 _ => None,
43 }
44}
45
46pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option<DataType> {
52 let mut field_path = Vec::new();
53 let mut current_expr = expr;
54 loop {
56 match current_expr {
57 Expr::Column(c) => {
58 field_path.push(c.name.as_str());
59 break;
60 }
61 Expr::ScalarFunction(udf) => {
62 if udf.name() == GetFieldFunc::default().name() {
63 let name = get_as_string_scalar_opt(&udf.args[1])?;
64 field_path.push(name);
65 current_expr = &udf.args[0];
66 } else {
67 return None;
68 }
69 }
70 _ => return None,
71 }
72 }
73
74 let mut path_iter = field_path.iter().rev();
75 let mut field = schema.field(path_iter.next()?)?;
76 for name in path_iter {
77 if field.data_type().is_struct() {
78 field = field.children.iter().find(|f| &f.name == name)?;
79 } else {
80 return None;
81 }
82 }
83 Some(field.data_type())
84}
85
86pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
93 match expr {
94 Expr::Between(Between {
95 expr: inner_expr,
96 low,
97 high,
98 negated,
99 }) => {
100 if let Some(inner_expr_type) = resolve_column_type(inner_expr.as_ref(), schema) {
101 Ok(Expr::Between(Between {
102 expr: inner_expr.clone(),
103 low: Box::new(coerce_expr(low.as_ref(), &inner_expr_type)?),
104 high: Box::new(coerce_expr(high.as_ref(), &inner_expr_type)?),
105 negated: *negated,
106 }))
107 } else {
108 Ok(expr.clone())
109 }
110 }
111 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
112 if matches!(op, Operator::And | Operator::Or) {
113 Ok(Expr::BinaryExpr(BinaryExpr {
114 left: Box::new(resolve_expr(left.as_ref(), schema)?),
115 op: *op,
116 right: Box::new(resolve_expr(right.as_ref(), schema)?),
117 }))
118 } else if let Some(left_type) = resolve_column_type(left.as_ref(), schema) {
119 match right.as_ref() {
120 Expr::Literal(_) => Ok(Expr::BinaryExpr(BinaryExpr {
121 left: left.clone(),
122 op: *op,
123 right: Box::new(resolve_value(right.as_ref(), &left_type)?),
124 })),
125 Expr::BinaryExpr(r) => Ok(Expr::BinaryExpr(BinaryExpr {
127 left: left.clone(),
128 op: *op,
129 right: Box::new(Expr::BinaryExpr(BinaryExpr {
130 left: coerce_expr(&r.left, &left_type).map(Box::new)?,
131 op: r.op,
132 right: coerce_expr(&r.right, &left_type).map(Box::new)?,
133 })),
134 })),
135 _ => Ok(expr.clone()),
136 }
137 } else if let Some(right_type) = resolve_column_type(right.as_ref(), schema) {
138 match left.as_ref() {
139 Expr::Literal(_) => Ok(Expr::BinaryExpr(BinaryExpr {
140 left: Box::new(resolve_value(left.as_ref(), &right_type)?),
141 op: *op,
142 right: right.clone(),
143 })),
144 _ => Ok(expr.clone()),
145 }
146 } else {
147 Ok(expr.clone())
148 }
149 }
150 Expr::InList(in_list) => {
151 if matches!(in_list.expr.as_ref(), Expr::Column(_)) {
152 if let Some(resolved_type) = resolve_column_type(in_list.expr.as_ref(), schema) {
153 let resolved_values = in_list
154 .list
155 .iter()
156 .map(|val| coerce_expr(val, &resolved_type))
157 .collect::<Result<Vec<_>>>()?;
158 Ok(Expr::in_list(
159 in_list.expr.as_ref().clone(),
160 resolved_values,
161 in_list.negated,
162 ))
163 } else {
164 Ok(expr.clone())
165 }
166 } else {
167 Ok(expr.clone())
168 }
169 }
170 _ => {
171 Ok(expr.clone())
173 }
174 }
175}
176
177pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result<Expr> {
184 match expr {
185 Expr::BinaryExpr(BinaryExpr { left, op, right }) => Ok(Expr::BinaryExpr(BinaryExpr {
186 left: Box::new(coerce_expr(left, dtype)?),
187 op: *op,
188 right: Box::new(coerce_expr(right, dtype)?),
189 })),
190 Expr::Literal(l) => Ok(resolve_value(&Expr::Literal(l.clone()), dtype)?),
191 _ => Ok(expr.clone()),
192 }
193}
194
195pub fn coerce_filter_type_to_boolean(expr: Expr) -> Result<Expr> {
201 match &expr {
202 Expr::ScalarFunction(ScalarFunction { func, .. }) => {
205 if func.name() == "regexp_match" {
206 Ok(Expr::IsNotNull(Box::new(expr)))
207 } else {
208 Ok(expr)
209 }
210 }
211 _ => Ok(expr),
212 }
213}
214
215pub trait ExprExt {
226 fn field_newstyle(&self, name: &str) -> Expr;
229}
230
231impl ExprExt for Expr {
232 fn field_newstyle(&self, name: &str) -> Expr {
233 Self::ScalarFunction(ScalarFunction {
234 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
235 args: vec![
236 self.clone(),
237 Self::Literal(ScalarValue::Utf8(Some(name.to_string()))),
238 ],
239 })
240 }
241}
242
243#[cfg(test)]
244pub mod tests {
245 use std::sync::Arc;
246
247 use super::*;
248
249 use arrow_schema::{Field, Schema as ArrowSchema};
250 use datafusion_functions::core::expr_ext::FieldAccessor;
251
252 #[test]
253 fn test_resolve_large_utf8() {
254 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::LargeUtf8, false)]);
255 let expr = Expr::BinaryExpr(BinaryExpr {
256 left: Box::new(Expr::Column("a".to_string().into())),
257 op: Operator::Eq,
258 right: Box::new(Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))),
259 });
260
261 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
262 match resolved {
263 Expr::BinaryExpr(be) => {
264 assert_eq!(
265 be.right.as_ref(),
266 &Expr::Literal(ScalarValue::LargeUtf8(Some("a".to_string())))
267 )
268 }
269 _ => unreachable!("Expected BinaryExpr"),
270 };
271 }
272
273 #[test]
274 fn test_resolve_binary_expr_on_right() {
275 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float64, false)]);
276 let expr = Expr::BinaryExpr(BinaryExpr {
277 left: Box::new(Expr::Column("a".to_string().into())),
278 op: Operator::Eq,
279 right: Box::new(Expr::BinaryExpr(BinaryExpr {
280 left: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)))),
281 op: Operator::Minus,
282 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(-1)))),
283 })),
284 });
285 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
286
287 match resolved {
288 Expr::BinaryExpr(be) => match be.right.as_ref() {
289 Expr::BinaryExpr(r_be) => {
290 assert_eq!(
291 r_be.left.as_ref(),
292 &Expr::Literal(ScalarValue::Float64(Some(2.0)))
293 );
294 assert_eq!(
295 r_be.right.as_ref(),
296 &Expr::Literal(ScalarValue::Float64(Some(-1.0)))
297 );
298 }
299 _ => panic!("Expected BinaryExpr"),
300 },
301 _ => panic!("Expected BinaryExpr"),
302 }
303 }
304
305 #[test]
306 fn test_resolve_in_expr() {
307 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, false)]);
309 let expr = Expr::in_list(
310 Expr::Column("a".to_string().into()),
311 vec![Expr::Literal(ScalarValue::Float64(Some(0.0)))],
312 false,
313 );
314 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
315 let expected = Expr::in_list(
316 Expr::Column("a".to_string().into()),
317 vec![Expr::Literal(ScalarValue::Float32(Some(0.0)))],
318 false,
319 );
320 assert_eq!(resolved, expected);
321
322 let expr = Expr::in_list(
323 Expr::Column("a".to_string().into()),
324 vec![Expr::Literal(ScalarValue::Float64(Some(0.0)))],
325 true,
326 );
327 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
328 let expected = Expr::in_list(
329 Expr::Column("a".to_string().into()),
330 vec![Expr::Literal(ScalarValue::Float32(Some(0.0)))],
331 true,
332 );
333 assert_eq!(resolved, expected);
334 }
335
336 #[test]
337 fn test_resolve_column_type() {
338 let schema = Arc::new(ArrowSchema::new(vec![
339 Field::new("int", DataType::Int32, true),
340 Field::new(
341 "st",
342 DataType::Struct(
343 vec![
344 Field::new("str", DataType::Utf8, true),
345 Field::new(
346 "st",
347 DataType::Struct(
348 vec![Field::new("float", DataType::Float64, true)].into(),
349 ),
350 true,
351 ),
352 ]
353 .into(),
354 ),
355 true,
356 ),
357 ]));
358 let schema = Schema::try_from(schema.as_ref()).unwrap();
359
360 assert_eq!(
361 resolve_column_type(&col("int"), &schema),
362 Some(DataType::Int32)
363 );
364 assert_eq!(
365 resolve_column_type(&col("st").field("str"), &schema),
366 Some(DataType::Utf8)
367 );
368 assert_eq!(
369 resolve_column_type(&col("st").field("st").field("float"), &schema),
370 Some(DataType::Float64)
371 );
372
373 assert_eq!(resolve_column_type(&col("x"), &schema), None);
374 assert_eq!(resolve_column_type(&col("str"), &schema), None);
375 assert_eq!(resolve_column_type(&col("float"), &schema), None);
376 assert_eq!(
377 resolve_column_type(&col("st").field("str").eq(lit("x")), &schema),
378 None
379 );
380 }
381}