use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub use super::expr_dyn_fn::*;
use crate::prelude::*;
#[derive(PartialEq, Clone, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum AggExpr {
Min {
input: Arc<Expr>,
propagate_nans: bool,
},
Max {
input: Arc<Expr>,
propagate_nans: bool,
},
Median(Arc<Expr>),
NUnique(Arc<Expr>),
First(Arc<Expr>),
Last(Arc<Expr>),
Mean(Arc<Expr>),
Implode(Arc<Expr>),
Count(Arc<Expr>, bool),
Quantile {
expr: Arc<Expr>,
quantile: Arc<Expr>,
interpol: QuantileInterpolOptions,
},
Sum(Arc<Expr>),
AggGroups(Arc<Expr>),
Std(Arc<Expr>, u8),
Var(Arc<Expr>, u8),
}
impl AsRef<Expr> for AggExpr {
fn as_ref(&self) -> &Expr {
use AggExpr::*;
match self {
Min { input, .. } => input,
Max { input, .. } => input,
Median(e) => e,
NUnique(e) => e,
First(e) => e,
Last(e) => e,
Mean(e) => e,
Implode(e) => e,
Count(e, _) => e,
Quantile { expr, .. } => expr,
Sum(e) => e,
AggGroups(e) => e,
Std(e, _) => e,
Var(e, _) => e,
}
}
}
#[derive(Clone, PartialEq)]
#[must_use]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Expr {
Alias(Arc<Expr>, ColumnName),
Column(ColumnName),
Columns(Arc<[ColumnName]>),
DtypeColumn(Vec<DataType>),
IndexColumn(Arc<[i64]>),
Literal(LiteralValue),
BinaryExpr {
left: Arc<Expr>,
op: Operator,
right: Arc<Expr>,
},
Cast {
expr: Arc<Expr>,
data_type: DataType,
options: CastOptions,
},
Sort {
expr: Arc<Expr>,
options: SortOptions,
},
Gather {
expr: Arc<Expr>,
idx: Arc<Expr>,
returns_scalar: bool,
},
SortBy {
expr: Arc<Expr>,
by: Vec<Expr>,
sort_options: SortMultipleOptions,
},
Agg(AggExpr),
Ternary {
predicate: Arc<Expr>,
truthy: Arc<Expr>,
falsy: Arc<Expr>,
},
Function {
input: Vec<Expr>,
function: FunctionExpr,
options: FunctionOptions,
},
Explode(Arc<Expr>),
Filter {
input: Arc<Expr>,
by: Arc<Expr>,
},
Window {
function: Arc<Expr>,
partition_by: Vec<Expr>,
order_by: Option<(Arc<Expr>, SortOptions)>,
options: WindowType,
},
Wildcard,
Slice {
input: Arc<Expr>,
offset: Arc<Expr>,
length: Arc<Expr>,
},
Exclude(Arc<Expr>, Vec<Excluded>),
KeepName(Arc<Expr>),
Len,
Nth(i64),
#[cfg_attr(feature = "serde", serde(skip))]
RenameAlias {
function: SpecialEq<Arc<dyn RenameAliasFn>>,
expr: Arc<Expr>,
},
#[cfg(feature = "dtype-struct")]
Field(Arc<[ColumnName]>),
AnonymousFunction {
input: Vec<Expr>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
#[cfg_attr(feature = "serde", serde(skip))]
output_type: GetOutput,
options: FunctionOptions,
},
SubPlan(SpecialEq<Arc<DslPlan>>, Vec<String>),
Selector(super::selector::Selector),
}
#[allow(clippy::derived_hash_with_manual_eq)]
impl Hash for Expr {
fn hash<H: Hasher>(&self, state: &mut H) {
let d = std::mem::discriminant(self);
d.hash(state);
match self {
Expr::Column(name) => name.hash(state),
Expr::Columns(names) => names.hash(state),
Expr::DtypeColumn(dtypes) => dtypes.hash(state),
Expr::IndexColumn(indices) => indices.hash(state),
Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
Expr::Selector(s) => s.hash(state),
Expr::Nth(v) => v.hash(state),
Expr::Filter { input, by } => {
input.hash(state);
by.hash(state);
},
Expr::BinaryExpr { left, op, right } => {
left.hash(state);
right.hash(state);
std::mem::discriminant(op).hash(state)
},
Expr::Cast {
expr,
data_type,
options: strict,
} => {
expr.hash(state);
data_type.hash(state);
strict.hash(state)
},
Expr::Sort { expr, options } => {
expr.hash(state);
options.hash(state);
},
Expr::Alias(input, name) => {
input.hash(state);
name.hash(state)
},
Expr::KeepName(input) => input.hash(state),
Expr::Ternary {
predicate,
truthy,
falsy,
} => {
predicate.hash(state);
truthy.hash(state);
falsy.hash(state);
},
Expr::Function {
input,
function,
options,
} => {
input.hash(state);
std::mem::discriminant(function).hash(state);
options.hash(state);
},
Expr::Gather {
expr,
idx,
returns_scalar,
} => {
expr.hash(state);
idx.hash(state);
returns_scalar.hash(state);
},
Expr::Wildcard | Expr::Len => {},
Expr::SortBy {
expr,
by,
sort_options,
} => {
expr.hash(state);
by.hash(state);
sort_options.hash(state);
},
Expr::Agg(input) => input.hash(state),
Expr::Explode(input) => input.hash(state),
Expr::Window {
function,
partition_by,
order_by,
options,
} => {
function.hash(state);
partition_by.hash(state);
order_by.hash(state);
options.hash(state);
},
Expr::Slice {
input,
offset,
length,
} => {
input.hash(state);
offset.hash(state);
length.hash(state);
},
Expr::Exclude(input, excl) => {
input.hash(state);
excl.hash(state);
},
Expr::RenameAlias { function: _, expr } => expr.hash(state),
Expr::AnonymousFunction {
input,
function: _,
output_type: _,
options,
} => {
input.hash(state);
options.hash(state);
},
Expr::SubPlan(_, names) => names.hash(state),
#[cfg(feature = "dtype-struct")]
Expr::Field(names) => names.hash(state),
}
}
}
impl Eq for Expr {}
impl Default for Expr {
fn default() -> Self {
Expr::Literal(LiteralValue::Null)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Excluded {
Name(ColumnName),
Dtype(DataType),
}
impl Expr {
pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult<Field> {
let mut arena = Arena::with_capacity(5);
self.to_field_amortized(schema, ctxt, &mut arena)
}
pub(crate) fn to_field_amortized(
&self,
schema: &Schema,
ctxt: Context,
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<Field> {
let root = to_aexpr(self.clone(), expr_arena);
expr_arena.get(root).to_field(schema, ctxt, expr_arena)
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Operator {
Eq,
EqValidity,
NotEq,
NotEqValidity,
Lt,
LtEq,
Gt,
GtEq,
Plus,
Minus,
Multiply,
Divide,
TrueDivide,
FloorDivide,
Modulus,
And,
Or,
Xor,
LogicalAnd,
LogicalOr,
}
impl Display for Operator {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use Operator::*;
let tkn = match self {
Eq => "==",
EqValidity => "==v",
NotEq => "!=",
NotEqValidity => "!=v",
Lt => "<",
LtEq => "<=",
Gt => ">",
GtEq => ">=",
Plus => "+",
Minus => "-",
Multiply => "*",
Divide => "//",
TrueDivide => "/",
FloorDivide => "floor_div",
Modulus => "%",
And | LogicalAnd => "&",
Or | LogicalOr => "|",
Xor => "^",
};
write!(f, "{tkn}")
}
}
impl Operator {
pub(crate) fn is_comparison(&self) -> bool {
matches!(
self,
Self::Eq
| Self::NotEq
| Self::Lt
| Self::LtEq
| Self::Gt
| Self::GtEq
| Self::And
| Self::Or
| Self::Xor
| Self::EqValidity
| Self::NotEqValidity
)
}
pub(crate) fn is_arithmetic(&self) -> bool {
!(self.is_comparison())
}
}