datafusion_common/cse.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with
19//! a [`CSEController`], that defines how to eliminate common subtrees from a particular
20//! [`TreeNode`] tree.
21
22use crate::hash_utils::combine_hashes;
23use crate::tree_node::{
24 Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
25 TreeNodeVisitor,
26};
27use crate::Result;
28use indexmap::IndexMap;
29use std::collections::HashMap;
30use std::hash::{BuildHasher, Hash, Hasher, RandomState};
31use std::marker::PhantomData;
32use std::sync::Arc;
33
34/// Hashes the direct content of an [`TreeNode`] without recursing into its children.
35///
36/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds
37/// a deep hash of a node and its descendants during the bottom-up phase of the first
38/// traversal and so avoid computing the hash of the node and then the hash of its
39/// descendants separately.
40///
41/// If a node doesn't have any children then the value returned by `hash_node()` is
42/// similar to '.hash()`, but not necessarily returns the same value.
43pub trait HashNode {
44 fn hash_node<H: Hasher>(&self, state: &mut H);
45}
46
47impl<T: HashNode + ?Sized> HashNode for Arc<T> {
48 fn hash_node<H: Hasher>(&self, state: &mut H) {
49 (**self).hash_node(state);
50 }
51}
52
53/// The `Normalizeable` trait defines a method to determine whether a node can be normalized.
54///
55/// Normalization is the process of converting a node into a canonical form that can be used
56/// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE),
57/// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal.
58pub trait Normalizeable {
59 fn can_normalize(&self) -> bool;
60}
61
62/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing
63/// normalized nodes in optimizations like Common Subexpression Elimination (CSE).
64///
65/// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization)
66/// are considered equal in CSE optimization, even if their original forms differ.
67///
68/// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their
69/// internal representations.
70pub trait NormalizeEq: Eq + Normalizeable {
71 fn normalize_eq(&self, other: &Self) -> bool;
72}
73
74/// Identifier that represents a [`TreeNode`] tree.
75///
76/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
77/// "have no collision (as low as possible)"
78#[derive(Debug, Eq)]
79struct Identifier<'n, N: NormalizeEq> {
80 // Hash of `node` built up incrementally during the first, visiting traversal.
81 // Its value is not necessarily equal to default hash of the node. E.g. it is not
82 // equal to `expr.hash()` if the node is `Expr`.
83 hash: u64,
84 node: &'n N,
85}
86
87impl<N: NormalizeEq> Clone for Identifier<'_, N> {
88 fn clone(&self) -> Self {
89 *self
90 }
91}
92impl<N: NormalizeEq> Copy for Identifier<'_, N> {}
93
94impl<N: NormalizeEq> Hash for Identifier<'_, N> {
95 fn hash<H: Hasher>(&self, state: &mut H) {
96 state.write_u64(self.hash);
97 }
98}
99
100impl<N: NormalizeEq> PartialEq for Identifier<'_, N> {
101 fn eq(&self, other: &Self) -> bool {
102 self.hash == other.hash && self.node.normalize_eq(other.node)
103 }
104}
105
106impl<'n, N> Identifier<'n, N>
107where
108 N: HashNode + NormalizeEq,
109{
110 fn new(node: &'n N, random_state: &RandomState) -> Self {
111 let mut hasher = random_state.build_hasher();
112 node.hash_node(&mut hasher);
113 let hash = hasher.finish();
114 Self { hash, node }
115 }
116
117 fn combine(mut self, other: Option<Self>) -> Self {
118 other.map_or(self, |other_id| {
119 self.hash = combine_hashes(self.hash, other_id.hash);
120 self
121 })
122 }
123}
124
125/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the
126/// preorder index of the nodes.
127///
128/// This cache is filled by [`CSEVisitor`] during the first traversal and is
129/// used by [`CSERewriter`] during the second traversal.
130///
131/// The purpose of this cache is to quickly find the identifier of a node during the
132/// second traversal.
133///
134/// Elements in this array are added during `f_down` so the indexes represent the preorder
135/// index of nodes and thus element 0 belongs to the root of the tree.
136///
137/// The elements of the array are tuples that contain:
138/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start
139/// from 0.
140/// - The optional [`Identifier`] of the node. If none the node should not be considered
141/// for CSE.
142///
143/// # Example
144/// An expression tree like `(a + b)` would have the following `IdArray`:
145/// ```text
146/// [
147/// (2, Some(Identifier(hash_of("a + b"), &"a + b"))),
148/// (1, Some(Identifier(hash_of("a"), &"a"))),
149/// (0, Some(Identifier(hash_of("b"), &"b")))
150/// ]
151/// ```
152type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;
153
154#[derive(PartialEq, Eq)]
155/// How many times a node is evaluated. A node can be considered common if evaluated
156/// surely at least 2 times or surely only once but also conditionally.
157enum NodeEvaluation {
158 SurelyOnce,
159 ConditionallyAtLeastOnce,
160 Common,
161}
162
163/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers.
164type NodeStats<'n, N> = HashMap<Identifier<'n, N>, NodeEvaluation>;
165
166/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
167/// extracted during the second, rewriting traversal.
168type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
169
170type ChildrenList<N> = (Vec<N>, Vec<N>);
171
172/// The [`TreeNode`] specific definition of elimination.
173pub trait CSEController {
174 /// The type of the tree nodes.
175 type Node;
176
177 /// Splits the children to normal and conditionally evaluated ones or returns `None`
178 /// if all are always evaluated.
179 fn conditional_children(node: &Self::Node) -> Option<ChildrenList<&Self::Node>>;
180
181 // Returns true if a node is valid. If a node is invalid then it can't be eliminated.
182 // Validity is propagated up which means no subtree can be eliminated that contains
183 // an invalid node.
184 // (E.g. volatile expressions are not valid and subtrees containing such a node can't
185 // be extracted.)
186 fn is_valid(node: &Self::Node) -> bool;
187
188 // Returns true if a node should be ignored during CSE. Contrary to validity of a node,
189 // it is not propagated up.
190 fn is_ignored(&self, node: &Self::Node) -> bool;
191
192 // Generates a new name for the extracted subtree.
193 fn generate_alias(&self) -> String;
194
195 // Replaces a node to the generated alias.
196 fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
197
198 // A helper method called on each node during top-down traversal during the second,
199 // rewriting traversal of CSE.
200 fn rewrite_f_down(&mut self, _node: &Self::Node) {}
201
202 // A helper method called on each node during bottom-up traversal during the second,
203 // rewriting traversal of CSE.
204 fn rewrite_f_up(&mut self, _node: &Self::Node) {}
205}
206
207/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common
208/// subtrees.
209#[derive(Debug)]
210pub enum FoundCommonNodes<N> {
211 /// No common [`TreeNode`]s were found
212 No { original_nodes_list: Vec<Vec<N>> },
213
214 /// Common [`TreeNode`]s were found
215 Yes {
216 /// extracted common [`TreeNode`]
217 common_nodes: Vec<(N, String)>,
218
219 /// new [`TreeNode`]s with common subtrees replaced
220 new_nodes_list: Vec<Vec<N>>,
221
222 /// original [`TreeNode`]s
223 original_nodes_list: Vec<Vec<N>>,
224 },
225}
226
227/// Go through a [`TreeNode`] tree and generate identifiers for each subtrees.
228///
229/// An identifier contains information of the [`TreeNode`] itself and its subtrees.
230/// This visitor implementation use a stack `visit_stack` to track traversal, which
231/// lets us know when a subtree's visiting is finished. When `pre_visit` is called
232/// (traversing to a new node), an `EnterMark` and an `NodeItem` will be pushed into stack.
233/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `NodeItem`
234/// before the first `EnterMark` is considered to be sub-tree of the leaving node.
235///
236/// This visitor also records identifier in `id_array`. Makes the following traverse
237/// pass can get the identifier of a node without recalculate it. We assign each node
238/// in the tree a series number, start from 1, maintained by `series_number`.
239/// Series number represents the order we left (`f_up()`) a node. Has the property
240/// that child node's series number always smaller than parent's. While `id_array` is
241/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to
242/// get the index of `id_array` for each node.
243///
244/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
245/// because they should not be recognized as common subtree.
246struct CSEVisitor<'a, 'n, N, C>
247where
248 N: NormalizeEq,
249 C: CSEController<Node = N>,
250{
251 /// statistics of [`TreeNode`]s
252 node_stats: &'a mut NodeStats<'n, N>,
253
254 /// cache to speed up second traversal
255 id_array: &'a mut IdArray<'n, N>,
256
257 /// inner states
258 visit_stack: Vec<VisitRecord<'n, N>>,
259
260 /// preorder index, start from 0.
261 down_index: usize,
262
263 /// postorder index, start from 0.
264 up_index: usize,
265
266 /// a [`RandomState`] to generate hashes during the first traversal
267 random_state: &'a RandomState,
268
269 /// a flag to indicate that common [`TreeNode`]s found
270 found_common: bool,
271
272 /// if we are in a conditional branch. A conditional branch means that the [`TreeNode`]
273 /// might not be executed depending on the runtime values of other [`TreeNode`]s, and
274 /// thus can not be extracted as a common [`TreeNode`].
275 conditional: bool,
276
277 controller: &'a C,
278}
279
280/// Record item that used when traversing a [`TreeNode`] tree.
281enum VisitRecord<'n, N>
282where
283 N: NormalizeEq,
284{
285 /// Marks the beginning of [`TreeNode`]. It contains:
286 /// - The post-order index assigned during the first, visiting traversal.
287 EnterMark(usize),
288
289 /// Marks an accumulated subtree. It contains:
290 /// - The accumulated identifier of a subtree.
291 /// - A accumulated boolean flag if the subtree is valid for CSE.
292 /// The flag is propagated up from children to parent. (E.g. volatile expressions
293 /// are not valid and can't be extracted, but non-volatile children of volatile
294 /// expressions can be extracted.)
295 NodeItem(Identifier<'n, N>, bool),
296}
297
298impl<'n, N, C> CSEVisitor<'_, 'n, N, C>
299where
300 N: TreeNode + HashNode + NormalizeEq,
301 C: CSEController<Node = N>,
302{
303 /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
304 /// it. Returns a tuple that contains:
305 /// - The pre-order index of the [`TreeNode`] we marked.
306 /// - The accumulated identifier of the children of the marked [`TreeNode`].
307 /// - An accumulated boolean flag from the children of the marked [`TreeNode`] if all
308 /// children are valid for CSE (i.e. it is safe to extract the [`TreeNode`] as a
309 /// common [`TreeNode`] from its children POV).
310 /// (E.g. if any of the children of the marked expression is not valid (e.g. is
311 /// volatile) then the expression is also not valid, so we can propagate this
312 /// information up from children to parents via `visit_stack` during the first,
313 /// visiting traversal and no need to test the expression's validity beforehand with
314 /// an extra traversal).
315 fn pop_enter_mark(
316 &mut self,
317 can_normalize: bool,
318 ) -> (usize, Option<Identifier<'n, N>>, bool) {
319 let mut node_ids: Vec<Identifier<'n, N>> = vec![];
320 let mut is_valid = true;
321
322 while let Some(item) = self.visit_stack.pop() {
323 match item {
324 VisitRecord::EnterMark(down_index) => {
325 if can_normalize {
326 node_ids.sort_by_key(|i| i.hash);
327 }
328 let node_id = node_ids
329 .into_iter()
330 .fold(None, |accum, item| Some(item.combine(accum)));
331 return (down_index, node_id, is_valid);
332 }
333 VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => {
334 node_ids.push(sub_node_id);
335 is_valid &= sub_node_is_valid;
336 }
337 }
338 }
339 unreachable!("EnterMark should paired with NodeItem");
340 }
341}
342
343impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C>
344where
345 N: TreeNode + HashNode + NormalizeEq,
346 C: CSEController<Node = N>,
347{
348 type Node = N;
349
350 fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
351 self.id_array.push((0, None));
352 self.visit_stack
353 .push(VisitRecord::EnterMark(self.down_index));
354 self.down_index += 1;
355
356 // If a node can short-circuit then some of its children might not be executed so
357 // count the occurrence either normal or conditional.
358 Ok(if self.conditional {
359 // If we are already in a conditionally evaluated subtree then continue
360 // traversal.
361 TreeNodeRecursion::Continue
362 } else {
363 // If we are already in a node that can short-circuit then start new
364 // traversals on its normal conditional children.
365 match C::conditional_children(node) {
366 Some((normal, conditional)) => {
367 normal
368 .into_iter()
369 .try_for_each(|n| n.visit(self).map(|_| ()))?;
370 self.conditional = true;
371 conditional
372 .into_iter()
373 .try_for_each(|n| n.visit(self).map(|_| ()))?;
374 self.conditional = false;
375
376 TreeNodeRecursion::Jump
377 }
378
379 // In case of non-short-circuit node continue the traversal.
380 _ => TreeNodeRecursion::Continue,
381 }
382 })
383 }
384
385 fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
386 let (down_index, sub_node_id, sub_node_is_valid) =
387 self.pop_enter_mark(node.can_normalize());
388
389 let node_id = Identifier::new(node, self.random_state).combine(sub_node_id);
390 let is_valid = C::is_valid(node) && sub_node_is_valid;
391
392 self.id_array[down_index].0 = self.up_index;
393 if is_valid && !self.controller.is_ignored(node) {
394 self.id_array[down_index].1 = Some(node_id);
395 self.node_stats
396 .entry(node_id)
397 .and_modify(|evaluation| {
398 if *evaluation == NodeEvaluation::SurelyOnce
399 || *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
400 && !self.conditional
401 {
402 *evaluation = NodeEvaluation::Common;
403 self.found_common = true;
404 }
405 })
406 .or_insert_with(|| {
407 if self.conditional {
408 NodeEvaluation::ConditionallyAtLeastOnce
409 } else {
410 NodeEvaluation::SurelyOnce
411 }
412 });
413 }
414 self.visit_stack
415 .push(VisitRecord::NodeItem(node_id, is_valid));
416 self.up_index += 1;
417
418 Ok(TreeNodeRecursion::Continue)
419 }
420}
421
422/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
423/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
424/// replaced [`TreeNode`] tree.
425struct CSERewriter<'a, 'n, N, C>
426where
427 N: NormalizeEq,
428 C: CSEController<Node = N>,
429{
430 /// statistics of [`TreeNode`]s
431 node_stats: &'a NodeStats<'n, N>,
432
433 /// cache to speed up second traversal
434 id_array: &'a IdArray<'n, N>,
435
436 /// common [`TreeNode`]s, that are replaced during the second traversal, are collected
437 /// to this map
438 common_nodes: &'a mut CommonNodes<'n, N>,
439
440 // preorder index, starts from 0.
441 down_index: usize,
442
443 controller: &'a mut C,
444}
445
446impl<N, C> TreeNodeRewriter for CSERewriter<'_, '_, N, C>
447where
448 N: TreeNode + NormalizeEq,
449 C: CSEController<Node = N>,
450{
451 type Node = N;
452
453 fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
454 self.controller.rewrite_f_down(&node);
455
456 let (up_index, node_id) = self.id_array[self.down_index];
457 self.down_index += 1;
458
459 // Handle nodes with identifiers only
460 if let Some(node_id) = node_id {
461 let evaluation = self.node_stats.get(&node_id).unwrap();
462 if *evaluation == NodeEvaluation::Common {
463 // step index to skip all sub-node (which has smaller series number).
464 while self.down_index < self.id_array.len()
465 && self.id_array[self.down_index].0 < up_index
466 {
467 self.down_index += 1;
468 }
469
470 // We *must* replace all original nodes with same `node_id`, not just the first
471 // node which is inserted into the common_nodes. This is because nodes with the same
472 // `node_id` are semantically equivalent, but not exactly the same.
473 //
474 // For example, `a + 1` and `1 + a` are semantically equivalent but not identical.
475 // In this case, we should replace the common expression `1 + a` with a new variable
476 // (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by
477 // `__common_cse_1`.
478 //
479 // The final result would be:
480 // - `__common_cse_1 as a + 1`
481 // - `__common_cse_1 as 1 + a`
482 //
483 // This way, we can efficiently handle semantically equivalent expressions without
484 // incorrectly treating them as identical.
485 let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id)
486 {
487 self.controller.rewrite(&node, alias)
488 } else {
489 let node_alias = self.controller.generate_alias();
490 let rewritten = self.controller.rewrite(&node, &node_alias);
491 self.common_nodes.insert(node_id, (node, node_alias));
492 rewritten
493 };
494
495 return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
496 }
497 }
498
499 Ok(Transformed::no(node))
500 }
501
502 fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
503 self.controller.rewrite_f_up(&node);
504
505 Ok(Transformed::no(node))
506 }
507}
508
509/// The main entry point of Common Subexpression Elimination.
510///
511/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular
512/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the
513/// [`CSE::extract_common_nodes()`] method.
514pub struct CSE<N, C: CSEController<Node = N>> {
515 random_state: RandomState,
516 phantom_data: PhantomData<N>,
517 controller: C,
518}
519
520impl<N, C> CSE<N, C>
521where
522 N: TreeNode + HashNode + Clone + NormalizeEq,
523 C: CSEController<Node = N>,
524{
525 pub fn new(controller: C) -> Self {
526 Self {
527 random_state: RandomState::new(),
528 phantom_data: PhantomData,
529 controller,
530 }
531 }
532
533 /// Add an identifier to `id_array` for every [`TreeNode`] in this tree.
534 fn node_to_id_array<'n>(
535 &self,
536 node: &'n N,
537 node_stats: &mut NodeStats<'n, N>,
538 id_array: &mut IdArray<'n, N>,
539 ) -> Result<bool> {
540 let mut visitor = CSEVisitor {
541 node_stats,
542 id_array,
543 visit_stack: vec![],
544 down_index: 0,
545 up_index: 0,
546 random_state: &self.random_state,
547 found_common: false,
548 conditional: false,
549 controller: &self.controller,
550 };
551 node.visit(&mut visitor)?;
552
553 Ok(visitor.found_common)
554 }
555
556 /// Returns the identifier list for each element in `nodes` and a flag to indicate if
557 /// rewrite phase of CSE make sense.
558 ///
559 /// Returns and array with 1 element for each input node in `nodes`
560 ///
561 /// Each element is itself the result of [`CSE::node_to_id_array`] for that node
562 /// (e.g. the identifiers for each node in the tree)
563 fn to_arrays<'n>(
564 &self,
565 nodes: &'n [N],
566 node_stats: &mut NodeStats<'n, N>,
567 ) -> Result<(bool, Vec<IdArray<'n, N>>)> {
568 let mut found_common = false;
569 nodes
570 .iter()
571 .map(|n| {
572 let mut id_array = vec![];
573 self.node_to_id_array(n, node_stats, &mut id_array)
574 .map(|fc| {
575 found_common |= fc;
576
577 id_array
578 })
579 })
580 .collect::<Result<Vec<_>>>()
581 .map(|id_arrays| (found_common, id_arrays))
582 }
583
584 /// Replace common subtrees in `node` with the corresponding temporary
585 /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`]
586 fn replace_common_node<'n>(
587 &mut self,
588 node: N,
589 id_array: &IdArray<'n, N>,
590 node_stats: &NodeStats<'n, N>,
591 common_nodes: &mut CommonNodes<'n, N>,
592 ) -> Result<N> {
593 if id_array.is_empty() {
594 Ok(Transformed::no(node))
595 } else {
596 node.rewrite(&mut CSERewriter {
597 node_stats,
598 id_array,
599 common_nodes,
600 down_index: 0,
601 controller: &mut self.controller,
602 })
603 }
604 .data()
605 }
606
607 /// Replace common subtrees in `nodes_list` with the corresponding temporary
608 /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`].
609 fn rewrite_nodes_list<'n>(
610 &mut self,
611 nodes_list: Vec<Vec<N>>,
612 arrays_list: &[Vec<IdArray<'n, N>>],
613 node_stats: &NodeStats<'n, N>,
614 common_nodes: &mut CommonNodes<'n, N>,
615 ) -> Result<Vec<Vec<N>>> {
616 nodes_list
617 .into_iter()
618 .zip(arrays_list.iter())
619 .map(|(nodes, arrays)| {
620 nodes
621 .into_iter()
622 .zip(arrays.iter())
623 .map(|(node, id_array)| {
624 self.replace_common_node(node, id_array, node_stats, common_nodes)
625 })
626 .collect::<Result<Vec<_>>>()
627 })
628 .collect::<Result<Vec<_>>>()
629 }
630
631 /// Extracts common [`TreeNode`]s and rewrites `nodes_list`.
632 ///
633 /// Returns [`FoundCommonNodes`] recording the result of the extraction.
634 pub fn extract_common_nodes(
635 &mut self,
636 nodes_list: Vec<Vec<N>>,
637 ) -> Result<FoundCommonNodes<N>> {
638 let mut found_common = false;
639 let mut node_stats = NodeStats::new();
640
641 let id_arrays_list = nodes_list
642 .iter()
643 .map(|nodes| {
644 self.to_arrays(nodes, &mut node_stats)
645 .map(|(fc, id_arrays)| {
646 found_common |= fc;
647
648 id_arrays
649 })
650 })
651 .collect::<Result<Vec<_>>>()?;
652 if found_common {
653 let mut common_nodes = CommonNodes::new();
654 let new_nodes_list = self.rewrite_nodes_list(
655 // Must clone the list of nodes as Identifiers use references to original
656 // nodes so we have to keep them intact.
657 nodes_list.clone(),
658 &id_arrays_list,
659 &node_stats,
660 &mut common_nodes,
661 )?;
662 assert!(!common_nodes.is_empty());
663
664 Ok(FoundCommonNodes::Yes {
665 common_nodes: common_nodes.into_values().collect(),
666 new_nodes_list,
667 original_nodes_list: nodes_list,
668 })
669 } else {
670 Ok(FoundCommonNodes::No {
671 original_nodes_list: nodes_list,
672 })
673 }
674 }
675}
676
677#[cfg(test)]
678mod test {
679 use crate::alias::AliasGenerator;
680 use crate::cse::{
681 CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq,
682 Normalizeable, CSE,
683 };
684 use crate::tree_node::tests::TestTreeNode;
685 use crate::Result;
686 use std::collections::HashSet;
687 use std::hash::{Hash, Hasher};
688
689 const CSE_PREFIX: &str = "__common_node";
690
691 #[derive(Clone, Copy)]
692 pub enum TestTreeNodeMask {
693 Normal,
694 NormalAndAggregates,
695 }
696
697 pub struct TestTreeNodeCSEController<'a> {
698 alias_generator: &'a AliasGenerator,
699 mask: TestTreeNodeMask,
700 }
701
702 impl<'a> TestTreeNodeCSEController<'a> {
703 fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self {
704 Self {
705 alias_generator,
706 mask,
707 }
708 }
709 }
710
711 impl CSEController for TestTreeNodeCSEController<'_> {
712 type Node = TestTreeNode<String>;
713
714 fn conditional_children(
715 _: &Self::Node,
716 ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> {
717 None
718 }
719
720 fn is_valid(_node: &Self::Node) -> bool {
721 true
722 }
723
724 fn is_ignored(&self, node: &Self::Node) -> bool {
725 let is_leaf = node.is_leaf();
726 let is_aggr = node.data == "avg" || node.data == "sum";
727
728 match self.mask {
729 TestTreeNodeMask::Normal => is_leaf || is_aggr,
730 TestTreeNodeMask::NormalAndAggregates => is_leaf,
731 }
732 }
733
734 fn generate_alias(&self) -> String {
735 self.alias_generator.next(CSE_PREFIX)
736 }
737
738 fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
739 TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
740 }
741 }
742
743 impl HashNode for TestTreeNode<String> {
744 fn hash_node<H: Hasher>(&self, state: &mut H) {
745 self.data.hash(state);
746 }
747 }
748
749 impl Normalizeable for TestTreeNode<String> {
750 fn can_normalize(&self) -> bool {
751 false
752 }
753 }
754
755 impl NormalizeEq for TestTreeNode<String> {
756 fn normalize_eq(&self, other: &Self) -> bool {
757 self == other
758 }
759 }
760
761 #[test]
762 fn id_array_visitor() -> Result<()> {
763 let alias_generator = AliasGenerator::new();
764 let eliminator = CSE::new(TestTreeNodeCSEController::new(
765 &alias_generator,
766 TestTreeNodeMask::Normal,
767 ));
768
769 let a_plus_1 = TestTreeNode::new(
770 vec![
771 TestTreeNode::new_leaf("a".to_string()),
772 TestTreeNode::new_leaf("1".to_string()),
773 ],
774 "+".to_string(),
775 );
776 let avg_c = TestTreeNode::new(
777 vec![TestTreeNode::new_leaf("c".to_string())],
778 "avg".to_string(),
779 );
780 let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1], "sum".to_string());
781 let sum_a_plus_1_minus_avg_c =
782 TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string());
783 let root = TestTreeNode::new(
784 vec![
785 sum_a_plus_1_minus_avg_c,
786 TestTreeNode::new_leaf("2".to_string()),
787 ],
788 "*".to_string(),
789 );
790
791 let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else {
792 panic!("Cannot extract subtree references")
793 };
794 let [sum_a_plus_1, avg_c] = sum_a_plus_1_minus_avg_c.children.as_slice() else {
795 panic!("Cannot extract subtree references")
796 };
797 let [a_plus_1] = sum_a_plus_1.children.as_slice() else {
798 panic!("Cannot extract subtree references")
799 };
800
801 // skip aggregates
802 let mut id_array = vec![];
803 eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
804
805 // Collect distinct hashes and set them to 0 in `id_array`
806 fn collect_hashes(
807 id_array: &mut IdArray<'_, TestTreeNode<String>>,
808 ) -> HashSet<u64> {
809 id_array
810 .iter_mut()
811 .flat_map(|(_, id_option)| {
812 id_option.as_mut().map(|node_id| {
813 let hash = node_id.hash;
814 node_id.hash = 0;
815 hash
816 })
817 })
818 .collect::<HashSet<_>>()
819 }
820
821 let hashes = collect_hashes(&mut id_array);
822 assert_eq!(hashes.len(), 3);
823
824 let expected = vec![
825 (
826 8,
827 Some(Identifier {
828 hash: 0,
829 node: &root,
830 }),
831 ),
832 (
833 6,
834 Some(Identifier {
835 hash: 0,
836 node: sum_a_plus_1_minus_avg_c,
837 }),
838 ),
839 (3, None),
840 (
841 2,
842 Some(Identifier {
843 hash: 0,
844 node: a_plus_1,
845 }),
846 ),
847 (0, None),
848 (1, None),
849 (5, None),
850 (4, None),
851 (7, None),
852 ];
853 assert_eq!(expected, id_array);
854
855 // include aggregates
856 let eliminator = CSE::new(TestTreeNodeCSEController::new(
857 &alias_generator,
858 TestTreeNodeMask::NormalAndAggregates,
859 ));
860
861 let mut id_array = vec![];
862 eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
863
864 let hashes = collect_hashes(&mut id_array);
865 assert_eq!(hashes.len(), 5);
866
867 let expected = vec![
868 (
869 8,
870 Some(Identifier {
871 hash: 0,
872 node: &root,
873 }),
874 ),
875 (
876 6,
877 Some(Identifier {
878 hash: 0,
879 node: sum_a_plus_1_minus_avg_c,
880 }),
881 ),
882 (
883 3,
884 Some(Identifier {
885 hash: 0,
886 node: sum_a_plus_1,
887 }),
888 ),
889 (
890 2,
891 Some(Identifier {
892 hash: 0,
893 node: a_plus_1,
894 }),
895 ),
896 (0, None),
897 (1, None),
898 (
899 5,
900 Some(Identifier {
901 hash: 0,
902 node: avg_c,
903 }),
904 ),
905 (4, None),
906 (7, None),
907 ];
908 assert_eq!(expected, id_array);
909
910 Ok(())
911 }
912}