datafusion_physical_expr/statistics/
stats_solver.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::expressions::Literal;
21use crate::intervals::cp_solver::PropagationResult;
22use crate::physical_expr::PhysicalExpr;
23use crate::utils::{build_dag, ExprTreeNode};
24
25use arrow::datatypes::{DataType, Schema};
26use datafusion_common::{Result, ScalarValue};
27use datafusion_expr::statistics::Distribution;
28use datafusion_expr_common::interval_arithmetic::Interval;
29
30use petgraph::adj::DefaultIx;
31use petgraph::prelude::Bfs;
32use petgraph::stable_graph::{NodeIndex, StableGraph};
33use petgraph::visit::DfsPostOrder;
34use petgraph::Outgoing;
35
36/// This object implements a directed acyclic expression graph (DAEG) that
37/// is used to compute statistics/distributions for expressions hierarchically.
38#[derive(Clone, Debug)]
39pub struct ExprStatisticsGraph {
40    graph: StableGraph<ExprStatisticsGraphNode, usize>,
41    root: NodeIndex,
42}
43
44/// This is a node in the DAEG; it encapsulates a reference to the actual
45/// [`PhysicalExpr`] as well as its statistics/distribution.
46#[derive(Clone, Debug)]
47pub struct ExprStatisticsGraphNode {
48    expr: Arc<dyn PhysicalExpr>,
49    dist: Distribution,
50}
51
52impl ExprStatisticsGraphNode {
53    /// Constructs a new DAEG node based on the given interval with a
54    /// `Uniform` distribution.
55    fn new_uniform(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Result<Self> {
56        Distribution::new_uniform(interval)
57            .map(|dist| ExprStatisticsGraphNode { expr, dist })
58    }
59
60    /// Constructs a new DAEG node with a `Bernoulli` distribution having an
61    /// unknown success probability.
62    fn new_bernoulli(expr: Arc<dyn PhysicalExpr>) -> Result<Self> {
63        Distribution::new_bernoulli(ScalarValue::Float64(None))
64            .map(|dist| ExprStatisticsGraphNode { expr, dist })
65    }
66
67    /// Constructs a new DAEG node with a `Generic` distribution having no
68    /// definite summary statistics.
69    fn new_generic(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> {
70        let interval = Interval::make_unbounded(dt)?;
71        let dist = Distribution::new_from_interval(interval)?;
72        Ok(ExprStatisticsGraphNode { expr, dist })
73    }
74
75    /// Get the [`Distribution`] object representing the statistics of the
76    /// expression.
77    pub fn distribution(&self) -> &Distribution {
78        &self.dist
79    }
80
81    /// This function creates a DAEG node from DataFusion's [`ExprTreeNode`]
82    /// object. Literals are created with `Uniform` distributions with a
83    /// definite, singleton interval. Expressions with a `Boolean` data type
84    /// result in a`Bernoulli` distribution with an unknown success probability.
85    /// Any other expression starts with an `Unknown` distribution with an
86    /// indefinite range (i.e. `[-∞, ∞]`).
87    pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> {
88        let expr = Arc::clone(&node.expr);
89        if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
90            let value = literal.value();
91            Interval::try_new(value.clone(), value.clone())
92                .and_then(|interval| Self::new_uniform(expr, interval))
93        } else {
94            expr.data_type(schema).and_then(|dt| {
95                if dt.eq(&DataType::Boolean) {
96                    Self::new_bernoulli(expr)
97                } else {
98                    Self::new_generic(expr, &dt)
99                }
100            })
101        }
102    }
103}
104
105impl ExprStatisticsGraph {
106    pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> {
107        // Build the full graph:
108        let (root, graph) = build_dag(expr, &|node| {
109            ExprStatisticsGraphNode::make_node(node, schema)
110        })?;
111        Ok(Self { graph, root })
112    }
113
114    /// This function assigns given distributions to expressions in the DAEG.
115    /// The argument `assignments` associates indices of sought expressions
116    /// with their corresponding new distributions.
117    pub fn assign_statistics(&mut self, assignments: &[(usize, Distribution)]) {
118        for (index, stats) in assignments {
119            let node_index = NodeIndex::from(*index as DefaultIx);
120            self.graph[node_index].dist = stats.clone();
121        }
122    }
123
124    /// Computes statistics/distributions for an expression via a bottom-up
125    /// traversal.
126    pub fn evaluate_statistics(&mut self) -> Result<&Distribution> {
127        let mut dfs = DfsPostOrder::new(&self.graph, self.root);
128        while let Some(idx) = dfs.next(&self.graph) {
129            let neighbors = self.graph.neighbors_directed(idx, Outgoing);
130            let mut children_statistics = neighbors
131                .map(|child| self.graph[child].distribution())
132                .collect::<Vec<_>>();
133            // Note that all distributions are assumed to be independent.
134            if !children_statistics.is_empty() {
135                // Reverse to align with `PhysicalExpr`'s children:
136                children_statistics.reverse();
137                self.graph[idx].dist = self.graph[idx]
138                    .expr
139                    .evaluate_statistics(&children_statistics)?;
140            }
141        }
142        Ok(self.graph[self.root].distribution())
143    }
144
145    /// Runs a propagation mechanism in a top-down manner to update statistics
146    /// of leaf nodes.
147    pub fn propagate_statistics(
148        &mut self,
149        given_stats: Distribution,
150    ) -> Result<PropagationResult> {
151        // Adjust the root node with the given statistics:
152        let root_range = self.graph[self.root].dist.range()?;
153        let given_range = given_stats.range()?;
154        if let Some(interval) = root_range.intersect(&given_range)? {
155            if interval != root_range {
156                // If the given statistics enable us to obtain a more precise
157                // range for the root, update it:
158                let subset = root_range.contains(given_range)?;
159                self.graph[self.root].dist = if subset == Interval::CERTAINLY_TRUE {
160                    // Given statistics is strictly more informative, use it as is:
161                    given_stats
162                } else {
163                    // Intersecting ranges gives us a more precise range:
164                    Distribution::new_from_interval(interval)?
165                };
166            }
167        } else {
168            return Ok(PropagationResult::Infeasible);
169        }
170
171        let mut bfs = Bfs::new(&self.graph, self.root);
172
173        while let Some(node) = bfs.next(&self.graph) {
174            let neighbors = self.graph.neighbors_directed(node, Outgoing);
175            let mut children = neighbors.collect::<Vec<_>>();
176            // If the current expression is a leaf, its statistics is now final.
177            // So, just continue with the propagation procedure:
178            if children.is_empty() {
179                continue;
180            }
181            // Reverse to align with `PhysicalExpr`'s children:
182            children.reverse();
183            let children_stats = children
184                .iter()
185                .map(|child| self.graph[*child].distribution())
186                .collect::<Vec<_>>();
187            let node_statistics = self.graph[node].distribution();
188            let propagated_statistics = self.graph[node]
189                .expr
190                .propagate_statistics(node_statistics, &children_stats)?;
191            if let Some(propagated_stats) = propagated_statistics {
192                for (child_idx, stats) in children.into_iter().zip(propagated_stats) {
193                    self.graph[child_idx].dist = stats;
194                }
195            } else {
196                // The constraint is infeasible, report:
197                return Ok(PropagationResult::Infeasible);
198            }
199        }
200        Ok(PropagationResult::Success)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use std::sync::Arc;
207
208    use crate::expressions::{binary, try_cast, Column};
209    use crate::intervals::cp_solver::PropagationResult;
210    use crate::statistics::stats_solver::ExprStatisticsGraph;
211
212    use arrow::datatypes::{DataType, Field, Schema};
213    use datafusion_common::{Result, ScalarValue};
214    use datafusion_expr_common::interval_arithmetic::Interval;
215    use datafusion_expr_common::operator::Operator;
216    use datafusion_expr_common::statistics::Distribution;
217    use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
218    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
219
220    pub fn binary_expr(
221        left: Arc<dyn PhysicalExpr>,
222        op: Operator,
223        right: Arc<dyn PhysicalExpr>,
224        schema: &Schema,
225    ) -> Result<Arc<dyn PhysicalExpr>> {
226        let left_type = left.data_type(schema)?;
227        let right_type = right.data_type(schema)?;
228        let binary_type_coercer = BinaryTypeCoercer::new(&left_type, &op, &right_type);
229        let (lhs, rhs) = binary_type_coercer.get_input_types()?;
230
231        let left_expr = try_cast(left, schema, lhs)?;
232        let right_expr = try_cast(right, schema, rhs)?;
233        binary(left_expr, op, right_expr, schema)
234    }
235
236    #[test]
237    fn test_stats_integration() -> Result<()> {
238        let schema = &Schema::new(vec![
239            Field::new("a", DataType::Float64, false),
240            Field::new("b", DataType::Float64, false),
241            Field::new("c", DataType::Float64, false),
242            Field::new("d", DataType::Float64, false),
243        ]);
244
245        let a = Arc::new(Column::new("a", 0)) as _;
246        let b = Arc::new(Column::new("b", 1)) as _;
247        let c = Arc::new(Column::new("c", 2)) as _;
248        let d = Arc::new(Column::new("d", 3)) as _;
249
250        let left = binary_expr(a, Operator::Plus, b, schema)?;
251        let right = binary_expr(c, Operator::Minus, d, schema)?;
252        let expr = binary_expr(left, Operator::Eq, right, schema)?;
253
254        let mut graph = ExprStatisticsGraph::try_new(expr, schema)?;
255        // 2, 5 and 6 are BinaryExpr
256        graph.assign_statistics(&[
257            (
258                0usize,
259                Distribution::new_uniform(Interval::make(Some(0.), Some(1.))?)?,
260            ),
261            (
262                1usize,
263                Distribution::new_uniform(Interval::make(Some(0.), Some(2.))?)?,
264            ),
265            (
266                3usize,
267                Distribution::new_uniform(Interval::make(Some(1.), Some(3.))?)?,
268            ),
269            (
270                4usize,
271                Distribution::new_uniform(Interval::make(Some(1.), Some(5.))?)?,
272            ),
273        ]);
274        let ev_stats = graph.evaluate_statistics()?;
275        assert_eq!(
276            ev_stats,
277            &Distribution::new_bernoulli(ScalarValue::Float64(None))?
278        );
279
280        let one = ScalarValue::new_one(&DataType::Float64)?;
281        assert_eq!(
282            graph.propagate_statistics(Distribution::new_bernoulli(one)?)?,
283            PropagationResult::Success
284        );
285        Ok(())
286    }
287}