use std::ops::Deref;
use std::sync::Mutex;
use arrow::datatypes::ArrowSchemaRef;
use either::Either;
use polars_core::prelude::*;
use polars_utils::format_pl_smallstr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::prelude::*;
impl DslPlan {
pub fn compute_schema(&self) -> PolarsResult<SchemaRef> {
let mut lp_arena = Default::default();
let mut expr_arena = Default::default();
let node = to_alp(
self.clone(),
&mut expr_arena,
&mut lp_arena,
&mut OptFlags::schema_only(),
)?;
Ok(lp_arena.get(node).schema(&lp_arena).into_owned())
}
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct FileInfo {
pub schema: SchemaRef,
pub reader_schema: Option<Either<ArrowSchemaRef, SchemaRef>>,
pub row_estimation: (Option<usize>, usize),
}
impl FileInfo {
pub fn new(
schema: SchemaRef,
reader_schema: Option<Either<ArrowSchemaRef, SchemaRef>>,
row_estimation: (Option<usize>, usize),
) -> Self {
Self {
schema: schema.clone(),
reader_schema,
row_estimation,
}
}
pub fn update_schema_with_hive_schema(&mut self, hive_schema: SchemaRef) {
let schema = Arc::make_mut(&mut self.schema);
for field in hive_schema.iter_fields() {
if let Ok(existing) = schema.try_get_mut(&field.name) {
*existing = field.dtype().clone();
} else {
schema
.insert_at_index(schema.len(), field.name, field.dtype.clone())
.unwrap();
}
}
}
}
#[cfg(feature = "streaming")]
fn estimate_sizes(
known_size: Option<usize>,
estimated_size: usize,
filter_count: usize,
) -> (Option<usize>, usize) {
match (known_size, filter_count) {
(Some(known_size), 0) => (Some(known_size), estimated_size),
(None, 0) => (None, estimated_size),
(_, _) => (
None,
(estimated_size as f32 * 0.9f32.powf(filter_count as f32)) as usize,
),
}
}
#[cfg(feature = "streaming")]
pub fn set_estimated_row_counts(
root: Node,
lp_arena: &mut Arena<IR>,
expr_arena: &Arena<AExpr>,
mut _filter_count: usize,
scratch: &mut Vec<Node>,
) -> (Option<usize>, usize, usize) {
use IR::*;
fn apply_slice(out: &mut (Option<usize>, usize, usize), slice: Option<(i64, usize)>) {
if let Some((_, len)) = slice {
out.0 = out.0.map(|known_size| std::cmp::min(len, known_size));
out.1 = std::cmp::min(len, out.1);
}
}
match lp_arena.get(root) {
Filter { predicate, input } => {
_filter_count += expr_arena
.iter(predicate.node())
.filter(|(_, ae)| matches!(ae, AExpr::BinaryExpr { .. }))
.count()
+ 1;
set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch)
},
Slice { input, len, .. } => {
let len = *len as usize;
let mut out =
set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch);
apply_slice(&mut out, Some((0, len)));
out
},
Union { .. } => {
if let Union {
inputs,
mut options,
} = lp_arena.take(root)
{
let mut sum_output = (None, 0usize);
for input in &inputs {
let mut out =
set_estimated_row_counts(*input, lp_arena, expr_arena, 0, scratch);
if let Some((_offset, len)) = options.slice {
apply_slice(&mut out, Some((0, len)))
}
let out = estimate_sizes(out.0, out.1, out.2);
sum_output.1 = sum_output.1.saturating_add(out.1);
}
options.rows = sum_output;
lp_arena.replace(root, Union { inputs, options });
(sum_output.0, sum_output.1, 0)
} else {
unreachable!()
}
},
Join { .. } => {
if let Join {
input_left,
input_right,
mut options,
schema,
left_on,
right_on,
} = lp_arena.take(root)
{
let mut_options = Arc::make_mut(&mut options);
let (known_size, estimated_size, filter_count_left) =
set_estimated_row_counts(input_left, lp_arena, expr_arena, 0, scratch);
mut_options.rows_left =
estimate_sizes(known_size, estimated_size, filter_count_left);
let (known_size, estimated_size, filter_count_right) =
set_estimated_row_counts(input_right, lp_arena, expr_arena, 0, scratch);
mut_options.rows_right =
estimate_sizes(known_size, estimated_size, filter_count_right);
let mut out = match options.args.how {
JoinType::Left => {
let (known_size, estimated_size) = options.rows_left;
(known_size, estimated_size, filter_count_left)
},
JoinType::Cross | JoinType::Full { .. } => {
let (known_size_left, estimated_size_left) = options.rows_left;
let (known_size_right, estimated_size_right) = options.rows_right;
match (known_size_left, known_size_right) {
(Some(l), Some(r)) => {
(Some(l * r), estimated_size_left, estimated_size_right)
},
_ => (None, estimated_size_left * estimated_size_right, 0),
}
},
_ => {
let (known_size_left, estimated_size_left) = options.rows_left;
let (known_size_right, estimated_size_right) = options.rows_right;
if estimated_size_left > estimated_size_right {
(known_size_left, estimated_size_left, 0)
} else {
(known_size_right, estimated_size_right, 0)
}
},
};
apply_slice(&mut out, options.args.slice);
lp_arena.replace(
root,
Join {
input_left,
input_right,
options,
schema,
left_on,
right_on,
},
);
out
} else {
unreachable!()
}
},
DataFrameScan { df, .. } => {
let len = df.height();
(Some(len), len, _filter_count)
},
Scan { file_info, .. } => {
let (known_size, estimated_size) = file_info.row_estimation;
(known_size, estimated_size, _filter_count)
},
#[cfg(feature = "python")]
PythonScan { .. } => {
(None, usize::MAX, _filter_count)
},
lp => {
lp.copy_inputs(scratch);
let mut sum_output = (None, 0, 0);
while let Some(input) = scratch.pop() {
let out =
set_estimated_row_counts(input, lp_arena, expr_arena, _filter_count, scratch);
sum_output.1 += out.1;
sum_output.2 += out.2;
sum_output.0 = match sum_output.0 {
None => out.0,
p => p,
};
}
sum_output
},
}
}
pub(crate) fn det_join_schema(
schema_left: &SchemaRef,
schema_right: &SchemaRef,
left_on: &[Expr],
right_on: &[Expr],
options: &JoinOptions,
) -> PolarsResult<SchemaRef> {
match &options.args.how {
#[cfg(feature = "semi_anti_join")]
JoinType::Semi | JoinType::Anti => Ok(schema_left.clone()),
JoinType::Right => {
let mut arena = Arena::with_capacity(8);
let mut join_on_left: PlHashSet<_> = PlHashSet::with_capacity(left_on.len());
for e in left_on {
let field = e.to_field_amortized(schema_left, Context::Default, &mut arena)?;
join_on_left.insert(field.name);
}
let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len());
for e in right_on {
let field = e.to_field_amortized(schema_right, Context::Default, &mut arena)?;
join_on_right.insert(field.name);
}
let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len());
let should_coalesce = options.args.should_coalesce();
if !should_coalesce {
for (name, dtype) in schema_left.iter() {
new_schema.with_column(name.clone(), dtype.clone());
}
} else {
for (name, dtype) in schema_left.iter() {
if !join_on_left.contains(name) {
new_schema.with_column(name.clone(), dtype.clone());
}
}
}
for (name, dtype) in schema_right.iter() {
{
let left_is_removed = join_on_left.contains(name.as_str()) && should_coalesce;
if schema_left.contains(name.as_str()) && !left_is_removed {
let new_name = format_pl_smallstr!("{}{}", name, options.args.suffix());
new_schema.with_column(new_name, dtype.clone());
} else {
new_schema.with_column(name.clone(), dtype.clone());
}
}
}
Ok(Arc::new(new_schema))
},
_how => {
let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len());
for (name, dtype) in schema_left.iter() {
new_schema.with_column(name.clone(), dtype.clone());
}
let should_coalesce = options.args.should_coalesce();
let mut arena = Arena::with_capacity(8);
#[cfg(feature = "asof_join")]
if matches!(_how, JoinType::AsOf(_)) {
for (left_on, right_on) in left_on.iter().zip(right_on) {
let field_left =
left_on.to_field_amortized(schema_left, Context::Default, &mut arena)?;
let field_right =
right_on.to_field_amortized(schema_right, Context::Default, &mut arena)?;
if should_coalesce && field_left.name != field_right.name {
if schema_left.contains(&field_right.name) {
new_schema.with_column(
_join_suffix_name(&field_right.name, options.args.suffix()),
field_right.dtype,
);
} else {
new_schema.with_column(field_right.name, field_right.dtype);
}
}
}
}
let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len());
for e in right_on {
let field = e.to_field_amortized(schema_right, Context::Default, &mut arena)?;
join_on_right.insert(field.name);
}
for (name, dtype) in schema_right.iter() {
if !join_on_right.contains(name.as_str()) || (!should_coalesce) {
#[cfg(feature = "asof_join")]
if let JoinType::AsOf(asof_options) = &options.args.how {
if let Some(right_by) = &asof_options.right_by {
{
if right_by.contains(name) {
continue;
}
}
}
}
if schema_left.contains(name.as_str()) {
let new_name = format_pl_smallstr!("{}{}", name, options.args.suffix());
new_schema.with_column(new_name, dtype.clone());
} else {
new_schema.with_column(name.clone(), dtype.clone());
}
}
}
Ok(Arc::new(new_schema))
},
}
}
#[derive(Default)]
pub struct CachedSchema(Mutex<Option<SchemaRef>>);
impl AsRef<Mutex<Option<SchemaRef>>> for CachedSchema {
fn as_ref(&self) -> &Mutex<Option<SchemaRef>> {
&self.0
}
}
impl Deref for CachedSchema {
type Target = Mutex<Option<SchemaRef>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Clone for CachedSchema {
fn clone(&self) -> Self {
let inner = self.0.lock().unwrap();
Self(Mutex::new(inner.clone()))
}
}
impl CachedSchema {
pub fn get(&self) -> Option<SchemaRef> {
self.0.lock().unwrap().clone()
}
}