1use std::fmt::Formatter;
2use std::ops::Deref;
3use std::sync::Arc;
4
5#[cfg(feature = "serde")]
6use polars_utils::pl_serialize::deserialize_map_bytes;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9
10use super::*;
11
12pub trait ColumnsUdf: Send + Sync {
14 fn as_any(&self) -> &dyn std::any::Any {
15 unimplemented!("as_any not implemented for this 'opaque' function")
16 }
17
18 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>>;
19
20 fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
21 polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
22 }
23}
24
25#[cfg(feature = "serde")]
26impl Serialize for SpecialEq<Arc<dyn ColumnsUdf>> {
27 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
28 where
29 S: Serializer,
30 {
31 use serde::ser::Error;
32 let mut buf = vec![];
33 self.0
34 .try_serialize(&mut buf)
35 .map_err(|e| S::Error::custom(format!("{e}")))?;
36 serializer.serialize_bytes(&buf)
37 }
38}
39
40#[cfg(feature = "serde")]
41impl<T: Serialize + Clone> Serialize for LazySerde<T> {
42 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
43 where
44 S: Serializer,
45 {
46 match self {
47 Self::Deserialized(t) => t.serialize(serializer),
48 Self::Bytes(b) => b.serialize(serializer),
49 }
50 }
51}
52
53#[cfg(feature = "serde")]
54impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde<T> {
55 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
56 where
57 D: Deserializer<'a>,
58 {
59 let buf = bytes::Bytes::deserialize(deserializer)?;
60 Ok(Self::Bytes(buf))
61 }
62}
63
64#[cfg(feature = "serde")]
65impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn ColumnsUdf>> {
67 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
68 where
69 D: Deserializer<'a>,
70 {
71 use serde::de::Error;
72 #[cfg(feature = "python")]
73 {
74 deserialize_map_bytes(deserializer, &mut |buf| {
75 if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) {
76 let udf = python_udf::PythonUdfExpression::try_deserialize(&buf)
77 .map_err(|e| D::Error::custom(format!("{e}")))?;
78 Ok(SpecialEq::new(udf))
79 } else {
80 Err(D::Error::custom(
81 "deserialization not supported for this 'opaque' function",
82 ))
83 }
84 })?
85 }
86 #[cfg(not(feature = "python"))]
87 {
88 _ = deserializer;
89
90 Err(D::Error::custom(
91 "deserialization not supported for this 'opaque' function",
92 ))
93 }
94 }
95}
96
97impl<F> ColumnsUdf for F
98where
99 F: Fn(&mut [Column]) -> PolarsResult<Option<Column>> + Send + Sync,
100{
101 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
102 self(s)
103 }
104}
105
106impl Debug for dyn ColumnsUdf {
107 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108 write!(f, "ColumnUdf")
109 }
110}
111
112pub trait ColumnBinaryUdf: Send + Sync {
114 fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column>;
115}
116
117impl<F> ColumnBinaryUdf for F
118where
119 F: Fn(Column, Column) -> PolarsResult<Column> + Send + Sync,
120{
121 fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column> {
122 self(a, b)
123 }
124}
125
126impl Debug for dyn ColumnBinaryUdf {
127 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128 write!(f, "ColumnBinaryUdf")
129 }
130}
131
132impl Default for SpecialEq<Arc<dyn ColumnBinaryUdf>> {
133 fn default() -> Self {
134 panic!("implementation error");
135 }
136}
137
138impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {
139 fn default() -> Self {
140 let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None;
141 SpecialEq::new(Arc::new(output_field))
142 }
143}
144
145pub trait RenameAliasFn: Send + Sync {
146 fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr>;
147
148 fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
149 polars_bail!(ComputeError: "serialization not supported for this renaming function")
150 }
151}
152
153impl<F> RenameAliasFn for F
154where
155 F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync,
156{
157 fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
158 self(name)
159 }
160}
161
162impl Debug for dyn RenameAliasFn {
163 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164 write!(f, "RenameAliasFn")
165 }
166}
167
168#[derive(Clone)]
169pub struct SpecialEq<T>(T);
172
173#[cfg(feature = "serde")]
174impl Serialize for SpecialEq<Series> {
175 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
176 where
177 S: Serializer,
178 {
179 self.0.serialize(serializer)
180 }
181}
182
183#[cfg(feature = "serde")]
184impl<'a> Deserialize<'a> for SpecialEq<Series> {
185 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
186 where
187 D: Deserializer<'a>,
188 {
189 let t = Series::deserialize(deserializer)?;
190 Ok(SpecialEq(t))
191 }
192}
193
194#[cfg(feature = "serde")]
195impl Serialize for SpecialEq<Arc<DslPlan>> {
196 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
197 where
198 S: Serializer,
199 {
200 self.0.serialize(serializer)
201 }
202}
203
204#[cfg(feature = "serde")]
205impl<'a> Deserialize<'a> for SpecialEq<Arc<DslPlan>> {
206 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
207 where
208 D: Deserializer<'a>,
209 {
210 let t = DslPlan::deserialize(deserializer)?;
211 Ok(SpecialEq(Arc::new(t)))
212 }
213}
214
215impl<T> SpecialEq<T> {
216 pub fn new(val: T) -> Self {
217 SpecialEq(val)
218 }
219}
220
221impl<T: ?Sized> PartialEq for SpecialEq<Arc<T>> {
222 fn eq(&self, other: &Self) -> bool {
223 Arc::ptr_eq(&self.0, &other.0)
224 }
225}
226
227impl PartialEq for SpecialEq<Series> {
228 fn eq(&self, other: &Self) -> bool {
229 self.0 == other.0
230 }
231}
232
233impl<T> Debug for SpecialEq<T> {
234 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
235 write!(f, "no_eq")
236 }
237}
238
239impl<T> Deref for SpecialEq<T> {
240 type Target = T;
241
242 fn deref(&self) -> &Self::Target {
243 &self.0
244 }
245}
246
247pub trait BinaryUdfOutputField: Send + Sync {
248 fn get_field(
249 &self,
250 input_schema: &Schema,
251 cntxt: Context,
252 field_a: &Field,
253 field_b: &Field,
254 ) -> Option<Field>;
255}
256
257impl<F> BinaryUdfOutputField for F
258where
259 F: Fn(&Schema, Context, &Field, &Field) -> Option<Field> + Send + Sync,
260{
261 fn get_field(
262 &self,
263 input_schema: &Schema,
264 cntxt: Context,
265 field_a: &Field,
266 field_b: &Field,
267 ) -> Option<Field> {
268 self(input_schema, cntxt, field_a, field_b)
269 }
270}
271
272pub trait FunctionOutputField: Send + Sync {
273 fn get_field(
274 &self,
275 input_schema: &Schema,
276 cntxt: Context,
277 fields: &[Field],
278 ) -> PolarsResult<Field>;
279
280 fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
281 polars_bail!(ComputeError: "serialization not supported for this output field")
282 }
283}
284
285pub type GetOutput = SpecialEq<Arc<dyn FunctionOutputField>>;
286
287impl Default for GetOutput {
288 fn default() -> Self {
289 SpecialEq::new(Arc::new(
290 |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
291 ))
292 }
293}
294
295impl GetOutput {
296 pub fn same_type() -> Self {
297 Default::default()
298 }
299
300 pub fn first() -> Self {
301 SpecialEq::new(Arc::new(
302 |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
303 ))
304 }
305
306 pub fn from_type(dt: DataType) -> Self {
307 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
308 Ok(Field::new(flds[0].name().clone(), dt.clone()))
309 }))
310 }
311
312 pub fn map_field<F: 'static + Fn(&Field) -> PolarsResult<Field> + Send + Sync>(f: F) -> Self {
313 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
314 f(&flds[0])
315 }))
316 }
317
318 pub fn map_fields<F: 'static + Fn(&[Field]) -> PolarsResult<Field> + Send + Sync>(
319 f: F,
320 ) -> Self {
321 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
322 f(flds)
323 }))
324 }
325
326 pub fn map_dtype<F: 'static + Fn(&DataType) -> PolarsResult<DataType> + Send + Sync>(
327 f: F,
328 ) -> Self {
329 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
330 let mut fld = flds[0].clone();
331 let new_type = f(fld.dtype())?;
332 fld.coerce(new_type);
333 Ok(fld)
334 }))
335 }
336
337 pub fn float_type() -> Self {
338 Self::map_dtype(|dt| {
339 Ok(match dt {
340 DataType::Float32 => DataType::Float32,
341 _ => DataType::Float64,
342 })
343 })
344 }
345
346 pub fn super_type() -> Self {
347 Self::map_dtypes(|dtypes| {
348 let mut st = dtypes[0].clone();
349 for dt in &dtypes[1..] {
350 st = try_get_supertype(&st, dt)?;
351 }
352 Ok(st)
353 })
354 }
355
356 pub fn map_dtypes<F>(f: F) -> Self
357 where
358 F: 'static + Fn(&[&DataType]) -> PolarsResult<DataType> + Send + Sync,
359 {
360 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
361 let mut fld = flds[0].clone();
362 let dtypes = flds.iter().map(|fld| fld.dtype()).collect::<Vec<_>>();
363 let new_type = f(&dtypes)?;
364 fld.coerce(new_type);
365 Ok(fld)
366 }))
367 }
368}
369
370impl<F> FunctionOutputField for F
371where
372 F: Fn(&Schema, Context, &[Field]) -> PolarsResult<Field> + Send + Sync,
373{
374 fn get_field(
375 &self,
376 input_schema: &Schema,
377 cntxt: Context,
378 fields: &[Field],
379 ) -> PolarsResult<Field> {
380 self(input_schema, cntxt, fields)
381 }
382}
383
384#[cfg(feature = "serde")]
385impl Serialize for GetOutput {
386 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
387 where
388 S: Serializer,
389 {
390 use serde::ser::Error;
391 let mut buf = vec![];
392 self.0
393 .try_serialize(&mut buf)
394 .map_err(|e| S::Error::custom(format!("{e}")))?;
395 serializer.serialize_bytes(&buf)
396 }
397}
398
399#[cfg(feature = "serde")]
400impl<'a> Deserialize<'a> for GetOutput {
401 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
402 where
403 D: Deserializer<'a>,
404 {
405 use serde::de::Error;
406 #[cfg(feature = "python")]
407 {
408 deserialize_map_bytes(deserializer, &mut |buf| {
409 if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) {
410 let get_output = python_udf::PythonGetOutput::try_deserialize(&buf)
411 .map_err(|e| D::Error::custom(format!("{e}")))?;
412 Ok(SpecialEq::new(get_output))
413 } else {
414 Err(D::Error::custom(
415 "deserialization not supported for this output field",
416 ))
417 }
418 })?
419 }
420 #[cfg(not(feature = "python"))]
421 {
422 _ = deserializer;
423
424 Err(D::Error::custom(
425 "deserialization not supported for this output field",
426 ))
427 }
428 }
429}
430
431#[cfg(feature = "serde")]
432impl Serialize for SpecialEq<Arc<dyn RenameAliasFn>> {
433 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
434 where
435 S: Serializer,
436 {
437 use serde::ser::Error;
438 let mut buf = vec![];
439 self.0
440 .try_serialize(&mut buf)
441 .map_err(|e| S::Error::custom(format!("{e}")))?;
442 serializer.serialize_bytes(&buf)
443 }
444}
445
446#[cfg(feature = "serde")]
447impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn RenameAliasFn>> {
448 fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
449 where
450 D: Deserializer<'a>,
451 {
452 use serde::de::Error;
453 Err(D::Error::custom(
454 "deserialization not supported for this renaming function",
455 ))
456 }
457}