lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-FileCopyrightText: Copyright The Lance Authors
5
6//! Exec plan planner
7
8use std::borrow::Cow;
9use std::collections::{BTreeSet, VecDeque};
10use std::sync::Arc;
11
12use crate::expr::safe_coerce_scalar;
13use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
14use crate::sql::{parse_sql_expr, parse_sql_filter};
15use arrow::compute::CastOptions;
16use arrow_array::ListArray;
17use arrow_buffer::OffsetBuffer;
18use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
19use arrow_select::concat::concat;
20use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
21use datafusion::common::DFSchema;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DFResult;
24use datafusion::execution::config::SessionConfig;
25use datafusion::execution::context::SessionState;
26use datafusion::execution::runtime_env::RuntimeEnvBuilder;
27use datafusion::execution::session_state::SessionStateBuilder;
28use datafusion::logical_expr::expr::ScalarFunction;
29use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
30use datafusion::logical_expr::{
31    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
32    WindowUDF,
33};
34use datafusion::optimizer::simplify_expressions::SimplifyContext;
35use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
36use datafusion::sql::sqlparser::ast::{
37    Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo, Expr as SQLExpr,
38    Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Subscript, TimezoneInfo,
39    UnaryOperator, Value,
40};
41use datafusion::{
42    common::Column,
43    logical_expr::{col, Between, BinaryExpr, Like, Operator},
44    physical_expr::execution_props::ExecutionProps,
45    physical_plan::PhysicalExpr,
46    prelude::Expr,
47    scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use snafu::location;
53
54use lance_core::{Error, Result};
55
56#[derive(Debug, Clone)]
57struct CastListF16Udf {
58    signature: Signature,
59}
60
61impl CastListF16Udf {
62    pub fn new() -> Self {
63        Self {
64            signature: Signature::any(1, Volatility::Immutable),
65        }
66    }
67}
68
69impl ScalarUDFImpl for CastListF16Udf {
70    fn as_any(&self) -> &dyn std::any::Any {
71        self
72    }
73
74    fn name(&self) -> &str {
75        "_cast_list_f16"
76    }
77
78    fn signature(&self) -> &Signature {
79        &self.signature
80    }
81
82    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
83        let input = &arg_types[0];
84        match input {
85            ArrowDataType::FixedSizeList(field, size) => {
86                if field.data_type() != &ArrowDataType::Float32
87                    && field.data_type() != &ArrowDataType::Float16
88                {
89                    return Err(datafusion::error::DataFusionError::Execution(
90                        "cast_list_f16 only supports list of float32 or float16".to_string(),
91                    ));
92                }
93                Ok(ArrowDataType::FixedSizeList(
94                    Arc::new(Field::new(
95                        field.name(),
96                        ArrowDataType::Float16,
97                        field.is_nullable(),
98                    )),
99                    *size,
100                ))
101            }
102            ArrowDataType::List(field) => {
103                if field.data_type() != &ArrowDataType::Float32
104                    && field.data_type() != &ArrowDataType::Float16
105                {
106                    return Err(datafusion::error::DataFusionError::Execution(
107                        "cast_list_f16 only supports list of float32 or float16".to_string(),
108                    ));
109                }
110                Ok(ArrowDataType::List(Arc::new(Field::new(
111                    field.name(),
112                    ArrowDataType::Float16,
113                    field.is_nullable(),
114                ))))
115            }
116            _ => Err(datafusion::error::DataFusionError::Execution(
117                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
118            )),
119        }
120    }
121
122    fn invoke(&self, args: &[ColumnarValue]) -> DFResult<ColumnarValue> {
123        let ColumnarValue::Array(arr) = &args[0] else {
124            return Err(datafusion::error::DataFusionError::Execution(
125                "cast_list_f16 only supports array arguments".to_string(),
126            ));
127        };
128
129        let to_type = match arr.data_type() {
130            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
131                Arc::new(Field::new(
132                    field.name(),
133                    ArrowDataType::Float16,
134                    field.is_nullable(),
135                )),
136                *size,
137            ),
138            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
139                field.name(),
140                ArrowDataType::Float16,
141                field.is_nullable(),
142            ))),
143            _ => {
144                return Err(datafusion::error::DataFusionError::Execution(
145                    "cast_list_f16 only supports array arguments".to_string(),
146                ));
147            }
148        };
149
150        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
151        Ok(ColumnarValue::Array(res))
152    }
153}
154
155// Adapter that instructs datafusion how lance expects expressions to be interpreted
156struct LanceContextProvider {
157    options: datafusion::config::ConfigOptions,
158    state: SessionState,
159    expr_planners: Vec<Arc<dyn ExprPlanner>>,
160}
161
162impl Default for LanceContextProvider {
163    fn default() -> Self {
164        let config = SessionConfig::new();
165        let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
166        let mut state_builder = SessionStateBuilder::new()
167            .with_config(config)
168            .with_runtime_env(runtime)
169            .with_default_features();
170
171        // SessionState does not expose expr_planners, so we need to get the default ones from
172        // the builder and store them to return from get_expr_planners
173
174        // unwrap safe because with_default_features sets expr_planners
175        let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
176
177        Self {
178            options: ConfigOptions::default(),
179            state: state_builder.build(),
180            expr_planners,
181        }
182    }
183}
184
185impl ContextProvider for LanceContextProvider {
186    fn get_table_source(
187        &self,
188        name: datafusion::sql::TableReference,
189    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
190        Err(datafusion::error::DataFusionError::NotImplemented(format!(
191            "Attempt to reference inner table {} not supported",
192            name
193        )))
194    }
195
196    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
197        self.state.aggregate_functions().get(name).cloned()
198    }
199
200    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
201        self.state.window_functions().get(name).cloned()
202    }
203
204    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
205        match f {
206            // TODO: cast should go thru CAST syntax instead of UDF
207            // Going thru UDF makes it hard for the optimizer to find no-ops
208            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
209            _ => self.state.scalar_functions().get(f).cloned(),
210        }
211    }
212
213    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
214        // Variables (things like @@LANGUAGE) not supported
215        None
216    }
217
218    fn options(&self) -> &datafusion::config::ConfigOptions {
219        &self.options
220    }
221
222    fn udf_names(&self) -> Vec<String> {
223        self.state.scalar_functions().keys().cloned().collect()
224    }
225
226    fn udaf_names(&self) -> Vec<String> {
227        self.state.aggregate_functions().keys().cloned().collect()
228    }
229
230    fn udwf_names(&self) -> Vec<String> {
231        self.state.window_functions().keys().cloned().collect()
232    }
233
234    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
235        &self.expr_planners
236    }
237}
238
239pub struct Planner {
240    schema: SchemaRef,
241    context_provider: LanceContextProvider,
242}
243
244impl Planner {
245    pub fn new(schema: SchemaRef) -> Self {
246        Self {
247            schema,
248            context_provider: LanceContextProvider::default(),
249        }
250    }
251
252    fn column(idents: &[Ident]) -> Expr {
253        let mut column = col(&idents[0].value);
254        for ident in &idents[1..] {
255            column = Expr::ScalarFunction(ScalarFunction {
256                args: vec![
257                    column,
258                    Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone()))),
259                ],
260                func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
261            });
262        }
263        column
264    }
265
266    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
267        Ok(match op {
268            BinaryOperator::Plus => Operator::Plus,
269            BinaryOperator::Minus => Operator::Minus,
270            BinaryOperator::Multiply => Operator::Multiply,
271            BinaryOperator::Divide => Operator::Divide,
272            BinaryOperator::Modulo => Operator::Modulo,
273            BinaryOperator::StringConcat => Operator::StringConcat,
274            BinaryOperator::Gt => Operator::Gt,
275            BinaryOperator::Lt => Operator::Lt,
276            BinaryOperator::GtEq => Operator::GtEq,
277            BinaryOperator::LtEq => Operator::LtEq,
278            BinaryOperator::Eq => Operator::Eq,
279            BinaryOperator::NotEq => Operator::NotEq,
280            BinaryOperator::And => Operator::And,
281            BinaryOperator::Or => Operator::Or,
282            _ => {
283                return Err(Error::invalid_input(
284                    format!("Operator {op} is not supported"),
285                    location!(),
286                ));
287            }
288        })
289    }
290
291    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
292        Ok(Expr::BinaryExpr(BinaryExpr::new(
293            Box::new(self.parse_sql_expr(left)?),
294            self.binary_op(op)?,
295            Box::new(self.parse_sql_expr(right)?),
296        )))
297    }
298
299    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
300        Ok(match op {
301            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
302                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
303            }
304
305            UnaryOperator::Minus => {
306                use datafusion::logical_expr::lit;
307                match expr {
308                    SQLExpr::Value(Value::Number(n, _)) => match n.parse::<i64>() {
309                        Ok(n) => lit(-n),
310                        Err(_) => lit(-n
311                            .parse::<f64>()
312                            .map_err(|_e| {
313                                Error::invalid_input(
314                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
315                                    location!(),
316                                )
317                            })?),
318                    },
319                    _ => {
320                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
321                    }
322                }
323            }
324
325            _ => {
326                return Err(Error::invalid_input(
327                    format!("Unary operator '{:?}' is not supported", op),
328                    location!(),
329                ));
330            }
331        })
332    }
333
334    // See datafusion `sqlToRel::parse_sql_number()`
335    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
336        use datafusion::logical_expr::lit;
337        let value: Cow<str> = if negative {
338            Cow::Owned(format!("-{}", value))
339        } else {
340            Cow::Borrowed(value)
341        };
342        if let Ok(n) = value.parse::<i64>() {
343            Ok(lit(n))
344        } else {
345            value.parse::<f64>().map(lit).map_err(|_| {
346                Error::invalid_input(
347                    format!("'{value}' is not supported number value."),
348                    location!(),
349                )
350            })
351        }
352    }
353
354    fn value(&self, value: &Value) -> Result<Expr> {
355        Ok(match value {
356            Value::Number(v, _) => self.number(v.as_str(), false)?,
357            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
358            Value::HexStringLiteral(hsl) => {
359                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)))
360            }
361            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
362            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v))),
363            Value::Null => Expr::Literal(ScalarValue::Null),
364            _ => todo!(),
365        })
366    }
367
368    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
369        match func_args {
370            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
371            _ => Err(Error::invalid_input(
372                format!("Unsupported function args: {:?}", func_args),
373                location!(),
374            )),
375        }
376    }
377
378    // We now use datafusion to parse functions.  This allows us to use datafusion's
379    // entire collection of functions (previously we had just hard-coded support for two functions).
380    //
381    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
382    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
383    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
384    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
385        match &func.args {
386            FunctionArguments::List(args) => {
387                if func.name.0.len() != 1 {
388                    return Err(Error::invalid_input(
389                        format!("Function name must have 1 part, got: {:?}", func.name.0),
390                        location!(),
391                    ));
392                }
393                Ok(Expr::IsNotNull(Box::new(
394                    self.parse_function_args(&args.args[0])?,
395                )))
396            }
397            _ => Err(Error::invalid_input(
398                format!("Unsupported function args: {:?}", &func.args),
399                location!(),
400            )),
401        }
402    }
403
404    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
405        if let SQLExpr::Function(function) = &function {
406            if !function.name.0.is_empty() && function.name.0[0].value == "is_valid" {
407                return self.legacy_parse_function(function);
408            }
409        }
410        let sql_to_rel = SqlToRel::new_with_options(
411            &self.context_provider,
412            ParserOptions {
413                parse_float_as_decimal: false,
414                enable_ident_normalization: false,
415                support_varchar_with_length: false,
416                enable_options_value_normalization: false,
417            },
418        );
419
420        let mut planner_context = PlannerContext::default();
421        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
422        Ok(sql_to_rel.sql_to_expr(function, &schema, &mut planner_context)?)
423    }
424
425    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
426        const SUPPORTED_TYPES: [&str; 13] = [
427            "int [unsigned]",
428            "tinyint [unsigned]",
429            "smallint [unsigned]",
430            "bigint [unsigned]",
431            "float",
432            "double",
433            "string",
434            "binary",
435            "date",
436            "timestamp(precision)",
437            "datetime(precision)",
438            "decimal(precision,scale)",
439            "boolean",
440        ];
441        match data_type {
442            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
443            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
444            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
445            SQLDataType::Double => Ok(ArrowDataType::Float64),
446            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
447            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
448            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
449            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
450            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
451            SQLDataType::UnsignedTinyInt(_) => Ok(ArrowDataType::UInt8),
452            SQLDataType::UnsignedSmallInt(_) => Ok(ArrowDataType::UInt16),
453            SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => {
454                Ok(ArrowDataType::UInt32)
455            }
456            SQLDataType::UnsignedBigInt(_) => Ok(ArrowDataType::UInt64),
457            SQLDataType::Date => Ok(ArrowDataType::Date32),
458            SQLDataType::Timestamp(resolution, tz) => {
459                match tz {
460                    TimezoneInfo::None => {}
461                    _ => {
462                        return Err(Error::invalid_input(
463                            "Timezone not supported in timestamp".to_string(),
464                            location!(),
465                        ));
466                    }
467                };
468                let time_unit = match resolution {
469                    // Default to microsecond to match PyArrow
470                    None => TimeUnit::Microsecond,
471                    Some(0) => TimeUnit::Second,
472                    Some(3) => TimeUnit::Millisecond,
473                    Some(6) => TimeUnit::Microsecond,
474                    Some(9) => TimeUnit::Nanosecond,
475                    _ => {
476                        return Err(Error::invalid_input(
477                            format!("Unsupported datetime resolution: {:?}", resolution),
478                            location!(),
479                        ));
480                    }
481                };
482                Ok(ArrowDataType::Timestamp(time_unit, None))
483            }
484            SQLDataType::Datetime(resolution) => {
485                let time_unit = match resolution {
486                    None => TimeUnit::Microsecond,
487                    Some(0) => TimeUnit::Second,
488                    Some(3) => TimeUnit::Millisecond,
489                    Some(6) => TimeUnit::Microsecond,
490                    Some(9) => TimeUnit::Nanosecond,
491                    _ => {
492                        return Err(Error::invalid_input(
493                            format!("Unsupported datetime resolution: {:?}", resolution),
494                            location!(),
495                        ));
496                    }
497                };
498                Ok(ArrowDataType::Timestamp(time_unit, None))
499            }
500            SQLDataType::Decimal(number_info) => match number_info {
501                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
502                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
503                }
504                _ => Err(Error::invalid_input(
505                    format!(
506                        "Must provide precision and scale for decimal: {:?}",
507                        number_info
508                    ),
509                    location!(),
510                )),
511            },
512            _ => Err(Error::invalid_input(
513                format!(
514                    "Unsupported data type: {:?}. Supported types: {:?}",
515                    data_type, SUPPORTED_TYPES
516                ),
517                location!(),
518            )),
519        }
520    }
521
522    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
523        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
524        for planner in self.context_provider.get_expr_planners() {
525            match planner.plan_field_access(field_access_expr, &df_schema)? {
526                PlannerResult::Planned(expr) => return Ok(expr),
527                PlannerResult::Original(expr) => {
528                    field_access_expr = expr;
529                }
530            }
531        }
532        Err(Error::invalid_input(
533            "Field access could not be planned",
534            location!(),
535        ))
536    }
537
538    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
539        match expr {
540            SQLExpr::Identifier(id) => {
541                // Users can pass string literals wrapped in `"`.
542                // (Normally SQL only allows single quotes.)
543                if id.quote_style == Some('"') {
544                    Ok(Expr::Literal(ScalarValue::Utf8(Some(id.value.clone()))))
545                // Users can wrap identifiers with ` to reference non-standard
546                // names, such as uppercase or spaces.
547                } else if id.quote_style == Some('`') {
548                    Ok(Expr::Column(Column::from_name(id.value.clone())))
549                } else {
550                    Ok(Self::column(vec![id.clone()].as_slice()))
551                }
552            }
553            SQLExpr::CompoundIdentifier(ids) => Ok(Self::column(ids.as_slice())),
554            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
555            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
556            SQLExpr::Value(value) => self.value(value),
557            SQLExpr::Array(SQLArray { elem, .. }) => {
558                let mut values = vec![];
559
560                let array_literal_error = |pos: usize, value: &_| {
561                    Err(Error::invalid_input(
562                        format!(
563                            "Expected a literal value in array, instead got {} at position {}",
564                            value, pos
565                        ),
566                        location!(),
567                    ))
568                };
569
570                for (pos, expr) in elem.iter().enumerate() {
571                    match expr {
572                        SQLExpr::Value(value) => {
573                            if let Expr::Literal(value) = self.value(value)? {
574                                values.push(value);
575                            } else {
576                                return array_literal_error(pos, expr);
577                            }
578                        }
579                        SQLExpr::UnaryOp {
580                            op: UnaryOperator::Minus,
581                            expr,
582                        } => {
583                            if let SQLExpr::Value(Value::Number(number, _)) = expr.as_ref() {
584                                if let Expr::Literal(value) = self.number(number, true)? {
585                                    values.push(value);
586                                } else {
587                                    return array_literal_error(pos, expr);
588                                }
589                            } else {
590                                return array_literal_error(pos, expr);
591                            }
592                        }
593                        _ => {
594                            return array_literal_error(pos, expr);
595                        }
596                    }
597                }
598
599                let field = if !values.is_empty() {
600                    let data_type = values[0].data_type();
601
602                    for value in &mut values {
603                        if value.data_type() != data_type {
604                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
605                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
606                                location!()
607                            ))?;
608                        }
609                    }
610                    Field::new("item", data_type, true)
611                } else {
612                    Field::new("item", ArrowDataType::Null, true)
613                };
614
615                let values = values
616                    .into_iter()
617                    .map(|v| v.to_array().map_err(Error::from))
618                    .collect::<Result<Vec<_>>>()?;
619                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
620                let values = concat(&array_refs)?;
621                let values = ListArray::try_new(
622                    field.into(),
623                    OffsetBuffer::from_lengths([values.len()]),
624                    values,
625                    None,
626                )?;
627
628                Ok(Expr::Literal(ScalarValue::List(Arc::new(values))))
629            }
630            // For example, DATE '2020-01-01'
631            SQLExpr::TypedString { data_type, value } => {
632                Ok(Expr::Cast(datafusion::logical_expr::Cast {
633                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value.clone())))),
634                    data_type: self.parse_type(data_type)?,
635                }))
636            }
637            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
638            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
639            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
640            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
641            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
642            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
643            SQLExpr::InList {
644                expr,
645                list,
646                negated,
647            } => {
648                let value_expr = self.parse_sql_expr(expr)?;
649                let list_exprs = list
650                    .iter()
651                    .map(|e| self.parse_sql_expr(e))
652                    .collect::<Result<Vec<_>>>()?;
653                Ok(value_expr.in_list(list_exprs, *negated))
654            }
655            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
656            SQLExpr::Function(_) => self.parse_function(expr.clone()),
657            SQLExpr::ILike {
658                negated,
659                expr,
660                pattern,
661                escape_char,
662                any: _,
663            } => Ok(Expr::Like(Like::new(
664                *negated,
665                Box::new(self.parse_sql_expr(expr)?),
666                Box::new(self.parse_sql_expr(pattern)?),
667                escape_char.as_ref().and_then(|c| c.chars().next()),
668                true,
669            ))),
670            SQLExpr::Like {
671                negated,
672                expr,
673                pattern,
674                escape_char,
675                any: _,
676            } => Ok(Expr::Like(Like::new(
677                *negated,
678                Box::new(self.parse_sql_expr(expr)?),
679                Box::new(self.parse_sql_expr(pattern)?),
680                escape_char.as_ref().and_then(|c| c.chars().next()),
681                false,
682            ))),
683            SQLExpr::Cast {
684                expr, data_type, ..
685            } => Ok(Expr::Cast(datafusion::logical_expr::Cast {
686                expr: Box::new(self.parse_sql_expr(expr)?),
687                data_type: self.parse_type(data_type)?,
688            })),
689            SQLExpr::MapAccess { column, keys } => {
690                let mut expr = self.parse_sql_expr(column)?;
691
692                for key in keys {
693                    let field_access = match &key.key {
694                        SQLExpr::Value(
695                            Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
696                        ) => GetFieldAccess::NamedStructField {
697                            name: ScalarValue::from(s.as_str()),
698                        },
699                        SQLExpr::JsonAccess { .. } => {
700                            return Err(Error::invalid_input(
701                                "JSON access is not supported",
702                                location!(),
703                            ));
704                        }
705                        key => {
706                            let key = Box::new(self.parse_sql_expr(key)?);
707                            GetFieldAccess::ListIndex { key }
708                        }
709                    };
710
711                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
712
713                    expr = self.plan_field_access(field_access_expr)?;
714                }
715
716                Ok(expr)
717            }
718            SQLExpr::Subscript { expr, subscript } => {
719                let expr = self.parse_sql_expr(expr)?;
720
721                let field_access = match subscript.as_ref() {
722                    Subscript::Index { index } => match index {
723                        SQLExpr::Value(
724                            Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
725                        ) => GetFieldAccess::NamedStructField {
726                            name: ScalarValue::from(s.as_str()),
727                        },
728                        SQLExpr::JsonAccess { .. } => {
729                            return Err(Error::invalid_input(
730                                "JSON access is not supported",
731                                location!(),
732                            ));
733                        }
734                        _ => {
735                            let key = Box::new(self.parse_sql_expr(index)?);
736                            GetFieldAccess::ListIndex { key }
737                        }
738                    },
739                    Subscript::Slice { .. } => {
740                        return Err(Error::invalid_input(
741                            "Slice subscript is not supported",
742                            location!(),
743                        ));
744                    }
745                };
746
747                let field_access_expr = RawFieldAccessExpr { expr, field_access };
748                self.plan_field_access(field_access_expr)
749            }
750            SQLExpr::Between {
751                expr,
752                negated,
753                low,
754                high,
755            } => {
756                // Parse the main expression and bounds
757                let expr = self.parse_sql_expr(expr)?;
758                let low = self.parse_sql_expr(low)?;
759                let high = self.parse_sql_expr(high)?;
760
761                let between = Expr::Between(Between::new(
762                    Box::new(expr),
763                    *negated,
764                    Box::new(low),
765                    Box::new(high),
766                ));
767                Ok(between)
768            }
769            _ => Err(Error::invalid_input(
770                format!("Expression '{expr}' is not supported SQL in lance"),
771                location!(),
772            )),
773        }
774    }
775
776    /// Create Logical [Expr] from a SQL filter clause.
777    ///
778    /// Note: the returned expression must be passed through [optimize_expr()]
779    /// before being passed to [create_physical_expr()].
780    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
781        // Allow sqlparser to parse filter as part of ONE SQL statement.
782        let ast_expr = parse_sql_filter(filter)?;
783        let expr = self.parse_sql_expr(&ast_expr)?;
784        let schema = Schema::try_from(self.schema.as_ref())?;
785        let resolved = resolve_expr(&expr, &schema)?;
786        coerce_filter_type_to_boolean(resolved)
787    }
788
789    /// Create Logical [Expr] from a SQL expression.
790    ///
791    /// Note: the returned expression must be passed through [optimize_filter()]
792    /// before being passed to [create_physical_expr()].
793    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
794        let ast_expr = parse_sql_expr(expr)?;
795        let expr = self.parse_sql_expr(&ast_expr)?;
796        let schema = Schema::try_from(self.schema.as_ref())?;
797        let resolved = resolve_expr(&expr, &schema)?;
798        Ok(resolved)
799    }
800
801    /// Try to decode bytes from hex literal string.
802    ///
803    /// Copied from datafusion because this is not public.
804    ///
805    /// TODO: use SqlToRel from Datafusion directly?
806    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
807        let hex_bytes = s.as_bytes();
808        let mut decoded_bytes = Vec::with_capacity((hex_bytes.len() + 1) / 2);
809
810        let start_idx = hex_bytes.len() % 2;
811        if start_idx > 0 {
812            // The first byte is formed of only one char.
813            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
814        }
815
816        for i in (start_idx..hex_bytes.len()).step_by(2) {
817            let high = Self::try_decode_hex_char(hex_bytes[i])?;
818            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
819            decoded_bytes.push((high << 4) | low);
820        }
821
822        Some(decoded_bytes)
823    }
824
825    /// Try to decode a byte from a hex char.
826    ///
827    /// None will be returned if the input char is hex-invalid.
828    const fn try_decode_hex_char(c: u8) -> Option<u8> {
829        match c {
830            b'A'..=b'F' => Some(c - b'A' + 10),
831            b'a'..=b'f' => Some(c - b'a' + 10),
832            b'0'..=b'9' => Some(c - b'0'),
833            _ => None,
834        }
835    }
836
837    /// Optimize the filter expression and coerce data types.
838    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
839        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
840
841        // DataFusion needs the simplify and coerce passes to be applied before
842        // expressions can be handled by the physical planner.
843        let props = ExecutionProps::default();
844        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
845        let simplifier =
846            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
847
848        let expr = simplifier.simplify(expr)?;
849        let expr = simplifier.coerce(expr, &df_schema)?;
850
851        Ok(expr)
852    }
853
854    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
855    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
856        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
857
858        Ok(datafusion::physical_expr::create_physical_expr(
859            expr,
860            df_schema.as_ref(),
861            &Default::default(),
862        )?)
863    }
864
865    /// Collect the columns in the expression.
866    ///
867    /// The columns are returned in sorted order.
868    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
869        let mut visitor = ColumnCapturingVisitor {
870            current_path: VecDeque::new(),
871            columns: BTreeSet::new(),
872        };
873        expr.visit(&mut visitor).unwrap();
874        visitor.columns.into_iter().collect()
875    }
876}
877
878struct ColumnCapturingVisitor {
879    // Current column path. If this is empty, we are not in a column expression.
880    current_path: VecDeque<String>,
881    columns: BTreeSet<String>,
882}
883
884impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
885    type Node = Expr;
886
887    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
888        match node {
889            Expr::Column(Column { name, .. }) => {
890                let mut path = name.clone();
891                for part in self.current_path.drain(..) {
892                    path.push('.');
893                    path.push_str(&part);
894                }
895                self.columns.insert(path);
896                self.current_path.clear();
897            }
898            Expr::ScalarFunction(udf) => {
899                if udf.name() == GetFieldFunc::default().name() {
900                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
901                        self.current_path.push_front(name.to_string())
902                    } else {
903                        self.current_path.clear();
904                    }
905                } else {
906                    self.current_path.clear();
907                }
908            }
909            _ => {
910                self.current_path.clear();
911            }
912        }
913
914        Ok(TreeNodeRecursion::Continue)
915    }
916}
917
918#[cfg(test)]
919mod tests {
920
921    use crate::logical_expr::ExprExt;
922
923    use super::*;
924
925    use arrow::datatypes::Float64Type;
926    use arrow_array::{
927        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
928        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
929        TimestampNanosecondArray, TimestampSecondArray,
930    };
931    use arrow_schema::{DataType, Fields, Schema};
932    use datafusion::{
933        logical_expr::{lit, Cast},
934        prelude::{array_element, get_field},
935    };
936    use datafusion_functions::core::expr_ext::FieldAccessor;
937
938    #[test]
939    fn test_parse_filter_simple() {
940        let schema = Arc::new(Schema::new(vec![
941            Field::new("i", DataType::Int32, false),
942            Field::new("s", DataType::Utf8, true),
943            Field::new(
944                "st",
945                DataType::Struct(Fields::from(vec![
946                    Field::new("x", DataType::Float32, false),
947                    Field::new("y", DataType::Float32, false),
948                ])),
949                true,
950            ),
951        ]));
952
953        let planner = Planner::new(schema.clone());
954
955        let expected = col("i")
956            .gt(lit(3_i32))
957            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
958            .and(
959                col("s")
960                    .eq(lit("str-4"))
961                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
962            );
963
964        // double quotes
965        let expr = planner
966            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
967            .unwrap();
968        assert_eq!(expr, expected);
969
970        // single quote
971        let expr = planner
972            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
973            .unwrap();
974
975        let physical_expr = planner.create_physical_expr(&expr).unwrap();
976
977        let batch = RecordBatch::try_new(
978            schema,
979            vec![
980                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
981                Arc::new(StringArray::from_iter_values(
982                    (0..10).map(|v| format!("str-{}", v)),
983                )),
984                Arc::new(StructArray::from(vec![
985                    (
986                        Arc::new(Field::new("x", DataType::Float32, false)),
987                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
988                            as ArrayRef,
989                    ),
990                    (
991                        Arc::new(Field::new("y", DataType::Float32, false)),
992                        Arc::new(Float32Array::from_iter_values(
993                            (0..10).map(|v| (v * 10) as f32),
994                        )),
995                    ),
996                ])),
997            ],
998        )
999        .unwrap();
1000        let predicates = physical_expr.evaluate(&batch).unwrap();
1001        assert_eq!(
1002            predicates.into_array(0).unwrap().as_ref(),
1003            &BooleanArray::from(vec![
1004                false, false, false, false, true, true, false, false, false, false
1005            ])
1006        );
1007    }
1008
1009    #[test]
1010    fn test_nested_col_refs() {
1011        let schema = Arc::new(Schema::new(vec![
1012            Field::new("s0", DataType::Utf8, true),
1013            Field::new(
1014                "st",
1015                DataType::Struct(Fields::from(vec![
1016                    Field::new("s1", DataType::Utf8, true),
1017                    Field::new(
1018                        "st",
1019                        DataType::Struct(Fields::from(vec![Field::new(
1020                            "s2",
1021                            DataType::Utf8,
1022                            true,
1023                        )])),
1024                        true,
1025                    ),
1026                ])),
1027                true,
1028            ),
1029        ]));
1030
1031        let planner = Planner::new(schema);
1032
1033        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1034            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1035            assert!(matches!(
1036                expr,
1037                Expr::BinaryExpr(BinaryExpr {
1038                    left: _,
1039                    op: Operator::Eq,
1040                    right: _
1041                })
1042            ));
1043            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1044                assert_eq!(left.as_ref(), expected);
1045            }
1046        }
1047
1048        let expected = Expr::Column(Column {
1049            relation: None,
1050            name: "s0".to_string(),
1051        });
1052        assert_column_eq(&planner, "s0", &expected);
1053        assert_column_eq(&planner, "`s0`", &expected);
1054
1055        let expected = Expr::ScalarFunction(ScalarFunction {
1056            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1057            args: vec![
1058                Expr::Column(Column {
1059                    relation: None,
1060                    name: "st".to_string(),
1061                }),
1062                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string()))),
1063            ],
1064        });
1065        assert_column_eq(&planner, "st.s1", &expected);
1066        assert_column_eq(&planner, "`st`.`s1`", &expected);
1067        assert_column_eq(&planner, "st.`s1`", &expected);
1068
1069        let expected = Expr::ScalarFunction(ScalarFunction {
1070            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1071            args: vec![
1072                Expr::ScalarFunction(ScalarFunction {
1073                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1074                    args: vec![
1075                        Expr::Column(Column {
1076                            relation: None,
1077                            name: "st".to_string(),
1078                        }),
1079                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string()))),
1080                    ],
1081                }),
1082                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string()))),
1083            ],
1084        });
1085
1086        assert_column_eq(&planner, "st.st.s2", &expected);
1087        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1088        assert_column_eq(&planner, "st.st.`s2`", &expected);
1089        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1090    }
1091
1092    #[test]
1093    fn test_nested_list_refs() {
1094        let schema = Arc::new(Schema::new(vec![Field::new(
1095            "l",
1096            DataType::List(Arc::new(Field::new(
1097                "item",
1098                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1099                true,
1100            ))),
1101            true,
1102        )]));
1103
1104        let planner = Planner::new(schema);
1105
1106        let expected = array_element(col("l"), lit(0_i64));
1107        let expr = planner.parse_expr("l[0]").unwrap();
1108        assert_eq!(expr, expected);
1109
1110        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1111        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1112        assert_eq!(expr, expected);
1113
1114        // FIXME: This should work, but sqlparser doesn't recognize anything
1115        // after the period for some reason.
1116        // let expr = planner.parse_expr("l[0].f1").unwrap();
1117        // assert_eq!(expr, expected);
1118    }
1119
1120    #[test]
1121    fn test_negative_expressions() {
1122        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1123
1124        let planner = Planner::new(schema.clone());
1125
1126        let expected = col("x")
1127            .gt(lit(-3_i64))
1128            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1129
1130        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1131
1132        assert_eq!(expr, expected);
1133
1134        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1135
1136        let batch = RecordBatch::try_new(
1137            schema,
1138            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1139        )
1140        .unwrap();
1141        let predicates = physical_expr.evaluate(&batch).unwrap();
1142        assert_eq!(
1143            predicates.into_array(0).unwrap().as_ref(),
1144            &BooleanArray::from(vec![
1145                false, false, false, true, true, true, true, false, false, false
1146            ])
1147        );
1148    }
1149
1150    #[test]
1151    fn test_negative_array_expressions() {
1152        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1153
1154        let planner = Planner::new(schema);
1155
1156        let expected = Expr::Literal(ScalarValue::List(Arc::new(
1157            ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1158                [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1159            )]),
1160        )));
1161
1162        let expr = planner
1163            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1164            .unwrap();
1165
1166        assert_eq!(expr, expected);
1167    }
1168
1169    #[test]
1170    fn test_sql_like() {
1171        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1172
1173        let planner = Planner::new(schema.clone());
1174
1175        let expected = col("s").like(lit("str-4"));
1176        // single quote
1177        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1178        assert_eq!(expr, expected);
1179        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1180
1181        let batch = RecordBatch::try_new(
1182            schema,
1183            vec![Arc::new(StringArray::from_iter_values(
1184                (0..10).map(|v| format!("str-{}", v)),
1185            ))],
1186        )
1187        .unwrap();
1188        let predicates = physical_expr.evaluate(&batch).unwrap();
1189        assert_eq!(
1190            predicates.into_array(0).unwrap().as_ref(),
1191            &BooleanArray::from(vec![
1192                false, false, false, false, true, false, false, false, false, false
1193            ])
1194        );
1195    }
1196
1197    #[test]
1198    fn test_not_like() {
1199        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1200
1201        let planner = Planner::new(schema.clone());
1202
1203        let expected = col("s").not_like(lit("str-4"));
1204        // single quote
1205        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1206        assert_eq!(expr, expected);
1207        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1208
1209        let batch = RecordBatch::try_new(
1210            schema,
1211            vec![Arc::new(StringArray::from_iter_values(
1212                (0..10).map(|v| format!("str-{}", v)),
1213            ))],
1214        )
1215        .unwrap();
1216        let predicates = physical_expr.evaluate(&batch).unwrap();
1217        assert_eq!(
1218            predicates.into_array(0).unwrap().as_ref(),
1219            &BooleanArray::from(vec![
1220                true, true, true, true, false, true, true, true, true, true
1221            ])
1222        );
1223    }
1224
1225    #[test]
1226    fn test_sql_is_in() {
1227        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1228
1229        let planner = Planner::new(schema.clone());
1230
1231        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1232        // single quote
1233        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1234        assert_eq!(expr, expected);
1235        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1236
1237        let batch = RecordBatch::try_new(
1238            schema,
1239            vec![Arc::new(StringArray::from_iter_values(
1240                (0..10).map(|v| format!("str-{}", v)),
1241            ))],
1242        )
1243        .unwrap();
1244        let predicates = physical_expr.evaluate(&batch).unwrap();
1245        assert_eq!(
1246            predicates.into_array(0).unwrap().as_ref(),
1247            &BooleanArray::from(vec![
1248                false, false, false, false, true, true, false, false, false, false
1249            ])
1250        );
1251    }
1252
1253    #[test]
1254    fn test_sql_is_null() {
1255        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1256
1257        let planner = Planner::new(schema.clone());
1258
1259        let expected = col("s").is_null();
1260        let expr = planner.parse_filter("s IS NULL").unwrap();
1261        assert_eq!(expr, expected);
1262        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1263
1264        let batch = RecordBatch::try_new(
1265            schema,
1266            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1267                if v % 3 == 0 {
1268                    Some(format!("str-{}", v))
1269                } else {
1270                    None
1271                }
1272            })))],
1273        )
1274        .unwrap();
1275        let predicates = physical_expr.evaluate(&batch).unwrap();
1276        assert_eq!(
1277            predicates.into_array(0).unwrap().as_ref(),
1278            &BooleanArray::from(vec![
1279                false, true, true, false, true, true, false, true, true, false
1280            ])
1281        );
1282
1283        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1284        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1285        let predicates = physical_expr.evaluate(&batch).unwrap();
1286        assert_eq!(
1287            predicates.into_array(0).unwrap().as_ref(),
1288            &BooleanArray::from(vec![
1289                true, false, false, true, false, false, true, false, false, true,
1290            ])
1291        );
1292    }
1293
1294    #[test]
1295    fn test_sql_invert() {
1296        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1297
1298        let planner = Planner::new(schema.clone());
1299
1300        let expr = planner.parse_filter("NOT s").unwrap();
1301        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1302
1303        let batch = RecordBatch::try_new(
1304            schema,
1305            vec![Arc::new(BooleanArray::from_iter(
1306                (0..10).map(|v| Some(v % 3 == 0)),
1307            ))],
1308        )
1309        .unwrap();
1310        let predicates = physical_expr.evaluate(&batch).unwrap();
1311        assert_eq!(
1312            predicates.into_array(0).unwrap().as_ref(),
1313            &BooleanArray::from(vec![
1314                false, true, true, false, true, true, false, true, true, false
1315            ])
1316        );
1317    }
1318
1319    #[test]
1320    fn test_sql_cast() {
1321        let cases = &[
1322            (
1323                "x = cast('2021-01-01 00:00:00' as timestamp)",
1324                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1325            ),
1326            (
1327                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1328                ArrowDataType::Timestamp(TimeUnit::Second, None),
1329            ),
1330            (
1331                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1332                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1333            ),
1334            (
1335                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1336                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1337            ),
1338            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1339            (
1340                "x = cast('1.238' as decimal(9,3))",
1341                ArrowDataType::Decimal128(9, 3),
1342            ),
1343            ("x = cast(1 as float)", ArrowDataType::Float32),
1344            ("x = cast(1 as double)", ArrowDataType::Float64),
1345            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1346            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1347            ("x = cast(1 as int)", ArrowDataType::Int32),
1348            ("x = cast(1 as integer)", ArrowDataType::Int32),
1349            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1350            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1351            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1352            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1353            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1354            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1355            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1356            ("x = cast(1 as string)", ArrowDataType::Utf8),
1357        ];
1358
1359        for (sql, expected_data_type) in cases {
1360            let schema = Arc::new(Schema::new(vec![Field::new(
1361                "x",
1362                expected_data_type.clone(),
1363                true,
1364            )]));
1365            let planner = Planner::new(schema.clone());
1366            let expr = planner.parse_filter(sql).unwrap();
1367
1368            // Get the thing after 'cast(` but before ' as'.
1369            let expected_value_str = sql
1370                .split("cast(")
1371                .nth(1)
1372                .unwrap()
1373                .split(" as")
1374                .next()
1375                .unwrap();
1376            // Remove any quote marks
1377            let expected_value_str = expected_value_str.trim_matches('\'');
1378
1379            match expr {
1380                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1381                    Expr::Cast(Cast { expr, data_type }) => {
1382                        match expr.as_ref() {
1383                            Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1384                                assert_eq!(value_str, expected_value_str);
1385                            }
1386                            Expr::Literal(ScalarValue::Int64(Some(value))) => {
1387                                assert_eq!(*value, 1);
1388                            }
1389                            _ => panic!("Expected cast to be applied to literal"),
1390                        }
1391                        assert_eq!(data_type, expected_data_type);
1392                    }
1393                    _ => panic!("Expected right to be a cast"),
1394                },
1395                _ => panic!("Expected binary expression"),
1396            }
1397        }
1398    }
1399
1400    #[test]
1401    fn test_sql_literals() {
1402        let cases = &[
1403            (
1404                "x = timestamp '2021-01-01 00:00:00'",
1405                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1406            ),
1407            (
1408                "x = timestamp(0) '2021-01-01 00:00:00'",
1409                ArrowDataType::Timestamp(TimeUnit::Second, None),
1410            ),
1411            (
1412                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1413                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1414            ),
1415            ("x = date '2021-01-01'", ArrowDataType::Date32),
1416            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1417        ];
1418
1419        for (sql, expected_data_type) in cases {
1420            let schema = Arc::new(Schema::new(vec![Field::new(
1421                "x",
1422                expected_data_type.clone(),
1423                true,
1424            )]));
1425            let planner = Planner::new(schema.clone());
1426            let expr = planner.parse_filter(sql).unwrap();
1427
1428            let expected_value_str = sql.split('\'').nth(1).unwrap();
1429
1430            match expr {
1431                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1432                    Expr::Cast(Cast { expr, data_type }) => {
1433                        match expr.as_ref() {
1434                            Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1435                                assert_eq!(value_str, expected_value_str);
1436                            }
1437                            _ => panic!("Expected cast to be applied to literal"),
1438                        }
1439                        assert_eq!(data_type, expected_data_type);
1440                    }
1441                    _ => panic!("Expected right to be a cast"),
1442                },
1443                _ => panic!("Expected binary expression"),
1444            }
1445        }
1446    }
1447
1448    #[test]
1449    fn test_sql_array_literals() {
1450        let cases = [
1451            (
1452                "x = [1, 2, 3]",
1453                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1454            ),
1455            (
1456                "x = [1, 2, 3]",
1457                ArrowDataType::FixedSizeList(
1458                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1459                    3,
1460                ),
1461            ),
1462        ];
1463
1464        for (sql, expected_data_type) in cases {
1465            let schema = Arc::new(Schema::new(vec![Field::new(
1466                "x",
1467                expected_data_type.clone(),
1468                true,
1469            )]));
1470            let planner = Planner::new(schema.clone());
1471            let expr = planner.parse_filter(sql).unwrap();
1472            let expr = planner.optimize_expr(expr).unwrap();
1473
1474            match expr {
1475                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1476                    Expr::Literal(value) => {
1477                        assert_eq!(&value.data_type(), &expected_data_type);
1478                    }
1479                    _ => panic!("Expected right to be a literal"),
1480                },
1481                _ => panic!("Expected binary expression"),
1482            }
1483        }
1484    }
1485
1486    #[test]
1487    fn test_sql_between() {
1488        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1489        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1490        use std::sync::Arc;
1491
1492        let schema = Arc::new(Schema::new(vec![
1493            Field::new("x", DataType::Int32, false),
1494            Field::new("y", DataType::Float64, false),
1495            Field::new(
1496                "ts",
1497                DataType::Timestamp(TimeUnit::Microsecond, None),
1498                false,
1499            ),
1500        ]));
1501
1502        let planner = Planner::new(schema.clone());
1503
1504        // Test integer BETWEEN
1505        let expr = planner
1506            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1507            .unwrap();
1508        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1509
1510        // Create timestamp array with values representing:
1511        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1512        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1513        let ts_array = TimestampMicrosecondArray::from_iter_values(
1514            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1515        );
1516
1517        let batch = RecordBatch::try_new(
1518            schema,
1519            vec![
1520                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1521                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1522                Arc::new(ts_array),
1523            ],
1524        )
1525        .unwrap();
1526
1527        let predicates = physical_expr.evaluate(&batch).unwrap();
1528        assert_eq!(
1529            predicates.into_array(0).unwrap().as_ref(),
1530            &BooleanArray::from(vec![
1531                false, false, false, true, true, true, true, true, false, false
1532            ])
1533        );
1534
1535        // Test NOT BETWEEN
1536        let expr = planner
1537            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1538            .unwrap();
1539        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1540
1541        let predicates = physical_expr.evaluate(&batch).unwrap();
1542        assert_eq!(
1543            predicates.into_array(0).unwrap().as_ref(),
1544            &BooleanArray::from(vec![
1545                true, true, true, false, false, false, false, false, true, true
1546            ])
1547        );
1548
1549        // Test floating point BETWEEN
1550        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1551        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1552
1553        let predicates = physical_expr.evaluate(&batch).unwrap();
1554        assert_eq!(
1555            predicates.into_array(0).unwrap().as_ref(),
1556            &BooleanArray::from(vec![
1557                false, false, false, true, true, true, true, false, false, false
1558            ])
1559        );
1560
1561        // Test timestamp BETWEEN
1562        let expr = planner
1563            .parse_filter(
1564                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1565            )
1566            .unwrap();
1567        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1568
1569        let predicates = physical_expr.evaluate(&batch).unwrap();
1570        assert_eq!(
1571            predicates.into_array(0).unwrap().as_ref(),
1572            &BooleanArray::from(vec![
1573                false, false, false, true, true, true, true, true, false, false
1574            ])
1575        );
1576    }
1577
1578    #[test]
1579    fn test_sql_comparison() {
1580        // Create a batch with all data types
1581        let batch: Vec<(&str, ArrayRef)> = vec![
1582            (
1583                "timestamp_s",
1584                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1585            ),
1586            (
1587                "timestamp_ms",
1588                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1589            ),
1590            (
1591                "timestamp_us",
1592                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1593            ),
1594            (
1595                "timestamp_ns",
1596                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1597            ),
1598        ];
1599        let batch = RecordBatch::try_from_iter(batch).unwrap();
1600
1601        let planner = Planner::new(batch.schema());
1602
1603        // Each expression is meant to select the final 5 rows
1604        let expressions = &[
1605            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1606            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1607            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1608            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1609        ];
1610
1611        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1612            std::iter::repeat(Some(false))
1613                .take(5)
1614                .chain(std::iter::repeat(Some(true)).take(5)),
1615        ));
1616        for expression in expressions {
1617            // convert to physical expression
1618            let logical_expr = planner.parse_filter(expression).unwrap();
1619            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1620            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1621
1622            // Evaluate and assert they have correct results
1623            let result = physical_expr.evaluate(&batch).unwrap();
1624            let result = result.into_array(batch.num_rows()).unwrap();
1625            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1626        }
1627    }
1628
1629    #[test]
1630    fn test_columns_in_expr() {
1631        let expr = col("s0").gt(lit("value")).and(
1632            col("st")
1633                .field("st")
1634                .field("s2")
1635                .eq(lit("value"))
1636                .or(col("st")
1637                    .field("s1")
1638                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1639        );
1640
1641        let columns = Planner::column_names_in_expr(&expr);
1642        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1643    }
1644
1645    #[test]
1646    fn test_parse_binary_expr() {
1647        let bin_str = "x'616263'";
1648
1649        let schema = Arc::new(Schema::new(vec![Field::new(
1650            "binary",
1651            DataType::Binary,
1652            true,
1653        )]));
1654        let planner = Planner::new(schema);
1655        let expr = planner.parse_expr(bin_str).unwrap();
1656        assert_eq!(
1657            expr,
1658            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])))
1659        );
1660    }
1661
1662    #[test]
1663    fn test_lance_context_provider_expr_planners() {
1664        let ctx_provider = LanceContextProvider::default();
1665        assert!(!ctx_provider.get_expr_planners().is_empty());
1666    }
1667}