1use std::{any::Any, sync::Arc, task::Poll};
22
23use super::utils::{
24 adjust_right_output_partitioning, reorder_output_after_swap, BatchSplitter,
25 BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut,
26 StatefulStreamResult,
27};
28use crate::coalesce_partitions::CoalescePartitionsExec;
29use crate::execution_plan::{boundedness_from_children, EmissionType};
30use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
31use crate::projection::{
32 join_allows_pushdown, join_table_borders, new_join_children,
33 physical_to_column_exprs, ProjectionExec,
34};
35use crate::{
36 handle_state, ColumnStatistics, DisplayAs, DisplayFormatType, Distribution,
37 ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
38 SendableRecordBatchStream, Statistics,
39};
40
41use arrow::array::{RecordBatch, RecordBatchOptions};
42use arrow::compute::concat_batches;
43use arrow::datatypes::{Fields, Schema, SchemaRef};
44use datafusion_common::stats::Precision;
45use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
46use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
47use datafusion_execution::TaskContext;
48use datafusion_physical_expr::equivalence::join_equivalence_properties;
49
50use async_trait::async_trait;
51use futures::{ready, Stream, StreamExt, TryStreamExt};
52
53#[derive(Debug)]
55struct JoinLeftData {
56 merged_batch: RecordBatch,
58 _reservation: MemoryReservation,
61}
62
63#[allow(rustdoc::private_intra_doc_links)]
64#[derive(Debug)]
79pub struct CrossJoinExec {
80 pub left: Arc<dyn ExecutionPlan>,
82 pub right: Arc<dyn ExecutionPlan>,
84 schema: SchemaRef,
86 left_fut: OnceAsync<JoinLeftData>,
93 metrics: ExecutionPlanMetricsSet,
95 cache: PlanProperties,
97}
98
99impl CrossJoinExec {
100 pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self {
102 let (all_columns, metadata) = {
104 let left_schema = left.schema();
105 let right_schema = right.schema();
106 let left_fields = left_schema.fields().iter();
107 let right_fields = right_schema.fields().iter();
108
109 let mut metadata = left_schema.metadata().clone();
110 metadata.extend(right_schema.metadata().clone());
111
112 (
113 left_fields.chain(right_fields).cloned().collect::<Fields>(),
114 metadata,
115 )
116 };
117
118 let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
119 let cache = Self::compute_properties(&left, &right, Arc::clone(&schema));
120
121 CrossJoinExec {
122 left,
123 right,
124 schema,
125 left_fut: Default::default(),
126 metrics: ExecutionPlanMetricsSet::default(),
127 cache,
128 }
129 }
130
131 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
133 &self.left
134 }
135
136 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
138 &self.right
139 }
140
141 fn compute_properties(
143 left: &Arc<dyn ExecutionPlan>,
144 right: &Arc<dyn ExecutionPlan>,
145 schema: SchemaRef,
146 ) -> PlanProperties {
147 let eq_properties = join_equivalence_properties(
151 left.equivalence_properties().clone(),
152 right.equivalence_properties().clone(),
153 &JoinType::Full,
154 schema,
155 &[false, false],
156 None,
157 &[],
158 );
159
160 let output_partitioning = adjust_right_output_partitioning(
164 right.output_partitioning(),
165 left.schema().fields.len(),
166 );
167
168 PlanProperties::new(
169 eq_properties,
170 output_partitioning,
171 EmissionType::Final,
172 boundedness_from_children([left, right]),
173 )
174 }
175
176 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
180 let new_join =
181 CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left));
182 reorder_output_after_swap(
183 Arc::new(new_join),
184 &self.left.schema(),
185 &self.right.schema(),
186 )
187 }
188}
189
190async fn load_left_input(
192 left: Arc<dyn ExecutionPlan>,
193 context: Arc<TaskContext>,
194 metrics: BuildProbeJoinMetrics,
195 reservation: MemoryReservation,
196) -> Result<JoinLeftData> {
197 let left_schema = left.schema();
199 let merge = if left.output_partitioning().partition_count() != 1 {
200 Arc::new(CoalescePartitionsExec::new(left))
201 } else {
202 left
203 };
204 let stream = merge.execute(0, context)?;
205
206 let (batches, _metrics, reservation) = stream
208 .try_fold(
209 (Vec::new(), metrics, reservation),
210 |(mut batches, metrics, mut reservation), batch| async {
211 let batch_size = batch.get_array_memory_size();
212 reservation.try_grow(batch_size)?;
214 metrics.build_mem_used.add(batch_size);
216 metrics.build_input_batches.add(1);
217 metrics.build_input_rows.add(batch.num_rows());
218 batches.push(batch);
220 Ok((batches, metrics, reservation))
221 },
222 )
223 .await?;
224
225 let merged_batch = concat_batches(&left_schema, &batches)?;
226
227 Ok(JoinLeftData {
228 merged_batch,
229 _reservation: reservation,
230 })
231}
232
233impl DisplayAs for CrossJoinExec {
234 fn fmt_as(
235 &self,
236 t: DisplayFormatType,
237 f: &mut std::fmt::Formatter,
238 ) -> std::fmt::Result {
239 match t {
240 DisplayFormatType::Default | DisplayFormatType::Verbose => {
241 write!(f, "CrossJoinExec")
242 }
243 }
244 }
245}
246
247impl ExecutionPlan for CrossJoinExec {
248 fn name(&self) -> &'static str {
249 "CrossJoinExec"
250 }
251
252 fn as_any(&self) -> &dyn Any {
253 self
254 }
255
256 fn properties(&self) -> &PlanProperties {
257 &self.cache
258 }
259
260 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
261 vec![&self.left, &self.right]
262 }
263
264 fn metrics(&self) -> Option<MetricsSet> {
265 Some(self.metrics.clone_inner())
266 }
267
268 fn with_new_children(
269 self: Arc<Self>,
270 children: Vec<Arc<dyn ExecutionPlan>>,
271 ) -> Result<Arc<dyn ExecutionPlan>> {
272 Ok(Arc::new(CrossJoinExec::new(
273 Arc::clone(&children[0]),
274 Arc::clone(&children[1]),
275 )))
276 }
277
278 fn required_input_distribution(&self) -> Vec<Distribution> {
279 vec![
280 Distribution::SinglePartition,
281 Distribution::UnspecifiedDistribution,
282 ]
283 }
284
285 fn execute(
286 &self,
287 partition: usize,
288 context: Arc<TaskContext>,
289 ) -> Result<SendableRecordBatchStream> {
290 let stream = self.right.execute(partition, Arc::clone(&context))?;
291
292 let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
293
294 let reservation =
296 MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
297
298 let batch_size = context.session_config().batch_size();
299 let enforce_batch_size_in_joins =
300 context.session_config().enforce_batch_size_in_joins();
301
302 let left_fut = self.left_fut.once(|| {
303 load_left_input(
304 Arc::clone(&self.left),
305 context,
306 join_metrics.clone(),
307 reservation,
308 )
309 });
310
311 if enforce_batch_size_in_joins {
312 Ok(Box::pin(CrossJoinStream {
313 schema: Arc::clone(&self.schema),
314 left_fut,
315 right: stream,
316 left_index: 0,
317 join_metrics,
318 state: CrossJoinStreamState::WaitBuildSide,
319 left_data: RecordBatch::new_empty(self.left().schema()),
320 batch_transformer: BatchSplitter::new(batch_size),
321 }))
322 } else {
323 Ok(Box::pin(CrossJoinStream {
324 schema: Arc::clone(&self.schema),
325 left_fut,
326 right: stream,
327 left_index: 0,
328 join_metrics,
329 state: CrossJoinStreamState::WaitBuildSide,
330 left_data: RecordBatch::new_empty(self.left().schema()),
331 batch_transformer: NoopBatchTransformer::new(),
332 }))
333 }
334 }
335
336 fn statistics(&self) -> Result<Statistics> {
337 Ok(stats_cartesian_product(
338 self.left.statistics()?,
339 self.right.statistics()?,
340 ))
341 }
342
343 fn try_swapping_with_projection(
347 &self,
348 projection: &ProjectionExec,
349 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
350 let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
352 else {
353 return Ok(None);
354 };
355
356 let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
357 self.left().schema().fields().len(),
358 &projection_as_columns,
359 );
360
361 if !join_allows_pushdown(
362 &projection_as_columns,
363 &self.schema(),
364 far_right_left_col_ind,
365 far_left_right_col_ind,
366 ) {
367 return Ok(None);
368 }
369
370 let (new_left, new_right) = new_join_children(
371 &projection_as_columns,
372 far_right_left_col_ind,
373 far_left_right_col_ind,
374 self.left(),
375 self.right(),
376 )?;
377
378 Ok(Some(Arc::new(CrossJoinExec::new(
379 Arc::new(new_left),
380 Arc::new(new_right),
381 ))))
382 }
383}
384
385fn stats_cartesian_product(
387 left_stats: Statistics,
388 right_stats: Statistics,
389) -> Statistics {
390 let left_row_count = left_stats.num_rows;
391 let right_row_count = right_stats.num_rows;
392
393 let num_rows = left_row_count.multiply(&right_row_count);
395 let total_byte_size = left_stats
397 .total_byte_size
398 .multiply(&right_stats.total_byte_size)
399 .multiply(&Precision::Exact(2));
400
401 let left_col_stats = left_stats.column_statistics;
402 let right_col_stats = right_stats.column_statistics;
403
404 let cross_join_stats = left_col_stats
407 .into_iter()
408 .map(|s| ColumnStatistics {
409 null_count: s.null_count.multiply(&right_row_count),
410 distinct_count: s.distinct_count,
411 min_value: s.min_value,
412 max_value: s.max_value,
413 sum_value: s
414 .sum_value
415 .get_value()
416 .and_then(|v| {
418 Precision::<ScalarValue>::from(right_row_count)
419 .cast_to(&v.data_type())
420 .ok()
421 })
422 .map(|row_count| s.sum_value.multiply(&row_count))
423 .unwrap_or(Precision::Absent),
424 })
425 .chain(right_col_stats.into_iter().map(|s| {
426 ColumnStatistics {
427 null_count: s.null_count.multiply(&left_row_count),
428 distinct_count: s.distinct_count,
429 min_value: s.min_value,
430 max_value: s.max_value,
431 sum_value: s
432 .sum_value
433 .get_value()
434 .and_then(|v| {
436 Precision::<ScalarValue>::from(left_row_count)
437 .cast_to(&v.data_type())
438 .ok()
439 })
440 .map(|row_count| s.sum_value.multiply(&row_count))
441 .unwrap_or(Precision::Absent),
442 }
443 }))
444 .collect();
445
446 Statistics {
447 num_rows,
448 total_byte_size,
449 column_statistics: cross_join_stats,
450 }
451}
452
453struct CrossJoinStream<T> {
455 schema: Arc<Schema>,
457 left_fut: OnceFut<JoinLeftData>,
459 right: SendableRecordBatchStream,
461 left_index: usize,
463 join_metrics: BuildProbeJoinMetrics,
465 state: CrossJoinStreamState,
467 left_data: RecordBatch,
469 batch_transformer: T,
471}
472
473impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for CrossJoinStream<T> {
474 fn schema(&self) -> SchemaRef {
475 Arc::clone(&self.schema)
476 }
477}
478
479enum CrossJoinStreamState {
481 WaitBuildSide,
482 FetchProbeBatch,
483 BuildBatches(RecordBatch),
485}
486
487impl CrossJoinStreamState {
488 fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
491 match self {
492 CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
493 _ => internal_err!("Expected RecordBatch in BuildBatches state"),
494 }
495 }
496}
497
498fn build_batch(
499 left_index: usize,
500 batch: &RecordBatch,
501 left_data: &RecordBatch,
502 schema: &Schema,
503) -> Result<RecordBatch> {
504 let arrays = left_data
506 .columns()
507 .iter()
508 .map(|arr| {
509 let scalar = ScalarValue::try_from_array(arr, left_index)?;
510 scalar.to_array_of_size(batch.num_rows())
511 })
512 .collect::<Result<Vec<_>>>()?;
513
514 RecordBatch::try_new_with_options(
515 Arc::new(schema.clone()),
516 arrays
517 .iter()
518 .chain(batch.columns().iter())
519 .cloned()
520 .collect(),
521 &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
522 )
523 .map_err(Into::into)
524}
525
526#[async_trait]
527impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
528 type Item = Result<RecordBatch>;
529
530 fn poll_next(
531 mut self: std::pin::Pin<&mut Self>,
532 cx: &mut std::task::Context<'_>,
533 ) -> Poll<Option<Self::Item>> {
534 self.poll_next_impl(cx)
535 }
536}
537
538impl<T: BatchTransformer> CrossJoinStream<T> {
539 fn poll_next_impl(
542 &mut self,
543 cx: &mut std::task::Context<'_>,
544 ) -> Poll<Option<Result<RecordBatch>>> {
545 loop {
546 return match self.state {
547 CrossJoinStreamState::WaitBuildSide => {
548 handle_state!(ready!(self.collect_build_side(cx)))
549 }
550 CrossJoinStreamState::FetchProbeBatch => {
551 handle_state!(ready!(self.fetch_probe_batch(cx)))
552 }
553 CrossJoinStreamState::BuildBatches(_) => {
554 handle_state!(self.build_batches())
555 }
556 };
557 }
558 }
559
560 fn collect_build_side(
563 &mut self,
564 cx: &mut std::task::Context<'_>,
565 ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
566 let build_timer = self.join_metrics.build_time.timer();
567 let left_data = match ready!(self.left_fut.get(cx)) {
568 Ok(left_data) => left_data,
569 Err(e) => return Poll::Ready(Err(e)),
570 };
571 build_timer.done();
572
573 let left_data = left_data.merged_batch.clone();
574 let result = if left_data.num_rows() == 0 {
575 StatefulStreamResult::Ready(None)
576 } else {
577 self.left_data = left_data;
578 self.state = CrossJoinStreamState::FetchProbeBatch;
579 StatefulStreamResult::Continue
580 };
581 Poll::Ready(Ok(result))
582 }
583
584 fn fetch_probe_batch(
587 &mut self,
588 cx: &mut std::task::Context<'_>,
589 ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
590 self.left_index = 0;
591 let right_data = match ready!(self.right.poll_next_unpin(cx)) {
592 Some(Ok(right_data)) => right_data,
593 Some(Err(e)) => return Poll::Ready(Err(e)),
594 None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
595 };
596 self.join_metrics.input_batches.add(1);
597 self.join_metrics.input_rows.add(right_data.num_rows());
598
599 self.state = CrossJoinStreamState::BuildBatches(right_data);
600 Poll::Ready(Ok(StatefulStreamResult::Continue))
601 }
602
603 fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
606 let right_batch = self.state.try_as_record_batch()?;
607 if self.left_index < self.left_data.num_rows() {
608 match self.batch_transformer.next() {
609 None => {
610 let join_timer = self.join_metrics.join_time.timer();
611 let result = build_batch(
612 self.left_index,
613 right_batch,
614 &self.left_data,
615 &self.schema,
616 );
617 join_timer.done();
618
619 self.batch_transformer.set_batch(result?);
620 }
621 Some((batch, last)) => {
622 if last {
623 self.left_index += 1;
624 }
625
626 self.join_metrics.output_batches.add(1);
627 self.join_metrics.output_rows.add(batch.num_rows());
628 return Ok(StatefulStreamResult::Ready(Some(batch)));
629 }
630 }
631 } else {
632 self.state = CrossJoinStreamState::FetchProbeBatch;
633 }
634 Ok(StatefulStreamResult::Continue)
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use crate::common;
642 use crate::test::build_table_scan_i32;
643
644 use datafusion_common::{assert_batches_sorted_eq, assert_contains};
645 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
646
647 async fn join_collect(
648 left: Arc<dyn ExecutionPlan>,
649 right: Arc<dyn ExecutionPlan>,
650 context: Arc<TaskContext>,
651 ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
652 let join = CrossJoinExec::new(left, right);
653 let columns_header = columns(&join.schema());
654
655 let stream = join.execute(0, context)?;
656 let batches = common::collect(stream).await?;
657
658 Ok((columns_header, batches))
659 }
660
661 #[tokio::test]
662 async fn test_stats_cartesian_product() {
663 let left_row_count = 11;
664 let left_bytes = 23;
665 let right_row_count = 7;
666 let right_bytes = 27;
667
668 let left = Statistics {
669 num_rows: Precision::Exact(left_row_count),
670 total_byte_size: Precision::Exact(left_bytes),
671 column_statistics: vec![
672 ColumnStatistics {
673 distinct_count: Precision::Exact(5),
674 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
675 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
676 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
677 null_count: Precision::Exact(0),
678 },
679 ColumnStatistics {
680 distinct_count: Precision::Exact(1),
681 max_value: Precision::Exact(ScalarValue::from("x")),
682 min_value: Precision::Exact(ScalarValue::from("a")),
683 sum_value: Precision::Absent,
684 null_count: Precision::Exact(3),
685 },
686 ],
687 };
688
689 let right = Statistics {
690 num_rows: Precision::Exact(right_row_count),
691 total_byte_size: Precision::Exact(right_bytes),
692 column_statistics: vec![ColumnStatistics {
693 distinct_count: Precision::Exact(3),
694 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
695 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
696 sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
697 null_count: Precision::Exact(2),
698 }],
699 };
700
701 let result = stats_cartesian_product(left, right);
702
703 let expected = Statistics {
704 num_rows: Precision::Exact(left_row_count * right_row_count),
705 total_byte_size: Precision::Exact(2 * left_bytes * right_bytes),
706 column_statistics: vec![
707 ColumnStatistics {
708 distinct_count: Precision::Exact(5),
709 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
710 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
711 sum_value: Precision::Exact(ScalarValue::Int64(Some(
712 42 * right_row_count as i64,
713 ))),
714 null_count: Precision::Exact(0),
715 },
716 ColumnStatistics {
717 distinct_count: Precision::Exact(1),
718 max_value: Precision::Exact(ScalarValue::from("x")),
719 min_value: Precision::Exact(ScalarValue::from("a")),
720 sum_value: Precision::Absent,
721 null_count: Precision::Exact(3 * right_row_count),
722 },
723 ColumnStatistics {
724 distinct_count: Precision::Exact(3),
725 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
726 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
727 sum_value: Precision::Exact(ScalarValue::Int64(Some(
728 20 * left_row_count as i64,
729 ))),
730 null_count: Precision::Exact(2 * left_row_count),
731 },
732 ],
733 };
734
735 assert_eq!(result, expected);
736 }
737
738 #[tokio::test]
739 async fn test_stats_cartesian_product_with_unknown_size() {
740 let left_row_count = 11;
741
742 let left = Statistics {
743 num_rows: Precision::Exact(left_row_count),
744 total_byte_size: Precision::Exact(23),
745 column_statistics: vec![
746 ColumnStatistics {
747 distinct_count: Precision::Exact(5),
748 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
749 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
750 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
751 null_count: Precision::Exact(0),
752 },
753 ColumnStatistics {
754 distinct_count: Precision::Exact(1),
755 max_value: Precision::Exact(ScalarValue::from("x")),
756 min_value: Precision::Exact(ScalarValue::from("a")),
757 sum_value: Precision::Absent,
758 null_count: Precision::Exact(3),
759 },
760 ],
761 };
762
763 let right = Statistics {
764 num_rows: Precision::Absent,
765 total_byte_size: Precision::Absent,
766 column_statistics: vec![ColumnStatistics {
767 distinct_count: Precision::Exact(3),
768 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
769 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
770 sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
771 null_count: Precision::Exact(2),
772 }],
773 };
774
775 let result = stats_cartesian_product(left, right);
776
777 let expected = Statistics {
778 num_rows: Precision::Absent,
779 total_byte_size: Precision::Absent,
780 column_statistics: vec![
781 ColumnStatistics {
782 distinct_count: Precision::Exact(5),
783 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
784 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
785 sum_value: Precision::Absent, null_count: Precision::Absent, },
788 ColumnStatistics {
789 distinct_count: Precision::Exact(1),
790 max_value: Precision::Exact(ScalarValue::from("x")),
791 min_value: Precision::Exact(ScalarValue::from("a")),
792 sum_value: Precision::Absent,
793 null_count: Precision::Absent, },
795 ColumnStatistics {
796 distinct_count: Precision::Exact(3),
797 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
798 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
799 sum_value: Precision::Exact(ScalarValue::Int64(Some(
800 20 * left_row_count as i64,
801 ))),
802 null_count: Precision::Exact(2 * left_row_count),
803 },
804 ],
805 };
806
807 assert_eq!(result, expected);
808 }
809
810 #[tokio::test]
811 async fn test_join() -> Result<()> {
812 let task_ctx = Arc::new(TaskContext::default());
813
814 let left = build_table_scan_i32(
815 ("a1", &vec![1, 2, 3]),
816 ("b1", &vec![4, 5, 6]),
817 ("c1", &vec![7, 8, 9]),
818 );
819 let right = build_table_scan_i32(
820 ("a2", &vec![10, 11]),
821 ("b2", &vec![12, 13]),
822 ("c2", &vec![14, 15]),
823 );
824
825 let (columns, batches) = join_collect(left, right, task_ctx).await?;
826
827 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
828 let expected = [
829 "+----+----+----+----+----+----+",
830 "| a1 | b1 | c1 | a2 | b2 | c2 |",
831 "+----+----+----+----+----+----+",
832 "| 1 | 4 | 7 | 10 | 12 | 14 |",
833 "| 1 | 4 | 7 | 11 | 13 | 15 |",
834 "| 2 | 5 | 8 | 10 | 12 | 14 |",
835 "| 2 | 5 | 8 | 11 | 13 | 15 |",
836 "| 3 | 6 | 9 | 10 | 12 | 14 |",
837 "| 3 | 6 | 9 | 11 | 13 | 15 |",
838 "+----+----+----+----+----+----+",
839 ];
840
841 assert_batches_sorted_eq!(expected, &batches);
842
843 Ok(())
844 }
845
846 #[tokio::test]
847 async fn test_overallocation() -> Result<()> {
848 let runtime = RuntimeEnvBuilder::new()
849 .with_memory_limit(100, 1.0)
850 .build_arc()?;
851 let task_ctx = TaskContext::default().with_runtime(runtime);
852 let task_ctx = Arc::new(task_ctx);
853
854 let left = build_table_scan_i32(
855 ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
856 ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
857 ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
858 );
859 let right = build_table_scan_i32(
860 ("a2", &vec![10, 11]),
861 ("b2", &vec![12, 13]),
862 ("c2", &vec![14, 15]),
863 );
864
865 let err = join_collect(left, right, task_ctx).await.unwrap_err();
866
867 assert_contains!(
868 err.to_string(),
869 "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec"
870 );
871
872 Ok(())
873 }
874
875 fn columns(schema: &Schema) -> Vec<String> {
877 schema.fields().iter().map(|f| f.name().clone()).collect()
878 }
879}