1use std::any::Any;
24use std::cmp::Ordering;
25use std::collections::{HashMap, VecDeque};
26use std::fmt::Formatter;
27use std::fs::File;
28use std::io::BufReader;
29use std::mem::size_of;
30use std::ops::Range;
31use std::pin::Pin;
32use std::sync::atomic::AtomicUsize;
33use std::sync::atomic::Ordering::Relaxed;
34use std::sync::Arc;
35use std::task::{Context, Poll};
36
37use crate::execution_plan::{boundedness_from_children, EmissionType};
38use crate::expressions::PhysicalSortExpr;
39use crate::joins::utils::{
40 build_join_schema, check_join_is_valid, estimate_join_statistics,
41 reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn,
42 JoinOnRef,
43};
44use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
45use crate::projection::{
46 join_allows_pushdown, join_table_borders, new_join_children,
47 physical_to_column_exprs, update_join_on, ProjectionExec,
48};
49use crate::spill::spill_record_batches;
50use crate::{
51 metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
52 ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream,
53 SendableRecordBatchStream, Statistics,
54};
55
56use arrow::array::{types::UInt64Type, *};
57use arrow::compute::{
58 self, concat_batches, filter_record_batch, is_not_null, take, SortOptions,
59};
60use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
61use arrow::error::ArrowError;
62use arrow::ipc::reader::StreamReader;
63use datafusion_common::{
64 exec_err, internal_err, not_impl_err, plan_err, DataFusionError, HashSet, JoinSide,
65 JoinType, Result,
66};
67use datafusion_execution::disk_manager::RefCountedTempFile;
68use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
69use datafusion_execution::runtime_env::RuntimeEnv;
70use datafusion_execution::TaskContext;
71use datafusion_physical_expr::equivalence::join_equivalence_properties;
72use datafusion_physical_expr::PhysicalExprRef;
73use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
74
75use futures::{Stream, StreamExt};
76
77#[derive(Debug, Clone)]
126pub struct SortMergeJoinExec {
127 pub left: Arc<dyn ExecutionPlan>,
129 pub right: Arc<dyn ExecutionPlan>,
131 pub on: JoinOn,
133 pub filter: Option<JoinFilter>,
135 pub join_type: JoinType,
137 schema: SchemaRef,
139 metrics: ExecutionPlanMetricsSet,
141 left_sort_exprs: LexOrdering,
143 right_sort_exprs: LexOrdering,
145 pub sort_options: Vec<SortOptions>,
147 pub null_equals_null: bool,
149 cache: PlanProperties,
151}
152
153impl SortMergeJoinExec {
154 pub fn try_new(
159 left: Arc<dyn ExecutionPlan>,
160 right: Arc<dyn ExecutionPlan>,
161 on: JoinOn,
162 filter: Option<JoinFilter>,
163 join_type: JoinType,
164 sort_options: Vec<SortOptions>,
165 null_equals_null: bool,
166 ) -> Result<Self> {
167 let left_schema = left.schema();
168 let right_schema = right.schema();
169
170 if join_type == JoinType::RightSemi {
171 return not_impl_err!(
172 "SortMergeJoinExec does not support JoinType::RightSemi"
173 );
174 }
175
176 check_join_is_valid(&left_schema, &right_schema, &on)?;
177 if sort_options.len() != on.len() {
178 return plan_err!(
179 "Expected number of sort options: {}, actual: {}",
180 on.len(),
181 sort_options.len()
182 );
183 }
184
185 let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
186 .iter()
187 .zip(sort_options.iter())
188 .map(|((l, r), sort_op)| {
189 let left = PhysicalSortExpr {
190 expr: Arc::clone(l),
191 options: *sort_op,
192 };
193 let right = PhysicalSortExpr {
194 expr: Arc::clone(r),
195 options: *sort_op,
196 };
197 (left, right)
198 })
199 .unzip();
200
201 let schema =
202 Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
203 let cache =
204 Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on);
205 Ok(Self {
206 left,
207 right,
208 on,
209 filter,
210 join_type,
211 schema,
212 metrics: ExecutionPlanMetricsSet::new(),
213 left_sort_exprs: LexOrdering::new(left_sort_exprs),
214 right_sort_exprs: LexOrdering::new(right_sort_exprs),
215 sort_options,
216 null_equals_null,
217 cache,
218 })
219 }
220
221 pub fn probe_side(join_type: &JoinType) -> JoinSide {
224 match join_type {
227 JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
228 JoinSide::Right
229 }
230 JoinType::Inner
231 | JoinType::Left
232 | JoinType::Full
233 | JoinType::LeftAnti
234 | JoinType::LeftSemi
235 | JoinType::LeftMark => JoinSide::Left,
236 }
237 }
238
239 fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
241 match join_type {
242 JoinType::Inner => vec![true, false],
243 JoinType::Left
244 | JoinType::LeftSemi
245 | JoinType::LeftAnti
246 | JoinType::LeftMark => vec![true, false],
247 JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
248 vec![false, true]
249 }
250 _ => vec![false, false],
251 }
252 }
253
254 pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
256 &self.on
257 }
258
259 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
261 &self.right
262 }
263
264 pub fn join_type(&self) -> JoinType {
266 self.join_type
267 }
268
269 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
271 &self.left
272 }
273
274 pub fn filter(&self) -> &Option<JoinFilter> {
276 &self.filter
277 }
278
279 pub fn sort_options(&self) -> &[SortOptions] {
281 &self.sort_options
282 }
283
284 pub fn null_equals_null(&self) -> bool {
286 self.null_equals_null
287 }
288
289 fn compute_properties(
291 left: &Arc<dyn ExecutionPlan>,
292 right: &Arc<dyn ExecutionPlan>,
293 schema: SchemaRef,
294 join_type: JoinType,
295 join_on: JoinOnRef,
296 ) -> PlanProperties {
297 let eq_properties = join_equivalence_properties(
299 left.equivalence_properties().clone(),
300 right.equivalence_properties().clone(),
301 &join_type,
302 schema,
303 &Self::maintains_input_order(join_type),
304 Some(Self::probe_side(&join_type)),
305 join_on,
306 );
307
308 let output_partitioning =
309 symmetric_join_output_partitioning(left, right, &join_type);
310
311 PlanProperties::new(
312 eq_properties,
313 output_partitioning,
314 EmissionType::Incremental,
315 boundedness_from_children([left, right]),
316 )
317 }
318
319 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
320 let left = self.left();
321 let right = self.right();
322 let new_join = SortMergeJoinExec::try_new(
323 Arc::clone(right),
324 Arc::clone(left),
325 self.on()
326 .iter()
327 .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
328 .collect::<Vec<_>>(),
329 self.filter().as_ref().map(JoinFilter::swap),
330 self.join_type().swap(),
331 self.sort_options.clone(),
332 self.null_equals_null,
333 )?;
334
335 if matches!(
338 self.join_type(),
339 JoinType::LeftSemi
340 | JoinType::RightSemi
341 | JoinType::LeftAnti
342 | JoinType::RightAnti
343 ) {
344 Ok(Arc::new(new_join))
345 } else {
346 reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
347 }
348 }
349}
350
351impl DisplayAs for SortMergeJoinExec {
352 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
353 match t {
354 DisplayFormatType::Default | DisplayFormatType::Verbose => {
355 let on = self
356 .on
357 .iter()
358 .map(|(c1, c2)| format!("({}, {})", c1, c2))
359 .collect::<Vec<String>>()
360 .join(", ");
361 write!(
362 f,
363 "SortMergeJoin: join_type={:?}, on=[{}]{}",
364 self.join_type,
365 on,
366 self.filter.as_ref().map_or("".to_string(), |f| format!(
367 ", filter={}",
368 f.expression()
369 ))
370 )
371 }
372 }
373 }
374}
375
376impl ExecutionPlan for SortMergeJoinExec {
377 fn name(&self) -> &'static str {
378 "SortMergeJoinExec"
379 }
380
381 fn as_any(&self) -> &dyn Any {
382 self
383 }
384
385 fn properties(&self) -> &PlanProperties {
386 &self.cache
387 }
388
389 fn required_input_distribution(&self) -> Vec<Distribution> {
390 let (left_expr, right_expr) = self
391 .on
392 .iter()
393 .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
394 .unzip();
395 vec![
396 Distribution::HashPartitioned(left_expr),
397 Distribution::HashPartitioned(right_expr),
398 ]
399 }
400
401 fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
402 vec![
403 Some(LexRequirement::from(self.left_sort_exprs.clone())),
404 Some(LexRequirement::from(self.right_sort_exprs.clone())),
405 ]
406 }
407
408 fn maintains_input_order(&self) -> Vec<bool> {
409 Self::maintains_input_order(self.join_type)
410 }
411
412 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
413 vec![&self.left, &self.right]
414 }
415
416 fn with_new_children(
417 self: Arc<Self>,
418 children: Vec<Arc<dyn ExecutionPlan>>,
419 ) -> Result<Arc<dyn ExecutionPlan>> {
420 match &children[..] {
421 [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
422 Arc::clone(left),
423 Arc::clone(right),
424 self.on.clone(),
425 self.filter.clone(),
426 self.join_type,
427 self.sort_options.clone(),
428 self.null_equals_null,
429 )?)),
430 _ => internal_err!("SortMergeJoin wrong number of children"),
431 }
432 }
433
434 fn execute(
435 &self,
436 partition: usize,
437 context: Arc<TaskContext>,
438 ) -> Result<SendableRecordBatchStream> {
439 let left_partitions = self.left.output_partitioning().partition_count();
440 let right_partitions = self.right.output_partitioning().partition_count();
441 if left_partitions != right_partitions {
442 return internal_err!(
443 "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
444 consider using RepartitionExec"
445 );
446 }
447 let (on_left, on_right) = self.on.iter().cloned().unzip();
448 let (streamed, buffered, on_streamed, on_buffered) =
449 if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
450 (
451 Arc::clone(&self.left),
452 Arc::clone(&self.right),
453 on_left,
454 on_right,
455 )
456 } else {
457 (
458 Arc::clone(&self.right),
459 Arc::clone(&self.left),
460 on_right,
461 on_left,
462 )
463 };
464
465 let streamed = streamed.execute(partition, Arc::clone(&context))?;
467 let buffered = buffered.execute(partition, Arc::clone(&context))?;
468
469 let batch_size = context.session_config().batch_size();
471
472 let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
474 .register(context.memory_pool());
475
476 Ok(Box::pin(SortMergeJoinStream::try_new(
478 Arc::clone(&self.schema),
479 self.sort_options.clone(),
480 self.null_equals_null,
481 streamed,
482 buffered,
483 on_streamed,
484 on_buffered,
485 self.filter.clone(),
486 self.join_type,
487 batch_size,
488 SortMergeJoinMetrics::new(partition, &self.metrics),
489 reservation,
490 context.runtime_env(),
491 )?))
492 }
493
494 fn metrics(&self) -> Option<MetricsSet> {
495 Some(self.metrics.clone_inner())
496 }
497
498 fn statistics(&self) -> Result<Statistics> {
499 estimate_join_statistics(
503 Arc::clone(&self.left),
504 Arc::clone(&self.right),
505 self.on.clone(),
506 &self.join_type,
507 &self.schema,
508 )
509 }
510
511 fn try_swapping_with_projection(
515 &self,
516 projection: &ProjectionExec,
517 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
518 let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
520 else {
521 return Ok(None);
522 };
523
524 let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
525 self.left().schema().fields().len(),
526 &projection_as_columns,
527 );
528
529 if !join_allows_pushdown(
530 &projection_as_columns,
531 &self.schema(),
532 far_right_left_col_ind,
533 far_left_right_col_ind,
534 ) {
535 return Ok(None);
536 }
537
538 let Some(new_on) = update_join_on(
539 &projection_as_columns[0..=far_right_left_col_ind as _],
540 &projection_as_columns[far_left_right_col_ind as _..],
541 self.on(),
542 self.left().schema().fields().len(),
543 ) else {
544 return Ok(None);
545 };
546
547 let (new_left, new_right) = new_join_children(
548 &projection_as_columns,
549 far_right_left_col_ind,
550 far_left_right_col_ind,
551 self.children()[0],
552 self.children()[1],
553 )?;
554
555 Ok(Some(Arc::new(SortMergeJoinExec::try_new(
556 Arc::new(new_left),
557 Arc::new(new_right),
558 new_on,
559 self.filter.clone(),
560 self.join_type,
561 self.sort_options.clone(),
562 self.null_equals_null,
563 )?)))
564 }
565}
566
567#[allow(dead_code)]
569struct SortMergeJoinMetrics {
570 join_time: metrics::Time,
572 input_batches: Count,
574 input_rows: Count,
576 output_batches: Count,
578 output_rows: Count,
580 peak_mem_used: metrics::Gauge,
583 spill_count: Count,
585 spilled_bytes: Count,
587 spilled_rows: Count,
589}
590
591impl SortMergeJoinMetrics {
592 #[allow(dead_code)]
593 pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
594 let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
595 let input_batches =
596 MetricBuilder::new(metrics).counter("input_batches", partition);
597 let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
598 let output_batches =
599 MetricBuilder::new(metrics).counter("output_batches", partition);
600 let output_rows = MetricBuilder::new(metrics).output_rows(partition);
601 let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition);
602 let spill_count = MetricBuilder::new(metrics).spill_count(partition);
603 let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition);
604 let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition);
605
606 Self {
607 join_time,
608 input_batches,
609 input_rows,
610 output_batches,
611 output_rows,
612 peak_mem_used,
613 spill_count,
614 spilled_bytes,
615 spilled_rows,
616 }
617 }
618}
619
620#[derive(Debug, PartialEq, Eq)]
622enum SortMergeJoinState {
623 Init,
625 Polling,
627 JoinOutput,
629 Exhausted,
631}
632
633#[derive(Debug, PartialEq, Eq)]
635enum StreamedState {
636 Init,
638 Polling,
640 Ready,
642 Exhausted,
644}
645
646#[derive(Debug, PartialEq, Eq)]
648enum BufferedState {
649 Init,
651 PollingFirst,
653 PollingRest,
655 Ready,
657 Exhausted,
659}
660
661struct StreamedJoinedChunk {
663 buffered_batch_idx: Option<usize>,
665 streamed_indices: UInt64Builder,
667 buffered_indices: UInt64Builder,
670}
671
672struct StreamedBatch {
676 pub batch: RecordBatch,
678 pub idx: usize,
680 pub join_arrays: Vec<ArrayRef>,
683 pub output_indices: Vec<StreamedJoinedChunk>,
685 pub buffered_batch_idx: Option<usize>,
687 pub join_filter_matched_idxs: HashSet<u64>,
691}
692
693impl StreamedBatch {
694 fn new(batch: RecordBatch, on_column: &[Arc<dyn PhysicalExpr>]) -> Self {
695 let join_arrays = join_arrays(&batch, on_column);
696 StreamedBatch {
697 batch,
698 idx: 0,
699 join_arrays,
700 output_indices: vec![],
701 buffered_batch_idx: None,
702 join_filter_matched_idxs: HashSet::new(),
703 }
704 }
705
706 fn new_empty(schema: SchemaRef) -> Self {
707 StreamedBatch {
708 batch: RecordBatch::new_empty(schema),
709 idx: 0,
710 join_arrays: vec![],
711 output_indices: vec![],
712 buffered_batch_idx: None,
713 join_filter_matched_idxs: HashSet::new(),
714 }
715 }
716
717 fn append_output_pair(
720 &mut self,
721 buffered_batch_idx: Option<usize>,
722 buffered_idx: Option<usize>,
723 ) {
724 if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx
727 {
728 self.output_indices.push(StreamedJoinedChunk {
729 buffered_batch_idx,
730 streamed_indices: UInt64Builder::with_capacity(1),
731 buffered_indices: UInt64Builder::with_capacity(1),
732 });
733 self.buffered_batch_idx = buffered_batch_idx;
734 };
735 let current_chunk = self.output_indices.last_mut().unwrap();
736
737 current_chunk.streamed_indices.append_value(self.idx as u64);
739 if let Some(idx) = buffered_idx {
740 current_chunk.buffered_indices.append_value(idx as u64);
741 } else {
742 current_chunk.buffered_indices.append_null();
743 }
744 }
745}
746
747#[derive(Debug)]
749struct BufferedBatch {
750 pub batch: Option<RecordBatch>,
753 pub range: Range<usize>,
755 pub join_arrays: Vec<ArrayRef>,
757 pub null_joined: Vec<usize>,
759 pub size_estimation: usize,
761 pub join_filter_not_matched_map: HashMap<u64, bool>,
766 pub num_rows: usize,
770 pub spill_file: Option<RefCountedTempFile>,
774}
775
776impl BufferedBatch {
777 fn new(
778 batch: RecordBatch,
779 range: Range<usize>,
780 on_column: &[PhysicalExprRef],
781 ) -> Self {
782 let join_arrays = join_arrays(&batch, on_column);
783
784 let size_estimation = batch.get_array_memory_size()
791 + join_arrays
792 .iter()
793 .map(|arr| arr.get_array_memory_size())
794 .sum::<usize>()
795 + batch.num_rows().next_power_of_two() * size_of::<usize>()
796 + size_of::<Range<usize>>()
797 + size_of::<usize>();
798
799 let num_rows = batch.num_rows();
800 BufferedBatch {
801 batch: Some(batch),
802 range,
803 join_arrays,
804 null_joined: vec![],
805 size_estimation,
806 join_filter_not_matched_map: HashMap::new(),
807 num_rows,
808 spill_file: None,
809 }
810 }
811}
812
813struct SortMergeJoinStream {
816 pub state: SortMergeJoinState,
818 pub schema: SchemaRef,
820 pub sort_options: Vec<SortOptions>,
822 pub null_equals_null: bool,
824 pub streamed_schema: SchemaRef,
826 pub buffered_schema: SchemaRef,
828 pub streamed: SendableRecordBatchStream,
830 pub buffered: SendableRecordBatchStream,
832 pub streamed_batch: StreamedBatch,
834 pub buffered_data: BufferedData,
836 pub streamed_joined: bool,
838 pub buffered_joined: bool,
840 pub streamed_state: StreamedState,
842 pub buffered_state: BufferedState,
844 pub current_ordering: Ordering,
846 pub on_streamed: Vec<PhysicalExprRef>,
848 pub on_buffered: Vec<PhysicalExprRef>,
850 pub filter: Option<JoinFilter>,
852 pub staging_output_record_batches: JoinedRecordBatches,
854 pub output: RecordBatch,
857 pub output_size: usize,
861 pub batch_size: usize,
863 pub join_type: JoinType,
865 pub join_metrics: SortMergeJoinMetrics,
867 pub reservation: MemoryReservation,
869 pub runtime_env: Arc<RuntimeEnv>,
871 pub streamed_batch_counter: AtomicUsize,
873}
874
875struct JoinedRecordBatches {
877 pub batches: Vec<RecordBatch>,
879 pub filter_mask: BooleanBuilder,
881 pub row_indices: UInt64Builder,
883 pub batch_ids: Vec<usize>,
887}
888
889impl JoinedRecordBatches {
890 fn clear(&mut self) {
891 self.batches.clear();
892 self.batch_ids.clear();
893 self.filter_mask = BooleanBuilder::new();
894 self.row_indices = UInt64Builder::new();
895 }
896}
897impl RecordBatchStream for SortMergeJoinStream {
898 fn schema(&self) -> SchemaRef {
899 Arc::clone(&self.schema)
900 }
901}
902
903#[inline(always)]
908fn last_index_for_row(
909 row_index: usize,
910 indices: &UInt64Array,
911 batch_ids: &[usize],
912 indices_len: usize,
913) -> bool {
914 row_index == indices_len - 1
915 || batch_ids[row_index] != batch_ids[row_index + 1]
916 || indices.value(row_index) != indices.value(row_index + 1)
917}
918
919fn get_corrected_filter_mask(
925 join_type: JoinType,
926 row_indices: &UInt64Array,
927 batch_ids: &[usize],
928 filter_mask: &BooleanArray,
929 expected_size: usize,
930) -> Option<BooleanArray> {
931 let row_indices_length = row_indices.len();
932 let mut corrected_mask: BooleanBuilder =
933 BooleanBuilder::with_capacity(row_indices_length);
934 let mut seen_true = false;
935
936 match join_type {
937 JoinType::Left | JoinType::Right => {
938 for i in 0..row_indices_length {
939 let last_index =
940 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
941 if filter_mask.value(i) {
942 seen_true = true;
943 corrected_mask.append_value(true);
944 } else if seen_true || !filter_mask.value(i) && !last_index {
945 corrected_mask.append_null(); } else {
947 corrected_mask.append_value(false); }
949
950 if last_index {
951 seen_true = false;
952 }
953 }
954
955 corrected_mask.append_n(expected_size - corrected_mask.len(), false);
957 Some(corrected_mask.finish())
958 }
959 JoinType::LeftMark => {
960 for i in 0..row_indices_length {
961 let last_index =
962 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
963 if filter_mask.value(i) && !seen_true {
964 seen_true = true;
965 corrected_mask.append_value(true);
966 } else if seen_true || !filter_mask.value(i) && !last_index {
967 corrected_mask.append_null(); } else {
969 corrected_mask.append_value(false); }
971
972 if last_index {
973 seen_true = false;
974 }
975 }
976
977 corrected_mask.append_n(expected_size - corrected_mask.len(), false);
979 Some(corrected_mask.finish())
980 }
981 JoinType::LeftSemi => {
982 for i in 0..row_indices_length {
983 let last_index =
984 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
985 if filter_mask.value(i) && !seen_true {
986 seen_true = true;
987 corrected_mask.append_value(true);
988 } else {
989 corrected_mask.append_null(); }
991
992 if last_index {
993 seen_true = false;
994 }
995 }
996
997 Some(corrected_mask.finish())
998 }
999 JoinType::LeftAnti | JoinType::RightAnti => {
1000 for i in 0..row_indices_length {
1001 let last_index =
1002 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
1003
1004 if filter_mask.value(i) {
1005 seen_true = true;
1006 }
1007
1008 if last_index {
1009 if !seen_true {
1010 corrected_mask.append_value(true);
1011 } else {
1012 corrected_mask.append_null();
1013 }
1014
1015 seen_true = false;
1016 } else {
1017 corrected_mask.append_null();
1018 }
1019 }
1020 corrected_mask.append_n(expected_size - corrected_mask.len(), true);
1023 Some(corrected_mask.finish())
1024 }
1025 JoinType::Full => {
1026 let mut mask: Vec<Option<bool>> = vec![Some(true); row_indices_length];
1027 let mut last_true_idx = 0;
1028 let mut first_row_idx = 0;
1029 let mut seen_false = false;
1030
1031 for i in 0..row_indices_length {
1032 let last_index =
1033 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
1034 let val = filter_mask.value(i);
1035 let is_null = filter_mask.is_null(i);
1036
1037 if val {
1038 if !seen_true {
1040 last_true_idx = i;
1041 }
1042 seen_true = true;
1043 }
1044
1045 if is_null || val {
1046 mask[i] = Some(true);
1047 } else if !is_null && !val && (seen_true || seen_false) {
1048 mask[i] = None;
1049 } else {
1050 mask[i] = Some(false);
1051 }
1052
1053 if !is_null && !val {
1054 seen_false = true;
1055 }
1056
1057 if last_index {
1058 if seen_true {
1061 #[allow(clippy::needless_range_loop)]
1062 for j in first_row_idx..last_true_idx {
1063 mask[j] = None;
1064 }
1065 }
1066
1067 seen_true = false;
1068 seen_false = false;
1069 last_true_idx = 0;
1070 first_row_idx = i + 1;
1071 }
1072 }
1073
1074 Some(BooleanArray::from(mask))
1075 }
1076 _ => None,
1078 }
1079}
1080
1081impl Stream for SortMergeJoinStream {
1082 type Item = Result<RecordBatch>;
1083
1084 fn poll_next(
1085 mut self: Pin<&mut Self>,
1086 cx: &mut Context<'_>,
1087 ) -> Poll<Option<Self::Item>> {
1088 let join_time = self.join_metrics.join_time.clone();
1089 let _timer = join_time.timer();
1090 loop {
1091 match &self.state {
1092 SortMergeJoinState::Init => {
1093 let streamed_exhausted =
1094 self.streamed_state == StreamedState::Exhausted;
1095 let buffered_exhausted =
1096 self.buffered_state == BufferedState::Exhausted;
1097 self.state = if streamed_exhausted && buffered_exhausted {
1098 SortMergeJoinState::Exhausted
1099 } else {
1100 match self.current_ordering {
1101 Ordering::Less | Ordering::Equal => {
1102 if !streamed_exhausted {
1103 if self.filter.is_some()
1104 && matches!(
1105 self.join_type,
1106 JoinType::Left
1107 | JoinType::LeftSemi
1108 | JoinType::LeftMark
1109 | JoinType::Right
1110 | JoinType::LeftAnti
1111 | JoinType::RightAnti
1112 | JoinType::Full
1113 )
1114 {
1115 self.freeze_all()?;
1116
1117 if !self
1120 .staging_output_record_batches
1121 .batches
1122 .is_empty()
1123 {
1124 let out_filtered_batch =
1126 self.filter_joined_batch()?;
1127
1128 self.output = concat_batches(
1130 &self.schema(),
1131 vec![&self.output, &out_filtered_batch],
1132 )?;
1133
1134 if self.output.num_rows() >= self.batch_size {
1136 let record_batch = std::mem::replace(
1137 &mut self.output,
1138 RecordBatch::new_empty(
1139 out_filtered_batch.schema(),
1140 ),
1141 );
1142 return Poll::Ready(Some(Ok(
1143 record_batch,
1144 )));
1145 }
1146 }
1147 }
1148
1149 self.streamed_joined = false;
1150 self.streamed_state = StreamedState::Init;
1151 }
1152 }
1153 Ordering::Greater => {
1154 if !buffered_exhausted {
1155 self.buffered_joined = false;
1156 self.buffered_state = BufferedState::Init;
1157 }
1158 }
1159 }
1160 SortMergeJoinState::Polling
1161 };
1162 }
1163 SortMergeJoinState::Polling => {
1164 if ![StreamedState::Exhausted, StreamedState::Ready]
1165 .contains(&self.streamed_state)
1166 {
1167 match self.poll_streamed_row(cx)? {
1168 Poll::Ready(_) => {}
1169 Poll::Pending => return Poll::Pending,
1170 }
1171 }
1172
1173 if ![BufferedState::Exhausted, BufferedState::Ready]
1174 .contains(&self.buffered_state)
1175 {
1176 match self.poll_buffered_batches(cx)? {
1177 Poll::Ready(_) => {}
1178 Poll::Pending => return Poll::Pending,
1179 }
1180 }
1181 let streamed_exhausted =
1182 self.streamed_state == StreamedState::Exhausted;
1183 let buffered_exhausted =
1184 self.buffered_state == BufferedState::Exhausted;
1185 if streamed_exhausted && buffered_exhausted {
1186 self.state = SortMergeJoinState::Exhausted;
1187 continue;
1188 }
1189 self.current_ordering = self.compare_streamed_buffered()?;
1190 self.state = SortMergeJoinState::JoinOutput;
1191 }
1192 SortMergeJoinState::JoinOutput => {
1193 self.join_partial()?;
1194
1195 if self.output_size < self.batch_size {
1196 if self.buffered_data.scanning_finished() {
1197 self.buffered_data.scanning_reset();
1198 self.state = SortMergeJoinState::Init;
1199 }
1200 } else {
1201 self.freeze_all()?;
1202 if !self.staging_output_record_batches.batches.is_empty() {
1203 let record_batch = self.output_record_batch_and_reset()?;
1204 if self.filter.is_some()
1210 && matches!(
1211 self.join_type,
1212 JoinType::Left
1213 | JoinType::LeftSemi
1214 | JoinType::Right
1215 | JoinType::LeftAnti
1216 | JoinType::RightAnti
1217 | JoinType::LeftMark
1218 | JoinType::Full
1219 )
1220 {
1221 continue;
1222 }
1223
1224 return Poll::Ready(Some(Ok(record_batch)));
1225 }
1226 return Poll::Pending;
1227 }
1228 }
1229 SortMergeJoinState::Exhausted => {
1230 self.freeze_all()?;
1231
1232 if !self.staging_output_record_batches.batches.is_empty() {
1234 if self.filter.is_some()
1235 && matches!(
1236 self.join_type,
1237 JoinType::Left
1238 | JoinType::LeftSemi
1239 | JoinType::Right
1240 | JoinType::LeftAnti
1241 | JoinType::RightAnti
1242 | JoinType::Full
1243 | JoinType::LeftMark
1244 )
1245 {
1246 let record_batch = self.filter_joined_batch()?;
1247 return Poll::Ready(Some(Ok(record_batch)));
1248 } else {
1249 let record_batch = self.output_record_batch_and_reset()?;
1250 return Poll::Ready(Some(Ok(record_batch)));
1251 }
1252 } else if self.output.num_rows() > 0 {
1253 let schema = self.output.schema();
1255 let record_batch = std::mem::replace(
1256 &mut self.output,
1257 RecordBatch::new_empty(schema),
1258 );
1259 return Poll::Ready(Some(Ok(record_batch)));
1260 } else {
1261 return Poll::Ready(None);
1262 }
1263 }
1264 }
1265 }
1266 }
1267}
1268
1269impl SortMergeJoinStream {
1270 #[allow(clippy::too_many_arguments)]
1271 pub fn try_new(
1272 schema: SchemaRef,
1273 sort_options: Vec<SortOptions>,
1274 null_equals_null: bool,
1275 streamed: SendableRecordBatchStream,
1276 buffered: SendableRecordBatchStream,
1277 on_streamed: Vec<Arc<dyn PhysicalExpr>>,
1278 on_buffered: Vec<Arc<dyn PhysicalExpr>>,
1279 filter: Option<JoinFilter>,
1280 join_type: JoinType,
1281 batch_size: usize,
1282 join_metrics: SortMergeJoinMetrics,
1283 reservation: MemoryReservation,
1284 runtime_env: Arc<RuntimeEnv>,
1285 ) -> Result<Self> {
1286 let streamed_schema = streamed.schema();
1287 let buffered_schema = buffered.schema();
1288 Ok(Self {
1289 state: SortMergeJoinState::Init,
1290 sort_options,
1291 null_equals_null,
1292 schema: Arc::clone(&schema),
1293 streamed_schema: Arc::clone(&streamed_schema),
1294 buffered_schema,
1295 streamed,
1296 buffered,
1297 streamed_batch: StreamedBatch::new_empty(streamed_schema),
1298 buffered_data: BufferedData::default(),
1299 streamed_joined: false,
1300 buffered_joined: false,
1301 streamed_state: StreamedState::Init,
1302 buffered_state: BufferedState::Init,
1303 current_ordering: Ordering::Equal,
1304 on_streamed,
1305 on_buffered,
1306 filter,
1307 staging_output_record_batches: JoinedRecordBatches {
1308 batches: vec![],
1309 filter_mask: BooleanBuilder::new(),
1310 row_indices: UInt64Builder::new(),
1311 batch_ids: vec![],
1312 },
1313 output: RecordBatch::new_empty(schema),
1314 output_size: 0,
1315 batch_size,
1316 join_type,
1317 join_metrics,
1318 reservation,
1319 runtime_env,
1320 streamed_batch_counter: AtomicUsize::new(0),
1321 })
1322 }
1323
1324 fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
1326 loop {
1327 match &self.streamed_state {
1328 StreamedState::Init => {
1329 if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows()
1330 {
1331 self.streamed_batch.idx += 1;
1332 self.streamed_state = StreamedState::Ready;
1333 return Poll::Ready(Some(Ok(())));
1334 } else {
1335 self.streamed_state = StreamedState::Polling;
1336 }
1337 }
1338 StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? {
1339 Poll::Pending => {
1340 return Poll::Pending;
1341 }
1342 Poll::Ready(None) => {
1343 self.streamed_state = StreamedState::Exhausted;
1344 }
1345 Poll::Ready(Some(batch)) => {
1346 if batch.num_rows() > 0 {
1347 self.freeze_streamed()?;
1348 self.join_metrics.input_batches.add(1);
1349 self.join_metrics.input_rows.add(batch.num_rows());
1350 self.streamed_batch =
1351 StreamedBatch::new(batch, &self.on_streamed);
1352 self.streamed_batch_counter
1355 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1356 self.streamed_state = StreamedState::Ready;
1357 }
1358 }
1359 },
1360 StreamedState::Ready => {
1361 return Poll::Ready(Some(Ok(())));
1362 }
1363 StreamedState::Exhausted => {
1364 return Poll::Ready(None);
1365 }
1366 }
1367 }
1368 }
1369
1370 fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> {
1371 if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some() {
1373 self.reservation
1374 .try_shrink(buffered_batch.size_estimation)?;
1375 }
1376
1377 Ok(())
1378 }
1379
1380 fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> {
1381 match self.reservation.try_grow(buffered_batch.size_estimation) {
1382 Ok(_) => {
1383 self.join_metrics
1384 .peak_mem_used
1385 .set_max(self.reservation.size());
1386 Ok(())
1387 }
1388 Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
1389 let spill_file = self
1391 .runtime_env
1392 .disk_manager
1393 .create_tmp_file("sort_merge_join_buffered_spill")?;
1394
1395 if let Some(batch) = buffered_batch.batch {
1396 spill_record_batches(
1397 &[batch],
1398 spill_file.path().into(),
1399 Arc::clone(&self.buffered_schema),
1400 )?;
1401 buffered_batch.spill_file = Some(spill_file);
1402 buffered_batch.batch = None;
1403
1404 self.join_metrics.spill_count.add(1);
1406 self.join_metrics
1407 .spilled_bytes
1408 .add(buffered_batch.size_estimation);
1409 self.join_metrics.spilled_rows.add(buffered_batch.num_rows);
1410 Ok(())
1411 } else {
1412 internal_err!("Buffered batch has empty body")
1413 }
1414 }
1415 Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()),
1416 }?;
1417
1418 self.buffered_data.batches.push_back(buffered_batch);
1419 Ok(())
1420 }
1421
1422 fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
1424 loop {
1425 match &self.buffered_state {
1426 BufferedState::Init => {
1427 while !self.buffered_data.batches.is_empty() {
1429 let head_batch = self.buffered_data.head_batch();
1430 if head_batch.range.end == head_batch.num_rows {
1432 self.freeze_dequeuing_buffered()?;
1433 if let Some(mut buffered_batch) =
1434 self.buffered_data.batches.pop_front()
1435 {
1436 self.produce_buffered_not_matched(&mut buffered_batch)?;
1437 self.free_reservation(buffered_batch)?;
1438 }
1439 } else {
1440 break;
1443 }
1444 }
1445 if self.buffered_data.batches.is_empty() {
1446 self.buffered_state = BufferedState::PollingFirst;
1447 } else {
1448 let tail_batch = self.buffered_data.tail_batch_mut();
1449 tail_batch.range.start = tail_batch.range.end;
1450 tail_batch.range.end += 1;
1451 self.buffered_state = BufferedState::PollingRest;
1452 }
1453 }
1454 BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? {
1455 Poll::Pending => {
1456 return Poll::Pending;
1457 }
1458 Poll::Ready(None) => {
1459 self.buffered_state = BufferedState::Exhausted;
1460 return Poll::Ready(None);
1461 }
1462 Poll::Ready(Some(batch)) => {
1463 self.join_metrics.input_batches.add(1);
1464 self.join_metrics.input_rows.add(batch.num_rows());
1465
1466 if batch.num_rows() > 0 {
1467 let buffered_batch =
1468 BufferedBatch::new(batch, 0..1, &self.on_buffered);
1469
1470 self.allocate_reservation(buffered_batch)?;
1471 self.buffered_state = BufferedState::PollingRest;
1472 }
1473 }
1474 },
1475 BufferedState::PollingRest => {
1476 if self.buffered_data.tail_batch().range.end
1477 < self.buffered_data.tail_batch().num_rows
1478 {
1479 while self.buffered_data.tail_batch().range.end
1480 < self.buffered_data.tail_batch().num_rows
1481 {
1482 if is_join_arrays_equal(
1483 &self.buffered_data.head_batch().join_arrays,
1484 self.buffered_data.head_batch().range.start,
1485 &self.buffered_data.tail_batch().join_arrays,
1486 self.buffered_data.tail_batch().range.end,
1487 )? {
1488 self.buffered_data.tail_batch_mut().range.end += 1;
1489 } else {
1490 self.buffered_state = BufferedState::Ready;
1491 return Poll::Ready(Some(Ok(())));
1492 }
1493 }
1494 } else {
1495 match self.buffered.poll_next_unpin(cx)? {
1496 Poll::Pending => {
1497 return Poll::Pending;
1498 }
1499 Poll::Ready(None) => {
1500 self.buffered_state = BufferedState::Ready;
1501 }
1502 Poll::Ready(Some(batch)) => {
1503 self.join_metrics.input_batches.add(1);
1505 self.join_metrics.input_rows.add(batch.num_rows());
1506 if batch.num_rows() > 0 {
1507 let buffered_batch = BufferedBatch::new(
1508 batch,
1509 0..0,
1510 &self.on_buffered,
1511 );
1512 self.allocate_reservation(buffered_batch)?;
1513 }
1514 }
1515 }
1516 }
1517 }
1518 BufferedState::Ready => {
1519 return Poll::Ready(Some(Ok(())));
1520 }
1521 BufferedState::Exhausted => {
1522 return Poll::Ready(None);
1523 }
1524 }
1525 }
1526 }
1527
1528 fn compare_streamed_buffered(&self) -> Result<Ordering> {
1530 if self.streamed_state == StreamedState::Exhausted {
1531 return Ok(Ordering::Greater);
1532 }
1533 if !self.buffered_data.has_buffered_rows() {
1534 return Ok(Ordering::Less);
1535 }
1536
1537 compare_join_arrays(
1538 &self.streamed_batch.join_arrays,
1539 self.streamed_batch.idx,
1540 &self.buffered_data.head_batch().join_arrays,
1541 self.buffered_data.head_batch().range.start,
1542 &self.sort_options,
1543 self.null_equals_null,
1544 )
1545 }
1546
1547 fn join_partial(&mut self) -> Result<()> {
1550 let mut join_streamed = false;
1552 let mut join_buffered = false;
1554 let mut mark_row_as_match = false;
1556
1557 match self.current_ordering {
1559 Ordering::Less => {
1560 if matches!(
1561 self.join_type,
1562 JoinType::Left
1563 | JoinType::Right
1564 | JoinType::RightSemi
1565 | JoinType::Full
1566 | JoinType::LeftAnti
1567 | JoinType::RightAnti
1568 | JoinType::LeftMark
1569 ) {
1570 join_streamed = !self.streamed_joined;
1571 }
1572 }
1573 Ordering::Equal => {
1574 if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftMark) {
1575 mark_row_as_match = matches!(self.join_type, JoinType::LeftMark);
1576 if self.filter.is_some() {
1581 join_streamed = !self
1582 .streamed_batch
1583 .join_filter_matched_idxs
1584 .contains(&(self.streamed_batch.idx as u64))
1585 && !self.streamed_joined;
1586 join_buffered = join_streamed;
1589 } else {
1590 join_streamed = !self.streamed_joined;
1591 }
1592 }
1593 if matches!(
1594 self.join_type,
1595 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
1596 ) {
1597 join_streamed = true;
1598 join_buffered = true;
1599 };
1600
1601 if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti)
1602 && self.filter.is_some()
1603 {
1604 join_streamed = !self.streamed_joined;
1605 join_buffered = join_streamed;
1606 }
1607 }
1608 Ordering::Greater => {
1609 if matches!(self.join_type, JoinType::Full) {
1610 join_buffered = !self.buffered_joined;
1611 };
1612 }
1613 }
1614 if !join_streamed && !join_buffered {
1615 self.buffered_data.scanning_finish();
1617 return Ok(());
1618 }
1619
1620 if join_buffered {
1621 while !self.buffered_data.scanning_finished()
1623 && self.output_size < self.batch_size
1624 {
1625 let scanning_idx = self.buffered_data.scanning_idx();
1626 if join_streamed {
1627 self.streamed_batch.append_output_pair(
1629 Some(self.buffered_data.scanning_batch_idx),
1630 Some(scanning_idx),
1631 );
1632 } else {
1633 self.buffered_data
1635 .scanning_batch_mut()
1636 .null_joined
1637 .push(scanning_idx);
1638 }
1639 self.output_size += 1;
1640 self.buffered_data.scanning_advance();
1641
1642 if self.buffered_data.scanning_finished() {
1643 self.streamed_joined = join_streamed;
1644 self.buffered_joined = true;
1645 }
1646 }
1647 } else {
1648 let scanning_batch_idx = if self.buffered_data.scanning_finished() {
1650 None
1651 } else {
1652 Some(self.buffered_data.scanning_batch_idx)
1653 };
1654 let scanning_idx = mark_row_as_match.then_some(0);
1656
1657 self.streamed_batch
1658 .append_output_pair(scanning_batch_idx, scanning_idx);
1659 self.output_size += 1;
1660 self.buffered_data.scanning_finish();
1661 self.streamed_joined = true;
1662 }
1663 Ok(())
1664 }
1665
1666 fn freeze_all(&mut self) -> Result<()> {
1667 self.freeze_buffered(self.buffered_data.batches.len())?;
1668 self.freeze_streamed()?;
1669 Ok(())
1670 }
1671
1672 fn freeze_dequeuing_buffered(&mut self) -> Result<()> {
1677 self.freeze_streamed()?;
1678 self.freeze_buffered(1)?;
1680 Ok(())
1681 }
1682
1683 fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> {
1689 if !matches!(self.join_type, JoinType::Full) {
1690 return Ok(());
1691 }
1692 for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) {
1693 let buffered_indices = UInt64Array::from_iter_values(
1694 buffered_batch.null_joined.iter().map(|&index| index as u64),
1695 );
1696 if let Some(record_batch) = produce_buffered_null_batch(
1697 &self.schema,
1698 &self.streamed_schema,
1699 &buffered_indices,
1700 buffered_batch,
1701 )? {
1702 let num_rows = record_batch.num_rows();
1703 self.staging_output_record_batches
1704 .filter_mask
1705 .append_nulls(num_rows);
1706 self.staging_output_record_batches
1707 .row_indices
1708 .append_nulls(num_rows);
1709 self.staging_output_record_batches.batch_ids.resize(
1710 self.staging_output_record_batches.batch_ids.len() + num_rows,
1711 0,
1712 );
1713
1714 self.staging_output_record_batches
1715 .batches
1716 .push(record_batch);
1717 }
1718 buffered_batch.null_joined.clear();
1719 }
1720 Ok(())
1721 }
1722
1723 fn produce_buffered_not_matched(
1724 &mut self,
1725 buffered_batch: &mut BufferedBatch,
1726 ) -> Result<()> {
1727 if !matches!(self.join_type, JoinType::Full) {
1728 return Ok(());
1729 }
1730
1731 let not_matched_buffered_indices = buffered_batch
1734 .join_filter_not_matched_map
1735 .iter()
1736 .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None })
1737 .collect::<Vec<_>>();
1738
1739 let buffered_indices =
1740 UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied());
1741
1742 if let Some(record_batch) = produce_buffered_null_batch(
1743 &self.schema,
1744 &self.streamed_schema,
1745 &buffered_indices,
1746 buffered_batch,
1747 )? {
1748 let num_rows = record_batch.num_rows();
1749
1750 self.staging_output_record_batches
1751 .filter_mask
1752 .append_nulls(num_rows);
1753 self.staging_output_record_batches
1754 .row_indices
1755 .append_nulls(num_rows);
1756 self.staging_output_record_batches.batch_ids.resize(
1757 self.staging_output_record_batches.batch_ids.len() + num_rows,
1758 0,
1759 );
1760 self.staging_output_record_batches
1761 .batches
1762 .push(record_batch);
1763 }
1764 buffered_batch.join_filter_not_matched_map.clear();
1765
1766 Ok(())
1767 }
1768
1769 fn freeze_streamed(&mut self) -> Result<()> {
1772 for chunk in self.streamed_batch.output_indices.iter_mut() {
1773 let left_indices = chunk.streamed_indices.finish();
1775
1776 if left_indices.is_empty() {
1777 continue;
1778 }
1779
1780 let mut left_columns = self
1781 .streamed_batch
1782 .batch
1783 .columns()
1784 .iter()
1785 .map(|column| take(column, &left_indices, None))
1786 .collect::<Result<Vec<_>, ArrowError>>()?;
1787
1788 let right_indices: UInt64Array = chunk.buffered_indices.finish();
1790 let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) {
1791 vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
1792 } else if matches!(
1793 self.join_type,
1794 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti
1795 ) {
1796 vec![]
1797 } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
1798 fetch_right_columns_by_idxs(
1799 &self.buffered_data,
1800 buffered_idx,
1801 &right_indices,
1802 )?
1803 } else {
1804 create_unmatched_columns(
1807 self.join_type,
1808 &self.buffered_schema,
1809 right_indices.len(),
1810 )
1811 };
1812
1813 let filter_columns = if chunk.buffered_batch_idx.is_some() {
1816 if !matches!(self.join_type, JoinType::Right) {
1817 if matches!(
1818 self.join_type,
1819 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark
1820 ) {
1821 let right_cols = fetch_right_columns_by_idxs(
1822 &self.buffered_data,
1823 chunk.buffered_batch_idx.unwrap(),
1824 &right_indices,
1825 )?;
1826
1827 get_filter_column(&self.filter, &left_columns, &right_cols)
1828 } else if matches!(self.join_type, JoinType::RightAnti) {
1829 let right_cols = fetch_right_columns_by_idxs(
1830 &self.buffered_data,
1831 chunk.buffered_batch_idx.unwrap(),
1832 &right_indices,
1833 )?;
1834
1835 get_filter_column(&self.filter, &right_cols, &left_columns)
1836 } else {
1837 get_filter_column(&self.filter, &left_columns, &right_columns)
1838 }
1839 } else {
1840 get_filter_column(&self.filter, &right_columns, &left_columns)
1841 }
1842 } else {
1843 vec![]
1846 };
1847
1848 let columns = if !matches!(self.join_type, JoinType::Right) {
1849 left_columns.extend(right_columns);
1850 left_columns
1851 } else {
1852 right_columns.extend(left_columns);
1853 right_columns
1854 };
1855
1856 let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
1857 if !filter_columns.is_empty() {
1859 if let Some(f) = &self.filter {
1860 let filter_batch =
1862 RecordBatch::try_new(Arc::clone(f.schema()), filter_columns)?;
1863
1864 let filter_result = f
1865 .expression()
1866 .evaluate(&filter_batch)?
1867 .into_array(filter_batch.num_rows())?;
1868
1869 let pre_mask =
1871 datafusion_common::cast::as_boolean_array(&filter_result)?;
1872
1873 let mask = if pre_mask.null_count() > 0 {
1876 compute::prep_null_mask_filter(
1877 datafusion_common::cast::as_boolean_array(&filter_result)?,
1878 )
1879 } else {
1880 pre_mask.clone()
1881 };
1882
1883 if matches!(
1885 self.join_type,
1886 JoinType::Left
1887 | JoinType::LeftSemi
1888 | JoinType::Right
1889 | JoinType::LeftAnti
1890 | JoinType::RightAnti
1891 | JoinType::LeftMark
1892 | JoinType::Full
1893 ) {
1894 self.staging_output_record_batches
1895 .batches
1896 .push(output_batch);
1897 } else {
1898 let filtered_batch = filter_record_batch(&output_batch, &mask)?;
1899 self.staging_output_record_batches
1900 .batches
1901 .push(filtered_batch);
1902 }
1903
1904 if !matches!(self.join_type, JoinType::Full) {
1905 self.staging_output_record_batches.filter_mask.extend(&mask);
1906 } else {
1907 self.staging_output_record_batches
1908 .filter_mask
1909 .extend(pre_mask);
1910 }
1911 self.staging_output_record_batches
1912 .row_indices
1913 .extend(&left_indices);
1914 self.staging_output_record_batches.batch_ids.resize(
1915 self.staging_output_record_batches.batch_ids.len()
1916 + left_indices.len(),
1917 self.streamed_batch_counter.load(Relaxed),
1918 );
1919
1920 if matches!(self.join_type, JoinType::Full) {
1925 let buffered_batch = &mut self.buffered_data.batches
1926 [chunk.buffered_batch_idx.unwrap()];
1927
1928 for i in 0..pre_mask.len() {
1929 if right_indices.is_null(i) {
1932 continue;
1933 }
1934
1935 let buffered_index = right_indices.value(i);
1936
1937 buffered_batch.join_filter_not_matched_map.insert(
1938 buffered_index,
1939 *buffered_batch
1940 .join_filter_not_matched_map
1941 .get(&buffered_index)
1942 .unwrap_or(&true)
1943 && !pre_mask.value(i),
1944 );
1945 }
1946 }
1947 } else {
1948 self.staging_output_record_batches
1949 .batches
1950 .push(output_batch);
1951 }
1952 } else {
1953 self.staging_output_record_batches
1954 .batches
1955 .push(output_batch);
1956 }
1957 }
1958
1959 self.streamed_batch.output_indices.clear();
1960
1961 Ok(())
1962 }
1963
1964 fn output_record_batch_and_reset(&mut self) -> Result<RecordBatch> {
1965 let record_batch =
1966 concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
1967 self.join_metrics.output_batches.add(1);
1968 self.join_metrics.output_rows.add(record_batch.num_rows());
1969 if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size {
1975 self.output_size = 0;
1976 } else {
1977 self.output_size -= record_batch.num_rows();
1978 }
1979
1980 if !(self.filter.is_some()
1981 && matches!(
1982 self.join_type,
1983 JoinType::Left
1984 | JoinType::LeftSemi
1985 | JoinType::Right
1986 | JoinType::LeftAnti
1987 | JoinType::RightAnti
1988 | JoinType::LeftMark
1989 | JoinType::Full
1990 ))
1991 {
1992 self.staging_output_record_batches.batches.clear();
1993 }
1994 Ok(record_batch)
1995 }
1996
1997 fn filter_joined_batch(&mut self) -> Result<RecordBatch> {
1998 let record_batch =
1999 concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
2000 let mut out_indices = self.staging_output_record_batches.row_indices.finish();
2001 let mut out_mask = self.staging_output_record_batches.filter_mask.finish();
2002 let mut batch_ids = &self.staging_output_record_batches.batch_ids;
2003 let default_batch_ids = vec![0; record_batch.num_rows()];
2004
2005 if out_indices.null_count() == out_indices.len()
2009 && out_indices.len() != record_batch.num_rows()
2010 {
2011 out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]);
2012 out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]);
2013 batch_ids = &default_batch_ids;
2014 }
2015
2016 if out_mask.is_empty() {
2017 self.staging_output_record_batches.batches.clear();
2018 return Ok(record_batch);
2019 }
2020
2021 let maybe_corrected_mask = get_corrected_filter_mask(
2022 self.join_type,
2023 &out_indices,
2024 batch_ids,
2025 &out_mask,
2026 record_batch.num_rows(),
2027 );
2028
2029 let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask {
2030 filtered_join_mask
2031 } else {
2032 &out_mask
2033 };
2034
2035 self.filter_record_batch_by_join_type(record_batch, corrected_mask)
2036 }
2037
2038 fn filter_record_batch_by_join_type(
2039 &mut self,
2040 record_batch: RecordBatch,
2041 corrected_mask: &BooleanArray,
2042 ) -> Result<RecordBatch> {
2043 let mut filtered_record_batch =
2044 filter_record_batch(&record_batch, corrected_mask)?;
2045 let left_columns_length = self.streamed_schema.fields.len();
2046 let right_columns_length = self.buffered_schema.fields.len();
2047
2048 if matches!(
2049 self.join_type,
2050 JoinType::Left | JoinType::LeftMark | JoinType::Right
2051 ) {
2052 let null_mask = compute::not(corrected_mask)?;
2053 let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?;
2054
2055 let mut right_columns = create_unmatched_columns(
2056 self.join_type,
2057 &self.buffered_schema,
2058 null_joined_batch.num_rows(),
2059 );
2060
2061 let columns = if !matches!(self.join_type, JoinType::Right) {
2062 let mut left_columns = null_joined_batch
2063 .columns()
2064 .iter()
2065 .take(right_columns_length)
2066 .cloned()
2067 .collect::<Vec<_>>();
2068
2069 left_columns.extend(right_columns);
2070 left_columns
2071 } else {
2072 let left_columns = null_joined_batch
2073 .columns()
2074 .iter()
2075 .skip(left_columns_length)
2076 .cloned()
2077 .collect::<Vec<_>>();
2078
2079 right_columns.extend(left_columns);
2080 right_columns
2081 };
2082
2083 let null_joined_streamed_batch =
2085 RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
2086
2087 filtered_record_batch = concat_batches(
2088 &self.schema,
2089 &[filtered_record_batch, null_joined_streamed_batch],
2090 )?;
2091 } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
2092 let output_column_indices = (0..left_columns_length).collect::<Vec<_>>();
2093 filtered_record_batch =
2094 filtered_record_batch.project(&output_column_indices)?;
2095 } else if matches!(self.join_type, JoinType::RightAnti) {
2096 let output_column_indices = (0..right_columns_length).collect::<Vec<_>>();
2097 filtered_record_batch =
2098 filtered_record_batch.project(&output_column_indices)?;
2099 } else if matches!(self.join_type, JoinType::Full)
2100 && corrected_mask.false_count() > 0
2101 {
2102 let joined_filter_not_matched_mask = compute::not(corrected_mask)?;
2104 let joined_filter_not_matched_batch =
2105 filter_record_batch(&record_batch, &joined_filter_not_matched_mask)?;
2106
2107 let right_null_columns = self
2109 .buffered_schema
2110 .fields()
2111 .iter()
2112 .map(|f| {
2113 new_null_array(
2114 f.data_type(),
2115 joined_filter_not_matched_batch.num_rows(),
2116 )
2117 })
2118 .collect::<Vec<_>>();
2119
2120 let mut result_joined = joined_filter_not_matched_batch
2121 .columns()
2122 .iter()
2123 .take(left_columns_length)
2124 .cloned()
2125 .collect::<Vec<_>>();
2126
2127 result_joined.extend(right_null_columns);
2128
2129 let left_null_joined_batch =
2130 RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?;
2131
2132 let mut result_joined = self
2134 .streamed_schema
2135 .fields()
2136 .iter()
2137 .map(|f| {
2138 new_null_array(
2139 f.data_type(),
2140 joined_filter_not_matched_batch.num_rows(),
2141 )
2142 })
2143 .collect::<Vec<_>>();
2144
2145 let right_data = joined_filter_not_matched_batch
2146 .columns()
2147 .iter()
2148 .skip(left_columns_length)
2149 .cloned()
2150 .collect::<Vec<_>>();
2151
2152 result_joined.extend(right_data);
2153
2154 filtered_record_batch = concat_batches(
2155 &self.schema,
2156 &[filtered_record_batch, left_null_joined_batch],
2157 )?;
2158 }
2159
2160 self.staging_output_record_batches.clear();
2161
2162 Ok(filtered_record_batch)
2163 }
2164}
2165
2166fn create_unmatched_columns(
2167 join_type: JoinType,
2168 schema: &SchemaRef,
2169 size: usize,
2170) -> Vec<ArrayRef> {
2171 if matches!(join_type, JoinType::LeftMark) {
2172 vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef]
2173 } else {
2174 schema
2175 .fields()
2176 .iter()
2177 .map(|f| new_null_array(f.data_type(), size))
2178 .collect::<Vec<_>>()
2179 }
2180}
2181
2182fn get_filter_column(
2184 join_filter: &Option<JoinFilter>,
2185 streamed_columns: &[ArrayRef],
2186 buffered_columns: &[ArrayRef],
2187) -> Vec<ArrayRef> {
2188 let mut filter_columns = vec![];
2189
2190 if let Some(f) = join_filter {
2191 let left_columns = f
2192 .column_indices()
2193 .iter()
2194 .filter(|col_index| col_index.side == JoinSide::Left)
2195 .map(|i| Arc::clone(&streamed_columns[i.index]))
2196 .collect::<Vec<_>>();
2197
2198 let right_columns = f
2199 .column_indices()
2200 .iter()
2201 .filter(|col_index| col_index.side == JoinSide::Right)
2202 .map(|i| Arc::clone(&buffered_columns[i.index]))
2203 .collect::<Vec<_>>();
2204
2205 filter_columns.extend(left_columns);
2206 filter_columns.extend(right_columns);
2207 }
2208
2209 filter_columns
2210}
2211
2212fn produce_buffered_null_batch(
2213 schema: &SchemaRef,
2214 streamed_schema: &SchemaRef,
2215 buffered_indices: &PrimitiveArray<UInt64Type>,
2216 buffered_batch: &BufferedBatch,
2217) -> Result<Option<RecordBatch>> {
2218 if buffered_indices.is_empty() {
2219 return Ok(None);
2220 }
2221
2222 let right_columns =
2224 fetch_right_columns_from_batch_by_idxs(buffered_batch, buffered_indices)?;
2225
2226 let mut left_columns = streamed_schema
2228 .fields()
2229 .iter()
2230 .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
2231 .collect::<Vec<_>>();
2232
2233 left_columns.extend(right_columns);
2234
2235 Ok(Some(RecordBatch::try_new(
2236 Arc::clone(schema),
2237 left_columns,
2238 )?))
2239}
2240
2241#[inline(always)]
2243fn fetch_right_columns_by_idxs(
2244 buffered_data: &BufferedData,
2245 buffered_batch_idx: usize,
2246 buffered_indices: &UInt64Array,
2247) -> Result<Vec<ArrayRef>> {
2248 fetch_right_columns_from_batch_by_idxs(
2249 &buffered_data.batches[buffered_batch_idx],
2250 buffered_indices,
2251 )
2252}
2253
2254#[inline(always)]
2255fn fetch_right_columns_from_batch_by_idxs(
2256 buffered_batch: &BufferedBatch,
2257 buffered_indices: &UInt64Array,
2258) -> Result<Vec<ArrayRef>> {
2259 match (&buffered_batch.spill_file, &buffered_batch.batch) {
2260 (None, Some(batch)) => Ok(batch
2262 .columns()
2263 .iter()
2264 .map(|column| take(column, &buffered_indices, None))
2265 .collect::<Result<Vec<_>, ArrowError>>()
2266 .map_err(Into::<DataFusionError>::into)?),
2267 (Some(spill_file), None) => {
2269 let mut buffered_cols: Vec<ArrayRef> =
2270 Vec::with_capacity(buffered_indices.len());
2271
2272 let file = BufReader::new(File::open(spill_file.path())?);
2273 let reader = StreamReader::try_new(file, None)?;
2274
2275 for batch in reader {
2276 batch?.columns().iter().for_each(|column| {
2277 buffered_cols.extend(take(column, &buffered_indices, None))
2278 });
2279 }
2280
2281 Ok(buffered_cols)
2282 }
2283 (spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()),
2285 }
2286}
2287
2288#[derive(Debug, Default)]
2290struct BufferedData {
2291 pub batches: VecDeque<BufferedBatch>,
2293 pub scanning_batch_idx: usize,
2295 pub scanning_offset: usize,
2297}
2298
2299impl BufferedData {
2300 pub fn head_batch(&self) -> &BufferedBatch {
2301 self.batches.front().unwrap()
2302 }
2303
2304 pub fn tail_batch(&self) -> &BufferedBatch {
2305 self.batches.back().unwrap()
2306 }
2307
2308 pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch {
2309 self.batches.back_mut().unwrap()
2310 }
2311
2312 pub fn has_buffered_rows(&self) -> bool {
2313 self.batches.iter().any(|batch| !batch.range.is_empty())
2314 }
2315
2316 pub fn scanning_reset(&mut self) {
2317 self.scanning_batch_idx = 0;
2318 self.scanning_offset = 0;
2319 }
2320
2321 pub fn scanning_advance(&mut self) {
2322 self.scanning_offset += 1;
2323 while !self.scanning_finished() && self.scanning_batch_finished() {
2324 self.scanning_batch_idx += 1;
2325 self.scanning_offset = 0;
2326 }
2327 }
2328
2329 pub fn scanning_batch(&self) -> &BufferedBatch {
2330 &self.batches[self.scanning_batch_idx]
2331 }
2332
2333 pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch {
2334 &mut self.batches[self.scanning_batch_idx]
2335 }
2336
2337 pub fn scanning_idx(&self) -> usize {
2338 self.scanning_batch().range.start + self.scanning_offset
2339 }
2340
2341 pub fn scanning_batch_finished(&self) -> bool {
2342 self.scanning_offset == self.scanning_batch().range.len()
2343 }
2344
2345 pub fn scanning_finished(&self) -> bool {
2346 self.scanning_batch_idx == self.batches.len()
2347 }
2348
2349 pub fn scanning_finish(&mut self) {
2350 self.scanning_batch_idx = self.batches.len();
2351 self.scanning_offset = 0;
2352 }
2353}
2354
2355fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec<ArrayRef> {
2357 on_column
2358 .iter()
2359 .map(|c| {
2360 let num_rows = batch.num_rows();
2361 let c = c.evaluate(batch).unwrap();
2362 c.into_array(num_rows).unwrap()
2363 })
2364 .collect()
2365}
2366
2367fn compare_join_arrays(
2369 left_arrays: &[ArrayRef],
2370 left: usize,
2371 right_arrays: &[ArrayRef],
2372 right: usize,
2373 sort_options: &[SortOptions],
2374 null_equals_null: bool,
2375) -> Result<Ordering> {
2376 let mut res = Ordering::Equal;
2377 for ((left_array, right_array), sort_options) in
2378 left_arrays.iter().zip(right_arrays).zip(sort_options)
2379 {
2380 macro_rules! compare_value {
2381 ($T:ty) => {{
2382 let left_array = left_array.as_any().downcast_ref::<$T>().unwrap();
2383 let right_array = right_array.as_any().downcast_ref::<$T>().unwrap();
2384 match (left_array.is_null(left), right_array.is_null(right)) {
2385 (false, false) => {
2386 let left_value = &left_array.value(left);
2387 let right_value = &right_array.value(right);
2388 res = left_value.partial_cmp(right_value).unwrap();
2389 if sort_options.descending {
2390 res = res.reverse();
2391 }
2392 }
2393 (true, false) => {
2394 res = if sort_options.nulls_first {
2395 Ordering::Less
2396 } else {
2397 Ordering::Greater
2398 };
2399 }
2400 (false, true) => {
2401 res = if sort_options.nulls_first {
2402 Ordering::Greater
2403 } else {
2404 Ordering::Less
2405 };
2406 }
2407 _ => {
2408 res = if null_equals_null {
2409 Ordering::Equal
2410 } else {
2411 Ordering::Less
2412 };
2413 }
2414 }
2415 }};
2416 }
2417
2418 match left_array.data_type() {
2419 DataType::Null => {}
2420 DataType::Boolean => compare_value!(BooleanArray),
2421 DataType::Int8 => compare_value!(Int8Array),
2422 DataType::Int16 => compare_value!(Int16Array),
2423 DataType::Int32 => compare_value!(Int32Array),
2424 DataType::Int64 => compare_value!(Int64Array),
2425 DataType::UInt8 => compare_value!(UInt8Array),
2426 DataType::UInt16 => compare_value!(UInt16Array),
2427 DataType::UInt32 => compare_value!(UInt32Array),
2428 DataType::UInt64 => compare_value!(UInt64Array),
2429 DataType::Float32 => compare_value!(Float32Array),
2430 DataType::Float64 => compare_value!(Float64Array),
2431 DataType::Utf8 => compare_value!(StringArray),
2432 DataType::LargeUtf8 => compare_value!(LargeStringArray),
2433 DataType::Decimal128(..) => compare_value!(Decimal128Array),
2434 DataType::Timestamp(time_unit, None) => match time_unit {
2435 TimeUnit::Second => compare_value!(TimestampSecondArray),
2436 TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
2437 TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
2438 TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
2439 },
2440 DataType::Date32 => compare_value!(Date32Array),
2441 DataType::Date64 => compare_value!(Date64Array),
2442 dt => {
2443 return not_impl_err!(
2444 "Unsupported data type in sort merge join comparator: {}",
2445 dt
2446 );
2447 }
2448 }
2449 if !res.is_eq() {
2450 break;
2451 }
2452 }
2453 Ok(res)
2454}
2455
2456fn is_join_arrays_equal(
2459 left_arrays: &[ArrayRef],
2460 left: usize,
2461 right_arrays: &[ArrayRef],
2462 right: usize,
2463) -> Result<bool> {
2464 let mut is_equal = true;
2465 for (left_array, right_array) in left_arrays.iter().zip(right_arrays) {
2466 macro_rules! compare_value {
2467 ($T:ty) => {{
2468 match (left_array.is_null(left), right_array.is_null(right)) {
2469 (false, false) => {
2470 let left_array =
2471 left_array.as_any().downcast_ref::<$T>().unwrap();
2472 let right_array =
2473 right_array.as_any().downcast_ref::<$T>().unwrap();
2474 if left_array.value(left) != right_array.value(right) {
2475 is_equal = false;
2476 }
2477 }
2478 (true, false) => is_equal = false,
2479 (false, true) => is_equal = false,
2480 _ => {}
2481 }
2482 }};
2483 }
2484
2485 match left_array.data_type() {
2486 DataType::Null => {}
2487 DataType::Boolean => compare_value!(BooleanArray),
2488 DataType::Int8 => compare_value!(Int8Array),
2489 DataType::Int16 => compare_value!(Int16Array),
2490 DataType::Int32 => compare_value!(Int32Array),
2491 DataType::Int64 => compare_value!(Int64Array),
2492 DataType::UInt8 => compare_value!(UInt8Array),
2493 DataType::UInt16 => compare_value!(UInt16Array),
2494 DataType::UInt32 => compare_value!(UInt32Array),
2495 DataType::UInt64 => compare_value!(UInt64Array),
2496 DataType::Float32 => compare_value!(Float32Array),
2497 DataType::Float64 => compare_value!(Float64Array),
2498 DataType::Utf8 => compare_value!(StringArray),
2499 DataType::LargeUtf8 => compare_value!(LargeStringArray),
2500 DataType::Decimal128(..) => compare_value!(Decimal128Array),
2501 DataType::Timestamp(time_unit, None) => match time_unit {
2502 TimeUnit::Second => compare_value!(TimestampSecondArray),
2503 TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
2504 TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
2505 TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
2506 },
2507 DataType::Date32 => compare_value!(Date32Array),
2508 DataType::Date64 => compare_value!(Date64Array),
2509 dt => {
2510 return not_impl_err!(
2511 "Unsupported data type in sort merge join comparator: {}",
2512 dt
2513 );
2514 }
2515 }
2516 if !is_equal {
2517 return Ok(false);
2518 }
2519 }
2520 Ok(true)
2521}
2522
2523#[cfg(test)]
2524mod tests {
2525 use std::sync::Arc;
2526
2527 use arrow::array::{
2528 builder::{BooleanBuilder, UInt64Builder},
2529 BooleanArray, Date32Array, Date64Array, Int32Array, RecordBatch, UInt64Array,
2530 };
2531 use arrow::compute::{concat_batches, filter_record_batch, SortOptions};
2532 use arrow::datatypes::{DataType, Field, Schema};
2533
2534 use datafusion_common::JoinSide;
2535 use datafusion_common::JoinType::*;
2536 use datafusion_common::{
2537 assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
2538 };
2539 use datafusion_execution::config::SessionConfig;
2540 use datafusion_execution::disk_manager::DiskManagerConfig;
2541 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2542 use datafusion_execution::TaskContext;
2543 use datafusion_expr::Operator;
2544 use datafusion_physical_expr::expressions::BinaryExpr;
2545
2546 use crate::expressions::Column;
2547 use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches};
2548 use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
2549 use crate::joins::SortMergeJoinExec;
2550 use crate::test::TestMemoryExec;
2551 use crate::test::{build_table_i32, build_table_i32_two_cols};
2552 use crate::{common, ExecutionPlan};
2553
2554 fn build_table(
2555 a: (&str, &Vec<i32>),
2556 b: (&str, &Vec<i32>),
2557 c: (&str, &Vec<i32>),
2558 ) -> Arc<dyn ExecutionPlan> {
2559 let batch = build_table_i32(a, b, c);
2560 let schema = batch.schema();
2561 TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2562 }
2563
2564 fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> {
2565 let schema = batches.first().unwrap().schema();
2566 TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap()
2567 }
2568
2569 fn build_date_table(
2570 a: (&str, &Vec<i32>),
2571 b: (&str, &Vec<i32>),
2572 c: (&str, &Vec<i32>),
2573 ) -> Arc<dyn ExecutionPlan> {
2574 let schema = Schema::new(vec![
2575 Field::new(a.0, DataType::Date32, false),
2576 Field::new(b.0, DataType::Date32, false),
2577 Field::new(c.0, DataType::Date32, false),
2578 ]);
2579
2580 let batch = RecordBatch::try_new(
2581 Arc::new(schema),
2582 vec![
2583 Arc::new(Date32Array::from(a.1.clone())),
2584 Arc::new(Date32Array::from(b.1.clone())),
2585 Arc::new(Date32Array::from(c.1.clone())),
2586 ],
2587 )
2588 .unwrap();
2589
2590 let schema = batch.schema();
2591 TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2592 }
2593
2594 fn build_date64_table(
2595 a: (&str, &Vec<i64>),
2596 b: (&str, &Vec<i64>),
2597 c: (&str, &Vec<i64>),
2598 ) -> Arc<dyn ExecutionPlan> {
2599 let schema = Schema::new(vec![
2600 Field::new(a.0, DataType::Date64, false),
2601 Field::new(b.0, DataType::Date64, false),
2602 Field::new(c.0, DataType::Date64, false),
2603 ]);
2604
2605 let batch = RecordBatch::try_new(
2606 Arc::new(schema),
2607 vec![
2608 Arc::new(Date64Array::from(a.1.clone())),
2609 Arc::new(Date64Array::from(b.1.clone())),
2610 Arc::new(Date64Array::from(c.1.clone())),
2611 ],
2612 )
2613 .unwrap();
2614
2615 let schema = batch.schema();
2616 TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2617 }
2618
2619 pub fn build_table_i32_nullable(
2621 a: (&str, &Vec<Option<i32>>),
2622 b: (&str, &Vec<Option<i32>>),
2623 c: (&str, &Vec<Option<i32>>),
2624 ) -> Arc<dyn ExecutionPlan> {
2625 let schema = Arc::new(Schema::new(vec![
2626 Field::new(a.0, DataType::Int32, true),
2627 Field::new(b.0, DataType::Int32, true),
2628 Field::new(c.0, DataType::Int32, true),
2629 ]));
2630 let batch = RecordBatch::try_new(
2631 Arc::clone(&schema),
2632 vec![
2633 Arc::new(Int32Array::from(a.1.clone())),
2634 Arc::new(Int32Array::from(b.1.clone())),
2635 Arc::new(Int32Array::from(c.1.clone())),
2636 ],
2637 )
2638 .unwrap();
2639 TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2640 }
2641
2642 pub fn build_table_two_cols(
2643 a: (&str, &Vec<i32>),
2644 b: (&str, &Vec<i32>),
2645 ) -> Arc<dyn ExecutionPlan> {
2646 let batch = build_table_i32_two_cols(a, b);
2647 let schema = batch.schema();
2648 TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2649 }
2650
2651 fn join(
2652 left: Arc<dyn ExecutionPlan>,
2653 right: Arc<dyn ExecutionPlan>,
2654 on: JoinOn,
2655 join_type: JoinType,
2656 ) -> Result<SortMergeJoinExec> {
2657 let sort_options = vec![SortOptions::default(); on.len()];
2658 SortMergeJoinExec::try_new(left, right, on, None, join_type, sort_options, false)
2659 }
2660
2661 fn join_with_options(
2662 left: Arc<dyn ExecutionPlan>,
2663 right: Arc<dyn ExecutionPlan>,
2664 on: JoinOn,
2665 join_type: JoinType,
2666 sort_options: Vec<SortOptions>,
2667 null_equals_null: bool,
2668 ) -> Result<SortMergeJoinExec> {
2669 SortMergeJoinExec::try_new(
2670 left,
2671 right,
2672 on,
2673 None,
2674 join_type,
2675 sort_options,
2676 null_equals_null,
2677 )
2678 }
2679
2680 fn join_with_filter(
2681 left: Arc<dyn ExecutionPlan>,
2682 right: Arc<dyn ExecutionPlan>,
2683 on: JoinOn,
2684 filter: JoinFilter,
2685 join_type: JoinType,
2686 sort_options: Vec<SortOptions>,
2687 null_equals_null: bool,
2688 ) -> Result<SortMergeJoinExec> {
2689 SortMergeJoinExec::try_new(
2690 left,
2691 right,
2692 on,
2693 Some(filter),
2694 join_type,
2695 sort_options,
2696 null_equals_null,
2697 )
2698 }
2699
2700 async fn join_collect(
2701 left: Arc<dyn ExecutionPlan>,
2702 right: Arc<dyn ExecutionPlan>,
2703 on: JoinOn,
2704 join_type: JoinType,
2705 ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2706 let sort_options = vec![SortOptions::default(); on.len()];
2707 join_collect_with_options(left, right, on, join_type, sort_options, false).await
2708 }
2709
2710 async fn join_collect_with_filter(
2711 left: Arc<dyn ExecutionPlan>,
2712 right: Arc<dyn ExecutionPlan>,
2713 on: JoinOn,
2714 filter: JoinFilter,
2715 join_type: JoinType,
2716 ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2717 let sort_options = vec![SortOptions::default(); on.len()];
2718
2719 let task_ctx = Arc::new(TaskContext::default());
2720 let join =
2721 join_with_filter(left, right, on, filter, join_type, sort_options, false)?;
2722 let columns = columns(&join.schema());
2723
2724 let stream = join.execute(0, task_ctx)?;
2725 let batches = common::collect(stream).await?;
2726 Ok((columns, batches))
2727 }
2728
2729 async fn join_collect_with_options(
2730 left: Arc<dyn ExecutionPlan>,
2731 right: Arc<dyn ExecutionPlan>,
2732 on: JoinOn,
2733 join_type: JoinType,
2734 sort_options: Vec<SortOptions>,
2735 null_equals_null: bool,
2736 ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2737 let task_ctx = Arc::new(TaskContext::default());
2738 let join = join_with_options(
2739 left,
2740 right,
2741 on,
2742 join_type,
2743 sort_options,
2744 null_equals_null,
2745 )?;
2746 let columns = columns(&join.schema());
2747
2748 let stream = join.execute(0, task_ctx)?;
2749 let batches = common::collect(stream).await?;
2750 Ok((columns, batches))
2751 }
2752
2753 async fn join_collect_batch_size_equals_two(
2754 left: Arc<dyn ExecutionPlan>,
2755 right: Arc<dyn ExecutionPlan>,
2756 on: JoinOn,
2757 join_type: JoinType,
2758 ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2759 let task_ctx = TaskContext::default()
2760 .with_session_config(SessionConfig::new().with_batch_size(2));
2761 let task_ctx = Arc::new(task_ctx);
2762 let join = join(left, right, on, join_type)?;
2763 let columns = columns(&join.schema());
2764
2765 let stream = join.execute(0, task_ctx)?;
2766 let batches = common::collect(stream).await?;
2767 Ok((columns, batches))
2768 }
2769
2770 #[tokio::test]
2771 async fn join_inner_one() -> Result<()> {
2772 let left = build_table(
2773 ("a1", &vec![1, 2, 3]),
2774 ("b1", &vec![4, 5, 5]), ("c1", &vec![7, 8, 9]),
2776 );
2777 let right = build_table(
2778 ("a2", &vec![10, 20, 30]),
2779 ("b1", &vec![4, 5, 6]),
2780 ("c2", &vec![70, 80, 90]),
2781 );
2782
2783 let on = vec![(
2784 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2785 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
2786 )];
2787
2788 let (_, batches) = join_collect(left, right, on, Inner).await?;
2789
2790 let expected = [
2791 "+----+----+----+----+----+----+",
2792 "| a1 | b1 | c1 | a2 | b1 | c2 |",
2793 "+----+----+----+----+----+----+",
2794 "| 1 | 4 | 7 | 10 | 4 | 70 |",
2795 "| 2 | 5 | 8 | 20 | 5 | 80 |",
2796 "| 3 | 5 | 9 | 20 | 5 | 80 |",
2797 "+----+----+----+----+----+----+",
2798 ];
2799 assert_batches_eq!(expected, &batches);
2801 Ok(())
2802 }
2803
2804 #[tokio::test]
2805 async fn join_inner_two() -> Result<()> {
2806 let left = build_table(
2807 ("a1", &vec![1, 2, 2]),
2808 ("b2", &vec![1, 2, 2]),
2809 ("c1", &vec![7, 8, 9]),
2810 );
2811 let right = build_table(
2812 ("a1", &vec![1, 2, 3]),
2813 ("b2", &vec![1, 2, 2]),
2814 ("c2", &vec![70, 80, 90]),
2815 );
2816 let on = vec![
2817 (
2818 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2819 Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2820 ),
2821 (
2822 Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2823 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2824 ),
2825 ];
2826
2827 let (_columns, batches) = join_collect(left, right, on, Inner).await?;
2828 let expected = [
2829 "+----+----+----+----+----+----+",
2830 "| a1 | b2 | c1 | a1 | b2 | c2 |",
2831 "+----+----+----+----+----+----+",
2832 "| 1 | 1 | 7 | 1 | 1 | 70 |",
2833 "| 2 | 2 | 8 | 2 | 2 | 80 |",
2834 "| 2 | 2 | 9 | 2 | 2 | 80 |",
2835 "+----+----+----+----+----+----+",
2836 ];
2837 assert_batches_eq!(expected, &batches);
2839 Ok(())
2840 }
2841
2842 #[tokio::test]
2843 async fn join_inner_two_two() -> Result<()> {
2844 let left = build_table(
2845 ("a1", &vec![1, 1, 2]),
2846 ("b2", &vec![1, 1, 2]),
2847 ("c1", &vec![7, 8, 9]),
2848 );
2849 let right = build_table(
2850 ("a1", &vec![1, 1, 3]),
2851 ("b2", &vec![1, 1, 2]),
2852 ("c2", &vec![70, 80, 90]),
2853 );
2854 let on = vec![
2855 (
2856 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2857 Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2858 ),
2859 (
2860 Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2861 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2862 ),
2863 ];
2864
2865 let (_columns, batches) = join_collect(left, right, on, Inner).await?;
2866 let expected = [
2867 "+----+----+----+----+----+----+",
2868 "| a1 | b2 | c1 | a1 | b2 | c2 |",
2869 "+----+----+----+----+----+----+",
2870 "| 1 | 1 | 7 | 1 | 1 | 70 |",
2871 "| 1 | 1 | 7 | 1 | 1 | 80 |",
2872 "| 1 | 1 | 8 | 1 | 1 | 70 |",
2873 "| 1 | 1 | 8 | 1 | 1 | 80 |",
2874 "+----+----+----+----+----+----+",
2875 ];
2876 assert_batches_eq!(expected, &batches);
2878 Ok(())
2879 }
2880
2881 #[tokio::test]
2882 async fn join_inner_with_nulls() -> Result<()> {
2883 let left = build_table_i32_nullable(
2884 ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]),
2885 ("b2", &vec![None, Some(1), Some(2), Some(2)]), ("c1", &vec![Some(1), None, Some(8), Some(9)]), );
2888 let right = build_table_i32_nullable(
2889 ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]),
2890 ("b2", &vec![None, Some(1), Some(2), Some(2)]),
2891 ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]),
2892 );
2893 let on = vec![
2894 (
2895 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2896 Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2897 ),
2898 (
2899 Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2900 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2901 ),
2902 ];
2903
2904 let (_, batches) = join_collect(left, right, on, Inner).await?;
2905 let expected = [
2906 "+----+----+----+----+----+----+",
2907 "| a1 | b2 | c1 | a1 | b2 | c2 |",
2908 "+----+----+----+----+----+----+",
2909 "| 1 | 1 | | 1 | 1 | 70 |",
2910 "| 2 | 2 | 8 | 2 | 2 | 80 |",
2911 "| 2 | 2 | 9 | 2 | 2 | 80 |",
2912 "+----+----+----+----+----+----+",
2913 ];
2914 assert_batches_eq!(expected, &batches);
2916 Ok(())
2917 }
2918
2919 #[tokio::test]
2920 async fn join_inner_with_nulls_with_options() -> Result<()> {
2921 let left = build_table_i32_nullable(
2922 ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]),
2923 ("b2", &vec![Some(2), Some(2), Some(1), None]), ("c1", &vec![Some(9), Some(8), None, Some(1)]), );
2926 let right = build_table_i32_nullable(
2927 ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]),
2928 ("b2", &vec![Some(2), Some(2), Some(1), None]),
2929 ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]),
2930 );
2931 let on = vec![
2932 (
2933 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2934 Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2935 ),
2936 (
2937 Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2938 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2939 ),
2940 ];
2941 let (_, batches) = join_collect_with_options(
2942 left,
2943 right,
2944 on,
2945 Inner,
2946 vec![
2947 SortOptions {
2948 descending: true,
2949 nulls_first: false,
2950 };
2951 2
2952 ],
2953 true,
2954 )
2955 .await?;
2956 let expected = [
2957 "+----+----+----+----+----+----+",
2958 "| a1 | b2 | c1 | a1 | b2 | c2 |",
2959 "+----+----+----+----+----+----+",
2960 "| 2 | 2 | 9 | 2 | 2 | 80 |",
2961 "| 2 | 2 | 8 | 2 | 2 | 80 |",
2962 "| 1 | 1 | | 1 | 1 | 70 |",
2963 "| 1 | | 1 | 1 | | 10 |",
2964 "+----+----+----+----+----+----+",
2965 ];
2966 assert_batches_eq!(expected, &batches);
2968 Ok(())
2969 }
2970
2971 #[tokio::test]
2972 async fn join_inner_output_two_batches() -> Result<()> {
2973 let left = build_table(
2974 ("a1", &vec![1, 2, 2]),
2975 ("b2", &vec![1, 2, 2]),
2976 ("c1", &vec![7, 8, 9]),
2977 );
2978 let right = build_table(
2979 ("a1", &vec![1, 2, 3]),
2980 ("b2", &vec![1, 2, 2]),
2981 ("c2", &vec![70, 80, 90]),
2982 );
2983 let on = vec![
2984 (
2985 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2986 Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2987 ),
2988 (
2989 Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2990 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2991 ),
2992 ];
2993
2994 let (_, batches) =
2995 join_collect_batch_size_equals_two(left, right, on, Inner).await?;
2996 let expected = [
2997 "+----+----+----+----+----+----+",
2998 "| a1 | b2 | c1 | a1 | b2 | c2 |",
2999 "+----+----+----+----+----+----+",
3000 "| 1 | 1 | 7 | 1 | 1 | 70 |",
3001 "| 2 | 2 | 8 | 2 | 2 | 80 |",
3002 "| 2 | 2 | 9 | 2 | 2 | 80 |",
3003 "+----+----+----+----+----+----+",
3004 ];
3005 assert_eq!(batches.len(), 2);
3006 assert_eq!(batches[0].num_rows(), 2);
3007 assert_eq!(batches[1].num_rows(), 1);
3008 assert_batches_eq!(expected, &batches);
3010 Ok(())
3011 }
3012
3013 #[tokio::test]
3014 async fn join_left_one() -> Result<()> {
3015 let left = build_table(
3016 ("a1", &vec![1, 2, 3]),
3017 ("b1", &vec![4, 5, 7]), ("c1", &vec![7, 8, 9]),
3019 );
3020 let right = build_table(
3021 ("a2", &vec![10, 20, 30]),
3022 ("b1", &vec![4, 5, 6]),
3023 ("c2", &vec![70, 80, 90]),
3024 );
3025 let on = vec![(
3026 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3027 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3028 )];
3029
3030 let (_, batches) = join_collect(left, right, on, Left).await?;
3031 let expected = [
3032 "+----+----+----+----+----+----+",
3033 "| a1 | b1 | c1 | a2 | b1 | c2 |",
3034 "+----+----+----+----+----+----+",
3035 "| 1 | 4 | 7 | 10 | 4 | 70 |",
3036 "| 2 | 5 | 8 | 20 | 5 | 80 |",
3037 "| 3 | 7 | 9 | | | |",
3038 "+----+----+----+----+----+----+",
3039 ];
3040 assert_batches_eq!(expected, &batches);
3042 Ok(())
3043 }
3044
3045 #[tokio::test]
3046 async fn join_right_one() -> Result<()> {
3047 let left = build_table(
3048 ("a1", &vec![1, 2, 3]),
3049 ("b1", &vec![4, 5, 7]),
3050 ("c1", &vec![7, 8, 9]),
3051 );
3052 let right = build_table(
3053 ("a2", &vec![10, 20, 30]),
3054 ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]),
3056 );
3057 let on = vec![(
3058 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3059 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3060 )];
3061
3062 let (_, batches) = join_collect(left, right, on, Right).await?;
3063 let expected = [
3064 "+----+----+----+----+----+----+",
3065 "| a1 | b1 | c1 | a2 | b1 | c2 |",
3066 "+----+----+----+----+----+----+",
3067 "| 1 | 4 | 7 | 10 | 4 | 70 |",
3068 "| 2 | 5 | 8 | 20 | 5 | 80 |",
3069 "| | | | 30 | 6 | 90 |",
3070 "+----+----+----+----+----+----+",
3071 ];
3072 assert_batches_eq!(expected, &batches);
3074 Ok(())
3075 }
3076
3077 #[tokio::test]
3078 async fn join_full_one() -> Result<()> {
3079 let left = build_table(
3080 ("a1", &vec![1, 2, 3]),
3081 ("b1", &vec![4, 5, 7]), ("c1", &vec![7, 8, 9]),
3083 );
3084 let right = build_table(
3085 ("a2", &vec![10, 20, 30]),
3086 ("b2", &vec![4, 5, 6]),
3087 ("c2", &vec![70, 80, 90]),
3088 );
3089 let on = vec![(
3090 Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
3091 Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
3092 )];
3093
3094 let (_, batches) = join_collect(left, right, on, Full).await?;
3095 let expected = [
3096 "+----+----+----+----+----+----+",
3097 "| a1 | b1 | c1 | a2 | b2 | c2 |",
3098 "+----+----+----+----+----+----+",
3099 "| | | | 30 | 6 | 90 |",
3100 "| 1 | 4 | 7 | 10 | 4 | 70 |",
3101 "| 2 | 5 | 8 | 20 | 5 | 80 |",
3102 "| 3 | 7 | 9 | | | |",
3103 "+----+----+----+----+----+----+",
3104 ];
3105 assert_batches_sorted_eq!(expected, &batches);
3106 Ok(())
3107 }
3108
3109 #[tokio::test]
3110 async fn join_left_anti() -> Result<()> {
3111 let left = build_table(
3112 ("a1", &vec![1, 2, 2, 3, 5]),
3113 ("b1", &vec![4, 5, 5, 7, 7]), ("c1", &vec![7, 8, 8, 9, 11]),
3115 );
3116 let right = build_table(
3117 ("a2", &vec![10, 20, 30]),
3118 ("b1", &vec![4, 5, 6]),
3119 ("c2", &vec![70, 80, 90]),
3120 );
3121 let on = vec![(
3122 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3123 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3124 )];
3125
3126 let (_, batches) = join_collect(left, right, on, LeftAnti).await?;
3127 let expected = [
3128 "+----+----+----+",
3129 "| a1 | b1 | c1 |",
3130 "+----+----+----+",
3131 "| 3 | 7 | 9 |",
3132 "| 5 | 7 | 11 |",
3133 "+----+----+----+",
3134 ];
3135 assert_batches_eq!(expected, &batches);
3137 Ok(())
3138 }
3139
3140 #[tokio::test]
3141 async fn join_right_anti_one_one() -> Result<()> {
3142 let left = build_table(
3143 ("a1", &vec![1, 2, 2]),
3144 ("b1", &vec![4, 5, 5]),
3145 ("c1", &vec![7, 8, 8]),
3146 );
3147 let right =
3148 build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]));
3149 let on = vec![(
3150 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3151 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3152 )];
3153
3154 let (_, batches) = join_collect(left, right, on, RightAnti).await?;
3155 let expected = [
3156 "+----+----+",
3157 "| a2 | b1 |",
3158 "+----+----+",
3159 "| 30 | 6 |",
3160 "+----+----+",
3161 ];
3162 assert_batches_eq!(expected, &batches);
3164
3165 let left2 = build_table(
3166 ("a1", &vec![1, 2, 2]),
3167 ("b1", &vec![4, 5, 5]),
3168 ("c1", &vec![7, 8, 8]),
3169 );
3170 let right2 = build_table(
3171 ("a2", &vec![10, 20, 30]),
3172 ("b1", &vec![4, 5, 6]),
3173 ("c2", &vec![70, 80, 90]),
3174 );
3175
3176 let on = vec![(
3177 Arc::new(Column::new_with_schema("b1", &left2.schema())?) as _,
3178 Arc::new(Column::new_with_schema("b1", &right2.schema())?) as _,
3179 )];
3180
3181 let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?;
3182 let expected2 = [
3183 "+----+----+----+",
3184 "| a2 | b1 | c2 |",
3185 "+----+----+----+",
3186 "| 30 | 6 | 90 |",
3187 "+----+----+----+",
3188 ];
3189 assert_batches_eq!(expected2, &batches2);
3191
3192 Ok(())
3193 }
3194
3195 #[tokio::test]
3196 async fn join_right_anti_two_two() -> Result<()> {
3197 let left = build_table(
3198 ("a1", &vec![1, 2, 2]),
3199 ("b1", &vec![4, 5, 5]),
3200 ("c1", &vec![7, 8, 8]),
3201 );
3202 let right =
3203 build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]));
3204 let on = vec![
3205 (
3206 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3207 Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
3208 ),
3209 (
3210 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3211 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3212 ),
3213 ];
3214
3215 let (_, batches) = join_collect(left, right, on, RightAnti).await?;
3216 let expected = [
3217 "+----+----+",
3218 "| a2 | b1 |",
3219 "+----+----+",
3220 "| 10 | 4 |",
3221 "| 20 | 5 |",
3222 "| 30 | 6 |",
3223 "+----+----+",
3224 ];
3225 assert_batches_eq!(expected, &batches);
3227
3228 let left = build_table(
3229 ("a1", &vec![1, 2, 2]),
3230 ("b1", &vec![4, 5, 5]),
3231 ("c1", &vec![7, 8, 8]),
3232 );
3233 let right = build_table(
3234 ("a2", &vec![10, 20, 30]),
3235 ("b1", &vec![4, 5, 6]),
3236 ("c2", &vec![70, 80, 90]),
3237 );
3238
3239 let on = vec![
3240 (
3241 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3242 Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
3243 ),
3244 (
3245 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3246 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3247 ),
3248 ];
3249
3250 let (_, batches) = join_collect(left, right, on, RightAnti).await?;
3251 let expected = [
3252 "+----+----+----+",
3253 "| a2 | b1 | c2 |",
3254 "+----+----+----+",
3255 "| 10 | 4 | 70 |",
3256 "| 20 | 5 | 80 |",
3257 "| 30 | 6 | 90 |",
3258 "+----+----+----+",
3259 ];
3260 assert_batches_eq!(expected, &batches);
3262
3263 Ok(())
3264 }
3265
3266 #[tokio::test]
3267 async fn join_right_anti_two_with_filter() -> Result<()> {
3268 let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30]));
3269 let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20]));
3270 let on = vec![
3271 (
3272 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3273 Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
3274 ),
3275 (
3276 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3277 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3278 ),
3279 ];
3280 let filter = JoinFilter::new(
3281 Arc::new(BinaryExpr::new(
3282 Arc::new(Column::new("c2", 1)),
3283 Operator::Gt,
3284 Arc::new(Column::new("c1", 0)),
3285 )),
3286 vec![
3287 ColumnIndex {
3288 index: 2,
3289 side: JoinSide::Left,
3290 },
3291 ColumnIndex {
3292 index: 2,
3293 side: JoinSide::Right,
3294 },
3295 ],
3296 Arc::new(Schema::new(vec![
3297 Field::new("c1", DataType::Int32, true),
3298 Field::new("c2", DataType::Int32, true),
3299 ])),
3300 );
3301 let (_, batches) =
3302 join_collect_with_filter(left, right, on, filter, RightAnti).await?;
3303 let expected = [
3304 "+----+----+----+",
3305 "| a1 | b1 | c2 |",
3306 "+----+----+----+",
3307 "| 1 | 10 | 20 |",
3308 "+----+----+----+",
3309 ];
3310 assert_batches_eq!(expected, &batches);
3311 Ok(())
3312 }
3313
3314 #[tokio::test]
3315 async fn join_right_anti_with_nulls() -> Result<()> {
3316 let left = build_table_i32_nullable(
3317 ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]),
3318 ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]),
3319 ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]),
3320 );
3321 let right = build_table_i32_nullable(
3322 ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]),
3323 ("b1", &vec![Some(4), Some(5), None, Some(6)]), ("c2", &vec![Some(7), Some(8), Some(8), None]), );
3326 let on = vec![
3327 (
3328 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3329 Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
3330 ),
3331 (
3332 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3333 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3334 ),
3335 ];
3336
3337 let (_, batches) = join_collect(left, right, on, RightAnti).await?;
3338 let expected = [
3339 "+----+----+----+",
3340 "| a1 | b1 | c2 |",
3341 "+----+----+----+",
3342 "| 2 | | 8 |",
3343 "+----+----+----+",
3344 ];
3345 assert_batches_eq!(expected, &batches);
3347 Ok(())
3348 }
3349
3350 #[tokio::test]
3351 async fn join_right_anti_with_nulls_with_options() -> Result<()> {
3352 let left = build_table_i32_nullable(
3353 ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]),
3354 ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]),
3355 ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]),
3356 );
3357 let right = build_table_i32_nullable(
3358 ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]),
3359 ("b1", &vec![None, Some(5), Some(5), Some(4)]), ("c2", &vec![Some(9), None, Some(8), Some(7)]), );
3362 let on = vec![
3363 (
3364 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3365 Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
3366 ),
3367 (
3368 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3369 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3370 ),
3371 ];
3372
3373 let (_, batches) = join_collect_with_options(
3374 left,
3375 right,
3376 on,
3377 RightAnti,
3378 vec![
3379 SortOptions {
3380 descending: true,
3381 nulls_first: false,
3382 };
3383 2
3384 ],
3385 true,
3386 )
3387 .await?;
3388
3389 let expected = [
3390 "+----+----+----+",
3391 "| a1 | b1 | c2 |",
3392 "+----+----+----+",
3393 "| 3 | | 9 |",
3394 "| 2 | 5 | |",
3395 "| 2 | 5 | 8 |",
3396 "+----+----+----+",
3397 ];
3398 assert_batches_eq!(expected, &batches);
3400 Ok(())
3401 }
3402
3403 #[tokio::test]
3404 async fn join_right_anti_output_two_batches() -> Result<()> {
3405 let left = build_table(
3406 ("a1", &vec![1, 2, 2]),
3407 ("b1", &vec![4, 5, 5]),
3408 ("c1", &vec![7, 8, 8]),
3409 );
3410 let right = build_table(
3411 ("a2", &vec![10, 20, 30]),
3412 ("b1", &vec![4, 5, 6]),
3413 ("c2", &vec![70, 80, 90]),
3414 );
3415 let on = vec![
3416 (
3417 Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3418 Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
3419 ),
3420 (
3421 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3422 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3423 ),
3424 ];
3425
3426 let (_, batches) =
3427 join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?;
3428 let expected = [
3429 "+----+----+----+",
3430 "| a1 | b1 | c1 |",
3431 "+----+----+----+",
3432 "| 1 | 4 | 7 |",
3433 "| 2 | 5 | 8 |",
3434 "| 2 | 5 | 8 |",
3435 "+----+----+----+",
3436 ];
3437 assert_eq!(batches.len(), 2);
3438 assert_eq!(batches[0].num_rows(), 2);
3439 assert_eq!(batches[1].num_rows(), 1);
3440 assert_batches_eq!(expected, &batches);
3441 Ok(())
3442 }
3443
3444 #[tokio::test]
3445 async fn join_semi() -> Result<()> {
3446 let left = build_table(
3447 ("a1", &vec![1, 2, 2, 3]),
3448 ("b1", &vec![4, 5, 5, 7]), ("c1", &vec![7, 8, 8, 9]),
3450 );
3451 let right = build_table(
3452 ("a2", &vec![10, 20, 30]),
3453 ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]),
3455 );
3456 let on = vec![(
3457 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3458 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3459 )];
3460
3461 let (_, batches) = join_collect(left, right, on, LeftSemi).await?;
3462 let expected = [
3463 "+----+----+----+",
3464 "| a1 | b1 | c1 |",
3465 "+----+----+----+",
3466 "| 1 | 4 | 7 |",
3467 "| 2 | 5 | 8 |",
3468 "| 2 | 5 | 8 |",
3469 "+----+----+----+",
3470 ];
3471 assert_batches_eq!(expected, &batches);
3473 Ok(())
3474 }
3475
3476 #[tokio::test]
3477 async fn join_left_mark() -> Result<()> {
3478 let left = build_table(
3479 ("a1", &vec![1, 2, 2, 3]),
3480 ("b1", &vec![4, 5, 5, 7]), ("c1", &vec![7, 8, 8, 9]),
3482 );
3483 let right = build_table(
3484 ("a2", &vec![10, 20, 30, 40]),
3485 ("b1", &vec![4, 4, 5, 6]), ("c2", &vec![60, 70, 80, 90]),
3487 );
3488 let on = vec![(
3489 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3490 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3491 )];
3492
3493 let (_, batches) = join_collect(left, right, on, LeftMark).await?;
3494 let expected = [
3495 "+----+----+----+-------+",
3496 "| a1 | b1 | c1 | mark |",
3497 "+----+----+----+-------+",
3498 "| 1 | 4 | 7 | true |",
3499 "| 2 | 5 | 8 | true |",
3500 "| 2 | 5 | 8 | true |",
3501 "| 3 | 7 | 9 | false |",
3502 "+----+----+----+-------+",
3503 ];
3504 assert_batches_eq!(expected, &batches);
3506 Ok(())
3507 }
3508
3509 #[tokio::test]
3510 async fn join_with_duplicated_column_names() -> Result<()> {
3511 let left = build_table(
3512 ("a", &vec![1, 2, 3]),
3513 ("b", &vec![4, 5, 7]),
3514 ("c", &vec![7, 8, 9]),
3515 );
3516 let right = build_table(
3517 ("a", &vec![10, 20, 30]),
3518 ("b", &vec![1, 2, 7]),
3519 ("c", &vec![70, 80, 90]),
3520 );
3521 let on = vec![(
3522 Arc::new(Column::new_with_schema("a", &left.schema())?) as _,
3524 Arc::new(Column::new_with_schema("b", &right.schema())?) as _,
3525 )];
3526
3527 let (_, batches) = join_collect(left, right, on, Inner).await?;
3528 let expected = [
3529 "+---+---+---+----+---+----+",
3530 "| a | b | c | a | b | c |",
3531 "+---+---+---+----+---+----+",
3532 "| 1 | 4 | 7 | 10 | 1 | 70 |",
3533 "| 2 | 5 | 8 | 20 | 2 | 80 |",
3534 "+---+---+---+----+---+----+",
3535 ];
3536 assert_batches_eq!(expected, &batches);
3538 Ok(())
3539 }
3540
3541 #[tokio::test]
3542 async fn join_date32() -> Result<()> {
3543 let left = build_date_table(
3544 ("a1", &vec![1, 2, 3]),
3545 ("b1", &vec![19107, 19108, 19108]), ("c1", &vec![7, 8, 9]),
3547 );
3548 let right = build_date_table(
3549 ("a2", &vec![10, 20, 30]),
3550 ("b1", &vec![19107, 19108, 19109]),
3551 ("c2", &vec![70, 80, 90]),
3552 );
3553
3554 let on = vec![(
3555 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3556 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3557 )];
3558
3559 let (_, batches) = join_collect(left, right, on, Inner).await?;
3560
3561 let expected = ["+------------+------------+------------+------------+------------+------------+",
3562 "| a1 | b1 | c1 | a2 | b1 | c2 |",
3563 "+------------+------------+------------+------------+------------+------------+",
3564 "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |",
3565 "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
3566 "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
3567 "+------------+------------+------------+------------+------------+------------+"];
3568 assert_batches_eq!(expected, &batches);
3570 Ok(())
3571 }
3572
3573 #[tokio::test]
3574 async fn join_date64() -> Result<()> {
3575 let left = build_date64_table(
3576 ("a1", &vec![1, 2, 3]),
3577 ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), ("c1", &vec![7, 8, 9]),
3579 );
3580 let right = build_date64_table(
3581 ("a2", &vec![10, 20, 30]),
3582 ("b1", &vec![1650703441000, 1650503441000, 1650903441000]),
3583 ("c2", &vec![70, 80, 90]),
3584 );
3585
3586 let on = vec![(
3587 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3588 Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3589 )];
3590
3591 let (_, batches) = join_collect(left, right, on, Inner).await?;
3592
3593 let expected = ["+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
3594 "| a1 | b1 | c1 | a2 | b1 | c2 |",
3595 "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
3596 "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |",
3597 "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
3598 "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
3599 "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+"];
3600 assert_batches_eq!(expected, &batches);
3602 Ok(())
3603 }
3604
3605 #[tokio::test]
3606 async fn join_left_sort_order() -> Result<()> {
3607 let left = build_table(
3608 ("a1", &vec![0, 1, 2, 3, 4, 5]),
3609 ("b1", &vec![3, 4, 5, 6, 6, 7]),
3610 ("c1", &vec![4, 5, 6, 7, 8, 9]),
3611 );
3612 let right = build_table(
3613 ("a2", &vec![0, 10, 20, 30, 40]),
3614 ("b2", &vec![2, 4, 6, 6, 8]),
3615 ("c2", &vec![50, 60, 70, 80, 90]),
3616 );
3617 let on = vec![(
3618 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3619 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3620 )];
3621
3622 let (_, batches) = join_collect(left, right, on, Left).await?;
3623 let expected = [
3624 "+----+----+----+----+----+----+",
3625 "| a1 | b1 | c1 | a2 | b2 | c2 |",
3626 "+----+----+----+----+----+----+",
3627 "| 0 | 3 | 4 | | | |",
3628 "| 1 | 4 | 5 | 10 | 4 | 60 |",
3629 "| 2 | 5 | 6 | | | |",
3630 "| 3 | 6 | 7 | 20 | 6 | 70 |",
3631 "| 3 | 6 | 7 | 30 | 6 | 80 |",
3632 "| 4 | 6 | 8 | 20 | 6 | 70 |",
3633 "| 4 | 6 | 8 | 30 | 6 | 80 |",
3634 "| 5 | 7 | 9 | | | |",
3635 "+----+----+----+----+----+----+",
3636 ];
3637 assert_batches_eq!(expected, &batches);
3638 Ok(())
3639 }
3640
3641 #[tokio::test]
3642 async fn join_right_sort_order() -> Result<()> {
3643 let left = build_table(
3644 ("a1", &vec![0, 1, 2, 3]),
3645 ("b1", &vec![3, 4, 5, 7]),
3646 ("c1", &vec![6, 7, 8, 9]),
3647 );
3648 let right = build_table(
3649 ("a2", &vec![0, 10, 20, 30]),
3650 ("b2", &vec![2, 4, 5, 6]),
3651 ("c2", &vec![60, 70, 80, 90]),
3652 );
3653 let on = vec![(
3654 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3655 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3656 )];
3657
3658 let (_, batches) = join_collect(left, right, on, Right).await?;
3659 let expected = [
3660 "+----+----+----+----+----+----+",
3661 "| a1 | b1 | c1 | a2 | b2 | c2 |",
3662 "+----+----+----+----+----+----+",
3663 "| | | | 0 | 2 | 60 |",
3664 "| 1 | 4 | 7 | 10 | 4 | 70 |",
3665 "| 2 | 5 | 8 | 20 | 5 | 80 |",
3666 "| | | | 30 | 6 | 90 |",
3667 "+----+----+----+----+----+----+",
3668 ];
3669 assert_batches_eq!(expected, &batches);
3670 Ok(())
3671 }
3672
3673 #[tokio::test]
3674 async fn join_left_multiple_batches() -> Result<()> {
3675 let left_batch_1 = build_table_i32(
3676 ("a1", &vec![0, 1, 2]),
3677 ("b1", &vec![3, 4, 5]),
3678 ("c1", &vec![4, 5, 6]),
3679 );
3680 let left_batch_2 = build_table_i32(
3681 ("a1", &vec![3, 4, 5, 6]),
3682 ("b1", &vec![6, 6, 7, 9]),
3683 ("c1", &vec![7, 8, 9, 9]),
3684 );
3685 let right_batch_1 = build_table_i32(
3686 ("a2", &vec![0, 10, 20]),
3687 ("b2", &vec![2, 4, 6]),
3688 ("c2", &vec![50, 60, 70]),
3689 );
3690 let right_batch_2 = build_table_i32(
3691 ("a2", &vec![30, 40]),
3692 ("b2", &vec![6, 8]),
3693 ("c2", &vec![80, 90]),
3694 );
3695 let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
3696 let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
3697 let on = vec![(
3698 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3699 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3700 )];
3701
3702 let (_, batches) = join_collect(left, right, on, Left).await?;
3703 let expected = vec![
3704 "+----+----+----+----+----+----+",
3705 "| a1 | b1 | c1 | a2 | b2 | c2 |",
3706 "+----+----+----+----+----+----+",
3707 "| 0 | 3 | 4 | | | |",
3708 "| 1 | 4 | 5 | 10 | 4 | 60 |",
3709 "| 2 | 5 | 6 | | | |",
3710 "| 3 | 6 | 7 | 20 | 6 | 70 |",
3711 "| 3 | 6 | 7 | 30 | 6 | 80 |",
3712 "| 4 | 6 | 8 | 20 | 6 | 70 |",
3713 "| 4 | 6 | 8 | 30 | 6 | 80 |",
3714 "| 5 | 7 | 9 | | | |",
3715 "| 6 | 9 | 9 | | | |",
3716 "+----+----+----+----+----+----+",
3717 ];
3718 assert_batches_eq!(expected, &batches);
3719 Ok(())
3720 }
3721
3722 #[tokio::test]
3723 async fn join_right_multiple_batches() -> Result<()> {
3724 let right_batch_1 = build_table_i32(
3725 ("a2", &vec![0, 1, 2]),
3726 ("b2", &vec![3, 4, 5]),
3727 ("c2", &vec![4, 5, 6]),
3728 );
3729 let right_batch_2 = build_table_i32(
3730 ("a2", &vec![3, 4, 5, 6]),
3731 ("b2", &vec![6, 6, 7, 9]),
3732 ("c2", &vec![7, 8, 9, 9]),
3733 );
3734 let left_batch_1 = build_table_i32(
3735 ("a1", &vec![0, 10, 20]),
3736 ("b1", &vec![2, 4, 6]),
3737 ("c1", &vec![50, 60, 70]),
3738 );
3739 let left_batch_2 = build_table_i32(
3740 ("a1", &vec![30, 40]),
3741 ("b1", &vec![6, 8]),
3742 ("c1", &vec![80, 90]),
3743 );
3744 let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
3745 let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
3746 let on = vec![(
3747 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3748 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3749 )];
3750
3751 let (_, batches) = join_collect(left, right, on, Right).await?;
3752 let expected = vec![
3753 "+----+----+----+----+----+----+",
3754 "| a1 | b1 | c1 | a2 | b2 | c2 |",
3755 "+----+----+----+----+----+----+",
3756 "| | | | 0 | 3 | 4 |",
3757 "| 10 | 4 | 60 | 1 | 4 | 5 |",
3758 "| | | | 2 | 5 | 6 |",
3759 "| 20 | 6 | 70 | 3 | 6 | 7 |",
3760 "| 30 | 6 | 80 | 3 | 6 | 7 |",
3761 "| 20 | 6 | 70 | 4 | 6 | 8 |",
3762 "| 30 | 6 | 80 | 4 | 6 | 8 |",
3763 "| | | | 5 | 7 | 9 |",
3764 "| | | | 6 | 9 | 9 |",
3765 "+----+----+----+----+----+----+",
3766 ];
3767 assert_batches_eq!(expected, &batches);
3768 Ok(())
3769 }
3770
3771 #[tokio::test]
3772 async fn join_full_multiple_batches() -> Result<()> {
3773 let left_batch_1 = build_table_i32(
3774 ("a1", &vec![0, 1, 2]),
3775 ("b1", &vec![3, 4, 5]),
3776 ("c1", &vec![4, 5, 6]),
3777 );
3778 let left_batch_2 = build_table_i32(
3779 ("a1", &vec![3, 4, 5, 6]),
3780 ("b1", &vec![6, 6, 7, 9]),
3781 ("c1", &vec![7, 8, 9, 9]),
3782 );
3783 let right_batch_1 = build_table_i32(
3784 ("a2", &vec![0, 10, 20]),
3785 ("b2", &vec![2, 4, 6]),
3786 ("c2", &vec![50, 60, 70]),
3787 );
3788 let right_batch_2 = build_table_i32(
3789 ("a2", &vec![30, 40]),
3790 ("b2", &vec![6, 8]),
3791 ("c2", &vec![80, 90]),
3792 );
3793 let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
3794 let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
3795 let on = vec![(
3796 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3797 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3798 )];
3799
3800 let (_, batches) = join_collect(left, right, on, Full).await?;
3801 let expected = vec![
3802 "+----+----+----+----+----+----+",
3803 "| a1 | b1 | c1 | a2 | b2 | c2 |",
3804 "+----+----+----+----+----+----+",
3805 "| | | | 0 | 2 | 50 |",
3806 "| | | | 40 | 8 | 90 |",
3807 "| 0 | 3 | 4 | | | |",
3808 "| 1 | 4 | 5 | 10 | 4 | 60 |",
3809 "| 2 | 5 | 6 | | | |",
3810 "| 3 | 6 | 7 | 20 | 6 | 70 |",
3811 "| 3 | 6 | 7 | 30 | 6 | 80 |",
3812 "| 4 | 6 | 8 | 20 | 6 | 70 |",
3813 "| 4 | 6 | 8 | 30 | 6 | 80 |",
3814 "| 5 | 7 | 9 | | | |",
3815 "| 6 | 9 | 9 | | | |",
3816 "+----+----+----+----+----+----+",
3817 ];
3818 assert_batches_sorted_eq!(expected, &batches);
3819 Ok(())
3820 }
3821
3822 #[tokio::test]
3823 async fn overallocation_single_batch_no_spill() -> Result<()> {
3824 let left = build_table(
3825 ("a1", &vec![0, 1, 2, 3, 4, 5]),
3826 ("b1", &vec![1, 2, 3, 4, 5, 6]),
3827 ("c1", &vec![4, 5, 6, 7, 8, 9]),
3828 );
3829 let right = build_table(
3830 ("a2", &vec![0, 10, 20, 30, 40]),
3831 ("b2", &vec![1, 3, 4, 6, 8]),
3832 ("c2", &vec![50, 60, 70, 80, 90]),
3833 );
3834 let on = vec![(
3835 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3836 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3837 )];
3838 let sort_options = vec![SortOptions::default(); on.len()];
3839
3840 let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark];
3841
3842 let runtime = RuntimeEnvBuilder::new()
3844 .with_memory_limit(100, 1.0)
3845 .with_disk_manager(DiskManagerConfig::Disabled)
3846 .build_arc()?;
3847 let session_config = SessionConfig::default().with_batch_size(50);
3848
3849 for join_type in join_types {
3850 let task_ctx = TaskContext::default()
3851 .with_session_config(session_config.clone())
3852 .with_runtime(Arc::clone(&runtime));
3853 let task_ctx = Arc::new(task_ctx);
3854
3855 let join = join_with_options(
3856 Arc::clone(&left),
3857 Arc::clone(&right),
3858 on.clone(),
3859 join_type,
3860 sort_options.clone(),
3861 false,
3862 )?;
3863
3864 let stream = join.execute(0, task_ctx)?;
3865 let err = common::collect(stream).await.unwrap_err();
3866
3867 assert_contains!(err.to_string(), "Failed to allocate additional");
3868 assert_contains!(err.to_string(), "SMJStream[0]");
3869 assert_contains!(err.to_string(), "Disk spilling disabled");
3870 assert!(join.metrics().is_some());
3871 assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
3872 assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
3873 assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
3874 }
3875
3876 Ok(())
3877 }
3878
3879 #[tokio::test]
3880 async fn overallocation_multi_batch_no_spill() -> Result<()> {
3881 let left_batch_1 = build_table_i32(
3882 ("a1", &vec![0, 1]),
3883 ("b1", &vec![1, 1]),
3884 ("c1", &vec![4, 5]),
3885 );
3886 let left_batch_2 = build_table_i32(
3887 ("a1", &vec![2, 3]),
3888 ("b1", &vec![1, 1]),
3889 ("c1", &vec![6, 7]),
3890 );
3891 let left_batch_3 = build_table_i32(
3892 ("a1", &vec![4, 5]),
3893 ("b1", &vec![1, 1]),
3894 ("c1", &vec![8, 9]),
3895 );
3896 let right_batch_1 = build_table_i32(
3897 ("a2", &vec![0, 10]),
3898 ("b2", &vec![1, 1]),
3899 ("c2", &vec![50, 60]),
3900 );
3901 let right_batch_2 = build_table_i32(
3902 ("a2", &vec![20, 30]),
3903 ("b2", &vec![1, 1]),
3904 ("c2", &vec![70, 80]),
3905 );
3906 let right_batch_3 =
3907 build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
3908 let left =
3909 build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
3910 let right =
3911 build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
3912 let on = vec![(
3913 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3914 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3915 )];
3916 let sort_options = vec![SortOptions::default(); on.len()];
3917
3918 let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark];
3919
3920 let runtime = RuntimeEnvBuilder::new()
3922 .with_memory_limit(100, 1.0)
3923 .with_disk_manager(DiskManagerConfig::Disabled)
3924 .build_arc()?;
3925 let session_config = SessionConfig::default().with_batch_size(50);
3926
3927 for join_type in join_types {
3928 let task_ctx = TaskContext::default()
3929 .with_session_config(session_config.clone())
3930 .with_runtime(Arc::clone(&runtime));
3931 let task_ctx = Arc::new(task_ctx);
3932 let join = join_with_options(
3933 Arc::clone(&left),
3934 Arc::clone(&right),
3935 on.clone(),
3936 join_type,
3937 sort_options.clone(),
3938 false,
3939 )?;
3940
3941 let stream = join.execute(0, task_ctx)?;
3942 let err = common::collect(stream).await.unwrap_err();
3943
3944 assert_contains!(err.to_string(), "Failed to allocate additional");
3945 assert_contains!(err.to_string(), "SMJStream[0]");
3946 assert_contains!(err.to_string(), "Disk spilling disabled");
3947 assert!(join.metrics().is_some());
3948 assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
3949 assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
3950 assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
3951 }
3952
3953 Ok(())
3954 }
3955
3956 #[tokio::test]
3957 async fn overallocation_single_batch_spill() -> Result<()> {
3958 let left = build_table(
3959 ("a1", &vec![0, 1, 2, 3, 4, 5]),
3960 ("b1", &vec![1, 2, 3, 4, 5, 6]),
3961 ("c1", &vec![4, 5, 6, 7, 8, 9]),
3962 );
3963 let right = build_table(
3964 ("a2", &vec![0, 10, 20, 30, 40]),
3965 ("b2", &vec![1, 3, 4, 6, 8]),
3966 ("c2", &vec![50, 60, 70, 80, 90]),
3967 );
3968 let on = vec![(
3969 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3970 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3971 )];
3972 let sort_options = vec![SortOptions::default(); on.len()];
3973
3974 let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark];
3975
3976 let runtime = RuntimeEnvBuilder::new()
3978 .with_memory_limit(100, 1.0)
3979 .with_disk_manager(DiskManagerConfig::NewOs)
3980 .build_arc()?;
3981
3982 for batch_size in [1, 50] {
3983 let session_config = SessionConfig::default().with_batch_size(batch_size);
3984
3985 for join_type in &join_types {
3986 let task_ctx = TaskContext::default()
3987 .with_session_config(session_config.clone())
3988 .with_runtime(Arc::clone(&runtime));
3989 let task_ctx = Arc::new(task_ctx);
3990
3991 let join = join_with_options(
3992 Arc::clone(&left),
3993 Arc::clone(&right),
3994 on.clone(),
3995 *join_type,
3996 sort_options.clone(),
3997 false,
3998 )?;
3999
4000 let stream = join.execute(0, task_ctx)?;
4001 let spilled_join_result = common::collect(stream).await.unwrap();
4002
4003 assert!(join.metrics().is_some());
4004 assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
4005 assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
4006 assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
4007
4008 let task_ctx_no_spill =
4010 TaskContext::default().with_session_config(session_config.clone());
4011 let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
4012
4013 let join = join_with_options(
4014 Arc::clone(&left),
4015 Arc::clone(&right),
4016 on.clone(),
4017 *join_type,
4018 sort_options.clone(),
4019 false,
4020 )?;
4021 let stream = join.execute(0, task_ctx_no_spill)?;
4022 let no_spilled_join_result = common::collect(stream).await.unwrap();
4023
4024 assert!(join.metrics().is_some());
4025 assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
4026 assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
4027 assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
4028 assert_eq!(spilled_join_result, no_spilled_join_result);
4030 }
4031 }
4032
4033 Ok(())
4034 }
4035
4036 #[tokio::test]
4037 async fn overallocation_multi_batch_spill() -> Result<()> {
4038 let left_batch_1 = build_table_i32(
4039 ("a1", &vec![0, 1]),
4040 ("b1", &vec![1, 1]),
4041 ("c1", &vec![4, 5]),
4042 );
4043 let left_batch_2 = build_table_i32(
4044 ("a1", &vec![2, 3]),
4045 ("b1", &vec![1, 1]),
4046 ("c1", &vec![6, 7]),
4047 );
4048 let left_batch_3 = build_table_i32(
4049 ("a1", &vec![4, 5]),
4050 ("b1", &vec![1, 1]),
4051 ("c1", &vec![8, 9]),
4052 );
4053 let right_batch_1 = build_table_i32(
4054 ("a2", &vec![0, 10]),
4055 ("b2", &vec![1, 1]),
4056 ("c2", &vec![50, 60]),
4057 );
4058 let right_batch_2 = build_table_i32(
4059 ("a2", &vec![20, 30]),
4060 ("b2", &vec![1, 1]),
4061 ("c2", &vec![70, 80]),
4062 );
4063 let right_batch_3 =
4064 build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
4065 let left =
4066 build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
4067 let right =
4068 build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
4069 let on = vec![(
4070 Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
4071 Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
4072 )];
4073 let sort_options = vec![SortOptions::default(); on.len()];
4074
4075 let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark];
4076
4077 let runtime = RuntimeEnvBuilder::new()
4079 .with_memory_limit(500, 1.0)
4080 .with_disk_manager(DiskManagerConfig::NewOs)
4081 .build_arc()?;
4082
4083 for batch_size in [1, 50] {
4084 let session_config = SessionConfig::default().with_batch_size(batch_size);
4085
4086 for join_type in &join_types {
4087 let task_ctx = TaskContext::default()
4088 .with_session_config(session_config.clone())
4089 .with_runtime(Arc::clone(&runtime));
4090 let task_ctx = Arc::new(task_ctx);
4091 let join = join_with_options(
4092 Arc::clone(&left),
4093 Arc::clone(&right),
4094 on.clone(),
4095 *join_type,
4096 sort_options.clone(),
4097 false,
4098 )?;
4099
4100 let stream = join.execute(0, task_ctx)?;
4101 let spilled_join_result = common::collect(stream).await.unwrap();
4102 assert!(join.metrics().is_some());
4103 assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
4104 assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
4105 assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
4106
4107 let task_ctx_no_spill =
4109 TaskContext::default().with_session_config(session_config.clone());
4110 let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
4111
4112 let join = join_with_options(
4113 Arc::clone(&left),
4114 Arc::clone(&right),
4115 on.clone(),
4116 *join_type,
4117 sort_options.clone(),
4118 false,
4119 )?;
4120 let stream = join.execute(0, task_ctx_no_spill)?;
4121 let no_spilled_join_result = common::collect(stream).await.unwrap();
4122
4123 assert!(join.metrics().is_some());
4124 assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
4125 assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
4126 assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
4127 assert_eq!(spilled_join_result, no_spilled_join_result);
4129 }
4130 }
4131
4132 Ok(())
4133 }
4134
4135 fn build_joined_record_batches() -> Result<JoinedRecordBatches> {
4136 let schema = Arc::new(Schema::new(vec![
4137 Field::new("a", DataType::Int32, true),
4138 Field::new("b", DataType::Int32, true),
4139 Field::new("x", DataType::Int32, true),
4140 Field::new("y", DataType::Int32, true),
4141 ]));
4142
4143 let mut batches = JoinedRecordBatches {
4144 batches: vec![],
4145 filter_mask: BooleanBuilder::new(),
4146 row_indices: UInt64Builder::new(),
4147 batch_ids: vec![],
4148 };
4149
4150 batches.batches.push(RecordBatch::try_new(
4152 Arc::clone(&schema),
4153 vec![
4154 Arc::new(Int32Array::from(vec![1, 1])),
4155 Arc::new(Int32Array::from(vec![10, 10])),
4156 Arc::new(Int32Array::from(vec![1, 1])),
4157 Arc::new(Int32Array::from(vec![11, 9])),
4158 ],
4159 )?);
4160
4161 batches.batches.push(RecordBatch::try_new(
4162 Arc::clone(&schema),
4163 vec![
4164 Arc::new(Int32Array::from(vec![1])),
4165 Arc::new(Int32Array::from(vec![11])),
4166 Arc::new(Int32Array::from(vec![1])),
4167 Arc::new(Int32Array::from(vec![12])),
4168 ],
4169 )?);
4170
4171 batches.batches.push(RecordBatch::try_new(
4172 Arc::clone(&schema),
4173 vec![
4174 Arc::new(Int32Array::from(vec![1, 1])),
4175 Arc::new(Int32Array::from(vec![12, 12])),
4176 Arc::new(Int32Array::from(vec![1, 1])),
4177 Arc::new(Int32Array::from(vec![11, 13])),
4178 ],
4179 )?);
4180
4181 batches.batches.push(RecordBatch::try_new(
4182 Arc::clone(&schema),
4183 vec![
4184 Arc::new(Int32Array::from(vec![1])),
4185 Arc::new(Int32Array::from(vec![13])),
4186 Arc::new(Int32Array::from(vec![1])),
4187 Arc::new(Int32Array::from(vec![12])),
4188 ],
4189 )?);
4190
4191 batches.batches.push(RecordBatch::try_new(
4192 Arc::clone(&schema),
4193 vec![
4194 Arc::new(Int32Array::from(vec![1, 1])),
4195 Arc::new(Int32Array::from(vec![14, 14])),
4196 Arc::new(Int32Array::from(vec![1, 1])),
4197 Arc::new(Int32Array::from(vec![12, 11])),
4198 ],
4199 )?);
4200
4201 let streamed_indices = vec![0, 0];
4202 batches.batch_ids.extend(vec![0; streamed_indices.len()]);
4203 batches
4204 .row_indices
4205 .extend(&UInt64Array::from(streamed_indices));
4206
4207 let streamed_indices = vec![1];
4208 batches.batch_ids.extend(vec![0; streamed_indices.len()]);
4209 batches
4210 .row_indices
4211 .extend(&UInt64Array::from(streamed_indices));
4212
4213 let streamed_indices = vec![0, 0];
4214 batches.batch_ids.extend(vec![1; streamed_indices.len()]);
4215 batches
4216 .row_indices
4217 .extend(&UInt64Array::from(streamed_indices));
4218
4219 let streamed_indices = vec![0];
4220 batches.batch_ids.extend(vec![2; streamed_indices.len()]);
4221 batches
4222 .row_indices
4223 .extend(&UInt64Array::from(streamed_indices));
4224
4225 let streamed_indices = vec![0, 0];
4226 batches.batch_ids.extend(vec![3; streamed_indices.len()]);
4227 batches
4228 .row_indices
4229 .extend(&UInt64Array::from(streamed_indices));
4230
4231 batches
4232 .filter_mask
4233 .extend(&BooleanArray::from(vec![true, false]));
4234 batches.filter_mask.extend(&BooleanArray::from(vec![true]));
4235 batches
4236 .filter_mask
4237 .extend(&BooleanArray::from(vec![false, true]));
4238 batches.filter_mask.extend(&BooleanArray::from(vec![false]));
4239 batches
4240 .filter_mask
4241 .extend(&BooleanArray::from(vec![false, false]));
4242
4243 Ok(batches)
4244 }
4245
4246 #[tokio::test]
4247 async fn test_left_outer_join_filtered_mask() -> Result<()> {
4248 let mut joined_batches = build_joined_record_batches()?;
4249 let schema = joined_batches.batches.first().unwrap().schema();
4250
4251 let output = concat_batches(&schema, &joined_batches.batches)?;
4252 let out_mask = joined_batches.filter_mask.finish();
4253 let out_indices = joined_batches.row_indices.finish();
4254
4255 assert_eq!(
4256 get_corrected_filter_mask(
4257 Left,
4258 &UInt64Array::from(vec![0]),
4259 &[0usize],
4260 &BooleanArray::from(vec![true]),
4261 output.num_rows()
4262 )
4263 .unwrap(),
4264 BooleanArray::from(vec![
4265 true, false, false, false, false, false, false, false
4266 ])
4267 );
4268
4269 assert_eq!(
4270 get_corrected_filter_mask(
4271 Left,
4272 &UInt64Array::from(vec![0]),
4273 &[0usize],
4274 &BooleanArray::from(vec![false]),
4275 output.num_rows()
4276 )
4277 .unwrap(),
4278 BooleanArray::from(vec![
4279 false, false, false, false, false, false, false, false
4280 ])
4281 );
4282
4283 assert_eq!(
4284 get_corrected_filter_mask(
4285 Left,
4286 &UInt64Array::from(vec![0, 0]),
4287 &[0usize; 2],
4288 &BooleanArray::from(vec![true, true]),
4289 output.num_rows()
4290 )
4291 .unwrap(),
4292 BooleanArray::from(vec![
4293 true, true, false, false, false, false, false, false
4294 ])
4295 );
4296
4297 assert_eq!(
4298 get_corrected_filter_mask(
4299 Left,
4300 &UInt64Array::from(vec![0, 0, 0]),
4301 &[0usize; 3],
4302 &BooleanArray::from(vec![true, true, true]),
4303 output.num_rows()
4304 )
4305 .unwrap(),
4306 BooleanArray::from(vec![true, true, true, false, false, false, false, false])
4307 );
4308
4309 assert_eq!(
4310 get_corrected_filter_mask(
4311 Left,
4312 &UInt64Array::from(vec![0, 0, 0]),
4313 &[0usize; 3],
4314 &BooleanArray::from(vec![true, false, true]),
4315 output.num_rows()
4316 )
4317 .unwrap(),
4318 BooleanArray::from(vec![
4319 Some(true),
4320 None,
4321 Some(true),
4322 Some(false),
4323 Some(false),
4324 Some(false),
4325 Some(false),
4326 Some(false)
4327 ])
4328 );
4329
4330 assert_eq!(
4331 get_corrected_filter_mask(
4332 Left,
4333 &UInt64Array::from(vec![0, 0, 0]),
4334 &[0usize; 3],
4335 &BooleanArray::from(vec![false, false, true]),
4336 output.num_rows()
4337 )
4338 .unwrap(),
4339 BooleanArray::from(vec![
4340 None,
4341 None,
4342 Some(true),
4343 Some(false),
4344 Some(false),
4345 Some(false),
4346 Some(false),
4347 Some(false)
4348 ])
4349 );
4350
4351 assert_eq!(
4352 get_corrected_filter_mask(
4353 Left,
4354 &UInt64Array::from(vec![0, 0, 0]),
4355 &[0usize; 3],
4356 &BooleanArray::from(vec![false, true, true]),
4357 output.num_rows()
4358 )
4359 .unwrap(),
4360 BooleanArray::from(vec![
4361 None,
4362 Some(true),
4363 Some(true),
4364 Some(false),
4365 Some(false),
4366 Some(false),
4367 Some(false),
4368 Some(false)
4369 ])
4370 );
4371
4372 assert_eq!(
4373 get_corrected_filter_mask(
4374 Left,
4375 &UInt64Array::from(vec![0, 0, 0]),
4376 &[0usize; 3],
4377 &BooleanArray::from(vec![false, false, false]),
4378 output.num_rows()
4379 )
4380 .unwrap(),
4381 BooleanArray::from(vec![
4382 None,
4383 None,
4384 Some(false),
4385 Some(false),
4386 Some(false),
4387 Some(false),
4388 Some(false),
4389 Some(false)
4390 ])
4391 );
4392
4393 let corrected_mask = get_corrected_filter_mask(
4394 Left,
4395 &out_indices,
4396 &joined_batches.batch_ids,
4397 &out_mask,
4398 output.num_rows(),
4399 )
4400 .unwrap();
4401
4402 assert_eq!(
4403 corrected_mask,
4404 BooleanArray::from(vec![
4405 Some(true),
4406 None,
4407 Some(true),
4408 None,
4409 Some(true),
4410 Some(false),
4411 None,
4412 Some(false)
4413 ])
4414 );
4415
4416 let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
4417
4418 assert_batches_eq!(
4419 &[
4420 "+---+----+---+----+",
4421 "| a | b | x | y |",
4422 "+---+----+---+----+",
4423 "| 1 | 10 | 1 | 11 |",
4424 "| 1 | 11 | 1 | 12 |",
4425 "| 1 | 12 | 1 | 13 |",
4426 "+---+----+---+----+",
4427 ],
4428 &[filtered_rb]
4429 );
4430
4431 let null_mask = arrow::compute::not(&corrected_mask)?;
4434 assert_eq!(
4435 null_mask,
4436 BooleanArray::from(vec![
4437 Some(false),
4438 None,
4439 Some(false),
4440 None,
4441 Some(false),
4442 Some(true),
4443 None,
4444 Some(true)
4445 ])
4446 );
4447
4448 let null_joined_batch = filter_record_batch(&output, &null_mask)?;
4449
4450 assert_batches_eq!(
4451 &[
4452 "+---+----+---+----+",
4453 "| a | b | x | y |",
4454 "+---+----+---+----+",
4455 "| 1 | 13 | 1 | 12 |",
4456 "| 1 | 14 | 1 | 11 |",
4457 "+---+----+---+----+",
4458 ],
4459 &[null_joined_batch]
4460 );
4461 Ok(())
4462 }
4463
4464 #[tokio::test]
4465 async fn test_left_semi_join_filtered_mask() -> Result<()> {
4466 let mut joined_batches = build_joined_record_batches()?;
4467 let schema = joined_batches.batches.first().unwrap().schema();
4468
4469 let output = concat_batches(&schema, &joined_batches.batches)?;
4470 let out_mask = joined_batches.filter_mask.finish();
4471 let out_indices = joined_batches.row_indices.finish();
4472
4473 assert_eq!(
4474 get_corrected_filter_mask(
4475 LeftSemi,
4476 &UInt64Array::from(vec![0]),
4477 &[0usize],
4478 &BooleanArray::from(vec![true]),
4479 output.num_rows()
4480 )
4481 .unwrap(),
4482 BooleanArray::from(vec![true])
4483 );
4484
4485 assert_eq!(
4486 get_corrected_filter_mask(
4487 LeftSemi,
4488 &UInt64Array::from(vec![0]),
4489 &[0usize],
4490 &BooleanArray::from(vec![false]),
4491 output.num_rows()
4492 )
4493 .unwrap(),
4494 BooleanArray::from(vec![None])
4495 );
4496
4497 assert_eq!(
4498 get_corrected_filter_mask(
4499 LeftSemi,
4500 &UInt64Array::from(vec![0, 0]),
4501 &[0usize; 2],
4502 &BooleanArray::from(vec![true, true]),
4503 output.num_rows()
4504 )
4505 .unwrap(),
4506 BooleanArray::from(vec![Some(true), None])
4507 );
4508
4509 assert_eq!(
4510 get_corrected_filter_mask(
4511 LeftSemi,
4512 &UInt64Array::from(vec![0, 0, 0]),
4513 &[0usize; 3],
4514 &BooleanArray::from(vec![true, true, true]),
4515 output.num_rows()
4516 )
4517 .unwrap(),
4518 BooleanArray::from(vec![Some(true), None, None])
4519 );
4520
4521 assert_eq!(
4522 get_corrected_filter_mask(
4523 LeftSemi,
4524 &UInt64Array::from(vec![0, 0, 0]),
4525 &[0usize; 3],
4526 &BooleanArray::from(vec![true, false, true]),
4527 output.num_rows()
4528 )
4529 .unwrap(),
4530 BooleanArray::from(vec![Some(true), None, None])
4531 );
4532
4533 assert_eq!(
4534 get_corrected_filter_mask(
4535 LeftSemi,
4536 &UInt64Array::from(vec![0, 0, 0]),
4537 &[0usize; 3],
4538 &BooleanArray::from(vec![false, false, true]),
4539 output.num_rows()
4540 )
4541 .unwrap(),
4542 BooleanArray::from(vec![None, None, Some(true),])
4543 );
4544
4545 assert_eq!(
4546 get_corrected_filter_mask(
4547 LeftSemi,
4548 &UInt64Array::from(vec![0, 0, 0]),
4549 &[0usize; 3],
4550 &BooleanArray::from(vec![false, true, true]),
4551 output.num_rows()
4552 )
4553 .unwrap(),
4554 BooleanArray::from(vec![None, Some(true), None])
4555 );
4556
4557 assert_eq!(
4558 get_corrected_filter_mask(
4559 LeftSemi,
4560 &UInt64Array::from(vec![0, 0, 0]),
4561 &[0usize; 3],
4562 &BooleanArray::from(vec![false, false, false]),
4563 output.num_rows()
4564 )
4565 .unwrap(),
4566 BooleanArray::from(vec![None, None, None])
4567 );
4568
4569 let corrected_mask = get_corrected_filter_mask(
4570 LeftSemi,
4571 &out_indices,
4572 &joined_batches.batch_ids,
4573 &out_mask,
4574 output.num_rows(),
4575 )
4576 .unwrap();
4577
4578 assert_eq!(
4579 corrected_mask,
4580 BooleanArray::from(vec![
4581 Some(true),
4582 None,
4583 Some(true),
4584 None,
4585 Some(true),
4586 None,
4587 None,
4588 None
4589 ])
4590 );
4591
4592 let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
4593
4594 assert_batches_eq!(
4595 &[
4596 "+---+----+---+----+",
4597 "| a | b | x | y |",
4598 "+---+----+---+----+",
4599 "| 1 | 10 | 1 | 11 |",
4600 "| 1 | 11 | 1 | 12 |",
4601 "| 1 | 12 | 1 | 13 |",
4602 "+---+----+---+----+",
4603 ],
4604 &[filtered_rb]
4605 );
4606
4607 let null_mask = arrow::compute::not(&corrected_mask)?;
4609 assert_eq!(
4610 null_mask,
4611 BooleanArray::from(vec![
4612 Some(false),
4613 None,
4614 Some(false),
4615 None,
4616 Some(false),
4617 None,
4618 None,
4619 None
4620 ])
4621 );
4622
4623 let null_joined_batch = filter_record_batch(&output, &null_mask)?;
4624
4625 assert_batches_eq!(
4626 &[
4627 "+---+---+---+---+",
4628 "| a | b | x | y |",
4629 "+---+---+---+---+",
4630 "+---+---+---+---+",
4631 ],
4632 &[null_joined_batch]
4633 );
4634 Ok(())
4635 }
4636
4637 #[tokio::test]
4638 async fn test_anti_join_filtered_mask() -> Result<()> {
4639 for join_type in [LeftAnti, RightAnti] {
4640 let mut joined_batches = build_joined_record_batches()?;
4641 let schema = joined_batches.batches.first().unwrap().schema();
4642
4643 let output = concat_batches(&schema, &joined_batches.batches)?;
4644 let out_mask = joined_batches.filter_mask.finish();
4645 let out_indices = joined_batches.row_indices.finish();
4646
4647 assert_eq!(
4648 get_corrected_filter_mask(
4649 join_type,
4650 &UInt64Array::from(vec![0]),
4651 &[0usize],
4652 &BooleanArray::from(vec![true]),
4653 1
4654 )
4655 .unwrap(),
4656 BooleanArray::from(vec![None])
4657 );
4658
4659 assert_eq!(
4660 get_corrected_filter_mask(
4661 join_type,
4662 &UInt64Array::from(vec![0]),
4663 &[0usize],
4664 &BooleanArray::from(vec![false]),
4665 1
4666 )
4667 .unwrap(),
4668 BooleanArray::from(vec![Some(true)])
4669 );
4670
4671 assert_eq!(
4672 get_corrected_filter_mask(
4673 join_type,
4674 &UInt64Array::from(vec![0, 0]),
4675 &[0usize; 2],
4676 &BooleanArray::from(vec![true, true]),
4677 2
4678 )
4679 .unwrap(),
4680 BooleanArray::from(vec![None, None])
4681 );
4682
4683 assert_eq!(
4684 get_corrected_filter_mask(
4685 join_type,
4686 &UInt64Array::from(vec![0, 0, 0]),
4687 &[0usize; 3],
4688 &BooleanArray::from(vec![true, true, true]),
4689 3
4690 )
4691 .unwrap(),
4692 BooleanArray::from(vec![None, None, None])
4693 );
4694
4695 assert_eq!(
4696 get_corrected_filter_mask(
4697 join_type,
4698 &UInt64Array::from(vec![0, 0, 0]),
4699 &[0usize; 3],
4700 &BooleanArray::from(vec![true, false, true]),
4701 3
4702 )
4703 .unwrap(),
4704 BooleanArray::from(vec![None, None, None])
4705 );
4706
4707 assert_eq!(
4708 get_corrected_filter_mask(
4709 join_type,
4710 &UInt64Array::from(vec![0, 0, 0]),
4711 &[0usize; 3],
4712 &BooleanArray::from(vec![false, false, true]),
4713 3
4714 )
4715 .unwrap(),
4716 BooleanArray::from(vec![None, None, None])
4717 );
4718
4719 assert_eq!(
4720 get_corrected_filter_mask(
4721 join_type,
4722 &UInt64Array::from(vec![0, 0, 0]),
4723 &[0usize; 3],
4724 &BooleanArray::from(vec![false, true, true]),
4725 3
4726 )
4727 .unwrap(),
4728 BooleanArray::from(vec![None, None, None])
4729 );
4730
4731 assert_eq!(
4732 get_corrected_filter_mask(
4733 join_type,
4734 &UInt64Array::from(vec![0, 0, 0]),
4735 &[0usize; 3],
4736 &BooleanArray::from(vec![false, false, false]),
4737 3
4738 )
4739 .unwrap(),
4740 BooleanArray::from(vec![None, None, Some(true)])
4741 );
4742
4743 let corrected_mask = get_corrected_filter_mask(
4744 join_type,
4745 &out_indices,
4746 &joined_batches.batch_ids,
4747 &out_mask,
4748 output.num_rows(),
4749 )
4750 .unwrap();
4751
4752 assert_eq!(
4753 corrected_mask,
4754 BooleanArray::from(vec![
4755 None,
4756 None,
4757 None,
4758 None,
4759 None,
4760 Some(true),
4761 None,
4762 Some(true)
4763 ])
4764 );
4765
4766 let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
4767
4768 assert_batches_eq!(
4769 &[
4770 "+---+----+---+----+",
4771 "| a | b | x | y |",
4772 "+---+----+---+----+",
4773 "| 1 | 13 | 1 | 12 |",
4774 "| 1 | 14 | 1 | 11 |",
4775 "+---+----+---+----+",
4776 ],
4777 &[filtered_rb]
4778 );
4779
4780 let null_mask = arrow::compute::not(&corrected_mask)?;
4782 assert_eq!(
4783 null_mask,
4784 BooleanArray::from(vec![
4785 None,
4786 None,
4787 None,
4788 None,
4789 None,
4790 Some(false),
4791 None,
4792 Some(false),
4793 ])
4794 );
4795
4796 let null_joined_batch = filter_record_batch(&output, &null_mask)?;
4797
4798 assert_batches_eq!(
4799 &[
4800 "+---+---+---+---+",
4801 "| a | b | x | y |",
4802 "+---+---+---+---+",
4803 "+---+---+---+---+",
4804 ],
4805 &[null_joined_batch]
4806 );
4807 }
4808 Ok(())
4809 }
4810
4811 fn columns(schema: &Schema) -> Vec<String> {
4813 schema.fields().iter().map(|f| f.name().clone()).collect()
4814 }
4815}