datafusion_physical_plan/joins/
cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the cross join plan for loading the left side of the cross join
19//! and producing batches in parallel for the right partitions
20
21use std::{any::Any, sync::Arc, task::Poll};
22
23use super::utils::{
24    adjust_right_output_partitioning, reorder_output_after_swap, BatchSplitter,
25    BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut,
26    StatefulStreamResult,
27};
28use crate::coalesce_partitions::CoalescePartitionsExec;
29use crate::execution_plan::{boundedness_from_children, EmissionType};
30use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
31use crate::projection::{
32    join_allows_pushdown, join_table_borders, new_join_children,
33    physical_to_column_exprs, ProjectionExec,
34};
35use crate::{
36    handle_state, ColumnStatistics, DisplayAs, DisplayFormatType, Distribution,
37    ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
38    SendableRecordBatchStream, Statistics,
39};
40
41use arrow::array::{RecordBatch, RecordBatchOptions};
42use arrow::compute::concat_batches;
43use arrow::datatypes::{Fields, Schema, SchemaRef};
44use datafusion_common::stats::Precision;
45use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
46use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
47use datafusion_execution::TaskContext;
48use datafusion_physical_expr::equivalence::join_equivalence_properties;
49
50use async_trait::async_trait;
51use futures::{ready, Stream, StreamExt, TryStreamExt};
52
53/// Data of the left side that is buffered into memory
54#[derive(Debug)]
55struct JoinLeftData {
56    /// Single RecordBatch with all rows from the left side
57    merged_batch: RecordBatch,
58    /// Track memory reservation for merged_batch. Relies on drop
59    /// semantics to release reservation when JoinLeftData is dropped.
60    _reservation: MemoryReservation,
61}
62
63#[allow(rustdoc::private_intra_doc_links)]
64/// Cross Join Execution Plan
65///
66/// This operator is used when there are no predicates between two tables and
67/// returns the Cartesian product of the two tables.
68///
69/// Buffers the left input into memory and then streams batches from each
70/// partition on the right input combining them with the buffered left input
71/// to generate the output.
72///
73/// # Clone / Shared State
74///
75/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
76/// loading of the left side with the processing in each output stream.
77/// Therefore it can not be [`Clone`]
78#[derive(Debug)]
79pub struct CrossJoinExec {
80    /// left (build) side which gets loaded in memory
81    pub left: Arc<dyn ExecutionPlan>,
82    /// right (probe) side which are combined with left side
83    pub right: Arc<dyn ExecutionPlan>,
84    /// The schema once the join is applied
85    schema: SchemaRef,
86    /// Buffered copy of left (build) side in memory.
87    ///
88    /// This structure is *shared* across all output streams.
89    ///
90    /// Each output stream waits on the `OnceAsync` to signal the completion of
91    /// the left side loading.
92    left_fut: OnceAsync<JoinLeftData>,
93    /// Execution plan metrics
94    metrics: ExecutionPlanMetricsSet,
95    /// Properties such as schema, equivalence properties, ordering, partitioning, etc.
96    cache: PlanProperties,
97}
98
99impl CrossJoinExec {
100    /// Create a new [CrossJoinExec].
101    pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self {
102        // left then right
103        let (all_columns, metadata) = {
104            let left_schema = left.schema();
105            let right_schema = right.schema();
106            let left_fields = left_schema.fields().iter();
107            let right_fields = right_schema.fields().iter();
108
109            let mut metadata = left_schema.metadata().clone();
110            metadata.extend(right_schema.metadata().clone());
111
112            (
113                left_fields.chain(right_fields).cloned().collect::<Fields>(),
114                metadata,
115            )
116        };
117
118        let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
119        let cache = Self::compute_properties(&left, &right, Arc::clone(&schema));
120
121        CrossJoinExec {
122            left,
123            right,
124            schema,
125            left_fut: Default::default(),
126            metrics: ExecutionPlanMetricsSet::default(),
127            cache,
128        }
129    }
130
131    /// left (build) side which gets loaded in memory
132    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
133        &self.left
134    }
135
136    /// right side which gets combined with left side
137    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
138        &self.right
139    }
140
141    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
142    fn compute_properties(
143        left: &Arc<dyn ExecutionPlan>,
144        right: &Arc<dyn ExecutionPlan>,
145        schema: SchemaRef,
146    ) -> PlanProperties {
147        // Calculate equivalence properties
148        // TODO: Check equivalence properties of cross join, it may preserve
149        //       ordering in some cases.
150        let eq_properties = join_equivalence_properties(
151            left.equivalence_properties().clone(),
152            right.equivalence_properties().clone(),
153            &JoinType::Full,
154            schema,
155            &[false, false],
156            None,
157            &[],
158        );
159
160        // Get output partitioning:
161        // TODO: Optimize the cross join implementation to generate M * N
162        //       partitions.
163        let output_partitioning = adjust_right_output_partitioning(
164            right.output_partitioning(),
165            left.schema().fields.len(),
166        );
167
168        PlanProperties::new(
169            eq_properties,
170            output_partitioning,
171            EmissionType::Final,
172            boundedness_from_children([left, right]),
173        )
174    }
175
176    /// Returns a new `ExecutionPlan` that computes the same join as this one,
177    /// with the left and right inputs swapped using the  specified
178    /// `partition_mode`.
179    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
180        let new_join =
181            CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left));
182        reorder_output_after_swap(
183            Arc::new(new_join),
184            &self.left.schema(),
185            &self.right.schema(),
186        )
187    }
188}
189
190/// Asynchronously collect the result of the left child
191async fn load_left_input(
192    left: Arc<dyn ExecutionPlan>,
193    context: Arc<TaskContext>,
194    metrics: BuildProbeJoinMetrics,
195    reservation: MemoryReservation,
196) -> Result<JoinLeftData> {
197    // merge all left parts into a single stream
198    let left_schema = left.schema();
199    let merge = if left.output_partitioning().partition_count() != 1 {
200        Arc::new(CoalescePartitionsExec::new(left))
201    } else {
202        left
203    };
204    let stream = merge.execute(0, context)?;
205
206    // Load all batches and count the rows
207    let (batches, _metrics, reservation) = stream
208        .try_fold(
209            (Vec::new(), metrics, reservation),
210            |(mut batches, metrics, mut reservation), batch| async {
211                let batch_size = batch.get_array_memory_size();
212                // Reserve memory for incoming batch
213                reservation.try_grow(batch_size)?;
214                // Update metrics
215                metrics.build_mem_used.add(batch_size);
216                metrics.build_input_batches.add(1);
217                metrics.build_input_rows.add(batch.num_rows());
218                // Push batch to output
219                batches.push(batch);
220                Ok((batches, metrics, reservation))
221            },
222        )
223        .await?;
224
225    let merged_batch = concat_batches(&left_schema, &batches)?;
226
227    Ok(JoinLeftData {
228        merged_batch,
229        _reservation: reservation,
230    })
231}
232
233impl DisplayAs for CrossJoinExec {
234    fn fmt_as(
235        &self,
236        t: DisplayFormatType,
237        f: &mut std::fmt::Formatter,
238    ) -> std::fmt::Result {
239        match t {
240            DisplayFormatType::Default | DisplayFormatType::Verbose => {
241                write!(f, "CrossJoinExec")
242            }
243        }
244    }
245}
246
247impl ExecutionPlan for CrossJoinExec {
248    fn name(&self) -> &'static str {
249        "CrossJoinExec"
250    }
251
252    fn as_any(&self) -> &dyn Any {
253        self
254    }
255
256    fn properties(&self) -> &PlanProperties {
257        &self.cache
258    }
259
260    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
261        vec![&self.left, &self.right]
262    }
263
264    fn metrics(&self) -> Option<MetricsSet> {
265        Some(self.metrics.clone_inner())
266    }
267
268    fn with_new_children(
269        self: Arc<Self>,
270        children: Vec<Arc<dyn ExecutionPlan>>,
271    ) -> Result<Arc<dyn ExecutionPlan>> {
272        Ok(Arc::new(CrossJoinExec::new(
273            Arc::clone(&children[0]),
274            Arc::clone(&children[1]),
275        )))
276    }
277
278    fn required_input_distribution(&self) -> Vec<Distribution> {
279        vec![
280            Distribution::SinglePartition,
281            Distribution::UnspecifiedDistribution,
282        ]
283    }
284
285    fn execute(
286        &self,
287        partition: usize,
288        context: Arc<TaskContext>,
289    ) -> Result<SendableRecordBatchStream> {
290        let stream = self.right.execute(partition, Arc::clone(&context))?;
291
292        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
293
294        // Initialization of operator-level reservation
295        let reservation =
296            MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
297
298        let batch_size = context.session_config().batch_size();
299        let enforce_batch_size_in_joins =
300            context.session_config().enforce_batch_size_in_joins();
301
302        let left_fut = self.left_fut.once(|| {
303            load_left_input(
304                Arc::clone(&self.left),
305                context,
306                join_metrics.clone(),
307                reservation,
308            )
309        });
310
311        if enforce_batch_size_in_joins {
312            Ok(Box::pin(CrossJoinStream {
313                schema: Arc::clone(&self.schema),
314                left_fut,
315                right: stream,
316                left_index: 0,
317                join_metrics,
318                state: CrossJoinStreamState::WaitBuildSide,
319                left_data: RecordBatch::new_empty(self.left().schema()),
320                batch_transformer: BatchSplitter::new(batch_size),
321            }))
322        } else {
323            Ok(Box::pin(CrossJoinStream {
324                schema: Arc::clone(&self.schema),
325                left_fut,
326                right: stream,
327                left_index: 0,
328                join_metrics,
329                state: CrossJoinStreamState::WaitBuildSide,
330                left_data: RecordBatch::new_empty(self.left().schema()),
331                batch_transformer: NoopBatchTransformer::new(),
332            }))
333        }
334    }
335
336    fn statistics(&self) -> Result<Statistics> {
337        Ok(stats_cartesian_product(
338            self.left.statistics()?,
339            self.right.statistics()?,
340        ))
341    }
342
343    /// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done,
344    /// it returns the new swapped version having the [`CrossJoinExec`] as the top plan.
345    /// Otherwise, it returns None.
346    fn try_swapping_with_projection(
347        &self,
348        projection: &ProjectionExec,
349    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
350        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
351        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
352        else {
353            return Ok(None);
354        };
355
356        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
357            self.left().schema().fields().len(),
358            &projection_as_columns,
359        );
360
361        if !join_allows_pushdown(
362            &projection_as_columns,
363            &self.schema(),
364            far_right_left_col_ind,
365            far_left_right_col_ind,
366        ) {
367            return Ok(None);
368        }
369
370        let (new_left, new_right) = new_join_children(
371            &projection_as_columns,
372            far_right_left_col_ind,
373            far_left_right_col_ind,
374            self.left(),
375            self.right(),
376        )?;
377
378        Ok(Some(Arc::new(CrossJoinExec::new(
379            Arc::new(new_left),
380            Arc::new(new_right),
381        ))))
382    }
383}
384
385/// [left/right]_col_count are required in case the column statistics are None
386fn stats_cartesian_product(
387    left_stats: Statistics,
388    right_stats: Statistics,
389) -> Statistics {
390    let left_row_count = left_stats.num_rows;
391    let right_row_count = right_stats.num_rows;
392
393    // calculate global stats
394    let num_rows = left_row_count.multiply(&right_row_count);
395    // the result size is two times a*b because you have the columns of both left and right
396    let total_byte_size = left_stats
397        .total_byte_size
398        .multiply(&right_stats.total_byte_size)
399        .multiply(&Precision::Exact(2));
400
401    let left_col_stats = left_stats.column_statistics;
402    let right_col_stats = right_stats.column_statistics;
403
404    // the null counts must be multiplied by the row counts of the other side (if defined)
405    // Min, max and distinct_count on the other hand are invariants.
406    let cross_join_stats = left_col_stats
407        .into_iter()
408        .map(|s| ColumnStatistics {
409            null_count: s.null_count.multiply(&right_row_count),
410            distinct_count: s.distinct_count,
411            min_value: s.min_value,
412            max_value: s.max_value,
413            sum_value: s
414                .sum_value
415                .get_value()
416                // Cast the row count into the same type as any existing sum value
417                .and_then(|v| {
418                    Precision::<ScalarValue>::from(right_row_count)
419                        .cast_to(&v.data_type())
420                        .ok()
421                })
422                .map(|row_count| s.sum_value.multiply(&row_count))
423                .unwrap_or(Precision::Absent),
424        })
425        .chain(right_col_stats.into_iter().map(|s| {
426            ColumnStatistics {
427                null_count: s.null_count.multiply(&left_row_count),
428                distinct_count: s.distinct_count,
429                min_value: s.min_value,
430                max_value: s.max_value,
431                sum_value: s
432                    .sum_value
433                    .get_value()
434                    // Cast the row count into the same type as any existing sum value
435                    .and_then(|v| {
436                        Precision::<ScalarValue>::from(left_row_count)
437                            .cast_to(&v.data_type())
438                            .ok()
439                    })
440                    .map(|row_count| s.sum_value.multiply(&row_count))
441                    .unwrap_or(Precision::Absent),
442            }
443        }))
444        .collect();
445
446    Statistics {
447        num_rows,
448        total_byte_size,
449        column_statistics: cross_join_stats,
450    }
451}
452
453/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
454struct CrossJoinStream<T> {
455    /// Input schema
456    schema: Arc<Schema>,
457    /// Future for data from left side
458    left_fut: OnceFut<JoinLeftData>,
459    /// Right side stream
460    right: SendableRecordBatchStream,
461    /// Current value on the left
462    left_index: usize,
463    /// Join execution metrics
464    join_metrics: BuildProbeJoinMetrics,
465    /// State of the stream
466    state: CrossJoinStreamState,
467    /// Left data (copy of the entire buffered left side)
468    left_data: RecordBatch,
469    /// Batch transformer
470    batch_transformer: T,
471}
472
473impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for CrossJoinStream<T> {
474    fn schema(&self) -> SchemaRef {
475        Arc::clone(&self.schema)
476    }
477}
478
479/// Represents states of CrossJoinStream
480enum CrossJoinStreamState {
481    WaitBuildSide,
482    FetchProbeBatch,
483    /// Holds the currently processed right side batch
484    BuildBatches(RecordBatch),
485}
486
487impl CrossJoinStreamState {
488    /// Tries to extract RecordBatch from CrossJoinStreamState enum.
489    /// Returns an error if state is not BuildBatches state.
490    fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
491        match self {
492            CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
493            _ => internal_err!("Expected RecordBatch in BuildBatches state"),
494        }
495    }
496}
497
498fn build_batch(
499    left_index: usize,
500    batch: &RecordBatch,
501    left_data: &RecordBatch,
502    schema: &Schema,
503) -> Result<RecordBatch> {
504    // Repeat value on the left n times
505    let arrays = left_data
506        .columns()
507        .iter()
508        .map(|arr| {
509            let scalar = ScalarValue::try_from_array(arr, left_index)?;
510            scalar.to_array_of_size(batch.num_rows())
511        })
512        .collect::<Result<Vec<_>>>()?;
513
514    RecordBatch::try_new_with_options(
515        Arc::new(schema.clone()),
516        arrays
517            .iter()
518            .chain(batch.columns().iter())
519            .cloned()
520            .collect(),
521        &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
522    )
523    .map_err(Into::into)
524}
525
526#[async_trait]
527impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
528    type Item = Result<RecordBatch>;
529
530    fn poll_next(
531        mut self: std::pin::Pin<&mut Self>,
532        cx: &mut std::task::Context<'_>,
533    ) -> Poll<Option<Self::Item>> {
534        self.poll_next_impl(cx)
535    }
536}
537
538impl<T: BatchTransformer> CrossJoinStream<T> {
539    /// Separate implementation function that unpins the [`CrossJoinStream`] so
540    /// that partial borrows work correctly
541    fn poll_next_impl(
542        &mut self,
543        cx: &mut std::task::Context<'_>,
544    ) -> Poll<Option<Result<RecordBatch>>> {
545        loop {
546            return match self.state {
547                CrossJoinStreamState::WaitBuildSide => {
548                    handle_state!(ready!(self.collect_build_side(cx)))
549                }
550                CrossJoinStreamState::FetchProbeBatch => {
551                    handle_state!(ready!(self.fetch_probe_batch(cx)))
552                }
553                CrossJoinStreamState::BuildBatches(_) => {
554                    handle_state!(self.build_batches())
555                }
556            };
557        }
558    }
559
560    /// Collects build (left) side of the join into the state. In case of an empty build batch,
561    /// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch.
562    fn collect_build_side(
563        &mut self,
564        cx: &mut std::task::Context<'_>,
565    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
566        let build_timer = self.join_metrics.build_time.timer();
567        let left_data = match ready!(self.left_fut.get(cx)) {
568            Ok(left_data) => left_data,
569            Err(e) => return Poll::Ready(Err(e)),
570        };
571        build_timer.done();
572
573        let left_data = left_data.merged_batch.clone();
574        let result = if left_data.num_rows() == 0 {
575            StatefulStreamResult::Ready(None)
576        } else {
577            self.left_data = left_data;
578            self.state = CrossJoinStreamState::FetchProbeBatch;
579            StatefulStreamResult::Continue
580        };
581        Poll::Ready(Ok(result))
582    }
583
584    /// Fetches the probe (right) batch, updates the metrics, and save the batch in the state.
585    /// Then, the state is updated to build result batches.
586    fn fetch_probe_batch(
587        &mut self,
588        cx: &mut std::task::Context<'_>,
589    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
590        self.left_index = 0;
591        let right_data = match ready!(self.right.poll_next_unpin(cx)) {
592            Some(Ok(right_data)) => right_data,
593            Some(Err(e)) => return Poll::Ready(Err(e)),
594            None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
595        };
596        self.join_metrics.input_batches.add(1);
597        self.join_metrics.input_rows.add(right_data.num_rows());
598
599        self.state = CrossJoinStreamState::BuildBatches(right_data);
600        Poll::Ready(Ok(StatefulStreamResult::Continue))
601    }
602
603    /// Joins the the indexed row of left data with the current probe batch.
604    /// If all the results are produced, the state is set to fetch new probe batch.
605    fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
606        let right_batch = self.state.try_as_record_batch()?;
607        if self.left_index < self.left_data.num_rows() {
608            match self.batch_transformer.next() {
609                None => {
610                    let join_timer = self.join_metrics.join_time.timer();
611                    let result = build_batch(
612                        self.left_index,
613                        right_batch,
614                        &self.left_data,
615                        &self.schema,
616                    );
617                    join_timer.done();
618
619                    self.batch_transformer.set_batch(result?);
620                }
621                Some((batch, last)) => {
622                    if last {
623                        self.left_index += 1;
624                    }
625
626                    self.join_metrics.output_batches.add(1);
627                    self.join_metrics.output_rows.add(batch.num_rows());
628                    return Ok(StatefulStreamResult::Ready(Some(batch)));
629                }
630            }
631        } else {
632            self.state = CrossJoinStreamState::FetchProbeBatch;
633        }
634        Ok(StatefulStreamResult::Continue)
635    }
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641    use crate::common;
642    use crate::test::build_table_scan_i32;
643
644    use datafusion_common::{assert_batches_sorted_eq, assert_contains};
645    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
646
647    async fn join_collect(
648        left: Arc<dyn ExecutionPlan>,
649        right: Arc<dyn ExecutionPlan>,
650        context: Arc<TaskContext>,
651    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
652        let join = CrossJoinExec::new(left, right);
653        let columns_header = columns(&join.schema());
654
655        let stream = join.execute(0, context)?;
656        let batches = common::collect(stream).await?;
657
658        Ok((columns_header, batches))
659    }
660
661    #[tokio::test]
662    async fn test_stats_cartesian_product() {
663        let left_row_count = 11;
664        let left_bytes = 23;
665        let right_row_count = 7;
666        let right_bytes = 27;
667
668        let left = Statistics {
669            num_rows: Precision::Exact(left_row_count),
670            total_byte_size: Precision::Exact(left_bytes),
671            column_statistics: vec![
672                ColumnStatistics {
673                    distinct_count: Precision::Exact(5),
674                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
675                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
676                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
677                    null_count: Precision::Exact(0),
678                },
679                ColumnStatistics {
680                    distinct_count: Precision::Exact(1),
681                    max_value: Precision::Exact(ScalarValue::from("x")),
682                    min_value: Precision::Exact(ScalarValue::from("a")),
683                    sum_value: Precision::Absent,
684                    null_count: Precision::Exact(3),
685                },
686            ],
687        };
688
689        let right = Statistics {
690            num_rows: Precision::Exact(right_row_count),
691            total_byte_size: Precision::Exact(right_bytes),
692            column_statistics: vec![ColumnStatistics {
693                distinct_count: Precision::Exact(3),
694                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
695                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
696                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
697                null_count: Precision::Exact(2),
698            }],
699        };
700
701        let result = stats_cartesian_product(left, right);
702
703        let expected = Statistics {
704            num_rows: Precision::Exact(left_row_count * right_row_count),
705            total_byte_size: Precision::Exact(2 * left_bytes * right_bytes),
706            column_statistics: vec![
707                ColumnStatistics {
708                    distinct_count: Precision::Exact(5),
709                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
710                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
711                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
712                        42 * right_row_count as i64,
713                    ))),
714                    null_count: Precision::Exact(0),
715                },
716                ColumnStatistics {
717                    distinct_count: Precision::Exact(1),
718                    max_value: Precision::Exact(ScalarValue::from("x")),
719                    min_value: Precision::Exact(ScalarValue::from("a")),
720                    sum_value: Precision::Absent,
721                    null_count: Precision::Exact(3 * right_row_count),
722                },
723                ColumnStatistics {
724                    distinct_count: Precision::Exact(3),
725                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
726                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
727                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
728                        20 * left_row_count as i64,
729                    ))),
730                    null_count: Precision::Exact(2 * left_row_count),
731                },
732            ],
733        };
734
735        assert_eq!(result, expected);
736    }
737
738    #[tokio::test]
739    async fn test_stats_cartesian_product_with_unknown_size() {
740        let left_row_count = 11;
741
742        let left = Statistics {
743            num_rows: Precision::Exact(left_row_count),
744            total_byte_size: Precision::Exact(23),
745            column_statistics: vec![
746                ColumnStatistics {
747                    distinct_count: Precision::Exact(5),
748                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
749                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
750                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
751                    null_count: Precision::Exact(0),
752                },
753                ColumnStatistics {
754                    distinct_count: Precision::Exact(1),
755                    max_value: Precision::Exact(ScalarValue::from("x")),
756                    min_value: Precision::Exact(ScalarValue::from("a")),
757                    sum_value: Precision::Absent,
758                    null_count: Precision::Exact(3),
759                },
760            ],
761        };
762
763        let right = Statistics {
764            num_rows: Precision::Absent,
765            total_byte_size: Precision::Absent,
766            column_statistics: vec![ColumnStatistics {
767                distinct_count: Precision::Exact(3),
768                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
769                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
770                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
771                null_count: Precision::Exact(2),
772            }],
773        };
774
775        let result = stats_cartesian_product(left, right);
776
777        let expected = Statistics {
778            num_rows: Precision::Absent,
779            total_byte_size: Precision::Absent,
780            column_statistics: vec![
781                ColumnStatistics {
782                    distinct_count: Precision::Exact(5),
783                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
784                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
785                    sum_value: Precision::Absent, // we don't know the row count on the right
786                    null_count: Precision::Absent, // we don't know the row count on the right
787                },
788                ColumnStatistics {
789                    distinct_count: Precision::Exact(1),
790                    max_value: Precision::Exact(ScalarValue::from("x")),
791                    min_value: Precision::Exact(ScalarValue::from("a")),
792                    sum_value: Precision::Absent,
793                    null_count: Precision::Absent, // we don't know the row count on the right
794                },
795                ColumnStatistics {
796                    distinct_count: Precision::Exact(3),
797                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
798                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
799                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
800                        20 * left_row_count as i64,
801                    ))),
802                    null_count: Precision::Exact(2 * left_row_count),
803                },
804            ],
805        };
806
807        assert_eq!(result, expected);
808    }
809
810    #[tokio::test]
811    async fn test_join() -> Result<()> {
812        let task_ctx = Arc::new(TaskContext::default());
813
814        let left = build_table_scan_i32(
815            ("a1", &vec![1, 2, 3]),
816            ("b1", &vec![4, 5, 6]),
817            ("c1", &vec![7, 8, 9]),
818        );
819        let right = build_table_scan_i32(
820            ("a2", &vec![10, 11]),
821            ("b2", &vec![12, 13]),
822            ("c2", &vec![14, 15]),
823        );
824
825        let (columns, batches) = join_collect(left, right, task_ctx).await?;
826
827        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
828        let expected = [
829            "+----+----+----+----+----+----+",
830            "| a1 | b1 | c1 | a2 | b2 | c2 |",
831            "+----+----+----+----+----+----+",
832            "| 1  | 4  | 7  | 10 | 12 | 14 |",
833            "| 1  | 4  | 7  | 11 | 13 | 15 |",
834            "| 2  | 5  | 8  | 10 | 12 | 14 |",
835            "| 2  | 5  | 8  | 11 | 13 | 15 |",
836            "| 3  | 6  | 9  | 10 | 12 | 14 |",
837            "| 3  | 6  | 9  | 11 | 13 | 15 |",
838            "+----+----+----+----+----+----+",
839        ];
840
841        assert_batches_sorted_eq!(expected, &batches);
842
843        Ok(())
844    }
845
846    #[tokio::test]
847    async fn test_overallocation() -> Result<()> {
848        let runtime = RuntimeEnvBuilder::new()
849            .with_memory_limit(100, 1.0)
850            .build_arc()?;
851        let task_ctx = TaskContext::default().with_runtime(runtime);
852        let task_ctx = Arc::new(task_ctx);
853
854        let left = build_table_scan_i32(
855            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
856            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
857            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
858        );
859        let right = build_table_scan_i32(
860            ("a2", &vec![10, 11]),
861            ("b2", &vec![12, 13]),
862            ("c2", &vec![14, 15]),
863        );
864
865        let err = join_collect(left, right, task_ctx).await.unwrap_err();
866
867        assert_contains!(
868            err.to_string(),
869            "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec"
870        );
871
872        Ok(())
873    }
874
875    /// Returns the column names on the schema
876    fn columns(schema: &Schema) -> Vec<String> {
877        schema.fields().iter().map(|f| f.name().clone()).collect()
878    }
879}