1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use polars_utils::unitvec;

use super::*;
use crate::prelude::*;

#[derive(Copy, Clone, Debug)]
pub struct IRNode {
    node: Node,
}

impl IRNode {
    pub fn new(node: Node) -> Self {
        Self { node }
    }

    pub fn node(&self) -> Node {
        self.node
    }

    pub fn replace_node(&mut self, node: Node) {
        self.node = node;
    }

    /// Replace the current `Node` with a new `IR`.
    pub fn replace(&mut self, ae: IR, arena: &mut Arena<IR>) {
        let node = self.node;
        arena.replace(node, ae);
    }

    pub fn to_alp<'a>(&self, arena: &'a Arena<IR>) -> &'a IR {
        arena.get(self.node)
    }

    pub fn to_alp_mut<'a>(&mut self, arena: &'a mut Arena<IR>) -> &'a mut IR {
        arena.get_mut(self.node)
    }

    pub fn assign(&mut self, ir_node: IR, arena: &mut Arena<IR>) {
        let node = arena.add(ir_node);
        self.node = node;
    }
}

pub type IRNodeArena = (Arena<IR>, Arena<AExpr>);

impl TreeWalker for IRNode {
    type Arena = IRNodeArena;

    fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
        &self,
        op: &mut F,
        arena: &Self::Arena,
    ) -> PolarsResult<VisitRecursion> {
        let mut scratch = unitvec![];

        self.to_alp(&arena.0).copy_inputs(&mut scratch);
        for &node in scratch.as_slice() {
            let lp_node = IRNode::new(node);
            match op(&lp_node, arena)? {
                // let the recursion continue
                VisitRecursion::Continue | VisitRecursion::Skip => {},
                // early stop
                VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
            }
        }
        Ok(VisitRecursion::Continue)
    }

    fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
        self,
        op: &mut F,
        arena: &mut Self::Arena,
    ) -> PolarsResult<Self> {
        let mut inputs = vec![];
        let mut exprs = vec![];

        let lp = arena.0.take(self.node);
        lp.copy_inputs(&mut inputs);
        lp.copy_exprs(&mut exprs);

        // rewrite the nodes
        for node in &mut inputs {
            let lp_node = IRNode::new(*node);
            *node = op(lp_node, arena)?.node;
        }

        let lp = lp.with_exprs_and_input(exprs, inputs);
        arena.0.replace(self.node, lp);
        Ok(self)
    }
}

#[cfg(feature = "cse")]
pub(crate) fn with_ir_arena<F: FnOnce(&mut IRNodeArena) -> T, T>(
    lp_arena: &mut Arena<IR>,
    expr_arena: &mut Arena<AExpr>,
    func: F,
) -> T {
    try_with_ir_arena(lp_arena, expr_arena, |a| Ok(func(a))).unwrap()
}

#[cfg(feature = "cse")]
pub(crate) fn try_with_ir_arena<F: FnOnce(&mut IRNodeArena) -> PolarsResult<T>, T>(
    lp_arena: &mut Arena<IR>,
    expr_arena: &mut Arena<AExpr>,
    func: F,
) -> PolarsResult<T> {
    let owned_lp_arena = std::mem::take(lp_arena);
    let owned_expr_arena = std::mem::take(expr_arena);

    let mut arena = (owned_lp_arena, owned_expr_arena);
    let out = func(&mut arena);
    std::mem::swap(lp_arena, &mut arena.0);
    std::mem::swap(expr_arena, &mut arena.1);
    out
}