polars_plan/plans/functions/
dsl.rsuse strum_macros::IntoStaticStr;
use super::*;
#[cfg(feature = "python")]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone)]
pub struct OpaquePythonUdf {
pub function: PythonFunction,
pub schema: Option<SchemaRef>,
pub predicate_pd: bool,
pub projection_pd: bool,
pub streamable: bool,
pub validate_output: bool,
}
#[derive(Clone, IntoStaticStr)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
pub enum DslFunction {
#[cfg_attr(feature = "serde", serde(skip))]
FunctionIR(FunctionIR),
#[cfg(feature = "python")]
OpaquePython(OpaquePythonUdf),
Explode {
columns: Vec<Selector>,
allow_empty: bool,
},
#[cfg(feature = "pivot")]
Unpivot {
args: UnpivotArgsDSL,
},
RowIndex {
name: PlSmallStr,
offset: Option<IdxSize>,
},
Rename {
existing: Arc<[PlSmallStr]>,
new: Arc<[PlSmallStr]>,
strict: bool,
},
Unnest(Vec<Selector>),
Stats(StatsFunction),
FillNan(Expr),
Drop(DropFunction),
}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct DropFunction {
pub(crate) to_drop: Vec<Selector>,
pub(crate) strict: bool,
}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum StatsFunction {
Var {
ddof: u8,
},
Std {
ddof: u8,
},
Quantile {
quantile: Expr,
method: QuantileMethod,
},
Median,
Mean,
Sum,
Min,
Max,
}
pub(crate) fn validate_columns_in_input<S: AsRef<str>, I: IntoIterator<Item = S>>(
columns: I,
input_schema: &Schema,
operation_name: &str,
) -> PolarsResult<()> {
let columns = columns.into_iter();
for c in columns {
polars_ensure!(input_schema.contains(c.as_ref()), ColumnNotFound: "'{}' on column: '{}' is invalid\n\nSchema at this point: {:?}", operation_name, c.as_ref(), input_schema)
}
Ok(())
}
impl DslFunction {
pub(crate) fn into_function_ir(self, input_schema: &Schema) -> PolarsResult<FunctionIR> {
let function = match self {
#[cfg(feature = "pivot")]
DslFunction::Unpivot { args } => {
let on = expand_selectors(args.on, input_schema, &[])?;
let index = expand_selectors(args.index, input_schema, &[])?;
validate_columns_in_input(on.as_ref(), input_schema, "unpivot")?;
validate_columns_in_input(index.as_ref(), input_schema, "unpivot")?;
let args = UnpivotArgsIR {
on: on.iter().cloned().collect(),
index: index.iter().cloned().collect(),
variable_name: args.variable_name.clone(),
value_name: args.value_name.clone(),
};
FunctionIR::Unpivot {
args: Arc::new(args),
schema: Default::default(),
}
},
DslFunction::FunctionIR(func) => func,
DslFunction::RowIndex { name, offset } => FunctionIR::RowIndex {
name,
offset,
schema: Default::default(),
},
DslFunction::Rename {
existing,
new,
strict,
} => {
let swapping = new.iter().any(|name| input_schema.get(name).is_some());
if strict {
validate_columns_in_input(existing.as_ref(), input_schema, "rename")?;
}
FunctionIR::Rename {
existing,
new,
swapping,
schema: Default::default(),
}
},
DslFunction::Unnest(selectors) => {
let columns = expand_selectors(selectors, input_schema, &[])?;
validate_columns_in_input(columns.as_ref(), input_schema, "explode")?;
FunctionIR::Unnest { columns }
},
#[cfg(feature = "python")]
DslFunction::OpaquePython(inner) => FunctionIR::OpaquePython(inner),
DslFunction::Stats(_)
| DslFunction::FillNan(_)
| DslFunction::Drop(_)
| DslFunction::Explode { .. } => {
panic!("impl error")
},
};
Ok(function)
}
}
impl Debug for DslFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
impl Display for DslFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use DslFunction::*;
match self {
FunctionIR(inner) => write!(f, "{inner}"),
v => {
let s: &str = v.into();
write!(f, "{s}")
},
}
}
}
impl From<FunctionIR> for DslFunction {
fn from(value: FunctionIR) -> Self {
DslFunction::FunctionIR(value)
}
}