polars_plan/plans/visitor/
lp.rs

1use polars_utils::unitvec;
2
3use super::*;
4use crate::prelude::*;
5
6#[derive(Copy, Clone, Debug)]
7pub struct IRNode {
8    node: Node,
9}
10
11impl IRNode {
12    pub fn new(node: Node) -> Self {
13        Self { node }
14    }
15
16    pub fn node(&self) -> Node {
17        self.node
18    }
19
20    pub fn replace_node(&mut self, node: Node) {
21        self.node = node;
22    }
23
24    /// Replace the current `Node` with a new `IR`.
25    pub fn replace(&mut self, ae: IR, arena: &mut Arena<IR>) {
26        let node = self.node;
27        arena.replace(node, ae);
28    }
29
30    pub fn to_alp<'a>(&self, arena: &'a Arena<IR>) -> &'a IR {
31        arena.get(self.node)
32    }
33
34    pub fn to_alp_mut<'a>(&mut self, arena: &'a mut Arena<IR>) -> &'a mut IR {
35        arena.get_mut(self.node)
36    }
37
38    pub fn assign(&mut self, ir_node: IR, arena: &mut Arena<IR>) {
39        let node = arena.add(ir_node);
40        self.node = node;
41    }
42}
43
44pub type IRNodeArena = (Arena<IR>, Arena<AExpr>);
45
46impl TreeWalker for IRNode {
47    type Arena = IRNodeArena;
48
49    fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
50        &self,
51        op: &mut F,
52        arena: &Self::Arena,
53    ) -> PolarsResult<VisitRecursion> {
54        let mut scratch = unitvec![];
55
56        self.to_alp(&arena.0).copy_inputs(&mut scratch);
57        for &node in scratch.as_slice() {
58            let lp_node = IRNode::new(node);
59            match op(&lp_node, arena)? {
60                // let the recursion continue
61                VisitRecursion::Continue | VisitRecursion::Skip => {},
62                // early stop
63                VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
64            }
65        }
66        Ok(VisitRecursion::Continue)
67    }
68
69    fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
70        self,
71        op: &mut F,
72        arena: &mut Self::Arena,
73    ) -> PolarsResult<Self> {
74        let mut inputs = vec![];
75        let mut exprs = vec![];
76
77        let lp = arena.0.take(self.node);
78        lp.copy_inputs(&mut inputs);
79        lp.copy_exprs(&mut exprs);
80
81        // rewrite the nodes
82        for node in &mut inputs {
83            let lp_node = IRNode::new(*node);
84            *node = op(lp_node, arena)?.node;
85        }
86
87        let lp = lp.with_exprs_and_input(exprs, inputs);
88        arena.0.replace(self.node, lp);
89        Ok(self)
90    }
91}
92
93#[cfg(feature = "cse")]
94pub(crate) fn with_ir_arena<F: FnOnce(&mut IRNodeArena) -> T, T>(
95    lp_arena: &mut Arena<IR>,
96    expr_arena: &mut Arena<AExpr>,
97    func: F,
98) -> T {
99    try_with_ir_arena(lp_arena, expr_arena, |a| Ok(func(a))).unwrap()
100}
101
102#[cfg(feature = "cse")]
103pub(crate) fn try_with_ir_arena<F: FnOnce(&mut IRNodeArena) -> PolarsResult<T>, T>(
104    lp_arena: &mut Arena<IR>,
105    expr_arena: &mut Arena<AExpr>,
106    func: F,
107) -> PolarsResult<T> {
108    let owned_lp_arena = std::mem::take(lp_arena);
109    let owned_expr_arena = std::mem::take(expr_arena);
110
111    let mut arena = (owned_lp_arena, owned_expr_arena);
112    let out = func(&mut arena);
113    std::mem::swap(lp_arena, &mut arena.0);
114    std::mem::swap(expr_arena, &mut arena.1);
115    out
116}