polars_plan/dsl/
python_udf.rs1use std::io::Cursor;
2use std::sync::Arc;
3
4use polars_core::datatypes::{DataType, Field};
5use polars_core::error::*;
6use polars_core::frame::column::Column;
7use polars_core::frame::DataFrame;
8use polars_core::schema::Schema;
9use pyo3::prelude::*;
10use pyo3::pybacked::PyBackedBytes;
11use pyo3::types::PyBytes;
12
13use super::expr_dyn_fn::*;
14use crate::constants::MAP_LIST_NAME;
15use crate::prelude::*;
16
17pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
19 fn(s: Column, lambda: &PyObject) -> PolarsResult<Column>,
20> = None;
21pub static mut CALL_DF_UDF_PYTHON: Option<
22 fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
23> = None;
24
25pub use polars_utils::python_function::{
26 PythonFunction, PYTHON3_VERSION, PYTHON_SERDE_MAGIC_BYTE_MARK,
27};
28
29pub struct PythonUdfExpression {
30 python_function: PyObject,
31 output_type: Option<DataType>,
32 is_elementwise: bool,
33 returns_scalar: bool,
34}
35
36impl PythonUdfExpression {
37 pub fn new(
38 lambda: PyObject,
39 output_type: Option<DataType>,
40 is_elementwise: bool,
41 returns_scalar: bool,
42 ) -> Self {
43 Self {
44 python_function: lambda,
45 output_type,
46 is_elementwise,
47 returns_scalar,
48 }
49 }
50
51 #[cfg(feature = "serde")]
52 pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
53 use polars_utils::pl_serialize;
56 debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
57 let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
58
59 let use_cloudpickle = buf[0];
61 if use_cloudpickle != 0 {
62 let ser_py_version = &buf[1..3];
63 let cur_py_version = *PYTHON3_VERSION;
64 polars_ensure!(
65 ser_py_version == cur_py_version,
66 InvalidOperation:
67 "current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
68 (3, cur_py_version[0], cur_py_version[1]),
69 (3, ser_py_version[0], ser_py_version[1] )
70 );
71 }
72 let buf = &buf[3..];
73
74 let mut reader = Cursor::new(buf);
76 let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) =
77 pl_serialize::deserialize_from_reader(&mut reader)?;
78
79 let remainder = &buf[reader.position() as usize..];
80
81 Python::with_gil(|py| {
83 let pickle = PyModule::import(py, "pickle")
84 .expect("unable to import 'pickle'")
85 .getattr("loads")
86 .unwrap();
87 let arg = (PyBytes::new(py, remainder),);
88 let python_function = pickle.call1(arg).map_err(from_pyerr)?;
89 Ok(Arc::new(Self::new(
90 python_function.into(),
91 output_type,
92 is_elementwise,
93 returns_scalar,
94 )) as Arc<dyn ColumnsUdf>)
95 })
96 }
97}
98
99fn from_pyerr(e: PyErr) -> PolarsError {
100 PolarsError::ComputeError(format!("error raised in python: {e}").into())
101}
102
103impl DataFrameUdf for polars_utils::python_function::PythonFunction {
104 fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
105 let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
106 func(df, &self.0)
107 }
108}
109
110impl ColumnsUdf for PythonUdfExpression {
111 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
112 let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
113
114 let output_type = self
115 .output_type
116 .clone()
117 .unwrap_or_else(|| DataType::Unknown(Default::default()));
118 let mut out = func(s[0].clone(), &self.python_function)?;
119 if !matches!(output_type, DataType::Unknown(_)) {
120 let must_cast = out.dtype().matches_schema_type(&output_type).map_err(|_| {
121 polars_err!(
122 SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
123 output_type, out.dtype(),
124 )
125 })?;
126 if must_cast {
127 out = out.cast(&output_type)?;
128 }
129 }
130
131 Ok(Some(out))
132 }
133
134 #[cfg(feature = "serde")]
135 fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
136 use polars_utils::pl_serialize;
139 buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
140
141 Python::with_gil(|py| {
142 let pickle = PyModule::import(py, "pickle")
144 .expect("unable to import 'pickle'")
145 .getattr("dumps")
146 .unwrap();
147 let pickle_result = pickle.call1((self.python_function.clone_ref(py),));
148 let (dumped, use_cloudpickle) = match pickle_result {
149 Ok(dumped) => (dumped, false),
150 Err(_) => {
151 let cloudpickle = PyModule::import(py, "cloudpickle")
152 .map_err(from_pyerr)?
153 .getattr("dumps")
154 .unwrap();
155 let dumped = cloudpickle
156 .call1((self.python_function.clone_ref(py),))
157 .map_err(from_pyerr)?;
158 (dumped, true)
159 },
160 };
161
162 buf.push(use_cloudpickle as u8);
164 buf.extend_from_slice(&*PYTHON3_VERSION);
165
166 pl_serialize::serialize_into_writer(
168 &mut *buf,
169 &(
170 self.output_type.clone(),
171 self.is_elementwise,
172 self.returns_scalar,
173 ),
174 )?;
175
176 let dumped = dumped.extract::<PyBackedBytes>().unwrap();
178 buf.extend_from_slice(&dumped);
179 Ok(())
180 })
181 }
182}
183
184pub struct PythonGetOutput {
186 return_dtype: Option<DataType>,
187}
188
189impl PythonGetOutput {
190 pub fn new(return_dtype: Option<DataType>) -> Self {
191 Self { return_dtype }
192 }
193
194 #[cfg(feature = "serde")]
195 pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn FunctionOutputField>> {
196 use polars_utils::pl_serialize;
199 debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
200 let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
201
202 let mut reader = Cursor::new(buf);
203 let return_dtype: Option<DataType> = pl_serialize::deserialize_from_reader(&mut reader)?;
204
205 Ok(Arc::new(Self::new(return_dtype)) as Arc<dyn FunctionOutputField>)
206 }
207}
208
209impl FunctionOutputField for PythonGetOutput {
210 fn get_field(
211 &self,
212 _input_schema: &Schema,
213 _cntxt: Context,
214 fields: &[Field],
215 ) -> PolarsResult<Field> {
216 let name = fields[0].name();
218 let return_dtype = match self.return_dtype {
219 Some(ref dtype) => dtype.clone(),
220 None => DataType::Unknown(Default::default()),
221 };
222 Ok(Field::new(name.clone(), return_dtype))
223 }
224
225 #[cfg(feature = "serde")]
226 fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
227 use polars_utils::pl_serialize;
228
229 buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
230 pl_serialize::serialize_into_writer(&mut *buf, &self.return_dtype)
231 }
232}
233
234impl Expr {
235 pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr {
236 let (collect_groups, name) = if agg_list {
237 (ApplyOptions::ApplyList, MAP_LIST_NAME)
238 } else if func.is_elementwise {
239 (ApplyOptions::ElementWise, "python_udf")
240 } else {
241 (ApplyOptions::GroupWise, "python_udf")
242 };
243
244 let returns_scalar = func.returns_scalar;
245 let return_dtype = func.output_type.clone();
246
247 let output_field = PythonGetOutput::new(return_dtype);
248 let output_type = SpecialEq::new(Arc::new(output_field) as Arc<dyn FunctionOutputField>);
249
250 let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
251 if returns_scalar {
252 flags |= FunctionFlags::RETURNS_SCALAR;
253 }
254
255 Expr::AnonymousFunction {
256 input: vec![self],
257 function: new_column_udf(func),
258 output_type,
259 options: FunctionOptions {
260 collect_groups,
261 fmt_str: name,
262 flags,
263 ..Default::default()
264 },
265 }
266 }
267}