polars_plan/dsl/
expr_dyn_fn.rs

1use std::fmt::Formatter;
2use std::ops::Deref;
3use std::sync::Arc;
4
5#[cfg(feature = "serde")]
6use polars_utils::pl_serialize::deserialize_map_bytes;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9
10use super::*;
11
12/// A wrapper trait for any closure `Fn(Vec<Series>) -> PolarsResult<Series>`
13pub trait ColumnsUdf: Send + Sync {
14    fn as_any(&self) -> &dyn std::any::Any {
15        unimplemented!("as_any not implemented for this 'opaque' function")
16    }
17
18    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>>;
19
20    fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
21        polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
22    }
23}
24
25#[cfg(feature = "serde")]
26impl Serialize for SpecialEq<Arc<dyn ColumnsUdf>> {
27    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
28    where
29        S: Serializer,
30    {
31        use serde::ser::Error;
32        let mut buf = vec![];
33        self.0
34            .try_serialize(&mut buf)
35            .map_err(|e| S::Error::custom(format!("{e}")))?;
36        serializer.serialize_bytes(&buf)
37    }
38}
39
40#[cfg(feature = "serde")]
41impl<T: Serialize + Clone> Serialize for LazySerde<T> {
42    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
43    where
44        S: Serializer,
45    {
46        match self {
47            Self::Deserialized(t) => t.serialize(serializer),
48            Self::Bytes(b) => b.serialize(serializer),
49        }
50    }
51}
52
53#[cfg(feature = "serde")]
54impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde<T> {
55    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
56    where
57        D: Deserializer<'a>,
58    {
59        let buf = bytes::Bytes::deserialize(deserializer)?;
60        Ok(Self::Bytes(buf))
61    }
62}
63
64#[cfg(feature = "serde")]
65// impl<T: Deserialize> Deserialize for crate::dsl::expr::LazySerde<T> {
66impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn ColumnsUdf>> {
67    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
68    where
69        D: Deserializer<'a>,
70    {
71        use serde::de::Error;
72        #[cfg(feature = "python")]
73        {
74            deserialize_map_bytes(deserializer, &mut |buf| {
75                if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) {
76                    let udf = python_udf::PythonUdfExpression::try_deserialize(&buf)
77                        .map_err(|e| D::Error::custom(format!("{e}")))?;
78                    Ok(SpecialEq::new(udf))
79                } else {
80                    Err(D::Error::custom(
81                        "deserialization not supported for this 'opaque' function",
82                    ))
83                }
84            })?
85        }
86        #[cfg(not(feature = "python"))]
87        {
88            _ = deserializer;
89
90            Err(D::Error::custom(
91                "deserialization not supported for this 'opaque' function",
92            ))
93        }
94    }
95}
96
97impl<F> ColumnsUdf for F
98where
99    F: Fn(&mut [Column]) -> PolarsResult<Option<Column>> + Send + Sync,
100{
101    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
102        self(s)
103    }
104}
105
106impl Debug for dyn ColumnsUdf {
107    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108        write!(f, "ColumnUdf")
109    }
110}
111
112/// A wrapper trait for any binary closure `Fn(Column, Column) -> PolarsResult<Column>`
113pub trait ColumnBinaryUdf: Send + Sync {
114    fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column>;
115}
116
117impl<F> ColumnBinaryUdf for F
118where
119    F: Fn(Column, Column) -> PolarsResult<Column> + Send + Sync,
120{
121    fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column> {
122        self(a, b)
123    }
124}
125
126impl Debug for dyn ColumnBinaryUdf {
127    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128        write!(f, "ColumnBinaryUdf")
129    }
130}
131
132impl Default for SpecialEq<Arc<dyn ColumnBinaryUdf>> {
133    fn default() -> Self {
134        panic!("implementation error");
135    }
136}
137
138impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {
139    fn default() -> Self {
140        let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None;
141        SpecialEq::new(Arc::new(output_field))
142    }
143}
144
145pub trait RenameAliasFn: Send + Sync {
146    fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr>;
147
148    fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
149        polars_bail!(ComputeError: "serialization not supported for this renaming function")
150    }
151}
152
153impl<F> RenameAliasFn for F
154where
155    F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync,
156{
157    fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
158        self(name)
159    }
160}
161
162impl Debug for dyn RenameAliasFn {
163    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164        write!(f, "RenameAliasFn")
165    }
166}
167
168#[derive(Clone)]
169/// Wrapper type that has special equality properties
170/// depending on the inner type specialization
171pub struct SpecialEq<T>(T);
172
173#[cfg(feature = "serde")]
174impl Serialize for SpecialEq<Series> {
175    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
176    where
177        S: Serializer,
178    {
179        self.0.serialize(serializer)
180    }
181}
182
183#[cfg(feature = "serde")]
184impl<'a> Deserialize<'a> for SpecialEq<Series> {
185    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
186    where
187        D: Deserializer<'a>,
188    {
189        let t = Series::deserialize(deserializer)?;
190        Ok(SpecialEq(t))
191    }
192}
193
194#[cfg(feature = "serde")]
195impl Serialize for SpecialEq<Arc<DslPlan>> {
196    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
197    where
198        S: Serializer,
199    {
200        self.0.serialize(serializer)
201    }
202}
203
204#[cfg(feature = "serde")]
205impl<'a> Deserialize<'a> for SpecialEq<Arc<DslPlan>> {
206    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
207    where
208        D: Deserializer<'a>,
209    {
210        let t = DslPlan::deserialize(deserializer)?;
211        Ok(SpecialEq(Arc::new(t)))
212    }
213}
214
215impl<T> SpecialEq<T> {
216    pub fn new(val: T) -> Self {
217        SpecialEq(val)
218    }
219}
220
221impl<T: ?Sized> PartialEq for SpecialEq<Arc<T>> {
222    fn eq(&self, other: &Self) -> bool {
223        Arc::ptr_eq(&self.0, &other.0)
224    }
225}
226
227impl PartialEq for SpecialEq<Series> {
228    fn eq(&self, other: &Self) -> bool {
229        self.0 == other.0
230    }
231}
232
233impl<T> Debug for SpecialEq<T> {
234    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
235        write!(f, "no_eq")
236    }
237}
238
239impl<T> Deref for SpecialEq<T> {
240    type Target = T;
241
242    fn deref(&self) -> &Self::Target {
243        &self.0
244    }
245}
246
247pub trait BinaryUdfOutputField: Send + Sync {
248    fn get_field(
249        &self,
250        input_schema: &Schema,
251        cntxt: Context,
252        field_a: &Field,
253        field_b: &Field,
254    ) -> Option<Field>;
255}
256
257impl<F> BinaryUdfOutputField for F
258where
259    F: Fn(&Schema, Context, &Field, &Field) -> Option<Field> + Send + Sync,
260{
261    fn get_field(
262        &self,
263        input_schema: &Schema,
264        cntxt: Context,
265        field_a: &Field,
266        field_b: &Field,
267    ) -> Option<Field> {
268        self(input_schema, cntxt, field_a, field_b)
269    }
270}
271
272pub trait FunctionOutputField: Send + Sync {
273    fn get_field(
274        &self,
275        input_schema: &Schema,
276        cntxt: Context,
277        fields: &[Field],
278    ) -> PolarsResult<Field>;
279
280    fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
281        polars_bail!(ComputeError: "serialization not supported for this output field")
282    }
283}
284
285pub type GetOutput = SpecialEq<Arc<dyn FunctionOutputField>>;
286
287impl Default for GetOutput {
288    fn default() -> Self {
289        SpecialEq::new(Arc::new(
290            |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
291        ))
292    }
293}
294
295impl GetOutput {
296    pub fn same_type() -> Self {
297        Default::default()
298    }
299
300    pub fn first() -> Self {
301        SpecialEq::new(Arc::new(
302            |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
303        ))
304    }
305
306    pub fn from_type(dt: DataType) -> Self {
307        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
308            Ok(Field::new(flds[0].name().clone(), dt.clone()))
309        }))
310    }
311
312    pub fn map_field<F: 'static + Fn(&Field) -> PolarsResult<Field> + Send + Sync>(f: F) -> Self {
313        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
314            f(&flds[0])
315        }))
316    }
317
318    pub fn map_fields<F: 'static + Fn(&[Field]) -> PolarsResult<Field> + Send + Sync>(
319        f: F,
320    ) -> Self {
321        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
322            f(flds)
323        }))
324    }
325
326    pub fn map_dtype<F: 'static + Fn(&DataType) -> PolarsResult<DataType> + Send + Sync>(
327        f: F,
328    ) -> Self {
329        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
330            let mut fld = flds[0].clone();
331            let new_type = f(fld.dtype())?;
332            fld.coerce(new_type);
333            Ok(fld)
334        }))
335    }
336
337    pub fn float_type() -> Self {
338        Self::map_dtype(|dt| {
339            Ok(match dt {
340                DataType::Float32 => DataType::Float32,
341                _ => DataType::Float64,
342            })
343        })
344    }
345
346    pub fn super_type() -> Self {
347        Self::map_dtypes(|dtypes| {
348            let mut st = dtypes[0].clone();
349            for dt in &dtypes[1..] {
350                st = try_get_supertype(&st, dt)?;
351            }
352            Ok(st)
353        })
354    }
355
356    pub fn map_dtypes<F>(f: F) -> Self
357    where
358        F: 'static + Fn(&[&DataType]) -> PolarsResult<DataType> + Send + Sync,
359    {
360        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
361            let mut fld = flds[0].clone();
362            let dtypes = flds.iter().map(|fld| fld.dtype()).collect::<Vec<_>>();
363            let new_type = f(&dtypes)?;
364            fld.coerce(new_type);
365            Ok(fld)
366        }))
367    }
368}
369
370impl<F> FunctionOutputField for F
371where
372    F: Fn(&Schema, Context, &[Field]) -> PolarsResult<Field> + Send + Sync,
373{
374    fn get_field(
375        &self,
376        input_schema: &Schema,
377        cntxt: Context,
378        fields: &[Field],
379    ) -> PolarsResult<Field> {
380        self(input_schema, cntxt, fields)
381    }
382}
383
384#[cfg(feature = "serde")]
385impl Serialize for GetOutput {
386    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
387    where
388        S: Serializer,
389    {
390        use serde::ser::Error;
391        let mut buf = vec![];
392        self.0
393            .try_serialize(&mut buf)
394            .map_err(|e| S::Error::custom(format!("{e}")))?;
395        serializer.serialize_bytes(&buf)
396    }
397}
398
399#[cfg(feature = "serde")]
400impl<'a> Deserialize<'a> for GetOutput {
401    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
402    where
403        D: Deserializer<'a>,
404    {
405        use serde::de::Error;
406        #[cfg(feature = "python")]
407        {
408            deserialize_map_bytes(deserializer, &mut |buf| {
409                if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) {
410                    let get_output = python_udf::PythonGetOutput::try_deserialize(&buf)
411                        .map_err(|e| D::Error::custom(format!("{e}")))?;
412                    Ok(SpecialEq::new(get_output))
413                } else {
414                    Err(D::Error::custom(
415                        "deserialization not supported for this output field",
416                    ))
417                }
418            })?
419        }
420        #[cfg(not(feature = "python"))]
421        {
422            _ = deserializer;
423
424            Err(D::Error::custom(
425                "deserialization not supported for this output field",
426            ))
427        }
428    }
429}
430
431#[cfg(feature = "serde")]
432impl Serialize for SpecialEq<Arc<dyn RenameAliasFn>> {
433    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
434    where
435        S: Serializer,
436    {
437        use serde::ser::Error;
438        let mut buf = vec![];
439        self.0
440            .try_serialize(&mut buf)
441            .map_err(|e| S::Error::custom(format!("{e}")))?;
442        serializer.serialize_bytes(&buf)
443    }
444}
445
446#[cfg(feature = "serde")]
447impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn RenameAliasFn>> {
448    fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
449    where
450        D: Deserializer<'a>,
451    {
452        use serde::de::Error;
453        Err(D::Error::custom(
454            "deserialization not supported for this renaming function",
455        ))
456    }
457}