datafusion_physical_expr/statistics/
stats_solver.rs1use std::sync::Arc;
19
20use crate::expressions::Literal;
21use crate::intervals::cp_solver::PropagationResult;
22use crate::physical_expr::PhysicalExpr;
23use crate::utils::{build_dag, ExprTreeNode};
24
25use arrow::datatypes::{DataType, Schema};
26use datafusion_common::{Result, ScalarValue};
27use datafusion_expr::statistics::Distribution;
28use datafusion_expr_common::interval_arithmetic::Interval;
29
30use petgraph::adj::DefaultIx;
31use petgraph::prelude::Bfs;
32use petgraph::stable_graph::{NodeIndex, StableGraph};
33use petgraph::visit::DfsPostOrder;
34use petgraph::Outgoing;
35
36#[derive(Clone, Debug)]
39pub struct ExprStatisticsGraph {
40 graph: StableGraph<ExprStatisticsGraphNode, usize>,
41 root: NodeIndex,
42}
43
44#[derive(Clone, Debug)]
47pub struct ExprStatisticsGraphNode {
48 expr: Arc<dyn PhysicalExpr>,
49 dist: Distribution,
50}
51
52impl ExprStatisticsGraphNode {
53 fn new_uniform(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Result<Self> {
56 Distribution::new_uniform(interval)
57 .map(|dist| ExprStatisticsGraphNode { expr, dist })
58 }
59
60 fn new_bernoulli(expr: Arc<dyn PhysicalExpr>) -> Result<Self> {
63 Distribution::new_bernoulli(ScalarValue::Float64(None))
64 .map(|dist| ExprStatisticsGraphNode { expr, dist })
65 }
66
67 fn new_generic(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> {
70 let interval = Interval::make_unbounded(dt)?;
71 let dist = Distribution::new_from_interval(interval)?;
72 Ok(ExprStatisticsGraphNode { expr, dist })
73 }
74
75 pub fn distribution(&self) -> &Distribution {
78 &self.dist
79 }
80
81 pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> {
88 let expr = Arc::clone(&node.expr);
89 if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
90 let value = literal.value();
91 Interval::try_new(value.clone(), value.clone())
92 .and_then(|interval| Self::new_uniform(expr, interval))
93 } else {
94 expr.data_type(schema).and_then(|dt| {
95 if dt.eq(&DataType::Boolean) {
96 Self::new_bernoulli(expr)
97 } else {
98 Self::new_generic(expr, &dt)
99 }
100 })
101 }
102 }
103}
104
105impl ExprStatisticsGraph {
106 pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> {
107 let (root, graph) = build_dag(expr, &|node| {
109 ExprStatisticsGraphNode::make_node(node, schema)
110 })?;
111 Ok(Self { graph, root })
112 }
113
114 pub fn assign_statistics(&mut self, assignments: &[(usize, Distribution)]) {
118 for (index, stats) in assignments {
119 let node_index = NodeIndex::from(*index as DefaultIx);
120 self.graph[node_index].dist = stats.clone();
121 }
122 }
123
124 pub fn evaluate_statistics(&mut self) -> Result<&Distribution> {
127 let mut dfs = DfsPostOrder::new(&self.graph, self.root);
128 while let Some(idx) = dfs.next(&self.graph) {
129 let neighbors = self.graph.neighbors_directed(idx, Outgoing);
130 let mut children_statistics = neighbors
131 .map(|child| self.graph[child].distribution())
132 .collect::<Vec<_>>();
133 if !children_statistics.is_empty() {
135 children_statistics.reverse();
137 self.graph[idx].dist = self.graph[idx]
138 .expr
139 .evaluate_statistics(&children_statistics)?;
140 }
141 }
142 Ok(self.graph[self.root].distribution())
143 }
144
145 pub fn propagate_statistics(
148 &mut self,
149 given_stats: Distribution,
150 ) -> Result<PropagationResult> {
151 let root_range = self.graph[self.root].dist.range()?;
153 let given_range = given_stats.range()?;
154 if let Some(interval) = root_range.intersect(&given_range)? {
155 if interval != root_range {
156 let subset = root_range.contains(given_range)?;
159 self.graph[self.root].dist = if subset == Interval::CERTAINLY_TRUE {
160 given_stats
162 } else {
163 Distribution::new_from_interval(interval)?
165 };
166 }
167 } else {
168 return Ok(PropagationResult::Infeasible);
169 }
170
171 let mut bfs = Bfs::new(&self.graph, self.root);
172
173 while let Some(node) = bfs.next(&self.graph) {
174 let neighbors = self.graph.neighbors_directed(node, Outgoing);
175 let mut children = neighbors.collect::<Vec<_>>();
176 if children.is_empty() {
179 continue;
180 }
181 children.reverse();
183 let children_stats = children
184 .iter()
185 .map(|child| self.graph[*child].distribution())
186 .collect::<Vec<_>>();
187 let node_statistics = self.graph[node].distribution();
188 let propagated_statistics = self.graph[node]
189 .expr
190 .propagate_statistics(node_statistics, &children_stats)?;
191 if let Some(propagated_stats) = propagated_statistics {
192 for (child_idx, stats) in children.into_iter().zip(propagated_stats) {
193 self.graph[child_idx].dist = stats;
194 }
195 } else {
196 return Ok(PropagationResult::Infeasible);
198 }
199 }
200 Ok(PropagationResult::Success)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use std::sync::Arc;
207
208 use crate::expressions::{binary, try_cast, Column};
209 use crate::intervals::cp_solver::PropagationResult;
210 use crate::statistics::stats_solver::ExprStatisticsGraph;
211
212 use arrow::datatypes::{DataType, Field, Schema};
213 use datafusion_common::{Result, ScalarValue};
214 use datafusion_expr_common::interval_arithmetic::Interval;
215 use datafusion_expr_common::operator::Operator;
216 use datafusion_expr_common::statistics::Distribution;
217 use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
218 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
219
220 pub fn binary_expr(
221 left: Arc<dyn PhysicalExpr>,
222 op: Operator,
223 right: Arc<dyn PhysicalExpr>,
224 schema: &Schema,
225 ) -> Result<Arc<dyn PhysicalExpr>> {
226 let left_type = left.data_type(schema)?;
227 let right_type = right.data_type(schema)?;
228 let binary_type_coercer = BinaryTypeCoercer::new(&left_type, &op, &right_type);
229 let (lhs, rhs) = binary_type_coercer.get_input_types()?;
230
231 let left_expr = try_cast(left, schema, lhs)?;
232 let right_expr = try_cast(right, schema, rhs)?;
233 binary(left_expr, op, right_expr, schema)
234 }
235
236 #[test]
237 fn test_stats_integration() -> Result<()> {
238 let schema = &Schema::new(vec![
239 Field::new("a", DataType::Float64, false),
240 Field::new("b", DataType::Float64, false),
241 Field::new("c", DataType::Float64, false),
242 Field::new("d", DataType::Float64, false),
243 ]);
244
245 let a = Arc::new(Column::new("a", 0)) as _;
246 let b = Arc::new(Column::new("b", 1)) as _;
247 let c = Arc::new(Column::new("c", 2)) as _;
248 let d = Arc::new(Column::new("d", 3)) as _;
249
250 let left = binary_expr(a, Operator::Plus, b, schema)?;
251 let right = binary_expr(c, Operator::Minus, d, schema)?;
252 let expr = binary_expr(left, Operator::Eq, right, schema)?;
253
254 let mut graph = ExprStatisticsGraph::try_new(expr, schema)?;
255 graph.assign_statistics(&[
257 (
258 0usize,
259 Distribution::new_uniform(Interval::make(Some(0.), Some(1.))?)?,
260 ),
261 (
262 1usize,
263 Distribution::new_uniform(Interval::make(Some(0.), Some(2.))?)?,
264 ),
265 (
266 3usize,
267 Distribution::new_uniform(Interval::make(Some(1.), Some(3.))?)?,
268 ),
269 (
270 4usize,
271 Distribution::new_uniform(Interval::make(Some(1.), Some(5.))?)?,
272 ),
273 ]);
274 let ev_stats = graph.evaluate_statistics()?;
275 assert_eq!(
276 ev_stats,
277 &Distribution::new_bernoulli(ScalarValue::Float64(None))?
278 );
279
280 let one = ScalarValue::new_one(&DataType::Float64)?;
281 assert_eq!(
282 graph.propagate_statistics(Distribution::new_bernoulli(one)?)?,
283 PropagationResult::Success
284 );
285 Ok(())
286 }
287}