polars_plan/dsl/
python_udf.rs

1use std::io::Cursor;
2use std::sync::Arc;
3
4use polars_core::datatypes::{DataType, Field};
5use polars_core::error::*;
6use polars_core::frame::column::Column;
7use polars_core::frame::DataFrame;
8use polars_core::schema::Schema;
9use pyo3::prelude::*;
10use pyo3::pybacked::PyBackedBytes;
11use pyo3::types::PyBytes;
12
13use super::expr_dyn_fn::*;
14use crate::constants::MAP_LIST_NAME;
15use crate::prelude::*;
16
17// Will be overwritten on Python Polars start up.
18pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
19    fn(s: Column, lambda: &PyObject) -> PolarsResult<Column>,
20> = None;
21pub static mut CALL_DF_UDF_PYTHON: Option<
22    fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
23> = None;
24
25pub use polars_utils::python_function::{
26    PythonFunction, PYTHON3_VERSION, PYTHON_SERDE_MAGIC_BYTE_MARK,
27};
28
29pub struct PythonUdfExpression {
30    python_function: PyObject,
31    output_type: Option<DataType>,
32    is_elementwise: bool,
33    returns_scalar: bool,
34}
35
36impl PythonUdfExpression {
37    pub fn new(
38        lambda: PyObject,
39        output_type: Option<DataType>,
40        is_elementwise: bool,
41        returns_scalar: bool,
42    ) -> Self {
43        Self {
44            python_function: lambda,
45            output_type,
46            is_elementwise,
47            returns_scalar,
48        }
49    }
50
51    #[cfg(feature = "serde")]
52    pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
53        // Handle byte mark
54
55        use polars_utils::pl_serialize;
56        debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
57        let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
58
59        // Handle pickle metadata
60        let use_cloudpickle = buf[0];
61        if use_cloudpickle != 0 {
62            let ser_py_version = &buf[1..3];
63            let cur_py_version = *PYTHON3_VERSION;
64            polars_ensure!(
65                ser_py_version == cur_py_version,
66                InvalidOperation:
67                "current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
68                (3, cur_py_version[0], cur_py_version[1]),
69                (3, ser_py_version[0], ser_py_version[1] )
70            );
71        }
72        let buf = &buf[3..];
73
74        // Load UDF metadata
75        let mut reader = Cursor::new(buf);
76        let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) =
77            pl_serialize::deserialize_from_reader(&mut reader)?;
78
79        let remainder = &buf[reader.position() as usize..];
80
81        // Load UDF
82        Python::with_gil(|py| {
83            let pickle = PyModule::import(py, "pickle")
84                .expect("unable to import 'pickle'")
85                .getattr("loads")
86                .unwrap();
87            let arg = (PyBytes::new(py, remainder),);
88            let python_function = pickle.call1(arg).map_err(from_pyerr)?;
89            Ok(Arc::new(Self::new(
90                python_function.into(),
91                output_type,
92                is_elementwise,
93                returns_scalar,
94            )) as Arc<dyn ColumnsUdf>)
95        })
96    }
97}
98
99fn from_pyerr(e: PyErr) -> PolarsError {
100    PolarsError::ComputeError(format!("error raised in python: {e}").into())
101}
102
103impl DataFrameUdf for polars_utils::python_function::PythonFunction {
104    fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
105        let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
106        func(df, &self.0)
107    }
108}
109
110impl ColumnsUdf for PythonUdfExpression {
111    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
112        let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
113
114        let output_type = self
115            .output_type
116            .clone()
117            .unwrap_or_else(|| DataType::Unknown(Default::default()));
118        let mut out = func(s[0].clone(), &self.python_function)?;
119        if !matches!(output_type, DataType::Unknown(_)) {
120            let must_cast = out.dtype().matches_schema_type(&output_type).map_err(|_| {
121                polars_err!(
122                    SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
123                    output_type, out.dtype(),
124                )
125            })?;
126            if must_cast {
127                out = out.cast(&output_type)?;
128            }
129        }
130
131        Ok(Some(out))
132    }
133
134    #[cfg(feature = "serde")]
135    fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
136        // Write byte marks
137
138        use polars_utils::pl_serialize;
139        buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
140
141        Python::with_gil(|py| {
142            // Try pickle to serialize the UDF, otherwise fall back to cloudpickle.
143            let pickle = PyModule::import(py, "pickle")
144                .expect("unable to import 'pickle'")
145                .getattr("dumps")
146                .unwrap();
147            let pickle_result = pickle.call1((self.python_function.clone_ref(py),));
148            let (dumped, use_cloudpickle) = match pickle_result {
149                Ok(dumped) => (dumped, false),
150                Err(_) => {
151                    let cloudpickle = PyModule::import(py, "cloudpickle")
152                        .map_err(from_pyerr)?
153                        .getattr("dumps")
154                        .unwrap();
155                    let dumped = cloudpickle
156                        .call1((self.python_function.clone_ref(py),))
157                        .map_err(from_pyerr)?;
158                    (dumped, true)
159                },
160            };
161
162            // Write pickle metadata
163            buf.push(use_cloudpickle as u8);
164            buf.extend_from_slice(&*PYTHON3_VERSION);
165
166            // Write UDF metadata
167            pl_serialize::serialize_into_writer(
168                &mut *buf,
169                &(
170                    self.output_type.clone(),
171                    self.is_elementwise,
172                    self.returns_scalar,
173                ),
174            )?;
175
176            // Write UDF
177            let dumped = dumped.extract::<PyBackedBytes>().unwrap();
178            buf.extend_from_slice(&dumped);
179            Ok(())
180        })
181    }
182}
183
184/// Serializable version of [`GetOutput`] for Python UDFs.
185pub struct PythonGetOutput {
186    return_dtype: Option<DataType>,
187}
188
189impl PythonGetOutput {
190    pub fn new(return_dtype: Option<DataType>) -> Self {
191        Self { return_dtype }
192    }
193
194    #[cfg(feature = "serde")]
195    pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn FunctionOutputField>> {
196        // Skip header.
197
198        use polars_utils::pl_serialize;
199        debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
200        let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
201
202        let mut reader = Cursor::new(buf);
203        let return_dtype: Option<DataType> = pl_serialize::deserialize_from_reader(&mut reader)?;
204
205        Ok(Arc::new(Self::new(return_dtype)) as Arc<dyn FunctionOutputField>)
206    }
207}
208
209impl FunctionOutputField for PythonGetOutput {
210    fn get_field(
211        &self,
212        _input_schema: &Schema,
213        _cntxt: Context,
214        fields: &[Field],
215    ) -> PolarsResult<Field> {
216        // Take the name of first field, just like [`GetOutput::map_field`].
217        let name = fields[0].name();
218        let return_dtype = match self.return_dtype {
219            Some(ref dtype) => dtype.clone(),
220            None => DataType::Unknown(Default::default()),
221        };
222        Ok(Field::new(name.clone(), return_dtype))
223    }
224
225    #[cfg(feature = "serde")]
226    fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
227        use polars_utils::pl_serialize;
228
229        buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
230        pl_serialize::serialize_into_writer(&mut *buf, &self.return_dtype)
231    }
232}
233
234impl Expr {
235    pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr {
236        let (collect_groups, name) = if agg_list {
237            (ApplyOptions::ApplyList, MAP_LIST_NAME)
238        } else if func.is_elementwise {
239            (ApplyOptions::ElementWise, "python_udf")
240        } else {
241            (ApplyOptions::GroupWise, "python_udf")
242        };
243
244        let returns_scalar = func.returns_scalar;
245        let return_dtype = func.output_type.clone();
246
247        let output_field = PythonGetOutput::new(return_dtype);
248        let output_type = SpecialEq::new(Arc::new(output_field) as Arc<dyn FunctionOutputField>);
249
250        let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
251        if returns_scalar {
252            flags |= FunctionFlags::RETURNS_SCALAR;
253        }
254
255        Expr::AnonymousFunction {
256            input: vec![self],
257            function: new_column_udf(func),
258            output_type,
259            options: FunctionOptions {
260                collect_groups,
261                fmt_str: name,
262                flags,
263                ..Default::default()
264            },
265        }
266    }
267}