1mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column};
25use crate::tree_node::ExprContext;
26use crate::PhysicalExpr;
27use crate::PhysicalSortExpr;
28
29use arrow::datatypes::SchemaRef;
30use datafusion_common::tree_node::{
31 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{HashMap, HashSet, Result};
34use datafusion_expr::Operator;
35
36use datafusion_physical_expr_common::sort_expr::LexOrdering;
37use itertools::Itertools;
38use petgraph::graph::NodeIndex;
39use petgraph::stable_graph::StableGraph;
40
41pub fn split_conjunction(
45 predicate: &Arc<dyn PhysicalExpr>,
46) -> Vec<&Arc<dyn PhysicalExpr>> {
47 split_impl(Operator::And, predicate, vec![])
48}
49
50pub fn split_disjunction(
54 predicate: &Arc<dyn PhysicalExpr>,
55) -> Vec<&Arc<dyn PhysicalExpr>> {
56 split_impl(Operator::Or, predicate, vec![])
57}
58
59fn split_impl<'a>(
60 operator: Operator,
61 predicate: &'a Arc<dyn PhysicalExpr>,
62 mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
63) -> Vec<&'a Arc<dyn PhysicalExpr>> {
64 match predicate.as_any().downcast_ref::<BinaryExpr>() {
65 Some(binary) if binary.op() == &operator => {
66 let exprs = split_impl(operator, binary.left(), exprs);
67 split_impl(operator, binary.right(), exprs)
68 }
69 Some(_) | None => {
70 exprs.push(predicate);
71 exprs
72 }
73 }
74}
75
76pub fn map_columns_before_projection(
85 parent_required: &[Arc<dyn PhysicalExpr>],
86 proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
87) -> Vec<Arc<dyn PhysicalExpr>> {
88 if parent_required.is_empty() {
89 return vec![];
91 }
92 let column_mapping = proj_exprs
93 .iter()
94 .filter_map(|(expr, name)| {
95 expr.as_any()
96 .downcast_ref::<Column>()
97 .map(|column| (name.clone(), column.clone()))
98 })
99 .collect::<HashMap<_, _>>();
100 parent_required
101 .iter()
102 .filter_map(|r| {
103 r.as_any()
104 .downcast_ref::<Column>()
105 .and_then(|c| column_mapping.get(c.name()))
106 })
107 .map(|e| Arc::new(e.clone()) as _)
108 .collect()
109}
110
111pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
114 sequence: impl IntoIterator<Item = T>,
115) -> Vec<Arc<dyn PhysicalExpr>> {
116 sequence
117 .into_iter()
118 .map(|elem| Arc::clone(&elem.borrow().expr))
119 .collect()
120}
121
122pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
125 targets: impl IntoIterator<Item = T>,
126 items: &[Arc<dyn PhysicalExpr>],
127) -> Vec<usize> {
128 targets
129 .into_iter()
130 .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
131 .collect()
132}
133
134pub type ExprTreeNode<T> = ExprContext<Option<T>>;
135
136struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
141 graph: StableGraph<T, usize>,
143 visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
145 constructor: &'a F,
147}
148
149impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
150 fn mutate(
153 &mut self,
154 mut node: ExprTreeNode<NodeIndex>,
155 ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
156 let expr = &node.expr;
158
159 let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
161 Some((_, idx)) => *idx,
163 None => {
167 let node_idx = self.graph.add_node((self.constructor)(&node)?);
168 for expr_node in node.children.iter() {
169 self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
170 }
171 self.visited_plans.push((Arc::clone(expr), node_idx));
172 node_idx
173 }
174 };
175 node.data = Some(node_idx);
177 Ok(Transformed::yes(node))
179 }
180}
181
182pub fn build_dag<T, F>(
184 expr: Arc<dyn PhysicalExpr>,
185 constructor: &F,
186) -> Result<(NodeIndex, StableGraph<T, usize>)>
187where
188 F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
189{
190 let init = ExprTreeNode::new_default(expr);
192 let mut builder = PhysicalExprDAEGBuilder {
194 graph: StableGraph::<T, usize>::new(),
195 visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
196 constructor,
197 };
198 let root = init.transform_up(|node| builder.mutate(node)).data()?;
200 Ok((root.data.unwrap(), builder.graph))
202}
203
204pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
206 let mut columns = HashSet::<Column>::new();
207 expr.apply(|expr| {
208 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
209 columns.get_or_insert_owned(column);
210 }
211 Ok(TreeNodeRecursion::Continue)
212 })
213 .expect("no way to return error during recursion");
215 columns
216}
217
218pub fn reassign_predicate_columns(
221 pred: Arc<dyn PhysicalExpr>,
222 schema: &SchemaRef,
223 ignore_not_found: bool,
224) -> Result<Arc<dyn PhysicalExpr>> {
225 pred.transform_down(|expr| {
226 let expr_any = expr.as_any();
227
228 if let Some(column) = expr_any.downcast_ref::<Column>() {
229 let index = match schema.index_of(column.name()) {
230 Ok(idx) => idx,
231 Err(_) if ignore_not_found => usize::MAX,
232 Err(e) => return Err(e.into()),
233 };
234 return Ok(Transformed::yes(Arc::new(Column::new(
235 column.name(),
236 index,
237 ))));
238 }
239 Ok(Transformed::no(expr))
240 })
241 .data()
242}
243
244pub fn merge_vectors(left: &LexOrdering, right: &LexOrdering) -> LexOrdering {
246 left.iter()
247 .cloned()
248 .chain(right.iter().cloned())
249 .unique()
250 .collect()
251}
252
253#[cfg(test)]
254pub(crate) mod tests {
255 use std::any::Any;
256 use std::fmt::{Display, Formatter};
257
258 use super::*;
259 use crate::expressions::{binary, cast, col, in_list, lit, Literal};
260
261 use arrow::array::{ArrayRef, Float32Array, Float64Array};
262 use arrow::datatypes::{DataType, Field, Schema};
263 use datafusion_common::{exec_err, DataFusionError, ScalarValue};
264 use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
265 use datafusion_expr::{
266 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
267 };
268
269 use petgraph::visit::Bfs;
270
271 #[derive(Debug, Clone)]
272 pub struct TestScalarUDF {
273 pub(crate) signature: Signature,
274 }
275
276 impl TestScalarUDF {
277 pub fn new() -> Self {
278 use DataType::*;
279 Self {
280 signature: Signature::uniform(
281 1,
282 vec![Float64, Float32],
283 Volatility::Immutable,
284 ),
285 }
286 }
287 }
288
289 impl ScalarUDFImpl for TestScalarUDF {
290 fn as_any(&self) -> &dyn Any {
291 self
292 }
293 fn name(&self) -> &str {
294 "test-scalar-udf"
295 }
296
297 fn signature(&self) -> &Signature {
298 &self.signature
299 }
300
301 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
302 let arg_type = &arg_types[0];
303
304 match arg_type {
305 DataType::Float32 => Ok(DataType::Float32),
306 _ => Ok(DataType::Float64),
307 }
308 }
309
310 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
311 Ok(input[0].sort_properties)
312 }
313
314 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
315 let args = ColumnarValue::values_to_arrays(&args.args)?;
316
317 let arr: ArrayRef = match args[0].data_type() {
318 DataType::Float64 => Arc::new({
319 let arg = &args[0]
320 .as_any()
321 .downcast_ref::<Float64Array>()
322 .ok_or_else(|| {
323 DataFusionError::Internal(format!(
324 "could not cast {} to {}",
325 self.name(),
326 std::any::type_name::<Float64Array>()
327 ))
328 })?;
329
330 arg.iter()
331 .map(|a| a.map(f64::floor))
332 .collect::<Float64Array>()
333 }),
334 DataType::Float32 => Arc::new({
335 let arg = &args[0]
336 .as_any()
337 .downcast_ref::<Float32Array>()
338 .ok_or_else(|| {
339 DataFusionError::Internal(format!(
340 "could not cast {} to {}",
341 self.name(),
342 std::any::type_name::<Float32Array>()
343 ))
344 })?;
345
346 arg.iter()
347 .map(|a| a.map(f32::floor))
348 .collect::<Float32Array>()
349 }),
350 other => {
351 return exec_err!(
352 "Unsupported data type {other:?} for function {}",
353 self.name()
354 );
355 }
356 };
357 Ok(ColumnarValue::Array(arr))
358 }
359 }
360
361 #[derive(Clone)]
362 struct DummyProperty {
363 expr_type: String,
364 }
365
366 #[derive(Clone)]
369 struct PhysicalExprDummyNode {
370 pub expr: Arc<dyn PhysicalExpr>,
371 pub property: DummyProperty,
372 }
373
374 impl Display for PhysicalExprDummyNode {
375 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
376 write!(f, "{}", self.expr)
377 }
378 }
379
380 fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
381 let expr = Arc::clone(&node.expr);
382 let dummy_property = if expr.as_any().is::<BinaryExpr>() {
383 "Binary"
384 } else if expr.as_any().is::<Column>() {
385 "Column"
386 } else if expr.as_any().is::<Literal>() {
387 "Literal"
388 } else {
389 "Other"
390 }
391 .to_owned();
392 Ok(PhysicalExprDummyNode {
393 expr,
394 property: DummyProperty {
395 expr_type: dummy_property,
396 },
397 })
398 }
399
400 #[test]
401 fn test_build_dag() -> Result<()> {
402 let schema = Schema::new(vec![
403 Field::new("0", DataType::Int32, true),
404 Field::new("1", DataType::Int32, true),
405 Field::new("2", DataType::Int32, true),
406 ]);
407 let expr = binary(
408 cast(
409 binary(
410 col("0", &schema)?,
411 Operator::Plus,
412 col("1", &schema)?,
413 &schema,
414 )?,
415 &schema,
416 DataType::Int64,
417 )?,
418 Operator::Gt,
419 binary(
420 cast(col("2", &schema)?, &schema, DataType::Int64)?,
421 Operator::Plus,
422 lit(ScalarValue::Int64(Some(10))),
423 &schema,
424 )?,
425 &schema,
426 )?;
427 let mut vector_dummy_props = vec![];
428 let (root, graph) = build_dag(expr, &make_dummy_node)?;
429 let mut bfs = Bfs::new(&graph, root);
430 while let Some(node_index) = bfs.next(&graph) {
431 let node = &graph[node_index];
432 vector_dummy_props.push(node.property.clone());
433 }
434
435 assert_eq!(
436 vector_dummy_props
437 .iter()
438 .filter(|property| property.expr_type == "Binary")
439 .count(),
440 3
441 );
442 assert_eq!(
443 vector_dummy_props
444 .iter()
445 .filter(|property| property.expr_type == "Column")
446 .count(),
447 3
448 );
449 assert_eq!(
450 vector_dummy_props
451 .iter()
452 .filter(|property| property.expr_type == "Literal")
453 .count(),
454 1
455 );
456 assert_eq!(
457 vector_dummy_props
458 .iter()
459 .filter(|property| property.expr_type == "Other")
460 .count(),
461 2
462 );
463 Ok(())
464 }
465
466 #[test]
467 fn test_convert_to_expr() -> Result<()> {
468 let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
469 let sort_expr = vec![PhysicalSortExpr {
470 expr: col("a", &schema)?,
471 options: Default::default(),
472 }];
473 assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
474 Ok(())
475 }
476
477 #[test]
478 fn test_get_indices_of_exprs_strict() {
479 let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
480 Arc::new(Column::new("a", 0)),
481 Arc::new(Column::new("b", 1)),
482 Arc::new(Column::new("c", 2)),
483 Arc::new(Column::new("d", 3)),
484 ];
485 let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
486 Arc::new(Column::new("b", 1)),
487 Arc::new(Column::new("c", 2)),
488 Arc::new(Column::new("a", 0)),
489 ];
490 assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
491 assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
492 }
493
494 #[test]
495 fn test_reassign_predicate_columns_in_list() {
496 let int_field = Field::new("should_not_matter", DataType::Int64, true);
497 let dict_field = Field::new(
498 "id",
499 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
500 true,
501 );
502 let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
503 let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
504 let pred = in_list(
505 Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
506 vec![lit(ScalarValue::Dictionary(
507 Box::new(DataType::Int32),
508 Box::new(ScalarValue::from("2")),
509 ))],
510 &false,
511 &schema_big,
512 )
513 .unwrap();
514
515 let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap();
516
517 let expected = in_list(
518 Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
519 vec![lit(ScalarValue::Dictionary(
520 Box::new(DataType::Int32),
521 Box::new(ScalarValue::from("2")),
522 ))],
523 &false,
524 &schema_small,
525 )
526 .unwrap();
527
528 assert_eq!(actual.as_ref(), expected.as_ref());
529 }
530
531 #[test]
532 fn test_collect_columns() -> Result<()> {
533 let expr1 = Arc::new(Column::new("col1", 2)) as _;
534 let mut expected = HashSet::new();
535 expected.insert(Column::new("col1", 2));
536 assert_eq!(collect_columns(&expr1), expected);
537
538 let expr2 = Arc::new(Column::new("col2", 5)) as _;
539 let mut expected = HashSet::new();
540 expected.insert(Column::new("col2", 5));
541 assert_eq!(collect_columns(&expr2), expected);
542
543 let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
544 let mut expected = HashSet::new();
545 expected.insert(Column::new("col1", 2));
546 expected.insert(Column::new("col2", 5));
547 assert_eq!(collect_columns(&expr3), expected);
548 Ok(())
549 }
550}