mod guarantee;
pub use guarantee::{Guarantee, LiteralGuarantee};
use std::borrow::{Borrow, Cow};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::expressions::{BinaryExpr, Column};
use crate::{PhysicalExpr, PhysicalSortExpr};
use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData};
use arrow::compute::{and_kleene, is_not_null, SlicesIterator};
use arrow::datatypes::SchemaRef;
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeRewriter, VisitRecursion,
};
use datafusion_common::Result;
use datafusion_expr::Operator;
use itertools::Itertools;
use petgraph::graph::NodeIndex;
use petgraph::stable_graph::StableGraph;
pub fn split_conjunction(
predicate: &Arc<dyn PhysicalExpr>,
) -> Vec<&Arc<dyn PhysicalExpr>> {
split_impl(Operator::And, predicate, vec![])
}
pub fn split_disjunction(
predicate: &Arc<dyn PhysicalExpr>,
) -> Vec<&Arc<dyn PhysicalExpr>> {
split_impl(Operator::Or, predicate, vec![])
}
fn split_impl<'a>(
operator: Operator,
predicate: &'a Arc<dyn PhysicalExpr>,
mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
) -> Vec<&'a Arc<dyn PhysicalExpr>> {
match predicate.as_any().downcast_ref::<BinaryExpr>() {
Some(binary) if binary.op() == &operator => {
let exprs = split_impl(operator, binary.left(), exprs);
split_impl(operator, binary.right(), exprs)
}
Some(_) | None => {
exprs.push(predicate);
exprs
}
}
}
pub fn map_columns_before_projection(
parent_required: &[Arc<dyn PhysicalExpr>],
proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
) -> Vec<Arc<dyn PhysicalExpr>> {
let column_mapping = proj_exprs
.iter()
.filter_map(|(expr, name)| {
expr.as_any()
.downcast_ref::<Column>()
.map(|column| (name.clone(), column.clone()))
})
.collect::<HashMap<_, _>>();
parent_required
.iter()
.filter_map(|r| {
r.as_any()
.downcast_ref::<Column>()
.and_then(|c| column_mapping.get(c.name()))
})
.map(|e| Arc::new(e.clone()) as _)
.collect()
}
pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
sequence: impl IntoIterator<Item = T>,
) -> Vec<Arc<dyn PhysicalExpr>> {
sequence
.into_iter()
.map(|elem| elem.borrow().expr.clone())
.collect()
}
pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
targets: impl IntoIterator<Item = T>,
items: &[Arc<dyn PhysicalExpr>],
) -> Vec<usize> {
targets
.into_iter()
.filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
.collect()
}
#[derive(Clone, Debug)]
pub struct ExprTreeNode<T> {
expr: Arc<dyn PhysicalExpr>,
data: Option<T>,
child_nodes: Vec<ExprTreeNode<T>>,
}
impl<T> ExprTreeNode<T> {
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
let children = expr.children();
ExprTreeNode {
expr,
data: None,
child_nodes: children.into_iter().map(Self::new).collect_vec(),
}
}
pub fn expression(&self) -> &Arc<dyn PhysicalExpr> {
&self.expr
}
pub fn children(&self) -> &[ExprTreeNode<T>] {
&self.child_nodes
}
}
impl<T: Clone> TreeNode for ExprTreeNode<T> {
fn children_nodes(&self) -> Vec<Cow<Self>> {
self.children().iter().map(Cow::Borrowed).collect()
}
fn map_children<F>(mut self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>,
{
self.child_nodes = self
.child_nodes
.into_iter()
.map(transform)
.collect::<Result<Vec<_>>>()?;
Ok(self)
}
}
struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
graph: StableGraph<T, usize>,
visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
constructor: &'a F,
}
impl<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> TreeNodeRewriter
for PhysicalExprDAEGBuilder<'a, T, F>
{
type N = ExprTreeNode<NodeIndex>;
fn mutate(
&mut self,
mut node: ExprTreeNode<NodeIndex>,
) -> Result<ExprTreeNode<NodeIndex>> {
let expr = &node.expr;
let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
Some((_, idx)) => *idx,
None => {
let node_idx = self.graph.add_node((self.constructor)(&node)?);
for expr_node in node.child_nodes.iter() {
self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
}
self.visited_plans.push((expr.clone(), node_idx));
node_idx
}
};
node.data = Some(node_idx);
Ok(node)
}
}
pub fn build_dag<T, F>(
expr: Arc<dyn PhysicalExpr>,
constructor: &F,
) -> Result<(NodeIndex, StableGraph<T, usize>)>
where
F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
{
let init = ExprTreeNode::new(expr);
let mut builder = PhysicalExprDAEGBuilder {
graph: StableGraph::<T, usize>::new(),
visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
constructor,
};
let root = init.rewrite(&mut builder)?;
Ok((root.data.unwrap(), builder.graph))
}
pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
let mut columns = HashSet::<Column>::new();
expr.apply(&mut |expr| {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if !columns.iter().any(|c| c.eq(column)) {
columns.insert(column.clone());
}
}
Ok(VisitRecursion::Continue)
})
.expect("no way to return error during recursion");
columns
}
pub fn reassign_predicate_columns(
pred: Arc<dyn PhysicalExpr>,
schema: &SchemaRef,
ignore_not_found: bool,
) -> Result<Arc<dyn PhysicalExpr>> {
pred.transform_down(&|expr| {
let expr_any = expr.as_any();
if let Some(column) = expr_any.downcast_ref::<Column>() {
let index = match schema.index_of(column.name()) {
Ok(idx) => idx,
Err(_) if ignore_not_found => usize::MAX,
Err(e) => return Err(e.into()),
};
return Ok(Transformed::Yes(Arc::new(Column::new(
column.name(),
index,
))));
}
Ok(Transformed::No(expr))
})
}
pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec<PhysicalSortExpr> {
order_bys
.iter()
.map(|e| PhysicalSortExpr {
expr: e.expr.clone(),
options: !e.options,
})
.collect()
}
pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result<ArrayRef> {
let truthy = truthy.to_data();
let mask = and_kleene(mask, &is_not_null(mask)?)?;
let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len());
let mut filled = 0;
let mut true_pos = 0;
SlicesIterator::new(&mask).for_each(|(start, end)| {
if start > filled {
mutable.extend_nulls(start - filled);
}
let len = end - start;
mutable.extend(0, true_pos, true_pos + len);
true_pos += len;
filled = end;
});
if filled < mask.len() {
mutable.extend_nulls(mask.len() - filled);
}
let data = mutable.freeze();
Ok(make_array(data))
}
pub fn merge_vectors(
left: &[PhysicalSortExpr],
right: &[PhysicalSortExpr],
) -> Vec<PhysicalSortExpr> {
left.iter()
.cloned()
.chain(right.iter().cloned())
.unique()
.collect()
}
#[cfg(test)]
mod tests {
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use super::*;
use crate::expressions::{binary, cast, col, in_list, lit, Column, Literal};
use crate::PhysicalSortExpr;
use arrow_array::Int32Array;
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::cast::{as_boolean_array, as_int32_array};
use datafusion_common::{Result, ScalarValue};
use petgraph::visit::Bfs;
#[derive(Clone)]
struct DummyProperty {
expr_type: String,
}
#[derive(Clone)]
struct PhysicalExprDummyNode {
pub expr: Arc<dyn PhysicalExpr>,
pub property: DummyProperty,
}
impl Display for PhysicalExprDummyNode {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.expr)
}
}
fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
let expr = node.expression().clone();
let dummy_property = if expr.as_any().is::<BinaryExpr>() {
"Binary"
} else if expr.as_any().is::<Column>() {
"Column"
} else if expr.as_any().is::<Literal>() {
"Literal"
} else {
"Other"
}
.to_owned();
Ok(PhysicalExprDummyNode {
expr,
property: DummyProperty {
expr_type: dummy_property,
},
})
}
#[test]
fn test_build_dag() -> Result<()> {
let schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let expr = binary(
cast(
binary(
col("0", &schema)?,
Operator::Plus,
col("1", &schema)?,
&schema,
)?,
&schema,
DataType::Int64,
)?,
Operator::Gt,
binary(
cast(col("2", &schema)?, &schema, DataType::Int64)?,
Operator::Plus,
lit(ScalarValue::Int64(Some(10))),
&schema,
)?,
&schema,
)?;
let mut vector_dummy_props = vec![];
let (root, graph) = build_dag(expr, &make_dummy_node)?;
let mut bfs = Bfs::new(&graph, root);
while let Some(node_index) = bfs.next(&graph) {
let node = &graph[node_index];
vector_dummy_props.push(node.property.clone());
}
assert_eq!(
vector_dummy_props
.iter()
.filter(|property| property.expr_type == "Binary")
.count(),
3
);
assert_eq!(
vector_dummy_props
.iter()
.filter(|property| property.expr_type == "Column")
.count(),
3
);
assert_eq!(
vector_dummy_props
.iter()
.filter(|property| property.expr_type == "Literal")
.count(),
1
);
assert_eq!(
vector_dummy_props
.iter()
.filter(|property| property.expr_type == "Other")
.count(),
2
);
Ok(())
}
#[test]
fn test_convert_to_expr() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
let sort_expr = vec![PhysicalSortExpr {
expr: col("a", &schema)?,
options: Default::default(),
}];
assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
Ok(())
}
#[test]
fn test_get_indices_of_exprs_strict() {
let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new("a", 0)),
Arc::new(Column::new("b", 1)),
Arc::new(Column::new("c", 2)),
Arc::new(Column::new("d", 3)),
];
let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new("b", 1)),
Arc::new(Column::new("c", 2)),
Arc::new(Column::new("a", 0)),
];
assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
}
#[test]
fn test_reassign_predicate_columns_in_list() {
let int_field = Field::new("should_not_matter", DataType::Int64, true);
let dict_field = Field::new(
"id",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
true,
);
let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
let pred = in_list(
Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
vec![lit(ScalarValue::Dictionary(
Box::new(DataType::Int32),
Box::new(ScalarValue::from("2")),
))],
&false,
&schema_big,
)
.unwrap();
let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap();
let expected = in_list(
Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
vec![lit(ScalarValue::Dictionary(
Box::new(DataType::Int32),
Box::new(ScalarValue::from("2")),
))],
&false,
&schema_small,
)
.unwrap();
assert_eq!(actual.as_ref(), expected.as_any());
}
#[test]
fn test_collect_columns() -> Result<()> {
let expr1 = Arc::new(Column::new("col1", 2)) as _;
let mut expected = HashSet::new();
expected.insert(Column::new("col1", 2));
assert_eq!(collect_columns(&expr1), expected);
let expr2 = Arc::new(Column::new("col2", 5)) as _;
let mut expected = HashSet::new();
expected.insert(Column::new("col2", 5));
assert_eq!(collect_columns(&expr2), expected);
let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
let mut expected = HashSet::new();
expected.insert(Column::new("col1", 2));
expected.insert(Column::new("col2", 5));
assert_eq!(collect_columns(&expr3), expected);
Ok(())
}
#[test]
fn scatter_int() -> Result<()> {
let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100]));
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let expected =
Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]);
let result = scatter(&mask, truthy.as_ref())?;
let result = as_int32_array(&result)?;
assert_eq!(&expected, result);
Ok(())
}
#[test]
fn scatter_int_end_with_false() -> Result<()> {
let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100]));
let mask = BooleanArray::from(vec![true, false, true, false, false, false]);
let expected =
Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]);
let result = scatter(&mask, truthy.as_ref())?;
let result = as_int32_array(&result)?;
assert_eq!(&expected, result);
Ok(())
}
#[test]
fn scatter_with_null_mask() -> Result<()> {
let truthy = Arc::new(Int32Array::from(vec![1, 10, 11]));
let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None]
.into_iter()
.collect();
let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]);
let result = scatter(&mask, truthy.as_ref())?;
let result = as_int32_array(&result)?;
assert_eq!(&expected, result);
Ok(())
}
#[test]
fn scatter_boolean() -> Result<()> {
let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true]));
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let expected = BooleanArray::from_iter(vec![
Some(false),
Some(false),
None,
None,
Some(false),
]);
let result = scatter(&mask, truthy.as_ref())?;
let result = as_boolean_array(&result)?;
assert_eq!(&expected, result);
Ok(())
}
}