1use std::borrow::Cow;
9use std::collections::{BTreeSet, VecDeque};
10use std::sync::Arc;
11
12use crate::expr::safe_coerce_scalar;
13use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
14use crate::sql::{parse_sql_expr, parse_sql_filter};
15use arrow::compute::CastOptions;
16use arrow_array::ListArray;
17use arrow_buffer::OffsetBuffer;
18use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
19use arrow_select::concat::concat;
20use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
21use datafusion::common::DFSchema;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DFResult;
24use datafusion::execution::config::SessionConfig;
25use datafusion::execution::context::SessionState;
26use datafusion::execution::runtime_env::RuntimeEnvBuilder;
27use datafusion::execution::session_state::SessionStateBuilder;
28use datafusion::logical_expr::expr::ScalarFunction;
29use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
30use datafusion::logical_expr::{
31 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
32 WindowUDF,
33};
34use datafusion::optimizer::simplify_expressions::SimplifyContext;
35use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
36use datafusion::sql::sqlparser::ast::{
37 Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo, Expr as SQLExpr,
38 Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Subscript, TimezoneInfo,
39 UnaryOperator, Value,
40};
41use datafusion::{
42 common::Column,
43 logical_expr::{col, Between, BinaryExpr, Like, Operator},
44 physical_expr::execution_props::ExecutionProps,
45 physical_plan::PhysicalExpr,
46 prelude::Expr,
47 scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use snafu::location;
53
54use lance_core::{Error, Result};
55
56#[derive(Debug, Clone)]
57struct CastListF16Udf {
58 signature: Signature,
59}
60
61impl CastListF16Udf {
62 pub fn new() -> Self {
63 Self {
64 signature: Signature::any(1, Volatility::Immutable),
65 }
66 }
67}
68
69impl ScalarUDFImpl for CastListF16Udf {
70 fn as_any(&self) -> &dyn std::any::Any {
71 self
72 }
73
74 fn name(&self) -> &str {
75 "_cast_list_f16"
76 }
77
78 fn signature(&self) -> &Signature {
79 &self.signature
80 }
81
82 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
83 let input = &arg_types[0];
84 match input {
85 ArrowDataType::FixedSizeList(field, size) => {
86 if field.data_type() != &ArrowDataType::Float32
87 && field.data_type() != &ArrowDataType::Float16
88 {
89 return Err(datafusion::error::DataFusionError::Execution(
90 "cast_list_f16 only supports list of float32 or float16".to_string(),
91 ));
92 }
93 Ok(ArrowDataType::FixedSizeList(
94 Arc::new(Field::new(
95 field.name(),
96 ArrowDataType::Float16,
97 field.is_nullable(),
98 )),
99 *size,
100 ))
101 }
102 ArrowDataType::List(field) => {
103 if field.data_type() != &ArrowDataType::Float32
104 && field.data_type() != &ArrowDataType::Float16
105 {
106 return Err(datafusion::error::DataFusionError::Execution(
107 "cast_list_f16 only supports list of float32 or float16".to_string(),
108 ));
109 }
110 Ok(ArrowDataType::List(Arc::new(Field::new(
111 field.name(),
112 ArrowDataType::Float16,
113 field.is_nullable(),
114 ))))
115 }
116 _ => Err(datafusion::error::DataFusionError::Execution(
117 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
118 )),
119 }
120 }
121
122 fn invoke(&self, args: &[ColumnarValue]) -> DFResult<ColumnarValue> {
123 let ColumnarValue::Array(arr) = &args[0] else {
124 return Err(datafusion::error::DataFusionError::Execution(
125 "cast_list_f16 only supports array arguments".to_string(),
126 ));
127 };
128
129 let to_type = match arr.data_type() {
130 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
131 Arc::new(Field::new(
132 field.name(),
133 ArrowDataType::Float16,
134 field.is_nullable(),
135 )),
136 *size,
137 ),
138 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
139 field.name(),
140 ArrowDataType::Float16,
141 field.is_nullable(),
142 ))),
143 _ => {
144 return Err(datafusion::error::DataFusionError::Execution(
145 "cast_list_f16 only supports array arguments".to_string(),
146 ));
147 }
148 };
149
150 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
151 Ok(ColumnarValue::Array(res))
152 }
153}
154
155struct LanceContextProvider {
157 options: datafusion::config::ConfigOptions,
158 state: SessionState,
159 expr_planners: Vec<Arc<dyn ExprPlanner>>,
160}
161
162impl Default for LanceContextProvider {
163 fn default() -> Self {
164 let config = SessionConfig::new();
165 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
166 let mut state_builder = SessionStateBuilder::new()
167 .with_config(config)
168 .with_runtime_env(runtime)
169 .with_default_features();
170
171 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
176
177 Self {
178 options: ConfigOptions::default(),
179 state: state_builder.build(),
180 expr_planners,
181 }
182 }
183}
184
185impl ContextProvider for LanceContextProvider {
186 fn get_table_source(
187 &self,
188 name: datafusion::sql::TableReference,
189 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
190 Err(datafusion::error::DataFusionError::NotImplemented(format!(
191 "Attempt to reference inner table {} not supported",
192 name
193 )))
194 }
195
196 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
197 self.state.aggregate_functions().get(name).cloned()
198 }
199
200 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
201 self.state.window_functions().get(name).cloned()
202 }
203
204 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
205 match f {
206 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
209 _ => self.state.scalar_functions().get(f).cloned(),
210 }
211 }
212
213 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
214 None
216 }
217
218 fn options(&self) -> &datafusion::config::ConfigOptions {
219 &self.options
220 }
221
222 fn udf_names(&self) -> Vec<String> {
223 self.state.scalar_functions().keys().cloned().collect()
224 }
225
226 fn udaf_names(&self) -> Vec<String> {
227 self.state.aggregate_functions().keys().cloned().collect()
228 }
229
230 fn udwf_names(&self) -> Vec<String> {
231 self.state.window_functions().keys().cloned().collect()
232 }
233
234 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
235 &self.expr_planners
236 }
237}
238
239pub struct Planner {
240 schema: SchemaRef,
241 context_provider: LanceContextProvider,
242}
243
244impl Planner {
245 pub fn new(schema: SchemaRef) -> Self {
246 Self {
247 schema,
248 context_provider: LanceContextProvider::default(),
249 }
250 }
251
252 fn column(idents: &[Ident]) -> Expr {
253 let mut column = col(&idents[0].value);
254 for ident in &idents[1..] {
255 column = Expr::ScalarFunction(ScalarFunction {
256 args: vec![
257 column,
258 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone()))),
259 ],
260 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
261 });
262 }
263 column
264 }
265
266 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
267 Ok(match op {
268 BinaryOperator::Plus => Operator::Plus,
269 BinaryOperator::Minus => Operator::Minus,
270 BinaryOperator::Multiply => Operator::Multiply,
271 BinaryOperator::Divide => Operator::Divide,
272 BinaryOperator::Modulo => Operator::Modulo,
273 BinaryOperator::StringConcat => Operator::StringConcat,
274 BinaryOperator::Gt => Operator::Gt,
275 BinaryOperator::Lt => Operator::Lt,
276 BinaryOperator::GtEq => Operator::GtEq,
277 BinaryOperator::LtEq => Operator::LtEq,
278 BinaryOperator::Eq => Operator::Eq,
279 BinaryOperator::NotEq => Operator::NotEq,
280 BinaryOperator::And => Operator::And,
281 BinaryOperator::Or => Operator::Or,
282 _ => {
283 return Err(Error::invalid_input(
284 format!("Operator {op} is not supported"),
285 location!(),
286 ));
287 }
288 })
289 }
290
291 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
292 Ok(Expr::BinaryExpr(BinaryExpr::new(
293 Box::new(self.parse_sql_expr(left)?),
294 self.binary_op(op)?,
295 Box::new(self.parse_sql_expr(right)?),
296 )))
297 }
298
299 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
300 Ok(match op {
301 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
302 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
303 }
304
305 UnaryOperator::Minus => {
306 use datafusion::logical_expr::lit;
307 match expr {
308 SQLExpr::Value(Value::Number(n, _)) => match n.parse::<i64>() {
309 Ok(n) => lit(-n),
310 Err(_) => lit(-n
311 .parse::<f64>()
312 .map_err(|_e| {
313 Error::invalid_input(
314 format!("negative operator can be only applied to integer and float operands, got: {n}"),
315 location!(),
316 )
317 })?),
318 },
319 _ => {
320 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
321 }
322 }
323 }
324
325 _ => {
326 return Err(Error::invalid_input(
327 format!("Unary operator '{:?}' is not supported", op),
328 location!(),
329 ));
330 }
331 })
332 }
333
334 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
336 use datafusion::logical_expr::lit;
337 let value: Cow<str> = if negative {
338 Cow::Owned(format!("-{}", value))
339 } else {
340 Cow::Borrowed(value)
341 };
342 if let Ok(n) = value.parse::<i64>() {
343 Ok(lit(n))
344 } else {
345 value.parse::<f64>().map(lit).map_err(|_| {
346 Error::invalid_input(
347 format!("'{value}' is not supported number value."),
348 location!(),
349 )
350 })
351 }
352 }
353
354 fn value(&self, value: &Value) -> Result<Expr> {
355 Ok(match value {
356 Value::Number(v, _) => self.number(v.as_str(), false)?,
357 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
358 Value::HexStringLiteral(hsl) => {
359 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)))
360 }
361 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
362 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v))),
363 Value::Null => Expr::Literal(ScalarValue::Null),
364 _ => todo!(),
365 })
366 }
367
368 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
369 match func_args {
370 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
371 _ => Err(Error::invalid_input(
372 format!("Unsupported function args: {:?}", func_args),
373 location!(),
374 )),
375 }
376 }
377
378 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
385 match &func.args {
386 FunctionArguments::List(args) => {
387 if func.name.0.len() != 1 {
388 return Err(Error::invalid_input(
389 format!("Function name must have 1 part, got: {:?}", func.name.0),
390 location!(),
391 ));
392 }
393 Ok(Expr::IsNotNull(Box::new(
394 self.parse_function_args(&args.args[0])?,
395 )))
396 }
397 _ => Err(Error::invalid_input(
398 format!("Unsupported function args: {:?}", &func.args),
399 location!(),
400 )),
401 }
402 }
403
404 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
405 if let SQLExpr::Function(function) = &function {
406 if !function.name.0.is_empty() && function.name.0[0].value == "is_valid" {
407 return self.legacy_parse_function(function);
408 }
409 }
410 let sql_to_rel = SqlToRel::new_with_options(
411 &self.context_provider,
412 ParserOptions {
413 parse_float_as_decimal: false,
414 enable_ident_normalization: false,
415 support_varchar_with_length: false,
416 enable_options_value_normalization: false,
417 },
418 );
419
420 let mut planner_context = PlannerContext::default();
421 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
422 Ok(sql_to_rel.sql_to_expr(function, &schema, &mut planner_context)?)
423 }
424
425 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
426 const SUPPORTED_TYPES: [&str; 13] = [
427 "int [unsigned]",
428 "tinyint [unsigned]",
429 "smallint [unsigned]",
430 "bigint [unsigned]",
431 "float",
432 "double",
433 "string",
434 "binary",
435 "date",
436 "timestamp(precision)",
437 "datetime(precision)",
438 "decimal(precision,scale)",
439 "boolean",
440 ];
441 match data_type {
442 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
443 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
444 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
445 SQLDataType::Double => Ok(ArrowDataType::Float64),
446 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
447 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
448 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
449 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
450 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
451 SQLDataType::UnsignedTinyInt(_) => Ok(ArrowDataType::UInt8),
452 SQLDataType::UnsignedSmallInt(_) => Ok(ArrowDataType::UInt16),
453 SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => {
454 Ok(ArrowDataType::UInt32)
455 }
456 SQLDataType::UnsignedBigInt(_) => Ok(ArrowDataType::UInt64),
457 SQLDataType::Date => Ok(ArrowDataType::Date32),
458 SQLDataType::Timestamp(resolution, tz) => {
459 match tz {
460 TimezoneInfo::None => {}
461 _ => {
462 return Err(Error::invalid_input(
463 "Timezone not supported in timestamp".to_string(),
464 location!(),
465 ));
466 }
467 };
468 let time_unit = match resolution {
469 None => TimeUnit::Microsecond,
471 Some(0) => TimeUnit::Second,
472 Some(3) => TimeUnit::Millisecond,
473 Some(6) => TimeUnit::Microsecond,
474 Some(9) => TimeUnit::Nanosecond,
475 _ => {
476 return Err(Error::invalid_input(
477 format!("Unsupported datetime resolution: {:?}", resolution),
478 location!(),
479 ));
480 }
481 };
482 Ok(ArrowDataType::Timestamp(time_unit, None))
483 }
484 SQLDataType::Datetime(resolution) => {
485 let time_unit = match resolution {
486 None => TimeUnit::Microsecond,
487 Some(0) => TimeUnit::Second,
488 Some(3) => TimeUnit::Millisecond,
489 Some(6) => TimeUnit::Microsecond,
490 Some(9) => TimeUnit::Nanosecond,
491 _ => {
492 return Err(Error::invalid_input(
493 format!("Unsupported datetime resolution: {:?}", resolution),
494 location!(),
495 ));
496 }
497 };
498 Ok(ArrowDataType::Timestamp(time_unit, None))
499 }
500 SQLDataType::Decimal(number_info) => match number_info {
501 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
502 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
503 }
504 _ => Err(Error::invalid_input(
505 format!(
506 "Must provide precision and scale for decimal: {:?}",
507 number_info
508 ),
509 location!(),
510 )),
511 },
512 _ => Err(Error::invalid_input(
513 format!(
514 "Unsupported data type: {:?}. Supported types: {:?}",
515 data_type, SUPPORTED_TYPES
516 ),
517 location!(),
518 )),
519 }
520 }
521
522 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
523 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
524 for planner in self.context_provider.get_expr_planners() {
525 match planner.plan_field_access(field_access_expr, &df_schema)? {
526 PlannerResult::Planned(expr) => return Ok(expr),
527 PlannerResult::Original(expr) => {
528 field_access_expr = expr;
529 }
530 }
531 }
532 Err(Error::invalid_input(
533 "Field access could not be planned",
534 location!(),
535 ))
536 }
537
538 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
539 match expr {
540 SQLExpr::Identifier(id) => {
541 if id.quote_style == Some('"') {
544 Ok(Expr::Literal(ScalarValue::Utf8(Some(id.value.clone()))))
545 } else if id.quote_style == Some('`') {
548 Ok(Expr::Column(Column::from_name(id.value.clone())))
549 } else {
550 Ok(Self::column(vec![id.clone()].as_slice()))
551 }
552 }
553 SQLExpr::CompoundIdentifier(ids) => Ok(Self::column(ids.as_slice())),
554 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
555 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
556 SQLExpr::Value(value) => self.value(value),
557 SQLExpr::Array(SQLArray { elem, .. }) => {
558 let mut values = vec![];
559
560 let array_literal_error = |pos: usize, value: &_| {
561 Err(Error::invalid_input(
562 format!(
563 "Expected a literal value in array, instead got {} at position {}",
564 value, pos
565 ),
566 location!(),
567 ))
568 };
569
570 for (pos, expr) in elem.iter().enumerate() {
571 match expr {
572 SQLExpr::Value(value) => {
573 if let Expr::Literal(value) = self.value(value)? {
574 values.push(value);
575 } else {
576 return array_literal_error(pos, expr);
577 }
578 }
579 SQLExpr::UnaryOp {
580 op: UnaryOperator::Minus,
581 expr,
582 } => {
583 if let SQLExpr::Value(Value::Number(number, _)) = expr.as_ref() {
584 if let Expr::Literal(value) = self.number(number, true)? {
585 values.push(value);
586 } else {
587 return array_literal_error(pos, expr);
588 }
589 } else {
590 return array_literal_error(pos, expr);
591 }
592 }
593 _ => {
594 return array_literal_error(pos, expr);
595 }
596 }
597 }
598
599 let field = if !values.is_empty() {
600 let data_type = values[0].data_type();
601
602 for value in &mut values {
603 if value.data_type() != data_type {
604 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
605 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
606 location!()
607 ))?;
608 }
609 }
610 Field::new("item", data_type, true)
611 } else {
612 Field::new("item", ArrowDataType::Null, true)
613 };
614
615 let values = values
616 .into_iter()
617 .map(|v| v.to_array().map_err(Error::from))
618 .collect::<Result<Vec<_>>>()?;
619 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
620 let values = concat(&array_refs)?;
621 let values = ListArray::try_new(
622 field.into(),
623 OffsetBuffer::from_lengths([values.len()]),
624 values,
625 None,
626 )?;
627
628 Ok(Expr::Literal(ScalarValue::List(Arc::new(values))))
629 }
630 SQLExpr::TypedString { data_type, value } => {
632 Ok(Expr::Cast(datafusion::logical_expr::Cast {
633 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value.clone())))),
634 data_type: self.parse_type(data_type)?,
635 }))
636 }
637 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
638 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
639 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
640 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
641 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
642 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
643 SQLExpr::InList {
644 expr,
645 list,
646 negated,
647 } => {
648 let value_expr = self.parse_sql_expr(expr)?;
649 let list_exprs = list
650 .iter()
651 .map(|e| self.parse_sql_expr(e))
652 .collect::<Result<Vec<_>>>()?;
653 Ok(value_expr.in_list(list_exprs, *negated))
654 }
655 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
656 SQLExpr::Function(_) => self.parse_function(expr.clone()),
657 SQLExpr::ILike {
658 negated,
659 expr,
660 pattern,
661 escape_char,
662 any: _,
663 } => Ok(Expr::Like(Like::new(
664 *negated,
665 Box::new(self.parse_sql_expr(expr)?),
666 Box::new(self.parse_sql_expr(pattern)?),
667 escape_char.as_ref().and_then(|c| c.chars().next()),
668 true,
669 ))),
670 SQLExpr::Like {
671 negated,
672 expr,
673 pattern,
674 escape_char,
675 any: _,
676 } => Ok(Expr::Like(Like::new(
677 *negated,
678 Box::new(self.parse_sql_expr(expr)?),
679 Box::new(self.parse_sql_expr(pattern)?),
680 escape_char.as_ref().and_then(|c| c.chars().next()),
681 false,
682 ))),
683 SQLExpr::Cast {
684 expr, data_type, ..
685 } => Ok(Expr::Cast(datafusion::logical_expr::Cast {
686 expr: Box::new(self.parse_sql_expr(expr)?),
687 data_type: self.parse_type(data_type)?,
688 })),
689 SQLExpr::MapAccess { column, keys } => {
690 let mut expr = self.parse_sql_expr(column)?;
691
692 for key in keys {
693 let field_access = match &key.key {
694 SQLExpr::Value(
695 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
696 ) => GetFieldAccess::NamedStructField {
697 name: ScalarValue::from(s.as_str()),
698 },
699 SQLExpr::JsonAccess { .. } => {
700 return Err(Error::invalid_input(
701 "JSON access is not supported",
702 location!(),
703 ));
704 }
705 key => {
706 let key = Box::new(self.parse_sql_expr(key)?);
707 GetFieldAccess::ListIndex { key }
708 }
709 };
710
711 let field_access_expr = RawFieldAccessExpr { expr, field_access };
712
713 expr = self.plan_field_access(field_access_expr)?;
714 }
715
716 Ok(expr)
717 }
718 SQLExpr::Subscript { expr, subscript } => {
719 let expr = self.parse_sql_expr(expr)?;
720
721 let field_access = match subscript.as_ref() {
722 Subscript::Index { index } => match index {
723 SQLExpr::Value(
724 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
725 ) => GetFieldAccess::NamedStructField {
726 name: ScalarValue::from(s.as_str()),
727 },
728 SQLExpr::JsonAccess { .. } => {
729 return Err(Error::invalid_input(
730 "JSON access is not supported",
731 location!(),
732 ));
733 }
734 _ => {
735 let key = Box::new(self.parse_sql_expr(index)?);
736 GetFieldAccess::ListIndex { key }
737 }
738 },
739 Subscript::Slice { .. } => {
740 return Err(Error::invalid_input(
741 "Slice subscript is not supported",
742 location!(),
743 ));
744 }
745 };
746
747 let field_access_expr = RawFieldAccessExpr { expr, field_access };
748 self.plan_field_access(field_access_expr)
749 }
750 SQLExpr::Between {
751 expr,
752 negated,
753 low,
754 high,
755 } => {
756 let expr = self.parse_sql_expr(expr)?;
758 let low = self.parse_sql_expr(low)?;
759 let high = self.parse_sql_expr(high)?;
760
761 let between = Expr::Between(Between::new(
762 Box::new(expr),
763 *negated,
764 Box::new(low),
765 Box::new(high),
766 ));
767 Ok(between)
768 }
769 _ => Err(Error::invalid_input(
770 format!("Expression '{expr}' is not supported SQL in lance"),
771 location!(),
772 )),
773 }
774 }
775
776 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
781 let ast_expr = parse_sql_filter(filter)?;
783 let expr = self.parse_sql_expr(&ast_expr)?;
784 let schema = Schema::try_from(self.schema.as_ref())?;
785 let resolved = resolve_expr(&expr, &schema)?;
786 coerce_filter_type_to_boolean(resolved)
787 }
788
789 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
794 let ast_expr = parse_sql_expr(expr)?;
795 let expr = self.parse_sql_expr(&ast_expr)?;
796 let schema = Schema::try_from(self.schema.as_ref())?;
797 let resolved = resolve_expr(&expr, &schema)?;
798 Ok(resolved)
799 }
800
801 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
807 let hex_bytes = s.as_bytes();
808 let mut decoded_bytes = Vec::with_capacity((hex_bytes.len() + 1) / 2);
809
810 let start_idx = hex_bytes.len() % 2;
811 if start_idx > 0 {
812 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
814 }
815
816 for i in (start_idx..hex_bytes.len()).step_by(2) {
817 let high = Self::try_decode_hex_char(hex_bytes[i])?;
818 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
819 decoded_bytes.push((high << 4) | low);
820 }
821
822 Some(decoded_bytes)
823 }
824
825 const fn try_decode_hex_char(c: u8) -> Option<u8> {
829 match c {
830 b'A'..=b'F' => Some(c - b'A' + 10),
831 b'a'..=b'f' => Some(c - b'a' + 10),
832 b'0'..=b'9' => Some(c - b'0'),
833 _ => None,
834 }
835 }
836
837 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
839 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
840
841 let props = ExecutionProps::default();
844 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
845 let simplifier =
846 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
847
848 let expr = simplifier.simplify(expr)?;
849 let expr = simplifier.coerce(expr, &df_schema)?;
850
851 Ok(expr)
852 }
853
854 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
856 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
857
858 Ok(datafusion::physical_expr::create_physical_expr(
859 expr,
860 df_schema.as_ref(),
861 &Default::default(),
862 )?)
863 }
864
865 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
869 let mut visitor = ColumnCapturingVisitor {
870 current_path: VecDeque::new(),
871 columns: BTreeSet::new(),
872 };
873 expr.visit(&mut visitor).unwrap();
874 visitor.columns.into_iter().collect()
875 }
876}
877
878struct ColumnCapturingVisitor {
879 current_path: VecDeque<String>,
881 columns: BTreeSet<String>,
882}
883
884impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
885 type Node = Expr;
886
887 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
888 match node {
889 Expr::Column(Column { name, .. }) => {
890 let mut path = name.clone();
891 for part in self.current_path.drain(..) {
892 path.push('.');
893 path.push_str(&part);
894 }
895 self.columns.insert(path);
896 self.current_path.clear();
897 }
898 Expr::ScalarFunction(udf) => {
899 if udf.name() == GetFieldFunc::default().name() {
900 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
901 self.current_path.push_front(name.to_string())
902 } else {
903 self.current_path.clear();
904 }
905 } else {
906 self.current_path.clear();
907 }
908 }
909 _ => {
910 self.current_path.clear();
911 }
912 }
913
914 Ok(TreeNodeRecursion::Continue)
915 }
916}
917
918#[cfg(test)]
919mod tests {
920
921 use crate::logical_expr::ExprExt;
922
923 use super::*;
924
925 use arrow::datatypes::Float64Type;
926 use arrow_array::{
927 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
928 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
929 TimestampNanosecondArray, TimestampSecondArray,
930 };
931 use arrow_schema::{DataType, Fields, Schema};
932 use datafusion::{
933 logical_expr::{lit, Cast},
934 prelude::{array_element, get_field},
935 };
936 use datafusion_functions::core::expr_ext::FieldAccessor;
937
938 #[test]
939 fn test_parse_filter_simple() {
940 let schema = Arc::new(Schema::new(vec![
941 Field::new("i", DataType::Int32, false),
942 Field::new("s", DataType::Utf8, true),
943 Field::new(
944 "st",
945 DataType::Struct(Fields::from(vec![
946 Field::new("x", DataType::Float32, false),
947 Field::new("y", DataType::Float32, false),
948 ])),
949 true,
950 ),
951 ]));
952
953 let planner = Planner::new(schema.clone());
954
955 let expected = col("i")
956 .gt(lit(3_i32))
957 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
958 .and(
959 col("s")
960 .eq(lit("str-4"))
961 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
962 );
963
964 let expr = planner
966 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
967 .unwrap();
968 assert_eq!(expr, expected);
969
970 let expr = planner
972 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
973 .unwrap();
974
975 let physical_expr = planner.create_physical_expr(&expr).unwrap();
976
977 let batch = RecordBatch::try_new(
978 schema,
979 vec![
980 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
981 Arc::new(StringArray::from_iter_values(
982 (0..10).map(|v| format!("str-{}", v)),
983 )),
984 Arc::new(StructArray::from(vec![
985 (
986 Arc::new(Field::new("x", DataType::Float32, false)),
987 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
988 as ArrayRef,
989 ),
990 (
991 Arc::new(Field::new("y", DataType::Float32, false)),
992 Arc::new(Float32Array::from_iter_values(
993 (0..10).map(|v| (v * 10) as f32),
994 )),
995 ),
996 ])),
997 ],
998 )
999 .unwrap();
1000 let predicates = physical_expr.evaluate(&batch).unwrap();
1001 assert_eq!(
1002 predicates.into_array(0).unwrap().as_ref(),
1003 &BooleanArray::from(vec![
1004 false, false, false, false, true, true, false, false, false, false
1005 ])
1006 );
1007 }
1008
1009 #[test]
1010 fn test_nested_col_refs() {
1011 let schema = Arc::new(Schema::new(vec![
1012 Field::new("s0", DataType::Utf8, true),
1013 Field::new(
1014 "st",
1015 DataType::Struct(Fields::from(vec![
1016 Field::new("s1", DataType::Utf8, true),
1017 Field::new(
1018 "st",
1019 DataType::Struct(Fields::from(vec![Field::new(
1020 "s2",
1021 DataType::Utf8,
1022 true,
1023 )])),
1024 true,
1025 ),
1026 ])),
1027 true,
1028 ),
1029 ]));
1030
1031 let planner = Planner::new(schema);
1032
1033 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1034 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1035 assert!(matches!(
1036 expr,
1037 Expr::BinaryExpr(BinaryExpr {
1038 left: _,
1039 op: Operator::Eq,
1040 right: _
1041 })
1042 ));
1043 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1044 assert_eq!(left.as_ref(), expected);
1045 }
1046 }
1047
1048 let expected = Expr::Column(Column {
1049 relation: None,
1050 name: "s0".to_string(),
1051 });
1052 assert_column_eq(&planner, "s0", &expected);
1053 assert_column_eq(&planner, "`s0`", &expected);
1054
1055 let expected = Expr::ScalarFunction(ScalarFunction {
1056 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1057 args: vec![
1058 Expr::Column(Column {
1059 relation: None,
1060 name: "st".to_string(),
1061 }),
1062 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string()))),
1063 ],
1064 });
1065 assert_column_eq(&planner, "st.s1", &expected);
1066 assert_column_eq(&planner, "`st`.`s1`", &expected);
1067 assert_column_eq(&planner, "st.`s1`", &expected);
1068
1069 let expected = Expr::ScalarFunction(ScalarFunction {
1070 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1071 args: vec![
1072 Expr::ScalarFunction(ScalarFunction {
1073 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1074 args: vec![
1075 Expr::Column(Column {
1076 relation: None,
1077 name: "st".to_string(),
1078 }),
1079 Expr::Literal(ScalarValue::Utf8(Some("st".to_string()))),
1080 ],
1081 }),
1082 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string()))),
1083 ],
1084 });
1085
1086 assert_column_eq(&planner, "st.st.s2", &expected);
1087 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1088 assert_column_eq(&planner, "st.st.`s2`", &expected);
1089 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1090 }
1091
1092 #[test]
1093 fn test_nested_list_refs() {
1094 let schema = Arc::new(Schema::new(vec![Field::new(
1095 "l",
1096 DataType::List(Arc::new(Field::new(
1097 "item",
1098 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1099 true,
1100 ))),
1101 true,
1102 )]));
1103
1104 let planner = Planner::new(schema);
1105
1106 let expected = array_element(col("l"), lit(0_i64));
1107 let expr = planner.parse_expr("l[0]").unwrap();
1108 assert_eq!(expr, expected);
1109
1110 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1111 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1112 assert_eq!(expr, expected);
1113
1114 }
1119
1120 #[test]
1121 fn test_negative_expressions() {
1122 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1123
1124 let planner = Planner::new(schema.clone());
1125
1126 let expected = col("x")
1127 .gt(lit(-3_i64))
1128 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1129
1130 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1131
1132 assert_eq!(expr, expected);
1133
1134 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1135
1136 let batch = RecordBatch::try_new(
1137 schema,
1138 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1139 )
1140 .unwrap();
1141 let predicates = physical_expr.evaluate(&batch).unwrap();
1142 assert_eq!(
1143 predicates.into_array(0).unwrap().as_ref(),
1144 &BooleanArray::from(vec![
1145 false, false, false, true, true, true, true, false, false, false
1146 ])
1147 );
1148 }
1149
1150 #[test]
1151 fn test_negative_array_expressions() {
1152 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1153
1154 let planner = Planner::new(schema);
1155
1156 let expected = Expr::Literal(ScalarValue::List(Arc::new(
1157 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1158 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1159 )]),
1160 )));
1161
1162 let expr = planner
1163 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1164 .unwrap();
1165
1166 assert_eq!(expr, expected);
1167 }
1168
1169 #[test]
1170 fn test_sql_like() {
1171 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1172
1173 let planner = Planner::new(schema.clone());
1174
1175 let expected = col("s").like(lit("str-4"));
1176 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1178 assert_eq!(expr, expected);
1179 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1180
1181 let batch = RecordBatch::try_new(
1182 schema,
1183 vec![Arc::new(StringArray::from_iter_values(
1184 (0..10).map(|v| format!("str-{}", v)),
1185 ))],
1186 )
1187 .unwrap();
1188 let predicates = physical_expr.evaluate(&batch).unwrap();
1189 assert_eq!(
1190 predicates.into_array(0).unwrap().as_ref(),
1191 &BooleanArray::from(vec![
1192 false, false, false, false, true, false, false, false, false, false
1193 ])
1194 );
1195 }
1196
1197 #[test]
1198 fn test_not_like() {
1199 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1200
1201 let planner = Planner::new(schema.clone());
1202
1203 let expected = col("s").not_like(lit("str-4"));
1204 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1206 assert_eq!(expr, expected);
1207 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1208
1209 let batch = RecordBatch::try_new(
1210 schema,
1211 vec![Arc::new(StringArray::from_iter_values(
1212 (0..10).map(|v| format!("str-{}", v)),
1213 ))],
1214 )
1215 .unwrap();
1216 let predicates = physical_expr.evaluate(&batch).unwrap();
1217 assert_eq!(
1218 predicates.into_array(0).unwrap().as_ref(),
1219 &BooleanArray::from(vec![
1220 true, true, true, true, false, true, true, true, true, true
1221 ])
1222 );
1223 }
1224
1225 #[test]
1226 fn test_sql_is_in() {
1227 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1228
1229 let planner = Planner::new(schema.clone());
1230
1231 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1232 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1234 assert_eq!(expr, expected);
1235 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1236
1237 let batch = RecordBatch::try_new(
1238 schema,
1239 vec![Arc::new(StringArray::from_iter_values(
1240 (0..10).map(|v| format!("str-{}", v)),
1241 ))],
1242 )
1243 .unwrap();
1244 let predicates = physical_expr.evaluate(&batch).unwrap();
1245 assert_eq!(
1246 predicates.into_array(0).unwrap().as_ref(),
1247 &BooleanArray::from(vec![
1248 false, false, false, false, true, true, false, false, false, false
1249 ])
1250 );
1251 }
1252
1253 #[test]
1254 fn test_sql_is_null() {
1255 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1256
1257 let planner = Planner::new(schema.clone());
1258
1259 let expected = col("s").is_null();
1260 let expr = planner.parse_filter("s IS NULL").unwrap();
1261 assert_eq!(expr, expected);
1262 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1263
1264 let batch = RecordBatch::try_new(
1265 schema,
1266 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1267 if v % 3 == 0 {
1268 Some(format!("str-{}", v))
1269 } else {
1270 None
1271 }
1272 })))],
1273 )
1274 .unwrap();
1275 let predicates = physical_expr.evaluate(&batch).unwrap();
1276 assert_eq!(
1277 predicates.into_array(0).unwrap().as_ref(),
1278 &BooleanArray::from(vec![
1279 false, true, true, false, true, true, false, true, true, false
1280 ])
1281 );
1282
1283 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1284 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1285 let predicates = physical_expr.evaluate(&batch).unwrap();
1286 assert_eq!(
1287 predicates.into_array(0).unwrap().as_ref(),
1288 &BooleanArray::from(vec![
1289 true, false, false, true, false, false, true, false, false, true,
1290 ])
1291 );
1292 }
1293
1294 #[test]
1295 fn test_sql_invert() {
1296 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1297
1298 let planner = Planner::new(schema.clone());
1299
1300 let expr = planner.parse_filter("NOT s").unwrap();
1301 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1302
1303 let batch = RecordBatch::try_new(
1304 schema,
1305 vec![Arc::new(BooleanArray::from_iter(
1306 (0..10).map(|v| Some(v % 3 == 0)),
1307 ))],
1308 )
1309 .unwrap();
1310 let predicates = physical_expr.evaluate(&batch).unwrap();
1311 assert_eq!(
1312 predicates.into_array(0).unwrap().as_ref(),
1313 &BooleanArray::from(vec![
1314 false, true, true, false, true, true, false, true, true, false
1315 ])
1316 );
1317 }
1318
1319 #[test]
1320 fn test_sql_cast() {
1321 let cases = &[
1322 (
1323 "x = cast('2021-01-01 00:00:00' as timestamp)",
1324 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1325 ),
1326 (
1327 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1328 ArrowDataType::Timestamp(TimeUnit::Second, None),
1329 ),
1330 (
1331 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1332 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1333 ),
1334 (
1335 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1336 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1337 ),
1338 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1339 (
1340 "x = cast('1.238' as decimal(9,3))",
1341 ArrowDataType::Decimal128(9, 3),
1342 ),
1343 ("x = cast(1 as float)", ArrowDataType::Float32),
1344 ("x = cast(1 as double)", ArrowDataType::Float64),
1345 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1346 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1347 ("x = cast(1 as int)", ArrowDataType::Int32),
1348 ("x = cast(1 as integer)", ArrowDataType::Int32),
1349 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1350 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1351 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1352 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1353 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1354 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1355 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1356 ("x = cast(1 as string)", ArrowDataType::Utf8),
1357 ];
1358
1359 for (sql, expected_data_type) in cases {
1360 let schema = Arc::new(Schema::new(vec![Field::new(
1361 "x",
1362 expected_data_type.clone(),
1363 true,
1364 )]));
1365 let planner = Planner::new(schema.clone());
1366 let expr = planner.parse_filter(sql).unwrap();
1367
1368 let expected_value_str = sql
1370 .split("cast(")
1371 .nth(1)
1372 .unwrap()
1373 .split(" as")
1374 .next()
1375 .unwrap();
1376 let expected_value_str = expected_value_str.trim_matches('\'');
1378
1379 match expr {
1380 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1381 Expr::Cast(Cast { expr, data_type }) => {
1382 match expr.as_ref() {
1383 Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1384 assert_eq!(value_str, expected_value_str);
1385 }
1386 Expr::Literal(ScalarValue::Int64(Some(value))) => {
1387 assert_eq!(*value, 1);
1388 }
1389 _ => panic!("Expected cast to be applied to literal"),
1390 }
1391 assert_eq!(data_type, expected_data_type);
1392 }
1393 _ => panic!("Expected right to be a cast"),
1394 },
1395 _ => panic!("Expected binary expression"),
1396 }
1397 }
1398 }
1399
1400 #[test]
1401 fn test_sql_literals() {
1402 let cases = &[
1403 (
1404 "x = timestamp '2021-01-01 00:00:00'",
1405 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1406 ),
1407 (
1408 "x = timestamp(0) '2021-01-01 00:00:00'",
1409 ArrowDataType::Timestamp(TimeUnit::Second, None),
1410 ),
1411 (
1412 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1413 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1414 ),
1415 ("x = date '2021-01-01'", ArrowDataType::Date32),
1416 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1417 ];
1418
1419 for (sql, expected_data_type) in cases {
1420 let schema = Arc::new(Schema::new(vec![Field::new(
1421 "x",
1422 expected_data_type.clone(),
1423 true,
1424 )]));
1425 let planner = Planner::new(schema.clone());
1426 let expr = planner.parse_filter(sql).unwrap();
1427
1428 let expected_value_str = sql.split('\'').nth(1).unwrap();
1429
1430 match expr {
1431 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1432 Expr::Cast(Cast { expr, data_type }) => {
1433 match expr.as_ref() {
1434 Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1435 assert_eq!(value_str, expected_value_str);
1436 }
1437 _ => panic!("Expected cast to be applied to literal"),
1438 }
1439 assert_eq!(data_type, expected_data_type);
1440 }
1441 _ => panic!("Expected right to be a cast"),
1442 },
1443 _ => panic!("Expected binary expression"),
1444 }
1445 }
1446 }
1447
1448 #[test]
1449 fn test_sql_array_literals() {
1450 let cases = [
1451 (
1452 "x = [1, 2, 3]",
1453 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1454 ),
1455 (
1456 "x = [1, 2, 3]",
1457 ArrowDataType::FixedSizeList(
1458 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1459 3,
1460 ),
1461 ),
1462 ];
1463
1464 for (sql, expected_data_type) in cases {
1465 let schema = Arc::new(Schema::new(vec![Field::new(
1466 "x",
1467 expected_data_type.clone(),
1468 true,
1469 )]));
1470 let planner = Planner::new(schema.clone());
1471 let expr = planner.parse_filter(sql).unwrap();
1472 let expr = planner.optimize_expr(expr).unwrap();
1473
1474 match expr {
1475 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1476 Expr::Literal(value) => {
1477 assert_eq!(&value.data_type(), &expected_data_type);
1478 }
1479 _ => panic!("Expected right to be a literal"),
1480 },
1481 _ => panic!("Expected binary expression"),
1482 }
1483 }
1484 }
1485
1486 #[test]
1487 fn test_sql_between() {
1488 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1489 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1490 use std::sync::Arc;
1491
1492 let schema = Arc::new(Schema::new(vec![
1493 Field::new("x", DataType::Int32, false),
1494 Field::new("y", DataType::Float64, false),
1495 Field::new(
1496 "ts",
1497 DataType::Timestamp(TimeUnit::Microsecond, None),
1498 false,
1499 ),
1500 ]));
1501
1502 let planner = Planner::new(schema.clone());
1503
1504 let expr = planner
1506 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1507 .unwrap();
1508 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1509
1510 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1514 (0..10).map(|i| base_ts + i * 1_000_000), );
1516
1517 let batch = RecordBatch::try_new(
1518 schema,
1519 vec![
1520 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1521 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1522 Arc::new(ts_array),
1523 ],
1524 )
1525 .unwrap();
1526
1527 let predicates = physical_expr.evaluate(&batch).unwrap();
1528 assert_eq!(
1529 predicates.into_array(0).unwrap().as_ref(),
1530 &BooleanArray::from(vec![
1531 false, false, false, true, true, true, true, true, false, false
1532 ])
1533 );
1534
1535 let expr = planner
1537 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1538 .unwrap();
1539 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1540
1541 let predicates = physical_expr.evaluate(&batch).unwrap();
1542 assert_eq!(
1543 predicates.into_array(0).unwrap().as_ref(),
1544 &BooleanArray::from(vec![
1545 true, true, true, false, false, false, false, false, true, true
1546 ])
1547 );
1548
1549 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1551 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1552
1553 let predicates = physical_expr.evaluate(&batch).unwrap();
1554 assert_eq!(
1555 predicates.into_array(0).unwrap().as_ref(),
1556 &BooleanArray::from(vec![
1557 false, false, false, true, true, true, true, false, false, false
1558 ])
1559 );
1560
1561 let expr = planner
1563 .parse_filter(
1564 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1565 )
1566 .unwrap();
1567 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1568
1569 let predicates = physical_expr.evaluate(&batch).unwrap();
1570 assert_eq!(
1571 predicates.into_array(0).unwrap().as_ref(),
1572 &BooleanArray::from(vec![
1573 false, false, false, true, true, true, true, true, false, false
1574 ])
1575 );
1576 }
1577
1578 #[test]
1579 fn test_sql_comparison() {
1580 let batch: Vec<(&str, ArrayRef)> = vec![
1582 (
1583 "timestamp_s",
1584 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1585 ),
1586 (
1587 "timestamp_ms",
1588 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1589 ),
1590 (
1591 "timestamp_us",
1592 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1593 ),
1594 (
1595 "timestamp_ns",
1596 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1597 ),
1598 ];
1599 let batch = RecordBatch::try_from_iter(batch).unwrap();
1600
1601 let planner = Planner::new(batch.schema());
1602
1603 let expressions = &[
1605 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1606 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1607 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1608 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1609 ];
1610
1611 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1612 std::iter::repeat(Some(false))
1613 .take(5)
1614 .chain(std::iter::repeat(Some(true)).take(5)),
1615 ));
1616 for expression in expressions {
1617 let logical_expr = planner.parse_filter(expression).unwrap();
1619 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1620 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1621
1622 let result = physical_expr.evaluate(&batch).unwrap();
1624 let result = result.into_array(batch.num_rows()).unwrap();
1625 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1626 }
1627 }
1628
1629 #[test]
1630 fn test_columns_in_expr() {
1631 let expr = col("s0").gt(lit("value")).and(
1632 col("st")
1633 .field("st")
1634 .field("s2")
1635 .eq(lit("value"))
1636 .or(col("st")
1637 .field("s1")
1638 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1639 );
1640
1641 let columns = Planner::column_names_in_expr(&expr);
1642 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1643 }
1644
1645 #[test]
1646 fn test_parse_binary_expr() {
1647 let bin_str = "x'616263'";
1648
1649 let schema = Arc::new(Schema::new(vec![Field::new(
1650 "binary",
1651 DataType::Binary,
1652 true,
1653 )]));
1654 let planner = Planner::new(schema);
1655 let expr = planner.parse_expr(bin_str).unwrap();
1656 assert_eq!(
1657 expr,
1658 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])))
1659 );
1660 }
1661
1662 #[test]
1663 fn test_lance_context_provider_expr_planners() {
1664 let ctx_provider = LanceContextProvider::default();
1665 assert!(!ctx_provider.get_expr_planners().is_empty());
1666 }
1667}