polars_plan/plans/python/
pyarrow.rs

1use std::fmt::Write;
2
3use polars_core::datatypes::AnyValue;
4use polars_core::prelude::{TimeUnit, TimeZone};
5
6use crate::prelude::*;
7
8#[derive(Default, Copy, Clone)]
9pub struct PyarrowArgs {
10    // pyarrow doesn't allow `filter([True, False])`
11    // but does allow `filter(field("a").isin([True, False]))`
12    allow_literal_series: bool,
13}
14
15fn to_py_datetime(v: i64, tu: &TimeUnit, tz: Option<&TimeZone>) -> String {
16    // note: `to_py_datetime` and the `Datetime`
17    // dtype have to be in-scope on the python side
18    match tz {
19        None => format!("to_py_datetime({},'{}')", v, tu.to_ascii()),
20        Some(tz) => format!("to_py_datetime({},'{}',{})", v, tu.to_ascii(), tz),
21    }
22}
23
24// convert to a pyarrow expression that can be evaluated with pythons eval
25pub fn predicate_to_pa(
26    predicate: Node,
27    expr_arena: &Arena<AExpr>,
28    args: PyarrowArgs,
29) -> Option<String> {
30    match expr_arena.get(predicate) {
31        AExpr::BinaryExpr { left, right, op } => {
32            if op.is_comparison() {
33                let left = predicate_to_pa(*left, expr_arena, args)?;
34                let right = predicate_to_pa(*right, expr_arena, args)?;
35                Some(format!("({left} {op} {right})"))
36            } else {
37                None
38            }
39        },
40        AExpr::Column(name) => Some(format!("pa.compute.field('{}')", name)),
41        AExpr::Literal(LiteralValue::Series(s)) => {
42            if !args.allow_literal_series || s.is_empty() || s.len() > 100 {
43                None
44            } else {
45                let mut list_repr = String::with_capacity(s.len() * 5);
46                list_repr.push('[');
47                for av in s.rechunk().iter() {
48                    if let AnyValue::Boolean(v) = av {
49                        let s = if v { "True" } else { "False" };
50                        write!(list_repr, "{},", s).unwrap();
51                    } else if let AnyValue::Datetime(v, tu, tz) = av {
52                        let dtm = to_py_datetime(v, &tu, tz);
53                        write!(list_repr, "{dtm},").unwrap();
54                    } else if let AnyValue::Date(v) = av {
55                        write!(list_repr, "to_py_date({v}),").unwrap();
56                    } else {
57                        write!(list_repr, "{av},").unwrap();
58                    }
59                }
60                // pop last comma
61                list_repr.pop();
62                list_repr.push(']');
63                Some(list_repr)
64            }
65        },
66        AExpr::Literal(lv) => {
67            let av = lv.to_any_value()?;
68            let dtype = av.dtype();
69            match av.as_borrowed() {
70                AnyValue::String(s) => Some(format!("'{s}'")),
71                AnyValue::Boolean(val) => {
72                    // python bools are capitalized
73                    if val {
74                        Some("pa.compute.scalar(True)".to_string())
75                    } else {
76                        Some("pa.compute.scalar(False)".to_string())
77                    }
78                },
79                #[cfg(feature = "dtype-date")]
80                AnyValue::Date(v) => {
81                    // the function `to_py_date` and the `Date`
82                    // dtype have to be in scope on the python side
83                    Some(format!("to_py_date({v})"))
84                },
85                #[cfg(feature = "dtype-datetime")]
86                AnyValue::Datetime(v, tu, tz) => Some(to_py_datetime(v, &tu, tz)),
87                // Activate once pyarrow supports them
88                // #[cfg(feature = "dtype-time")]
89                // AnyValue::Time(v) => {
90                //     // the function `to_py_time` has to be in scope
91                //     // on the python side
92                //     Some(format!("to_py_time(value={v})"))
93                // }
94                // #[cfg(feature = "dtype-duration")]
95                // AnyValue::Duration(v, tu) => {
96                //     // the function `to_py_timedelta` has to be in scope
97                //     // on the python side
98                //     Some(format!(
99                //         "to_py_timedelta(value={}, tu='{}')",
100                //         v,
101                //         tu.to_ascii()
102                //     ))
103                // }
104                av => {
105                    if dtype.is_float() {
106                        let val = av.extract::<f64>()?;
107                        Some(format!("{val}"))
108                    } else if dtype.is_integer() {
109                        let val = av.extract::<i64>()?;
110                        Some(format!("{val}"))
111                    } else {
112                        None
113                    }
114                },
115            }
116        },
117        #[cfg(feature = "is_in")]
118        AExpr::Function {
119            function: FunctionExpr::Boolean(BooleanFunction::IsIn),
120            input,
121            ..
122        } => {
123            let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;
124            let mut args = args;
125            args.allow_literal_series = true;
126            let values = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;
127
128            Some(format!("({col}).isin({values})"))
129        },
130        #[cfg(feature = "is_between")]
131        AExpr::Function {
132            function: FunctionExpr::Boolean(BooleanFunction::IsBetween { closed }),
133            input,
134            ..
135        } => {
136            if !matches!(expr_arena.get(input.first()?.node()), AExpr::Column(_)) {
137                None
138            } else {
139                let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;
140                let left_cmp_op = match closed {
141                    ClosedInterval::None | ClosedInterval::Right => Operator::Gt,
142                    ClosedInterval::Both | ClosedInterval::Left => Operator::GtEq,
143                };
144                let right_cmp_op = match closed {
145                    ClosedInterval::None | ClosedInterval::Left => Operator::Lt,
146                    ClosedInterval::Both | ClosedInterval::Right => Operator::LtEq,
147                };
148
149                let lower = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;
150                let upper = predicate_to_pa(input.get(2)?.node(), expr_arena, args)?;
151
152                Some(format!(
153                    "(({col} {left_cmp_op} {lower}) & ({col} {right_cmp_op} {upper}))"
154                ))
155            }
156        },
157        AExpr::Function {
158            function, input, ..
159        } => {
160            let input = input.first().unwrap().node();
161            let input = predicate_to_pa(input, expr_arena, args)?;
162
163            match function {
164                FunctionExpr::Boolean(BooleanFunction::Not) => Some(format!("~({input})")),
165                FunctionExpr::Boolean(BooleanFunction::IsNull) => {
166                    Some(format!("({input}).is_null()"))
167                },
168                FunctionExpr::Boolean(BooleanFunction::IsNotNull) => {
169                    Some(format!("~({input}).is_null()"))
170                },
171                _ => None,
172            }
173        },
174        _ => None,
175    }
176}