datafusion_physical_plan/joins/
sort_merge_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the Sort-Merge join execution plan.
19//! A Sort-Merge join plan consumes two sorted children plan and produces
20//! joined output by given join type and other options.
21//! Sort-Merge join feature is currently experimental.
22
23use std::any::Any;
24use std::cmp::Ordering;
25use std::collections::{HashMap, VecDeque};
26use std::fmt::Formatter;
27use std::fs::File;
28use std::io::BufReader;
29use std::mem::size_of;
30use std::ops::Range;
31use std::pin::Pin;
32use std::sync::atomic::AtomicUsize;
33use std::sync::atomic::Ordering::Relaxed;
34use std::sync::Arc;
35use std::task::{Context, Poll};
36
37use crate::execution_plan::{boundedness_from_children, EmissionType};
38use crate::expressions::PhysicalSortExpr;
39use crate::joins::utils::{
40    build_join_schema, check_join_is_valid, estimate_join_statistics,
41    reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn,
42    JoinOnRef,
43};
44use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
45use crate::projection::{
46    join_allows_pushdown, join_table_borders, new_join_children,
47    physical_to_column_exprs, update_join_on, ProjectionExec,
48};
49use crate::spill::spill_record_batches;
50use crate::{
51    metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
52    ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream,
53    SendableRecordBatchStream, Statistics,
54};
55
56use arrow::array::{types::UInt64Type, *};
57use arrow::compute::{
58    self, concat_batches, filter_record_batch, is_not_null, take, SortOptions,
59};
60use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
61use arrow::error::ArrowError;
62use arrow::ipc::reader::StreamReader;
63use datafusion_common::{
64    exec_err, internal_err, not_impl_err, plan_err, DataFusionError, HashSet, JoinSide,
65    JoinType, Result,
66};
67use datafusion_execution::disk_manager::RefCountedTempFile;
68use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
69use datafusion_execution::runtime_env::RuntimeEnv;
70use datafusion_execution::TaskContext;
71use datafusion_physical_expr::equivalence::join_equivalence_properties;
72use datafusion_physical_expr::PhysicalExprRef;
73use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
74
75use futures::{Stream, StreamExt};
76
77/// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge
78/// join algorithm and applies an optional filter post join. Can be used to join arbitrarily large
79/// inputs where one or both of the inputs don't fit in the available memory.
80///
81/// # Join Expressions
82///
83/// Equi-join predicate (e.g. `<col1> = <col2>`) expressions are represented by [`Self::on`].
84///
85/// Non-equality predicates, which can not be pushed down to join inputs (e.g.
86/// `<col1> != <col2>`) are known as "filter expressions" and are evaluated
87/// after the equijoin predicates. They are represented by [`Self::filter`]. These are optional
88/// expressions.
89///
90/// # Sorting
91///
92/// Assumes that both the left and right input to the join are pre-sorted. It is not the
93/// responsibility of this execution plan to sort the inputs.
94///
95/// # "Streamed" vs "Buffered"
96///
97/// The number of record batches of streamed input currently present in the memory will depend
98/// on the output batch size of the execution plan. There is no spilling support for streamed input.
99/// The comparisons are performed from values of join keys in streamed input with the values of
100/// join keys in buffered input. One row in streamed record batch could be matched with multiple rows in
101/// buffered input batches. The streamed input is managed through the states in `StreamedState`
102/// and streamed input batches are represented by `StreamedBatch`.
103///
104/// Buffered input is buffered for all record batches having the same value of join key.
105/// If the memory limit increases beyond the specified value and spilling is enabled,
106/// buffered batches could be spilled to disk. If spilling is disabled, the execution
107/// will fail under the same conditions. Multiple record batches of buffered could currently reside
108/// in memory/disk during the execution. The number of buffered batches residing in
109/// memory/disk depends on the number of rows of buffered input having the same value
110/// of join key as that of streamed input rows currently present in memory. Due to pre-sorted inputs,
111/// the algorithm understands when it is not needed anymore, and releases the buffered batches
112/// from memory/disk. The buffered input is managed through the states in `BufferedState`
113/// and buffered input batches are represented by `BufferedBatch`.
114///
115/// Depending on the type of join, left or right input may be selected as streamed or buffered
116/// respectively. For example, in a left-outer join, the left execution plan will be selected as
117/// streamed input while in a right-outer join, the right execution plan will be selected as the
118/// streamed input.
119///
120/// Reference for the algorithm:
121/// <https://en.wikipedia.org/wiki/Sort-merge_join>.
122///
123/// Helpful short video demonstration:
124/// <https://www.youtube.com/watch?v=jiWCPJtDE2c>.
125#[derive(Debug, Clone)]
126pub struct SortMergeJoinExec {
127    /// Left sorted joining execution plan
128    pub left: Arc<dyn ExecutionPlan>,
129    /// Right sorting joining execution plan
130    pub right: Arc<dyn ExecutionPlan>,
131    /// Set of common columns used to join on
132    pub on: JoinOn,
133    /// Filters which are applied while finding matching rows
134    pub filter: Option<JoinFilter>,
135    /// How the join is performed
136    pub join_type: JoinType,
137    /// The schema once the join is applied
138    schema: SchemaRef,
139    /// Execution metrics
140    metrics: ExecutionPlanMetricsSet,
141    /// The left SortExpr
142    left_sort_exprs: LexOrdering,
143    /// The right SortExpr
144    right_sort_exprs: LexOrdering,
145    /// Sort options of join columns used in sorting left and right execution plans
146    pub sort_options: Vec<SortOptions>,
147    /// If null_equals_null is true, null == null else null != null
148    pub null_equals_null: bool,
149    /// Cache holding plan properties like equivalences, output partitioning etc.
150    cache: PlanProperties,
151}
152
153impl SortMergeJoinExec {
154    /// Tries to create a new [SortMergeJoinExec].
155    /// The inputs are sorted using `sort_options` are applied to the columns in the `on`
156    /// # Error
157    /// This function errors when it is not possible to join the left and right sides on keys `on`.
158    pub fn try_new(
159        left: Arc<dyn ExecutionPlan>,
160        right: Arc<dyn ExecutionPlan>,
161        on: JoinOn,
162        filter: Option<JoinFilter>,
163        join_type: JoinType,
164        sort_options: Vec<SortOptions>,
165        null_equals_null: bool,
166    ) -> Result<Self> {
167        let left_schema = left.schema();
168        let right_schema = right.schema();
169
170        if join_type == JoinType::RightSemi {
171            return not_impl_err!(
172                "SortMergeJoinExec does not support JoinType::RightSemi"
173            );
174        }
175
176        check_join_is_valid(&left_schema, &right_schema, &on)?;
177        if sort_options.len() != on.len() {
178            return plan_err!(
179                "Expected number of sort options: {}, actual: {}",
180                on.len(),
181                sort_options.len()
182            );
183        }
184
185        let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
186            .iter()
187            .zip(sort_options.iter())
188            .map(|((l, r), sort_op)| {
189                let left = PhysicalSortExpr {
190                    expr: Arc::clone(l),
191                    options: *sort_op,
192                };
193                let right = PhysicalSortExpr {
194                    expr: Arc::clone(r),
195                    options: *sort_op,
196                };
197                (left, right)
198            })
199            .unzip();
200
201        let schema =
202            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
203        let cache =
204            Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on);
205        Ok(Self {
206            left,
207            right,
208            on,
209            filter,
210            join_type,
211            schema,
212            metrics: ExecutionPlanMetricsSet::new(),
213            left_sort_exprs: LexOrdering::new(left_sort_exprs),
214            right_sort_exprs: LexOrdering::new(right_sort_exprs),
215            sort_options,
216            null_equals_null,
217            cache,
218        })
219    }
220
221    /// Get probe side (e.g streaming side) information for this sort merge join.
222    /// In current implementation, probe side is determined according to join type.
223    pub fn probe_side(join_type: &JoinType) -> JoinSide {
224        // When output schema contains only the right side, probe side is right.
225        // Otherwise probe side is the left side.
226        match join_type {
227            JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
228                JoinSide::Right
229            }
230            JoinType::Inner
231            | JoinType::Left
232            | JoinType::Full
233            | JoinType::LeftAnti
234            | JoinType::LeftSemi
235            | JoinType::LeftMark => JoinSide::Left,
236        }
237    }
238
239    /// Calculate order preservation flags for this sort merge join.
240    fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
241        match join_type {
242            JoinType::Inner => vec![true, false],
243            JoinType::Left
244            | JoinType::LeftSemi
245            | JoinType::LeftAnti
246            | JoinType::LeftMark => vec![true, false],
247            JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
248                vec![false, true]
249            }
250            _ => vec![false, false],
251        }
252    }
253
254    /// Set of common columns used to join on
255    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
256        &self.on
257    }
258
259    /// Ref to right execution plan
260    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
261        &self.right
262    }
263
264    /// Join type
265    pub fn join_type(&self) -> JoinType {
266        self.join_type
267    }
268
269    /// Ref to left execution plan
270    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
271        &self.left
272    }
273
274    /// Ref to join filter
275    pub fn filter(&self) -> &Option<JoinFilter> {
276        &self.filter
277    }
278
279    /// Ref to sort options
280    pub fn sort_options(&self) -> &[SortOptions] {
281        &self.sort_options
282    }
283
284    /// Null equals null
285    pub fn null_equals_null(&self) -> bool {
286        self.null_equals_null
287    }
288
289    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
290    fn compute_properties(
291        left: &Arc<dyn ExecutionPlan>,
292        right: &Arc<dyn ExecutionPlan>,
293        schema: SchemaRef,
294        join_type: JoinType,
295        join_on: JoinOnRef,
296    ) -> PlanProperties {
297        // Calculate equivalence properties:
298        let eq_properties = join_equivalence_properties(
299            left.equivalence_properties().clone(),
300            right.equivalence_properties().clone(),
301            &join_type,
302            schema,
303            &Self::maintains_input_order(join_type),
304            Some(Self::probe_side(&join_type)),
305            join_on,
306        );
307
308        let output_partitioning =
309            symmetric_join_output_partitioning(left, right, &join_type);
310
311        PlanProperties::new(
312            eq_properties,
313            output_partitioning,
314            EmissionType::Incremental,
315            boundedness_from_children([left, right]),
316        )
317    }
318
319    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
320        let left = self.left();
321        let right = self.right();
322        let new_join = SortMergeJoinExec::try_new(
323            Arc::clone(right),
324            Arc::clone(left),
325            self.on()
326                .iter()
327                .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
328                .collect::<Vec<_>>(),
329            self.filter().as_ref().map(JoinFilter::swap),
330            self.join_type().swap(),
331            self.sort_options.clone(),
332            self.null_equals_null,
333        )?;
334
335        // TODO: OR this condition with having a built-in projection (like
336        //       ordinary hash join) when we support it.
337        if matches!(
338            self.join_type(),
339            JoinType::LeftSemi
340                | JoinType::RightSemi
341                | JoinType::LeftAnti
342                | JoinType::RightAnti
343        ) {
344            Ok(Arc::new(new_join))
345        } else {
346            reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
347        }
348    }
349}
350
351impl DisplayAs for SortMergeJoinExec {
352    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
353        match t {
354            DisplayFormatType::Default | DisplayFormatType::Verbose => {
355                let on = self
356                    .on
357                    .iter()
358                    .map(|(c1, c2)| format!("({}, {})", c1, c2))
359                    .collect::<Vec<String>>()
360                    .join(", ");
361                write!(
362                    f,
363                    "SortMergeJoin: join_type={:?}, on=[{}]{}",
364                    self.join_type,
365                    on,
366                    self.filter.as_ref().map_or("".to_string(), |f| format!(
367                        ", filter={}",
368                        f.expression()
369                    ))
370                )
371            }
372        }
373    }
374}
375
376impl ExecutionPlan for SortMergeJoinExec {
377    fn name(&self) -> &'static str {
378        "SortMergeJoinExec"
379    }
380
381    fn as_any(&self) -> &dyn Any {
382        self
383    }
384
385    fn properties(&self) -> &PlanProperties {
386        &self.cache
387    }
388
389    fn required_input_distribution(&self) -> Vec<Distribution> {
390        let (left_expr, right_expr) = self
391            .on
392            .iter()
393            .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
394            .unzip();
395        vec![
396            Distribution::HashPartitioned(left_expr),
397            Distribution::HashPartitioned(right_expr),
398        ]
399    }
400
401    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
402        vec![
403            Some(LexRequirement::from(self.left_sort_exprs.clone())),
404            Some(LexRequirement::from(self.right_sort_exprs.clone())),
405        ]
406    }
407
408    fn maintains_input_order(&self) -> Vec<bool> {
409        Self::maintains_input_order(self.join_type)
410    }
411
412    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
413        vec![&self.left, &self.right]
414    }
415
416    fn with_new_children(
417        self: Arc<Self>,
418        children: Vec<Arc<dyn ExecutionPlan>>,
419    ) -> Result<Arc<dyn ExecutionPlan>> {
420        match &children[..] {
421            [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
422                Arc::clone(left),
423                Arc::clone(right),
424                self.on.clone(),
425                self.filter.clone(),
426                self.join_type,
427                self.sort_options.clone(),
428                self.null_equals_null,
429            )?)),
430            _ => internal_err!("SortMergeJoin wrong number of children"),
431        }
432    }
433
434    fn execute(
435        &self,
436        partition: usize,
437        context: Arc<TaskContext>,
438    ) -> Result<SendableRecordBatchStream> {
439        let left_partitions = self.left.output_partitioning().partition_count();
440        let right_partitions = self.right.output_partitioning().partition_count();
441        if left_partitions != right_partitions {
442            return internal_err!(
443                "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
444                 consider using RepartitionExec"
445            );
446        }
447        let (on_left, on_right) = self.on.iter().cloned().unzip();
448        let (streamed, buffered, on_streamed, on_buffered) =
449            if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
450                (
451                    Arc::clone(&self.left),
452                    Arc::clone(&self.right),
453                    on_left,
454                    on_right,
455                )
456            } else {
457                (
458                    Arc::clone(&self.right),
459                    Arc::clone(&self.left),
460                    on_right,
461                    on_left,
462                )
463            };
464
465        // execute children plans
466        let streamed = streamed.execute(partition, Arc::clone(&context))?;
467        let buffered = buffered.execute(partition, Arc::clone(&context))?;
468
469        // create output buffer
470        let batch_size = context.session_config().batch_size();
471
472        // create memory reservation
473        let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
474            .register(context.memory_pool());
475
476        // create join stream
477        Ok(Box::pin(SortMergeJoinStream::try_new(
478            Arc::clone(&self.schema),
479            self.sort_options.clone(),
480            self.null_equals_null,
481            streamed,
482            buffered,
483            on_streamed,
484            on_buffered,
485            self.filter.clone(),
486            self.join_type,
487            batch_size,
488            SortMergeJoinMetrics::new(partition, &self.metrics),
489            reservation,
490            context.runtime_env(),
491        )?))
492    }
493
494    fn metrics(&self) -> Option<MetricsSet> {
495        Some(self.metrics.clone_inner())
496    }
497
498    fn statistics(&self) -> Result<Statistics> {
499        // TODO stats: it is not possible in general to know the output size of joins
500        // There are some special cases though, for example:
501        // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
502        estimate_join_statistics(
503            Arc::clone(&self.left),
504            Arc::clone(&self.right),
505            self.on.clone(),
506            &self.join_type,
507            &self.schema,
508        )
509    }
510
511    /// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done,
512    /// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan.
513    /// Otherwise, it returns None.
514    fn try_swapping_with_projection(
515        &self,
516        projection: &ProjectionExec,
517    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
518        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
519        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
520        else {
521            return Ok(None);
522        };
523
524        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
525            self.left().schema().fields().len(),
526            &projection_as_columns,
527        );
528
529        if !join_allows_pushdown(
530            &projection_as_columns,
531            &self.schema(),
532            far_right_left_col_ind,
533            far_left_right_col_ind,
534        ) {
535            return Ok(None);
536        }
537
538        let Some(new_on) = update_join_on(
539            &projection_as_columns[0..=far_right_left_col_ind as _],
540            &projection_as_columns[far_left_right_col_ind as _..],
541            self.on(),
542            self.left().schema().fields().len(),
543        ) else {
544            return Ok(None);
545        };
546
547        let (new_left, new_right) = new_join_children(
548            &projection_as_columns,
549            far_right_left_col_ind,
550            far_left_right_col_ind,
551            self.children()[0],
552            self.children()[1],
553        )?;
554
555        Ok(Some(Arc::new(SortMergeJoinExec::try_new(
556            Arc::new(new_left),
557            Arc::new(new_right),
558            new_on,
559            self.filter.clone(),
560            self.join_type,
561            self.sort_options.clone(),
562            self.null_equals_null,
563        )?)))
564    }
565}
566
567/// Metrics for SortMergeJoinExec
568#[allow(dead_code)]
569struct SortMergeJoinMetrics {
570    /// Total time for joining probe-side batches to the build-side batches
571    join_time: metrics::Time,
572    /// Number of batches consumed by this operator
573    input_batches: Count,
574    /// Number of rows consumed by this operator
575    input_rows: Count,
576    /// Number of batches produced by this operator
577    output_batches: Count,
578    /// Number of rows produced by this operator
579    output_rows: Count,
580    /// Peak memory used for buffered data.
581    /// Calculated as sum of peak memory values across partitions
582    peak_mem_used: metrics::Gauge,
583    /// count of spills during the execution of the operator
584    spill_count: Count,
585    /// total spilled bytes during the execution of the operator
586    spilled_bytes: Count,
587    /// total spilled rows during the execution of the operator
588    spilled_rows: Count,
589}
590
591impl SortMergeJoinMetrics {
592    #[allow(dead_code)]
593    pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
594        let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
595        let input_batches =
596            MetricBuilder::new(metrics).counter("input_batches", partition);
597        let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
598        let output_batches =
599            MetricBuilder::new(metrics).counter("output_batches", partition);
600        let output_rows = MetricBuilder::new(metrics).output_rows(partition);
601        let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition);
602        let spill_count = MetricBuilder::new(metrics).spill_count(partition);
603        let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition);
604        let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition);
605
606        Self {
607            join_time,
608            input_batches,
609            input_rows,
610            output_batches,
611            output_rows,
612            peak_mem_used,
613            spill_count,
614            spilled_bytes,
615            spilled_rows,
616        }
617    }
618}
619
620/// State of SMJ stream
621#[derive(Debug, PartialEq, Eq)]
622enum SortMergeJoinState {
623    /// Init joining with a new streamed row or a new buffered batches
624    Init,
625    /// Polling one streamed row or one buffered batch, or both
626    Polling,
627    /// Joining polled data and making output
628    JoinOutput,
629    /// No more output
630    Exhausted,
631}
632
633/// State of streamed data stream
634#[derive(Debug, PartialEq, Eq)]
635enum StreamedState {
636    /// Init polling
637    Init,
638    /// Polling one streamed row
639    Polling,
640    /// Ready to produce one streamed row
641    Ready,
642    /// No more streamed row
643    Exhausted,
644}
645
646/// State of buffered data stream
647#[derive(Debug, PartialEq, Eq)]
648enum BufferedState {
649    /// Init polling
650    Init,
651    /// Polling first row in the next batch
652    PollingFirst,
653    /// Polling rest rows in the next batch
654    PollingRest,
655    /// Ready to produce one batch
656    Ready,
657    /// No more buffered batches
658    Exhausted,
659}
660
661/// Represents a chunk of joined data from streamed and buffered side
662struct StreamedJoinedChunk {
663    /// Index of batch in buffered_data
664    buffered_batch_idx: Option<usize>,
665    /// Array builder for streamed indices
666    streamed_indices: UInt64Builder,
667    /// Array builder for buffered indices
668    /// This could contain nulls if the join is null-joined
669    buffered_indices: UInt64Builder,
670}
671
672/// Represents a record batch from streamed input.
673///
674/// Also stores information of matching rows from buffered batches.
675struct StreamedBatch {
676    /// The streamed record batch
677    pub batch: RecordBatch,
678    /// The index of row in the streamed batch to compare with buffered batches
679    pub idx: usize,
680    /// The join key arrays of streamed batch which are used to compare with buffered batches
681    /// and to produce output. They are produced by evaluating `on` expressions.
682    pub join_arrays: Vec<ArrayRef>,
683    /// Chunks of indices from buffered side (may be nulls) joined to streamed
684    pub output_indices: Vec<StreamedJoinedChunk>,
685    /// Index of currently scanned batch from buffered data
686    pub buffered_batch_idx: Option<usize>,
687    /// Indices that found a match for the given join filter
688    /// Used for semi joins to keep track the streaming index which got a join filter match
689    /// and already emitted to the output.
690    pub join_filter_matched_idxs: HashSet<u64>,
691}
692
693impl StreamedBatch {
694    fn new(batch: RecordBatch, on_column: &[Arc<dyn PhysicalExpr>]) -> Self {
695        let join_arrays = join_arrays(&batch, on_column);
696        StreamedBatch {
697            batch,
698            idx: 0,
699            join_arrays,
700            output_indices: vec![],
701            buffered_batch_idx: None,
702            join_filter_matched_idxs: HashSet::new(),
703        }
704    }
705
706    fn new_empty(schema: SchemaRef) -> Self {
707        StreamedBatch {
708            batch: RecordBatch::new_empty(schema),
709            idx: 0,
710            join_arrays: vec![],
711            output_indices: vec![],
712            buffered_batch_idx: None,
713            join_filter_matched_idxs: HashSet::new(),
714        }
715    }
716
717    /// Appends new pair consisting of current streamed index and `buffered_idx`
718    /// index of buffered batch with `buffered_batch_idx` index.
719    fn append_output_pair(
720        &mut self,
721        buffered_batch_idx: Option<usize>,
722        buffered_idx: Option<usize>,
723    ) {
724        // If no current chunk exists or current chunk is not for current buffered batch,
725        // create a new chunk
726        if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx
727        {
728            self.output_indices.push(StreamedJoinedChunk {
729                buffered_batch_idx,
730                streamed_indices: UInt64Builder::with_capacity(1),
731                buffered_indices: UInt64Builder::with_capacity(1),
732            });
733            self.buffered_batch_idx = buffered_batch_idx;
734        };
735        let current_chunk = self.output_indices.last_mut().unwrap();
736
737        // Append index of streamed batch and index of buffered batch into current chunk
738        current_chunk.streamed_indices.append_value(self.idx as u64);
739        if let Some(idx) = buffered_idx {
740            current_chunk.buffered_indices.append_value(idx as u64);
741        } else {
742            current_chunk.buffered_indices.append_null();
743        }
744    }
745}
746
747/// A buffered batch that contains contiguous rows with same join key
748#[derive(Debug)]
749struct BufferedBatch {
750    /// The buffered record batch
751    /// None if the batch spilled to disk th
752    pub batch: Option<RecordBatch>,
753    /// The range in which the rows share the same join key
754    pub range: Range<usize>,
755    /// Array refs of the join key
756    pub join_arrays: Vec<ArrayRef>,
757    /// Buffered joined index (null joining buffered)
758    pub null_joined: Vec<usize>,
759    /// Size estimation used for reserving / releasing memory
760    pub size_estimation: usize,
761    /// The indices of buffered batch that the join filter doesn't satisfy.
762    /// This is a map between right row index and a boolean value indicating whether all joined row
763    /// of the right row does not satisfy the filter .
764    /// When dequeuing the buffered batch, we need to produce null joined rows for these indices.
765    pub join_filter_not_matched_map: HashMap<u64, bool>,
766    /// Current buffered batch number of rows. Equal to batch.num_rows()
767    /// but if batch is spilled to disk this property is preferable
768    /// and less expensive
769    pub num_rows: usize,
770    /// An optional temp spill file name on the disk if the batch spilled
771    /// None by default
772    /// Some(fileName) if the batch spilled to the disk
773    pub spill_file: Option<RefCountedTempFile>,
774}
775
776impl BufferedBatch {
777    fn new(
778        batch: RecordBatch,
779        range: Range<usize>,
780        on_column: &[PhysicalExprRef],
781    ) -> Self {
782        let join_arrays = join_arrays(&batch, on_column);
783
784        // Estimation is calculated as
785        //   inner batch size
786        // + join keys size
787        // + worst case null_joined (as vector capacity * element size)
788        // + Range size
789        // + size of this estimation
790        let size_estimation = batch.get_array_memory_size()
791            + join_arrays
792                .iter()
793                .map(|arr| arr.get_array_memory_size())
794                .sum::<usize>()
795            + batch.num_rows().next_power_of_two() * size_of::<usize>()
796            + size_of::<Range<usize>>()
797            + size_of::<usize>();
798
799        let num_rows = batch.num_rows();
800        BufferedBatch {
801            batch: Some(batch),
802            range,
803            join_arrays,
804            null_joined: vec![],
805            size_estimation,
806            join_filter_not_matched_map: HashMap::new(),
807            num_rows,
808            spill_file: None,
809        }
810    }
811}
812
813/// Sort-Merge join stream that consumes streamed and buffered data streams
814/// and produces joined output stream.
815struct SortMergeJoinStream {
816    /// Current state of the stream
817    pub state: SortMergeJoinState,
818    /// Output schema
819    pub schema: SchemaRef,
820    /// Sort options of join columns used to sort streamed and buffered data stream
821    pub sort_options: Vec<SortOptions>,
822    /// null == null?
823    pub null_equals_null: bool,
824    /// Input schema of streamed
825    pub streamed_schema: SchemaRef,
826    /// Input schema of buffered
827    pub buffered_schema: SchemaRef,
828    /// Streamed data stream
829    pub streamed: SendableRecordBatchStream,
830    /// Buffered data stream
831    pub buffered: SendableRecordBatchStream,
832    /// Current processing record batch of streamed
833    pub streamed_batch: StreamedBatch,
834    /// Current buffered data
835    pub buffered_data: BufferedData,
836    /// (used in outer join) Is current streamed row joined at least once?
837    pub streamed_joined: bool,
838    /// (used in outer join) Is current buffered batches joined at least once?
839    pub buffered_joined: bool,
840    /// State of streamed
841    pub streamed_state: StreamedState,
842    /// State of buffered
843    pub buffered_state: BufferedState,
844    /// The comparison result of current streamed row and buffered batches
845    pub current_ordering: Ordering,
846    /// Join key columns of streamed
847    pub on_streamed: Vec<PhysicalExprRef>,
848    /// Join key columns of buffered
849    pub on_buffered: Vec<PhysicalExprRef>,
850    /// optional join filter
851    pub filter: Option<JoinFilter>,
852    /// Staging output array builders
853    pub staging_output_record_batches: JoinedRecordBatches,
854    /// Output buffer. Currently used by filtering as it requires double buffering
855    /// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches`
856    pub output: RecordBatch,
857    /// Staging output size, including output batches and staging joined results.
858    /// Increased when we put rows into buffer and decreased after we actually output batches.
859    /// Used to trigger output when sufficient rows are ready
860    pub output_size: usize,
861    /// Target output batch size
862    pub batch_size: usize,
863    /// How the join is performed
864    pub join_type: JoinType,
865    /// Metrics
866    pub join_metrics: SortMergeJoinMetrics,
867    /// Memory reservation
868    pub reservation: MemoryReservation,
869    /// Runtime env
870    pub runtime_env: Arc<RuntimeEnv>,
871    /// A unique number for each batch
872    pub streamed_batch_counter: AtomicUsize,
873}
874
875/// Joined batches with attached join filter information
876struct JoinedRecordBatches {
877    /// Joined batches. Each batch is already joined columns from left and right sources
878    pub batches: Vec<RecordBatch>,
879    /// Filter match mask for each row(matched/non-matched)
880    pub filter_mask: BooleanBuilder,
881    /// Row indices to glue together rows in `batches` and `filter_mask`
882    pub row_indices: UInt64Builder,
883    /// Which unique batch id the row belongs to
884    /// It is necessary to differentiate rows that are distributed the way when they point to the same
885    /// row index but in not the same batches
886    pub batch_ids: Vec<usize>,
887}
888
889impl JoinedRecordBatches {
890    fn clear(&mut self) {
891        self.batches.clear();
892        self.batch_ids.clear();
893        self.filter_mask = BooleanBuilder::new();
894        self.row_indices = UInt64Builder::new();
895    }
896}
897impl RecordBatchStream for SortMergeJoinStream {
898    fn schema(&self) -> SchemaRef {
899        Arc::clone(&self.schema)
900    }
901}
902
903/// True if next index refers to either:
904/// - another batch id
905/// - another row index within same batch id
906/// - end of row indices
907#[inline(always)]
908fn last_index_for_row(
909    row_index: usize,
910    indices: &UInt64Array,
911    batch_ids: &[usize],
912    indices_len: usize,
913) -> bool {
914    row_index == indices_len - 1
915        || batch_ids[row_index] != batch_ids[row_index + 1]
916        || indices.value(row_index) != indices.value(row_index + 1)
917}
918
919// Returns a corrected boolean bitmask for the given join type
920// Values in the corrected bitmask can be: true, false, null
921// `true` - the row found its match and sent to the output
922// `null` - the row ignored, no output
923// `false` - the row sent as NULL joined row
924fn get_corrected_filter_mask(
925    join_type: JoinType,
926    row_indices: &UInt64Array,
927    batch_ids: &[usize],
928    filter_mask: &BooleanArray,
929    expected_size: usize,
930) -> Option<BooleanArray> {
931    let row_indices_length = row_indices.len();
932    let mut corrected_mask: BooleanBuilder =
933        BooleanBuilder::with_capacity(row_indices_length);
934    let mut seen_true = false;
935
936    match join_type {
937        JoinType::Left | JoinType::Right => {
938            for i in 0..row_indices_length {
939                let last_index =
940                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
941                if filter_mask.value(i) {
942                    seen_true = true;
943                    corrected_mask.append_value(true);
944                } else if seen_true || !filter_mask.value(i) && !last_index {
945                    corrected_mask.append_null(); // to be ignored and not set to output
946                } else {
947                    corrected_mask.append_value(false); // to be converted to null joined row
948                }
949
950                if last_index {
951                    seen_true = false;
952                }
953            }
954
955            // Generate null joined rows for records which have no matching join key
956            corrected_mask.append_n(expected_size - corrected_mask.len(), false);
957            Some(corrected_mask.finish())
958        }
959        JoinType::LeftMark => {
960            for i in 0..row_indices_length {
961                let last_index =
962                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
963                if filter_mask.value(i) && !seen_true {
964                    seen_true = true;
965                    corrected_mask.append_value(true);
966                } else if seen_true || !filter_mask.value(i) && !last_index {
967                    corrected_mask.append_null(); // to be ignored and not set to output
968                } else {
969                    corrected_mask.append_value(false); // to be converted to null joined row
970                }
971
972                if last_index {
973                    seen_true = false;
974                }
975            }
976
977            // Generate null joined rows for records which have no matching join key
978            corrected_mask.append_n(expected_size - corrected_mask.len(), false);
979            Some(corrected_mask.finish())
980        }
981        JoinType::LeftSemi => {
982            for i in 0..row_indices_length {
983                let last_index =
984                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
985                if filter_mask.value(i) && !seen_true {
986                    seen_true = true;
987                    corrected_mask.append_value(true);
988                } else {
989                    corrected_mask.append_null(); // to be ignored and not set to output
990                }
991
992                if last_index {
993                    seen_true = false;
994                }
995            }
996
997            Some(corrected_mask.finish())
998        }
999        JoinType::LeftAnti | JoinType::RightAnti => {
1000            for i in 0..row_indices_length {
1001                let last_index =
1002                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
1003
1004                if filter_mask.value(i) {
1005                    seen_true = true;
1006                }
1007
1008                if last_index {
1009                    if !seen_true {
1010                        corrected_mask.append_value(true);
1011                    } else {
1012                        corrected_mask.append_null();
1013                    }
1014
1015                    seen_true = false;
1016                } else {
1017                    corrected_mask.append_null();
1018                }
1019            }
1020            // Generate null joined rows for records which have no matching join key,
1021            // for LeftAnti non-matched considered as true
1022            corrected_mask.append_n(expected_size - corrected_mask.len(), true);
1023            Some(corrected_mask.finish())
1024        }
1025        JoinType::Full => {
1026            let mut mask: Vec<Option<bool>> = vec![Some(true); row_indices_length];
1027            let mut last_true_idx = 0;
1028            let mut first_row_idx = 0;
1029            let mut seen_false = false;
1030
1031            for i in 0..row_indices_length {
1032                let last_index =
1033                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
1034                let val = filter_mask.value(i);
1035                let is_null = filter_mask.is_null(i);
1036
1037                if val {
1038                    // memoize the first seen matched row
1039                    if !seen_true {
1040                        last_true_idx = i;
1041                    }
1042                    seen_true = true;
1043                }
1044
1045                if is_null || val {
1046                    mask[i] = Some(true);
1047                } else if !is_null && !val && (seen_true || seen_false) {
1048                    mask[i] = None;
1049                } else {
1050                    mask[i] = Some(false);
1051                }
1052
1053                if !is_null && !val {
1054                    seen_false = true;
1055                }
1056
1057                if last_index {
1058                    // If the left row seen as true its needed to output it once
1059                    // To do that we mark all other matches for same row as null to avoid the output
1060                    if seen_true {
1061                        #[allow(clippy::needless_range_loop)]
1062                        for j in first_row_idx..last_true_idx {
1063                            mask[j] = None;
1064                        }
1065                    }
1066
1067                    seen_true = false;
1068                    seen_false = false;
1069                    last_true_idx = 0;
1070                    first_row_idx = i + 1;
1071                }
1072            }
1073
1074            Some(BooleanArray::from(mask))
1075        }
1076        // Only outer joins needs to keep track of processed rows and apply corrected filter mask
1077        _ => None,
1078    }
1079}
1080
1081impl Stream for SortMergeJoinStream {
1082    type Item = Result<RecordBatch>;
1083
1084    fn poll_next(
1085        mut self: Pin<&mut Self>,
1086        cx: &mut Context<'_>,
1087    ) -> Poll<Option<Self::Item>> {
1088        let join_time = self.join_metrics.join_time.clone();
1089        let _timer = join_time.timer();
1090        loop {
1091            match &self.state {
1092                SortMergeJoinState::Init => {
1093                    let streamed_exhausted =
1094                        self.streamed_state == StreamedState::Exhausted;
1095                    let buffered_exhausted =
1096                        self.buffered_state == BufferedState::Exhausted;
1097                    self.state = if streamed_exhausted && buffered_exhausted {
1098                        SortMergeJoinState::Exhausted
1099                    } else {
1100                        match self.current_ordering {
1101                            Ordering::Less | Ordering::Equal => {
1102                                if !streamed_exhausted {
1103                                    if self.filter.is_some()
1104                                        && matches!(
1105                                            self.join_type,
1106                                            JoinType::Left
1107                                                | JoinType::LeftSemi
1108                                                | JoinType::LeftMark
1109                                                | JoinType::Right
1110                                                | JoinType::LeftAnti
1111                                                | JoinType::RightAnti
1112                                                | JoinType::Full
1113                                        )
1114                                    {
1115                                        self.freeze_all()?;
1116
1117                                        // If join is filtered and there is joined tuples waiting
1118                                        // to be filtered
1119                                        if !self
1120                                            .staging_output_record_batches
1121                                            .batches
1122                                            .is_empty()
1123                                        {
1124                                            // Apply filter on joined tuples and get filtered batch
1125                                            let out_filtered_batch =
1126                                                self.filter_joined_batch()?;
1127
1128                                            // Append filtered batch to the output buffer
1129                                            self.output = concat_batches(
1130                                                &self.schema(),
1131                                                vec![&self.output, &out_filtered_batch],
1132                                            )?;
1133
1134                                            // Send to output if the output buffer surpassed the `batch_size`
1135                                            if self.output.num_rows() >= self.batch_size {
1136                                                let record_batch = std::mem::replace(
1137                                                    &mut self.output,
1138                                                    RecordBatch::new_empty(
1139                                                        out_filtered_batch.schema(),
1140                                                    ),
1141                                                );
1142                                                return Poll::Ready(Some(Ok(
1143                                                    record_batch,
1144                                                )));
1145                                            }
1146                                        }
1147                                    }
1148
1149                                    self.streamed_joined = false;
1150                                    self.streamed_state = StreamedState::Init;
1151                                }
1152                            }
1153                            Ordering::Greater => {
1154                                if !buffered_exhausted {
1155                                    self.buffered_joined = false;
1156                                    self.buffered_state = BufferedState::Init;
1157                                }
1158                            }
1159                        }
1160                        SortMergeJoinState::Polling
1161                    };
1162                }
1163                SortMergeJoinState::Polling => {
1164                    if ![StreamedState::Exhausted, StreamedState::Ready]
1165                        .contains(&self.streamed_state)
1166                    {
1167                        match self.poll_streamed_row(cx)? {
1168                            Poll::Ready(_) => {}
1169                            Poll::Pending => return Poll::Pending,
1170                        }
1171                    }
1172
1173                    if ![BufferedState::Exhausted, BufferedState::Ready]
1174                        .contains(&self.buffered_state)
1175                    {
1176                        match self.poll_buffered_batches(cx)? {
1177                            Poll::Ready(_) => {}
1178                            Poll::Pending => return Poll::Pending,
1179                        }
1180                    }
1181                    let streamed_exhausted =
1182                        self.streamed_state == StreamedState::Exhausted;
1183                    let buffered_exhausted =
1184                        self.buffered_state == BufferedState::Exhausted;
1185                    if streamed_exhausted && buffered_exhausted {
1186                        self.state = SortMergeJoinState::Exhausted;
1187                        continue;
1188                    }
1189                    self.current_ordering = self.compare_streamed_buffered()?;
1190                    self.state = SortMergeJoinState::JoinOutput;
1191                }
1192                SortMergeJoinState::JoinOutput => {
1193                    self.join_partial()?;
1194
1195                    if self.output_size < self.batch_size {
1196                        if self.buffered_data.scanning_finished() {
1197                            self.buffered_data.scanning_reset();
1198                            self.state = SortMergeJoinState::Init;
1199                        }
1200                    } else {
1201                        self.freeze_all()?;
1202                        if !self.staging_output_record_batches.batches.is_empty() {
1203                            let record_batch = self.output_record_batch_and_reset()?;
1204                            // For non-filtered join output whenever the target output batch size
1205                            // is hit. For filtered join its needed to output on later phase
1206                            // because target output batch size can be hit in the middle of
1207                            // filtering causing the filtering to be incomplete and causing
1208                            // correctness issues
1209                            if self.filter.is_some()
1210                                && matches!(
1211                                    self.join_type,
1212                                    JoinType::Left
1213                                        | JoinType::LeftSemi
1214                                        | JoinType::Right
1215                                        | JoinType::LeftAnti
1216                                        | JoinType::RightAnti
1217                                        | JoinType::LeftMark
1218                                        | JoinType::Full
1219                                )
1220                            {
1221                                continue;
1222                            }
1223
1224                            return Poll::Ready(Some(Ok(record_batch)));
1225                        }
1226                        return Poll::Pending;
1227                    }
1228                }
1229                SortMergeJoinState::Exhausted => {
1230                    self.freeze_all()?;
1231
1232                    // if there is still something not processed
1233                    if !self.staging_output_record_batches.batches.is_empty() {
1234                        if self.filter.is_some()
1235                            && matches!(
1236                                self.join_type,
1237                                JoinType::Left
1238                                    | JoinType::LeftSemi
1239                                    | JoinType::Right
1240                                    | JoinType::LeftAnti
1241                                    | JoinType::RightAnti
1242                                    | JoinType::Full
1243                                    | JoinType::LeftMark
1244                            )
1245                        {
1246                            let record_batch = self.filter_joined_batch()?;
1247                            return Poll::Ready(Some(Ok(record_batch)));
1248                        } else {
1249                            let record_batch = self.output_record_batch_and_reset()?;
1250                            return Poll::Ready(Some(Ok(record_batch)));
1251                        }
1252                    } else if self.output.num_rows() > 0 {
1253                        // if processed but still not outputted because it didn't hit batch size before
1254                        let schema = self.output.schema();
1255                        let record_batch = std::mem::replace(
1256                            &mut self.output,
1257                            RecordBatch::new_empty(schema),
1258                        );
1259                        return Poll::Ready(Some(Ok(record_batch)));
1260                    } else {
1261                        return Poll::Ready(None);
1262                    }
1263                }
1264            }
1265        }
1266    }
1267}
1268
1269impl SortMergeJoinStream {
1270    #[allow(clippy::too_many_arguments)]
1271    pub fn try_new(
1272        schema: SchemaRef,
1273        sort_options: Vec<SortOptions>,
1274        null_equals_null: bool,
1275        streamed: SendableRecordBatchStream,
1276        buffered: SendableRecordBatchStream,
1277        on_streamed: Vec<Arc<dyn PhysicalExpr>>,
1278        on_buffered: Vec<Arc<dyn PhysicalExpr>>,
1279        filter: Option<JoinFilter>,
1280        join_type: JoinType,
1281        batch_size: usize,
1282        join_metrics: SortMergeJoinMetrics,
1283        reservation: MemoryReservation,
1284        runtime_env: Arc<RuntimeEnv>,
1285    ) -> Result<Self> {
1286        let streamed_schema = streamed.schema();
1287        let buffered_schema = buffered.schema();
1288        Ok(Self {
1289            state: SortMergeJoinState::Init,
1290            sort_options,
1291            null_equals_null,
1292            schema: Arc::clone(&schema),
1293            streamed_schema: Arc::clone(&streamed_schema),
1294            buffered_schema,
1295            streamed,
1296            buffered,
1297            streamed_batch: StreamedBatch::new_empty(streamed_schema),
1298            buffered_data: BufferedData::default(),
1299            streamed_joined: false,
1300            buffered_joined: false,
1301            streamed_state: StreamedState::Init,
1302            buffered_state: BufferedState::Init,
1303            current_ordering: Ordering::Equal,
1304            on_streamed,
1305            on_buffered,
1306            filter,
1307            staging_output_record_batches: JoinedRecordBatches {
1308                batches: vec![],
1309                filter_mask: BooleanBuilder::new(),
1310                row_indices: UInt64Builder::new(),
1311                batch_ids: vec![],
1312            },
1313            output: RecordBatch::new_empty(schema),
1314            output_size: 0,
1315            batch_size,
1316            join_type,
1317            join_metrics,
1318            reservation,
1319            runtime_env,
1320            streamed_batch_counter: AtomicUsize::new(0),
1321        })
1322    }
1323
1324    /// Poll next streamed row
1325    fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
1326        loop {
1327            match &self.streamed_state {
1328                StreamedState::Init => {
1329                    if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows()
1330                    {
1331                        self.streamed_batch.idx += 1;
1332                        self.streamed_state = StreamedState::Ready;
1333                        return Poll::Ready(Some(Ok(())));
1334                    } else {
1335                        self.streamed_state = StreamedState::Polling;
1336                    }
1337                }
1338                StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? {
1339                    Poll::Pending => {
1340                        return Poll::Pending;
1341                    }
1342                    Poll::Ready(None) => {
1343                        self.streamed_state = StreamedState::Exhausted;
1344                    }
1345                    Poll::Ready(Some(batch)) => {
1346                        if batch.num_rows() > 0 {
1347                            self.freeze_streamed()?;
1348                            self.join_metrics.input_batches.add(1);
1349                            self.join_metrics.input_rows.add(batch.num_rows());
1350                            self.streamed_batch =
1351                                StreamedBatch::new(batch, &self.on_streamed);
1352                            // Every incoming streaming batch should have its unique id
1353                            // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation
1354                            self.streamed_batch_counter
1355                                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1356                            self.streamed_state = StreamedState::Ready;
1357                        }
1358                    }
1359                },
1360                StreamedState::Ready => {
1361                    return Poll::Ready(Some(Ok(())));
1362                }
1363                StreamedState::Exhausted => {
1364                    return Poll::Ready(None);
1365                }
1366            }
1367        }
1368    }
1369
1370    fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> {
1371        // Shrink memory usage for in-memory batches only
1372        if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some() {
1373            self.reservation
1374                .try_shrink(buffered_batch.size_estimation)?;
1375        }
1376
1377        Ok(())
1378    }
1379
1380    fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> {
1381        match self.reservation.try_grow(buffered_batch.size_estimation) {
1382            Ok(_) => {
1383                self.join_metrics
1384                    .peak_mem_used
1385                    .set_max(self.reservation.size());
1386                Ok(())
1387            }
1388            Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
1389                // spill buffered batch to disk
1390                let spill_file = self
1391                    .runtime_env
1392                    .disk_manager
1393                    .create_tmp_file("sort_merge_join_buffered_spill")?;
1394
1395                if let Some(batch) = buffered_batch.batch {
1396                    spill_record_batches(
1397                        &[batch],
1398                        spill_file.path().into(),
1399                        Arc::clone(&self.buffered_schema),
1400                    )?;
1401                    buffered_batch.spill_file = Some(spill_file);
1402                    buffered_batch.batch = None;
1403
1404                    // update metrics to register spill
1405                    self.join_metrics.spill_count.add(1);
1406                    self.join_metrics
1407                        .spilled_bytes
1408                        .add(buffered_batch.size_estimation);
1409                    self.join_metrics.spilled_rows.add(buffered_batch.num_rows);
1410                    Ok(())
1411                } else {
1412                    internal_err!("Buffered batch has empty body")
1413                }
1414            }
1415            Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()),
1416        }?;
1417
1418        self.buffered_data.batches.push_back(buffered_batch);
1419        Ok(())
1420    }
1421
1422    /// Poll next buffered batches
1423    fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
1424        loop {
1425            match &self.buffered_state {
1426                BufferedState::Init => {
1427                    // pop previous buffered batches
1428                    while !self.buffered_data.batches.is_empty() {
1429                        let head_batch = self.buffered_data.head_batch();
1430                        // If the head batch is fully processed, dequeue it and produce output of it.
1431                        if head_batch.range.end == head_batch.num_rows {
1432                            self.freeze_dequeuing_buffered()?;
1433                            if let Some(mut buffered_batch) =
1434                                self.buffered_data.batches.pop_front()
1435                            {
1436                                self.produce_buffered_not_matched(&mut buffered_batch)?;
1437                                self.free_reservation(buffered_batch)?;
1438                            }
1439                        } else {
1440                            // If the head batch is not fully processed, break the loop.
1441                            // Streamed batch will be joined with the head batch in the next step.
1442                            break;
1443                        }
1444                    }
1445                    if self.buffered_data.batches.is_empty() {
1446                        self.buffered_state = BufferedState::PollingFirst;
1447                    } else {
1448                        let tail_batch = self.buffered_data.tail_batch_mut();
1449                        tail_batch.range.start = tail_batch.range.end;
1450                        tail_batch.range.end += 1;
1451                        self.buffered_state = BufferedState::PollingRest;
1452                    }
1453                }
1454                BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? {
1455                    Poll::Pending => {
1456                        return Poll::Pending;
1457                    }
1458                    Poll::Ready(None) => {
1459                        self.buffered_state = BufferedState::Exhausted;
1460                        return Poll::Ready(None);
1461                    }
1462                    Poll::Ready(Some(batch)) => {
1463                        self.join_metrics.input_batches.add(1);
1464                        self.join_metrics.input_rows.add(batch.num_rows());
1465
1466                        if batch.num_rows() > 0 {
1467                            let buffered_batch =
1468                                BufferedBatch::new(batch, 0..1, &self.on_buffered);
1469
1470                            self.allocate_reservation(buffered_batch)?;
1471                            self.buffered_state = BufferedState::PollingRest;
1472                        }
1473                    }
1474                },
1475                BufferedState::PollingRest => {
1476                    if self.buffered_data.tail_batch().range.end
1477                        < self.buffered_data.tail_batch().num_rows
1478                    {
1479                        while self.buffered_data.tail_batch().range.end
1480                            < self.buffered_data.tail_batch().num_rows
1481                        {
1482                            if is_join_arrays_equal(
1483                                &self.buffered_data.head_batch().join_arrays,
1484                                self.buffered_data.head_batch().range.start,
1485                                &self.buffered_data.tail_batch().join_arrays,
1486                                self.buffered_data.tail_batch().range.end,
1487                            )? {
1488                                self.buffered_data.tail_batch_mut().range.end += 1;
1489                            } else {
1490                                self.buffered_state = BufferedState::Ready;
1491                                return Poll::Ready(Some(Ok(())));
1492                            }
1493                        }
1494                    } else {
1495                        match self.buffered.poll_next_unpin(cx)? {
1496                            Poll::Pending => {
1497                                return Poll::Pending;
1498                            }
1499                            Poll::Ready(None) => {
1500                                self.buffered_state = BufferedState::Ready;
1501                            }
1502                            Poll::Ready(Some(batch)) => {
1503                                // Polling batches coming concurrently as multiple partitions
1504                                self.join_metrics.input_batches.add(1);
1505                                self.join_metrics.input_rows.add(batch.num_rows());
1506                                if batch.num_rows() > 0 {
1507                                    let buffered_batch = BufferedBatch::new(
1508                                        batch,
1509                                        0..0,
1510                                        &self.on_buffered,
1511                                    );
1512                                    self.allocate_reservation(buffered_batch)?;
1513                                }
1514                            }
1515                        }
1516                    }
1517                }
1518                BufferedState::Ready => {
1519                    return Poll::Ready(Some(Ok(())));
1520                }
1521                BufferedState::Exhausted => {
1522                    return Poll::Ready(None);
1523                }
1524            }
1525        }
1526    }
1527
1528    /// Get comparison result of streamed row and buffered batches
1529    fn compare_streamed_buffered(&self) -> Result<Ordering> {
1530        if self.streamed_state == StreamedState::Exhausted {
1531            return Ok(Ordering::Greater);
1532        }
1533        if !self.buffered_data.has_buffered_rows() {
1534            return Ok(Ordering::Less);
1535        }
1536
1537        compare_join_arrays(
1538            &self.streamed_batch.join_arrays,
1539            self.streamed_batch.idx,
1540            &self.buffered_data.head_batch().join_arrays,
1541            self.buffered_data.head_batch().range.start,
1542            &self.sort_options,
1543            self.null_equals_null,
1544        )
1545    }
1546
1547    /// Produce join and fill output buffer until reaching target batch size
1548    /// or the join is finished
1549    fn join_partial(&mut self) -> Result<()> {
1550        // Whether to join streamed rows
1551        let mut join_streamed = false;
1552        // Whether to join buffered rows
1553        let mut join_buffered = false;
1554        // For Mark join we store a dummy id to indicate the the row has a match
1555        let mut mark_row_as_match = false;
1556
1557        // determine whether we need to join streamed/buffered rows
1558        match self.current_ordering {
1559            Ordering::Less => {
1560                if matches!(
1561                    self.join_type,
1562                    JoinType::Left
1563                        | JoinType::Right
1564                        | JoinType::RightSemi
1565                        | JoinType::Full
1566                        | JoinType::LeftAnti
1567                        | JoinType::RightAnti
1568                        | JoinType::LeftMark
1569                ) {
1570                    join_streamed = !self.streamed_joined;
1571                }
1572            }
1573            Ordering::Equal => {
1574                if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftMark) {
1575                    mark_row_as_match = matches!(self.join_type, JoinType::LeftMark);
1576                    // if the join filter is specified then its needed to output the streamed index
1577                    // only if it has not been emitted before
1578                    // the `join_filter_matched_idxs` keeps track on if streamed index has a successful
1579                    // filter match and prevents the same index to go into output more than once
1580                    if self.filter.is_some() {
1581                        join_streamed = !self
1582                            .streamed_batch
1583                            .join_filter_matched_idxs
1584                            .contains(&(self.streamed_batch.idx as u64))
1585                            && !self.streamed_joined;
1586                        // if the join filter specified there can be references to buffered columns
1587                        // so buffered columns are needed to access them
1588                        join_buffered = join_streamed;
1589                    } else {
1590                        join_streamed = !self.streamed_joined;
1591                    }
1592                }
1593                if matches!(
1594                    self.join_type,
1595                    JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
1596                ) {
1597                    join_streamed = true;
1598                    join_buffered = true;
1599                };
1600
1601                if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti)
1602                    && self.filter.is_some()
1603                {
1604                    join_streamed = !self.streamed_joined;
1605                    join_buffered = join_streamed;
1606                }
1607            }
1608            Ordering::Greater => {
1609                if matches!(self.join_type, JoinType::Full) {
1610                    join_buffered = !self.buffered_joined;
1611                };
1612            }
1613        }
1614        if !join_streamed && !join_buffered {
1615            // no joined data
1616            self.buffered_data.scanning_finish();
1617            return Ok(());
1618        }
1619
1620        if join_buffered {
1621            // joining streamed/nulls and buffered
1622            while !self.buffered_data.scanning_finished()
1623                && self.output_size < self.batch_size
1624            {
1625                let scanning_idx = self.buffered_data.scanning_idx();
1626                if join_streamed {
1627                    // Join streamed row and buffered row
1628                    self.streamed_batch.append_output_pair(
1629                        Some(self.buffered_data.scanning_batch_idx),
1630                        Some(scanning_idx),
1631                    );
1632                } else {
1633                    // Join nulls and buffered row for FULL join
1634                    self.buffered_data
1635                        .scanning_batch_mut()
1636                        .null_joined
1637                        .push(scanning_idx);
1638                }
1639                self.output_size += 1;
1640                self.buffered_data.scanning_advance();
1641
1642                if self.buffered_data.scanning_finished() {
1643                    self.streamed_joined = join_streamed;
1644                    self.buffered_joined = true;
1645                }
1646            }
1647        } else {
1648            // joining streamed and nulls
1649            let scanning_batch_idx = if self.buffered_data.scanning_finished() {
1650                None
1651            } else {
1652                Some(self.buffered_data.scanning_batch_idx)
1653            };
1654            // For Mark join we store a dummy id to indicate the the row has a match
1655            let scanning_idx = mark_row_as_match.then_some(0);
1656
1657            self.streamed_batch
1658                .append_output_pair(scanning_batch_idx, scanning_idx);
1659            self.output_size += 1;
1660            self.buffered_data.scanning_finish();
1661            self.streamed_joined = true;
1662        }
1663        Ok(())
1664    }
1665
1666    fn freeze_all(&mut self) -> Result<()> {
1667        self.freeze_buffered(self.buffered_data.batches.len())?;
1668        self.freeze_streamed()?;
1669        Ok(())
1670    }
1671
1672    // Produces and stages record batches to ensure dequeued buffered batch
1673    // no longer needed:
1674    //   1. freezes all indices joined to streamed side
1675    //   2. freezes NULLs joined to dequeued buffered batch to "release" it
1676    fn freeze_dequeuing_buffered(&mut self) -> Result<()> {
1677        self.freeze_streamed()?;
1678        // Only freeze and produce the first batch in buffered_data as the batch is fully processed
1679        self.freeze_buffered(1)?;
1680        Ok(())
1681    }
1682
1683    // Produces and stages record batch from buffered indices with corresponding
1684    // NULLs on streamed side.
1685    //
1686    // Applicable only in case of Full join.
1687    //
1688    fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> {
1689        if !matches!(self.join_type, JoinType::Full) {
1690            return Ok(());
1691        }
1692        for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) {
1693            let buffered_indices = UInt64Array::from_iter_values(
1694                buffered_batch.null_joined.iter().map(|&index| index as u64),
1695            );
1696            if let Some(record_batch) = produce_buffered_null_batch(
1697                &self.schema,
1698                &self.streamed_schema,
1699                &buffered_indices,
1700                buffered_batch,
1701            )? {
1702                let num_rows = record_batch.num_rows();
1703                self.staging_output_record_batches
1704                    .filter_mask
1705                    .append_nulls(num_rows);
1706                self.staging_output_record_batches
1707                    .row_indices
1708                    .append_nulls(num_rows);
1709                self.staging_output_record_batches.batch_ids.resize(
1710                    self.staging_output_record_batches.batch_ids.len() + num_rows,
1711                    0,
1712                );
1713
1714                self.staging_output_record_batches
1715                    .batches
1716                    .push(record_batch);
1717            }
1718            buffered_batch.null_joined.clear();
1719        }
1720        Ok(())
1721    }
1722
1723    fn produce_buffered_not_matched(
1724        &mut self,
1725        buffered_batch: &mut BufferedBatch,
1726    ) -> Result<()> {
1727        if !matches!(self.join_type, JoinType::Full) {
1728            return Ok(());
1729        }
1730
1731        // For buffered row which is joined with streamed side rows but all joined rows
1732        // don't satisfy the join filter
1733        let not_matched_buffered_indices = buffered_batch
1734            .join_filter_not_matched_map
1735            .iter()
1736            .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None })
1737            .collect::<Vec<_>>();
1738
1739        let buffered_indices =
1740            UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied());
1741
1742        if let Some(record_batch) = produce_buffered_null_batch(
1743            &self.schema,
1744            &self.streamed_schema,
1745            &buffered_indices,
1746            buffered_batch,
1747        )? {
1748            let num_rows = record_batch.num_rows();
1749
1750            self.staging_output_record_batches
1751                .filter_mask
1752                .append_nulls(num_rows);
1753            self.staging_output_record_batches
1754                .row_indices
1755                .append_nulls(num_rows);
1756            self.staging_output_record_batches.batch_ids.resize(
1757                self.staging_output_record_batches.batch_ids.len() + num_rows,
1758                0,
1759            );
1760            self.staging_output_record_batches
1761                .batches
1762                .push(record_batch);
1763        }
1764        buffered_batch.join_filter_not_matched_map.clear();
1765
1766        Ok(())
1767    }
1768
1769    // Produces and stages record batch for all output indices found
1770    // for current streamed batch and clears staged output indices.
1771    fn freeze_streamed(&mut self) -> Result<()> {
1772        for chunk in self.streamed_batch.output_indices.iter_mut() {
1773            // The row indices of joined streamed batch
1774            let left_indices = chunk.streamed_indices.finish();
1775
1776            if left_indices.is_empty() {
1777                continue;
1778            }
1779
1780            let mut left_columns = self
1781                .streamed_batch
1782                .batch
1783                .columns()
1784                .iter()
1785                .map(|column| take(column, &left_indices, None))
1786                .collect::<Result<Vec<_>, ArrowError>>()?;
1787
1788            // The row indices of joined buffered batch
1789            let right_indices: UInt64Array = chunk.buffered_indices.finish();
1790            let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) {
1791                vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
1792            } else if matches!(
1793                self.join_type,
1794                JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti
1795            ) {
1796                vec![]
1797            } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
1798                fetch_right_columns_by_idxs(
1799                    &self.buffered_data,
1800                    buffered_idx,
1801                    &right_indices,
1802                )?
1803            } else {
1804                // If buffered batch none, meaning it is null joined batch.
1805                // We need to create null arrays for buffered columns to join with streamed rows.
1806                create_unmatched_columns(
1807                    self.join_type,
1808                    &self.buffered_schema,
1809                    right_indices.len(),
1810                )
1811            };
1812
1813            // Prepare the columns we apply join filter on later.
1814            // Only for joined rows between streamed and buffered.
1815            let filter_columns = if chunk.buffered_batch_idx.is_some() {
1816                if !matches!(self.join_type, JoinType::Right) {
1817                    if matches!(
1818                        self.join_type,
1819                        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark
1820                    ) {
1821                        let right_cols = fetch_right_columns_by_idxs(
1822                            &self.buffered_data,
1823                            chunk.buffered_batch_idx.unwrap(),
1824                            &right_indices,
1825                        )?;
1826
1827                        get_filter_column(&self.filter, &left_columns, &right_cols)
1828                    } else if matches!(self.join_type, JoinType::RightAnti) {
1829                        let right_cols = fetch_right_columns_by_idxs(
1830                            &self.buffered_data,
1831                            chunk.buffered_batch_idx.unwrap(),
1832                            &right_indices,
1833                        )?;
1834
1835                        get_filter_column(&self.filter, &right_cols, &left_columns)
1836                    } else {
1837                        get_filter_column(&self.filter, &left_columns, &right_columns)
1838                    }
1839                } else {
1840                    get_filter_column(&self.filter, &right_columns, &left_columns)
1841                }
1842            } else {
1843                // This chunk is totally for null joined rows (outer join), we don't need to apply join filter.
1844                // Any join filter applied only on either streamed or buffered side will be pushed already.
1845                vec![]
1846            };
1847
1848            let columns = if !matches!(self.join_type, JoinType::Right) {
1849                left_columns.extend(right_columns);
1850                left_columns
1851            } else {
1852                right_columns.extend(left_columns);
1853                right_columns
1854            };
1855
1856            let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
1857            // Apply join filter if any
1858            if !filter_columns.is_empty() {
1859                if let Some(f) = &self.filter {
1860                    // Construct batch with only filter columns
1861                    let filter_batch =
1862                        RecordBatch::try_new(Arc::clone(f.schema()), filter_columns)?;
1863
1864                    let filter_result = f
1865                        .expression()
1866                        .evaluate(&filter_batch)?
1867                        .into_array(filter_batch.num_rows())?;
1868
1869                    // The boolean selection mask of the join filter result
1870                    let pre_mask =
1871                        datafusion_common::cast::as_boolean_array(&filter_result)?;
1872
1873                    // If there are nulls in join filter result, exclude them from selecting
1874                    // the rows to output.
1875                    let mask = if pre_mask.null_count() > 0 {
1876                        compute::prep_null_mask_filter(
1877                            datafusion_common::cast::as_boolean_array(&filter_result)?,
1878                        )
1879                    } else {
1880                        pre_mask.clone()
1881                    };
1882
1883                    // Push the filtered batch which contains rows passing join filter to the output
1884                    if matches!(
1885                        self.join_type,
1886                        JoinType::Left
1887                            | JoinType::LeftSemi
1888                            | JoinType::Right
1889                            | JoinType::LeftAnti
1890                            | JoinType::RightAnti
1891                            | JoinType::LeftMark
1892                            | JoinType::Full
1893                    ) {
1894                        self.staging_output_record_batches
1895                            .batches
1896                            .push(output_batch);
1897                    } else {
1898                        let filtered_batch = filter_record_batch(&output_batch, &mask)?;
1899                        self.staging_output_record_batches
1900                            .batches
1901                            .push(filtered_batch);
1902                    }
1903
1904                    if !matches!(self.join_type, JoinType::Full) {
1905                        self.staging_output_record_batches.filter_mask.extend(&mask);
1906                    } else {
1907                        self.staging_output_record_batches
1908                            .filter_mask
1909                            .extend(pre_mask);
1910                    }
1911                    self.staging_output_record_batches
1912                        .row_indices
1913                        .extend(&left_indices);
1914                    self.staging_output_record_batches.batch_ids.resize(
1915                        self.staging_output_record_batches.batch_ids.len()
1916                            + left_indices.len(),
1917                        self.streamed_batch_counter.load(Relaxed),
1918                    );
1919
1920                    // For outer joins, we need to push the null joined rows to the output if
1921                    // all joined rows are failed on the join filter.
1922                    // I.e., if all rows joined from a streamed row are failed with the join filter,
1923                    // we need to join it with nulls as buffered side.
1924                    if matches!(self.join_type, JoinType::Full) {
1925                        let buffered_batch = &mut self.buffered_data.batches
1926                            [chunk.buffered_batch_idx.unwrap()];
1927
1928                        for i in 0..pre_mask.len() {
1929                            // If the buffered row is not joined with streamed side,
1930                            // skip it.
1931                            if right_indices.is_null(i) {
1932                                continue;
1933                            }
1934
1935                            let buffered_index = right_indices.value(i);
1936
1937                            buffered_batch.join_filter_not_matched_map.insert(
1938                                buffered_index,
1939                                *buffered_batch
1940                                    .join_filter_not_matched_map
1941                                    .get(&buffered_index)
1942                                    .unwrap_or(&true)
1943                                    && !pre_mask.value(i),
1944                            );
1945                        }
1946                    }
1947                } else {
1948                    self.staging_output_record_batches
1949                        .batches
1950                        .push(output_batch);
1951                }
1952            } else {
1953                self.staging_output_record_batches
1954                    .batches
1955                    .push(output_batch);
1956            }
1957        }
1958
1959        self.streamed_batch.output_indices.clear();
1960
1961        Ok(())
1962    }
1963
1964    fn output_record_batch_and_reset(&mut self) -> Result<RecordBatch> {
1965        let record_batch =
1966            concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
1967        self.join_metrics.output_batches.add(1);
1968        self.join_metrics.output_rows.add(record_batch.num_rows());
1969        // If join filter exists, `self.output_size` is not accurate as we don't know the exact
1970        // number of rows in the output record batch. If streamed row joined with buffered rows,
1971        // once join filter is applied, the number of output rows may be more than 1.
1972        // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened
1973        // when the join filter is applied and all rows are filtered out.
1974        if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size {
1975            self.output_size = 0;
1976        } else {
1977            self.output_size -= record_batch.num_rows();
1978        }
1979
1980        if !(self.filter.is_some()
1981            && matches!(
1982                self.join_type,
1983                JoinType::Left
1984                    | JoinType::LeftSemi
1985                    | JoinType::Right
1986                    | JoinType::LeftAnti
1987                    | JoinType::RightAnti
1988                    | JoinType::LeftMark
1989                    | JoinType::Full
1990            ))
1991        {
1992            self.staging_output_record_batches.batches.clear();
1993        }
1994        Ok(record_batch)
1995    }
1996
1997    fn filter_joined_batch(&mut self) -> Result<RecordBatch> {
1998        let record_batch =
1999            concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
2000        let mut out_indices = self.staging_output_record_batches.row_indices.finish();
2001        let mut out_mask = self.staging_output_record_batches.filter_mask.finish();
2002        let mut batch_ids = &self.staging_output_record_batches.batch_ids;
2003        let default_batch_ids = vec![0; record_batch.num_rows()];
2004
2005        // If only nulls come in and indices sizes doesn't match with expected record batch count
2006        // generate missing indices
2007        // Happens for null joined batches for Full Join
2008        if out_indices.null_count() == out_indices.len()
2009            && out_indices.len() != record_batch.num_rows()
2010        {
2011            out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]);
2012            out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]);
2013            batch_ids = &default_batch_ids;
2014        }
2015
2016        if out_mask.is_empty() {
2017            self.staging_output_record_batches.batches.clear();
2018            return Ok(record_batch);
2019        }
2020
2021        let maybe_corrected_mask = get_corrected_filter_mask(
2022            self.join_type,
2023            &out_indices,
2024            batch_ids,
2025            &out_mask,
2026            record_batch.num_rows(),
2027        );
2028
2029        let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask {
2030            filtered_join_mask
2031        } else {
2032            &out_mask
2033        };
2034
2035        self.filter_record_batch_by_join_type(record_batch, corrected_mask)
2036    }
2037
2038    fn filter_record_batch_by_join_type(
2039        &mut self,
2040        record_batch: RecordBatch,
2041        corrected_mask: &BooleanArray,
2042    ) -> Result<RecordBatch> {
2043        let mut filtered_record_batch =
2044            filter_record_batch(&record_batch, corrected_mask)?;
2045        let left_columns_length = self.streamed_schema.fields.len();
2046        let right_columns_length = self.buffered_schema.fields.len();
2047
2048        if matches!(
2049            self.join_type,
2050            JoinType::Left | JoinType::LeftMark | JoinType::Right
2051        ) {
2052            let null_mask = compute::not(corrected_mask)?;
2053            let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?;
2054
2055            let mut right_columns = create_unmatched_columns(
2056                self.join_type,
2057                &self.buffered_schema,
2058                null_joined_batch.num_rows(),
2059            );
2060
2061            let columns = if !matches!(self.join_type, JoinType::Right) {
2062                let mut left_columns = null_joined_batch
2063                    .columns()
2064                    .iter()
2065                    .take(right_columns_length)
2066                    .cloned()
2067                    .collect::<Vec<_>>();
2068
2069                left_columns.extend(right_columns);
2070                left_columns
2071            } else {
2072                let left_columns = null_joined_batch
2073                    .columns()
2074                    .iter()
2075                    .skip(left_columns_length)
2076                    .cloned()
2077                    .collect::<Vec<_>>();
2078
2079                right_columns.extend(left_columns);
2080                right_columns
2081            };
2082
2083            // Push the streamed/buffered batch joined nulls to the output
2084            let null_joined_streamed_batch =
2085                RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
2086
2087            filtered_record_batch = concat_batches(
2088                &self.schema,
2089                &[filtered_record_batch, null_joined_streamed_batch],
2090            )?;
2091        } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
2092            let output_column_indices = (0..left_columns_length).collect::<Vec<_>>();
2093            filtered_record_batch =
2094                filtered_record_batch.project(&output_column_indices)?;
2095        } else if matches!(self.join_type, JoinType::RightAnti) {
2096            let output_column_indices = (0..right_columns_length).collect::<Vec<_>>();
2097            filtered_record_batch =
2098                filtered_record_batch.project(&output_column_indices)?;
2099        } else if matches!(self.join_type, JoinType::Full)
2100            && corrected_mask.false_count() > 0
2101        {
2102            // Find rows which joined by key but Filter predicate evaluated as false
2103            let joined_filter_not_matched_mask = compute::not(corrected_mask)?;
2104            let joined_filter_not_matched_batch =
2105                filter_record_batch(&record_batch, &joined_filter_not_matched_mask)?;
2106
2107            // Add left unmatched rows adding the right side as nulls
2108            let right_null_columns = self
2109                .buffered_schema
2110                .fields()
2111                .iter()
2112                .map(|f| {
2113                    new_null_array(
2114                        f.data_type(),
2115                        joined_filter_not_matched_batch.num_rows(),
2116                    )
2117                })
2118                .collect::<Vec<_>>();
2119
2120            let mut result_joined = joined_filter_not_matched_batch
2121                .columns()
2122                .iter()
2123                .take(left_columns_length)
2124                .cloned()
2125                .collect::<Vec<_>>();
2126
2127            result_joined.extend(right_null_columns);
2128
2129            let left_null_joined_batch =
2130                RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?;
2131
2132            // Add right unmatched rows adding the left side as nulls
2133            let mut result_joined = self
2134                .streamed_schema
2135                .fields()
2136                .iter()
2137                .map(|f| {
2138                    new_null_array(
2139                        f.data_type(),
2140                        joined_filter_not_matched_batch.num_rows(),
2141                    )
2142                })
2143                .collect::<Vec<_>>();
2144
2145            let right_data = joined_filter_not_matched_batch
2146                .columns()
2147                .iter()
2148                .skip(left_columns_length)
2149                .cloned()
2150                .collect::<Vec<_>>();
2151
2152            result_joined.extend(right_data);
2153
2154            filtered_record_batch = concat_batches(
2155                &self.schema,
2156                &[filtered_record_batch, left_null_joined_batch],
2157            )?;
2158        }
2159
2160        self.staging_output_record_batches.clear();
2161
2162        Ok(filtered_record_batch)
2163    }
2164}
2165
2166fn create_unmatched_columns(
2167    join_type: JoinType,
2168    schema: &SchemaRef,
2169    size: usize,
2170) -> Vec<ArrayRef> {
2171    if matches!(join_type, JoinType::LeftMark) {
2172        vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef]
2173    } else {
2174        schema
2175            .fields()
2176            .iter()
2177            .map(|f| new_null_array(f.data_type(), size))
2178            .collect::<Vec<_>>()
2179    }
2180}
2181
2182/// Gets the arrays which join filters are applied on.
2183fn get_filter_column(
2184    join_filter: &Option<JoinFilter>,
2185    streamed_columns: &[ArrayRef],
2186    buffered_columns: &[ArrayRef],
2187) -> Vec<ArrayRef> {
2188    let mut filter_columns = vec![];
2189
2190    if let Some(f) = join_filter {
2191        let left_columns = f
2192            .column_indices()
2193            .iter()
2194            .filter(|col_index| col_index.side == JoinSide::Left)
2195            .map(|i| Arc::clone(&streamed_columns[i.index]))
2196            .collect::<Vec<_>>();
2197
2198        let right_columns = f
2199            .column_indices()
2200            .iter()
2201            .filter(|col_index| col_index.side == JoinSide::Right)
2202            .map(|i| Arc::clone(&buffered_columns[i.index]))
2203            .collect::<Vec<_>>();
2204
2205        filter_columns.extend(left_columns);
2206        filter_columns.extend(right_columns);
2207    }
2208
2209    filter_columns
2210}
2211
2212fn produce_buffered_null_batch(
2213    schema: &SchemaRef,
2214    streamed_schema: &SchemaRef,
2215    buffered_indices: &PrimitiveArray<UInt64Type>,
2216    buffered_batch: &BufferedBatch,
2217) -> Result<Option<RecordBatch>> {
2218    if buffered_indices.is_empty() {
2219        return Ok(None);
2220    }
2221
2222    // Take buffered (right) columns
2223    let right_columns =
2224        fetch_right_columns_from_batch_by_idxs(buffered_batch, buffered_indices)?;
2225
2226    // Create null streamed (left) columns
2227    let mut left_columns = streamed_schema
2228        .fields()
2229        .iter()
2230        .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
2231        .collect::<Vec<_>>();
2232
2233    left_columns.extend(right_columns);
2234
2235    Ok(Some(RecordBatch::try_new(
2236        Arc::clone(schema),
2237        left_columns,
2238    )?))
2239}
2240
2241/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices
2242#[inline(always)]
2243fn fetch_right_columns_by_idxs(
2244    buffered_data: &BufferedData,
2245    buffered_batch_idx: usize,
2246    buffered_indices: &UInt64Array,
2247) -> Result<Vec<ArrayRef>> {
2248    fetch_right_columns_from_batch_by_idxs(
2249        &buffered_data.batches[buffered_batch_idx],
2250        buffered_indices,
2251    )
2252}
2253
2254#[inline(always)]
2255fn fetch_right_columns_from_batch_by_idxs(
2256    buffered_batch: &BufferedBatch,
2257    buffered_indices: &UInt64Array,
2258) -> Result<Vec<ArrayRef>> {
2259    match (&buffered_batch.spill_file, &buffered_batch.batch) {
2260        // In memory batch
2261        (None, Some(batch)) => Ok(batch
2262            .columns()
2263            .iter()
2264            .map(|column| take(column, &buffered_indices, None))
2265            .collect::<Result<Vec<_>, ArrowError>>()
2266            .map_err(Into::<DataFusionError>::into)?),
2267        // If the batch was spilled to disk, less likely
2268        (Some(spill_file), None) => {
2269            let mut buffered_cols: Vec<ArrayRef> =
2270                Vec::with_capacity(buffered_indices.len());
2271
2272            let file = BufReader::new(File::open(spill_file.path())?);
2273            let reader = StreamReader::try_new(file, None)?;
2274
2275            for batch in reader {
2276                batch?.columns().iter().for_each(|column| {
2277                    buffered_cols.extend(take(column, &buffered_indices, None))
2278                });
2279            }
2280
2281            Ok(buffered_cols)
2282        }
2283        // Invalid combination
2284        (spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()),
2285    }
2286}
2287
2288/// Buffered data contains all buffered batches with one unique join key
2289#[derive(Debug, Default)]
2290struct BufferedData {
2291    /// Buffered batches with the same key
2292    pub batches: VecDeque<BufferedBatch>,
2293    /// current scanning batch index used in join_partial()
2294    pub scanning_batch_idx: usize,
2295    /// current scanning offset used in join_partial()
2296    pub scanning_offset: usize,
2297}
2298
2299impl BufferedData {
2300    pub fn head_batch(&self) -> &BufferedBatch {
2301        self.batches.front().unwrap()
2302    }
2303
2304    pub fn tail_batch(&self) -> &BufferedBatch {
2305        self.batches.back().unwrap()
2306    }
2307
2308    pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch {
2309        self.batches.back_mut().unwrap()
2310    }
2311
2312    pub fn has_buffered_rows(&self) -> bool {
2313        self.batches.iter().any(|batch| !batch.range.is_empty())
2314    }
2315
2316    pub fn scanning_reset(&mut self) {
2317        self.scanning_batch_idx = 0;
2318        self.scanning_offset = 0;
2319    }
2320
2321    pub fn scanning_advance(&mut self) {
2322        self.scanning_offset += 1;
2323        while !self.scanning_finished() && self.scanning_batch_finished() {
2324            self.scanning_batch_idx += 1;
2325            self.scanning_offset = 0;
2326        }
2327    }
2328
2329    pub fn scanning_batch(&self) -> &BufferedBatch {
2330        &self.batches[self.scanning_batch_idx]
2331    }
2332
2333    pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch {
2334        &mut self.batches[self.scanning_batch_idx]
2335    }
2336
2337    pub fn scanning_idx(&self) -> usize {
2338        self.scanning_batch().range.start + self.scanning_offset
2339    }
2340
2341    pub fn scanning_batch_finished(&self) -> bool {
2342        self.scanning_offset == self.scanning_batch().range.len()
2343    }
2344
2345    pub fn scanning_finished(&self) -> bool {
2346        self.scanning_batch_idx == self.batches.len()
2347    }
2348
2349    pub fn scanning_finish(&mut self) {
2350        self.scanning_batch_idx = self.batches.len();
2351        self.scanning_offset = 0;
2352    }
2353}
2354
2355/// Get join array refs of given batch and join columns
2356fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec<ArrayRef> {
2357    on_column
2358        .iter()
2359        .map(|c| {
2360            let num_rows = batch.num_rows();
2361            let c = c.evaluate(batch).unwrap();
2362            c.into_array(num_rows).unwrap()
2363        })
2364        .collect()
2365}
2366
2367/// Get comparison result of two rows of join arrays
2368fn compare_join_arrays(
2369    left_arrays: &[ArrayRef],
2370    left: usize,
2371    right_arrays: &[ArrayRef],
2372    right: usize,
2373    sort_options: &[SortOptions],
2374    null_equals_null: bool,
2375) -> Result<Ordering> {
2376    let mut res = Ordering::Equal;
2377    for ((left_array, right_array), sort_options) in
2378        left_arrays.iter().zip(right_arrays).zip(sort_options)
2379    {
2380        macro_rules! compare_value {
2381            ($T:ty) => {{
2382                let left_array = left_array.as_any().downcast_ref::<$T>().unwrap();
2383                let right_array = right_array.as_any().downcast_ref::<$T>().unwrap();
2384                match (left_array.is_null(left), right_array.is_null(right)) {
2385                    (false, false) => {
2386                        let left_value = &left_array.value(left);
2387                        let right_value = &right_array.value(right);
2388                        res = left_value.partial_cmp(right_value).unwrap();
2389                        if sort_options.descending {
2390                            res = res.reverse();
2391                        }
2392                    }
2393                    (true, false) => {
2394                        res = if sort_options.nulls_first {
2395                            Ordering::Less
2396                        } else {
2397                            Ordering::Greater
2398                        };
2399                    }
2400                    (false, true) => {
2401                        res = if sort_options.nulls_first {
2402                            Ordering::Greater
2403                        } else {
2404                            Ordering::Less
2405                        };
2406                    }
2407                    _ => {
2408                        res = if null_equals_null {
2409                            Ordering::Equal
2410                        } else {
2411                            Ordering::Less
2412                        };
2413                    }
2414                }
2415            }};
2416        }
2417
2418        match left_array.data_type() {
2419            DataType::Null => {}
2420            DataType::Boolean => compare_value!(BooleanArray),
2421            DataType::Int8 => compare_value!(Int8Array),
2422            DataType::Int16 => compare_value!(Int16Array),
2423            DataType::Int32 => compare_value!(Int32Array),
2424            DataType::Int64 => compare_value!(Int64Array),
2425            DataType::UInt8 => compare_value!(UInt8Array),
2426            DataType::UInt16 => compare_value!(UInt16Array),
2427            DataType::UInt32 => compare_value!(UInt32Array),
2428            DataType::UInt64 => compare_value!(UInt64Array),
2429            DataType::Float32 => compare_value!(Float32Array),
2430            DataType::Float64 => compare_value!(Float64Array),
2431            DataType::Utf8 => compare_value!(StringArray),
2432            DataType::LargeUtf8 => compare_value!(LargeStringArray),
2433            DataType::Decimal128(..) => compare_value!(Decimal128Array),
2434            DataType::Timestamp(time_unit, None) => match time_unit {
2435                TimeUnit::Second => compare_value!(TimestampSecondArray),
2436                TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
2437                TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
2438                TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
2439            },
2440            DataType::Date32 => compare_value!(Date32Array),
2441            DataType::Date64 => compare_value!(Date64Array),
2442            dt => {
2443                return not_impl_err!(
2444                    "Unsupported data type in sort merge join comparator: {}",
2445                    dt
2446                );
2447            }
2448        }
2449        if !res.is_eq() {
2450            break;
2451        }
2452    }
2453    Ok(res)
2454}
2455
2456/// A faster version of compare_join_arrays() that only output whether
2457/// the given two rows are equal
2458fn is_join_arrays_equal(
2459    left_arrays: &[ArrayRef],
2460    left: usize,
2461    right_arrays: &[ArrayRef],
2462    right: usize,
2463) -> Result<bool> {
2464    let mut is_equal = true;
2465    for (left_array, right_array) in left_arrays.iter().zip(right_arrays) {
2466        macro_rules! compare_value {
2467            ($T:ty) => {{
2468                match (left_array.is_null(left), right_array.is_null(right)) {
2469                    (false, false) => {
2470                        let left_array =
2471                            left_array.as_any().downcast_ref::<$T>().unwrap();
2472                        let right_array =
2473                            right_array.as_any().downcast_ref::<$T>().unwrap();
2474                        if left_array.value(left) != right_array.value(right) {
2475                            is_equal = false;
2476                        }
2477                    }
2478                    (true, false) => is_equal = false,
2479                    (false, true) => is_equal = false,
2480                    _ => {}
2481                }
2482            }};
2483        }
2484
2485        match left_array.data_type() {
2486            DataType::Null => {}
2487            DataType::Boolean => compare_value!(BooleanArray),
2488            DataType::Int8 => compare_value!(Int8Array),
2489            DataType::Int16 => compare_value!(Int16Array),
2490            DataType::Int32 => compare_value!(Int32Array),
2491            DataType::Int64 => compare_value!(Int64Array),
2492            DataType::UInt8 => compare_value!(UInt8Array),
2493            DataType::UInt16 => compare_value!(UInt16Array),
2494            DataType::UInt32 => compare_value!(UInt32Array),
2495            DataType::UInt64 => compare_value!(UInt64Array),
2496            DataType::Float32 => compare_value!(Float32Array),
2497            DataType::Float64 => compare_value!(Float64Array),
2498            DataType::Utf8 => compare_value!(StringArray),
2499            DataType::LargeUtf8 => compare_value!(LargeStringArray),
2500            DataType::Decimal128(..) => compare_value!(Decimal128Array),
2501            DataType::Timestamp(time_unit, None) => match time_unit {
2502                TimeUnit::Second => compare_value!(TimestampSecondArray),
2503                TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
2504                TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
2505                TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
2506            },
2507            DataType::Date32 => compare_value!(Date32Array),
2508            DataType::Date64 => compare_value!(Date64Array),
2509            dt => {
2510                return not_impl_err!(
2511                    "Unsupported data type in sort merge join comparator: {}",
2512                    dt
2513                );
2514            }
2515        }
2516        if !is_equal {
2517            return Ok(false);
2518        }
2519    }
2520    Ok(true)
2521}
2522
2523#[cfg(test)]
2524mod tests {
2525    use std::sync::Arc;
2526
2527    use arrow::array::{
2528        builder::{BooleanBuilder, UInt64Builder},
2529        BooleanArray, Date32Array, Date64Array, Int32Array, RecordBatch, UInt64Array,
2530    };
2531    use arrow::compute::{concat_batches, filter_record_batch, SortOptions};
2532    use arrow::datatypes::{DataType, Field, Schema};
2533
2534    use datafusion_common::JoinSide;
2535    use datafusion_common::JoinType::*;
2536    use datafusion_common::{
2537        assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
2538    };
2539    use datafusion_execution::config::SessionConfig;
2540    use datafusion_execution::disk_manager::DiskManagerConfig;
2541    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2542    use datafusion_execution::TaskContext;
2543    use datafusion_expr::Operator;
2544    use datafusion_physical_expr::expressions::BinaryExpr;
2545
2546    use crate::expressions::Column;
2547    use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches};
2548    use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
2549    use crate::joins::SortMergeJoinExec;
2550    use crate::test::TestMemoryExec;
2551    use crate::test::{build_table_i32, build_table_i32_two_cols};
2552    use crate::{common, ExecutionPlan};
2553
2554    fn build_table(
2555        a: (&str, &Vec<i32>),
2556        b: (&str, &Vec<i32>),
2557        c: (&str, &Vec<i32>),
2558    ) -> Arc<dyn ExecutionPlan> {
2559        let batch = build_table_i32(a, b, c);
2560        let schema = batch.schema();
2561        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2562    }
2563
2564    fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> {
2565        let schema = batches.first().unwrap().schema();
2566        TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap()
2567    }
2568
2569    fn build_date_table(
2570        a: (&str, &Vec<i32>),
2571        b: (&str, &Vec<i32>),
2572        c: (&str, &Vec<i32>),
2573    ) -> Arc<dyn ExecutionPlan> {
2574        let schema = Schema::new(vec![
2575            Field::new(a.0, DataType::Date32, false),
2576            Field::new(b.0, DataType::Date32, false),
2577            Field::new(c.0, DataType::Date32, false),
2578        ]);
2579
2580        let batch = RecordBatch::try_new(
2581            Arc::new(schema),
2582            vec![
2583                Arc::new(Date32Array::from(a.1.clone())),
2584                Arc::new(Date32Array::from(b.1.clone())),
2585                Arc::new(Date32Array::from(c.1.clone())),
2586            ],
2587        )
2588        .unwrap();
2589
2590        let schema = batch.schema();
2591        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2592    }
2593
2594    fn build_date64_table(
2595        a: (&str, &Vec<i64>),
2596        b: (&str, &Vec<i64>),
2597        c: (&str, &Vec<i64>),
2598    ) -> Arc<dyn ExecutionPlan> {
2599        let schema = Schema::new(vec![
2600            Field::new(a.0, DataType::Date64, false),
2601            Field::new(b.0, DataType::Date64, false),
2602            Field::new(c.0, DataType::Date64, false),
2603        ]);
2604
2605        let batch = RecordBatch::try_new(
2606            Arc::new(schema),
2607            vec![
2608                Arc::new(Date64Array::from(a.1.clone())),
2609                Arc::new(Date64Array::from(b.1.clone())),
2610                Arc::new(Date64Array::from(c.1.clone())),
2611            ],
2612        )
2613        .unwrap();
2614
2615        let schema = batch.schema();
2616        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2617    }
2618
2619    /// returns a table with 3 columns of i32 in memory
2620    pub fn build_table_i32_nullable(
2621        a: (&str, &Vec<Option<i32>>),
2622        b: (&str, &Vec<Option<i32>>),
2623        c: (&str, &Vec<Option<i32>>),
2624    ) -> Arc<dyn ExecutionPlan> {
2625        let schema = Arc::new(Schema::new(vec![
2626            Field::new(a.0, DataType::Int32, true),
2627            Field::new(b.0, DataType::Int32, true),
2628            Field::new(c.0, DataType::Int32, true),
2629        ]));
2630        let batch = RecordBatch::try_new(
2631            Arc::clone(&schema),
2632            vec![
2633                Arc::new(Int32Array::from(a.1.clone())),
2634                Arc::new(Int32Array::from(b.1.clone())),
2635                Arc::new(Int32Array::from(c.1.clone())),
2636            ],
2637        )
2638        .unwrap();
2639        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2640    }
2641
2642    pub fn build_table_two_cols(
2643        a: (&str, &Vec<i32>),
2644        b: (&str, &Vec<i32>),
2645    ) -> Arc<dyn ExecutionPlan> {
2646        let batch = build_table_i32_two_cols(a, b);
2647        let schema = batch.schema();
2648        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
2649    }
2650
2651    fn join(
2652        left: Arc<dyn ExecutionPlan>,
2653        right: Arc<dyn ExecutionPlan>,
2654        on: JoinOn,
2655        join_type: JoinType,
2656    ) -> Result<SortMergeJoinExec> {
2657        let sort_options = vec![SortOptions::default(); on.len()];
2658        SortMergeJoinExec::try_new(left, right, on, None, join_type, sort_options, false)
2659    }
2660
2661    fn join_with_options(
2662        left: Arc<dyn ExecutionPlan>,
2663        right: Arc<dyn ExecutionPlan>,
2664        on: JoinOn,
2665        join_type: JoinType,
2666        sort_options: Vec<SortOptions>,
2667        null_equals_null: bool,
2668    ) -> Result<SortMergeJoinExec> {
2669        SortMergeJoinExec::try_new(
2670            left,
2671            right,
2672            on,
2673            None,
2674            join_type,
2675            sort_options,
2676            null_equals_null,
2677        )
2678    }
2679
2680    fn join_with_filter(
2681        left: Arc<dyn ExecutionPlan>,
2682        right: Arc<dyn ExecutionPlan>,
2683        on: JoinOn,
2684        filter: JoinFilter,
2685        join_type: JoinType,
2686        sort_options: Vec<SortOptions>,
2687        null_equals_null: bool,
2688    ) -> Result<SortMergeJoinExec> {
2689        SortMergeJoinExec::try_new(
2690            left,
2691            right,
2692            on,
2693            Some(filter),
2694            join_type,
2695            sort_options,
2696            null_equals_null,
2697        )
2698    }
2699
2700    async fn join_collect(
2701        left: Arc<dyn ExecutionPlan>,
2702        right: Arc<dyn ExecutionPlan>,
2703        on: JoinOn,
2704        join_type: JoinType,
2705    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2706        let sort_options = vec![SortOptions::default(); on.len()];
2707        join_collect_with_options(left, right, on, join_type, sort_options, false).await
2708    }
2709
2710    async fn join_collect_with_filter(
2711        left: Arc<dyn ExecutionPlan>,
2712        right: Arc<dyn ExecutionPlan>,
2713        on: JoinOn,
2714        filter: JoinFilter,
2715        join_type: JoinType,
2716    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2717        let sort_options = vec![SortOptions::default(); on.len()];
2718
2719        let task_ctx = Arc::new(TaskContext::default());
2720        let join =
2721            join_with_filter(left, right, on, filter, join_type, sort_options, false)?;
2722        let columns = columns(&join.schema());
2723
2724        let stream = join.execute(0, task_ctx)?;
2725        let batches = common::collect(stream).await?;
2726        Ok((columns, batches))
2727    }
2728
2729    async fn join_collect_with_options(
2730        left: Arc<dyn ExecutionPlan>,
2731        right: Arc<dyn ExecutionPlan>,
2732        on: JoinOn,
2733        join_type: JoinType,
2734        sort_options: Vec<SortOptions>,
2735        null_equals_null: bool,
2736    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2737        let task_ctx = Arc::new(TaskContext::default());
2738        let join = join_with_options(
2739            left,
2740            right,
2741            on,
2742            join_type,
2743            sort_options,
2744            null_equals_null,
2745        )?;
2746        let columns = columns(&join.schema());
2747
2748        let stream = join.execute(0, task_ctx)?;
2749        let batches = common::collect(stream).await?;
2750        Ok((columns, batches))
2751    }
2752
2753    async fn join_collect_batch_size_equals_two(
2754        left: Arc<dyn ExecutionPlan>,
2755        right: Arc<dyn ExecutionPlan>,
2756        on: JoinOn,
2757        join_type: JoinType,
2758    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2759        let task_ctx = TaskContext::default()
2760            .with_session_config(SessionConfig::new().with_batch_size(2));
2761        let task_ctx = Arc::new(task_ctx);
2762        let join = join(left, right, on, join_type)?;
2763        let columns = columns(&join.schema());
2764
2765        let stream = join.execute(0, task_ctx)?;
2766        let batches = common::collect(stream).await?;
2767        Ok((columns, batches))
2768    }
2769
2770    #[tokio::test]
2771    async fn join_inner_one() -> Result<()> {
2772        let left = build_table(
2773            ("a1", &vec![1, 2, 3]),
2774            ("b1", &vec![4, 5, 5]), // this has a repetition
2775            ("c1", &vec![7, 8, 9]),
2776        );
2777        let right = build_table(
2778            ("a2", &vec![10, 20, 30]),
2779            ("b1", &vec![4, 5, 6]),
2780            ("c2", &vec![70, 80, 90]),
2781        );
2782
2783        let on = vec![(
2784            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2785            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
2786        )];
2787
2788        let (_, batches) = join_collect(left, right, on, Inner).await?;
2789
2790        let expected = [
2791            "+----+----+----+----+----+----+",
2792            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2793            "+----+----+----+----+----+----+",
2794            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2795            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2796            "| 3  | 5  | 9  | 20 | 5  | 80 |",
2797            "+----+----+----+----+----+----+",
2798        ];
2799        // The output order is important as SMJ preserves sortedness
2800        assert_batches_eq!(expected, &batches);
2801        Ok(())
2802    }
2803
2804    #[tokio::test]
2805    async fn join_inner_two() -> Result<()> {
2806        let left = build_table(
2807            ("a1", &vec![1, 2, 2]),
2808            ("b2", &vec![1, 2, 2]),
2809            ("c1", &vec![7, 8, 9]),
2810        );
2811        let right = build_table(
2812            ("a1", &vec![1, 2, 3]),
2813            ("b2", &vec![1, 2, 2]),
2814            ("c2", &vec![70, 80, 90]),
2815        );
2816        let on = vec![
2817            (
2818                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2819                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2820            ),
2821            (
2822                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2823                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2824            ),
2825        ];
2826
2827        let (_columns, batches) = join_collect(left, right, on, Inner).await?;
2828        let expected = [
2829            "+----+----+----+----+----+----+",
2830            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2831            "+----+----+----+----+----+----+",
2832            "| 1  | 1  | 7  | 1  | 1  | 70 |",
2833            "| 2  | 2  | 8  | 2  | 2  | 80 |",
2834            "| 2  | 2  | 9  | 2  | 2  | 80 |",
2835            "+----+----+----+----+----+----+",
2836        ];
2837        // The output order is important as SMJ preserves sortedness
2838        assert_batches_eq!(expected, &batches);
2839        Ok(())
2840    }
2841
2842    #[tokio::test]
2843    async fn join_inner_two_two() -> Result<()> {
2844        let left = build_table(
2845            ("a1", &vec![1, 1, 2]),
2846            ("b2", &vec![1, 1, 2]),
2847            ("c1", &vec![7, 8, 9]),
2848        );
2849        let right = build_table(
2850            ("a1", &vec![1, 1, 3]),
2851            ("b2", &vec![1, 1, 2]),
2852            ("c2", &vec![70, 80, 90]),
2853        );
2854        let on = vec![
2855            (
2856                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2857                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2858            ),
2859            (
2860                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2861                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2862            ),
2863        ];
2864
2865        let (_columns, batches) = join_collect(left, right, on, Inner).await?;
2866        let expected = [
2867            "+----+----+----+----+----+----+",
2868            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2869            "+----+----+----+----+----+----+",
2870            "| 1  | 1  | 7  | 1  | 1  | 70 |",
2871            "| 1  | 1  | 7  | 1  | 1  | 80 |",
2872            "| 1  | 1  | 8  | 1  | 1  | 70 |",
2873            "| 1  | 1  | 8  | 1  | 1  | 80 |",
2874            "+----+----+----+----+----+----+",
2875        ];
2876        // The output order is important as SMJ preserves sortedness
2877        assert_batches_eq!(expected, &batches);
2878        Ok(())
2879    }
2880
2881    #[tokio::test]
2882    async fn join_inner_with_nulls() -> Result<()> {
2883        let left = build_table_i32_nullable(
2884            ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]),
2885            ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field
2886            ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field
2887        );
2888        let right = build_table_i32_nullable(
2889            ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]),
2890            ("b2", &vec![None, Some(1), Some(2), Some(2)]),
2891            ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]),
2892        );
2893        let on = vec![
2894            (
2895                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2896                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2897            ),
2898            (
2899                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2900                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2901            ),
2902        ];
2903
2904        let (_, batches) = join_collect(left, right, on, Inner).await?;
2905        let expected = [
2906            "+----+----+----+----+----+----+",
2907            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2908            "+----+----+----+----+----+----+",
2909            "| 1  | 1  |    | 1  | 1  | 70 |",
2910            "| 2  | 2  | 8  | 2  | 2  | 80 |",
2911            "| 2  | 2  | 9  | 2  | 2  | 80 |",
2912            "+----+----+----+----+----+----+",
2913        ];
2914        // The output order is important as SMJ preserves sortedness
2915        assert_batches_eq!(expected, &batches);
2916        Ok(())
2917    }
2918
2919    #[tokio::test]
2920    async fn join_inner_with_nulls_with_options() -> Result<()> {
2921        let left = build_table_i32_nullable(
2922            ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]),
2923            ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field
2924            ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field
2925        );
2926        let right = build_table_i32_nullable(
2927            ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]),
2928            ("b2", &vec![Some(2), Some(2), Some(1), None]),
2929            ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]),
2930        );
2931        let on = vec![
2932            (
2933                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2934                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2935            ),
2936            (
2937                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2938                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2939            ),
2940        ];
2941        let (_, batches) = join_collect_with_options(
2942            left,
2943            right,
2944            on,
2945            Inner,
2946            vec![
2947                SortOptions {
2948                    descending: true,
2949                    nulls_first: false,
2950                };
2951                2
2952            ],
2953            true,
2954        )
2955        .await?;
2956        let expected = [
2957            "+----+----+----+----+----+----+",
2958            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2959            "+----+----+----+----+----+----+",
2960            "| 2  | 2  | 9  | 2  | 2  | 80 |",
2961            "| 2  | 2  | 8  | 2  | 2  | 80 |",
2962            "| 1  | 1  |    | 1  | 1  | 70 |",
2963            "| 1  |    | 1  | 1  |    | 10 |",
2964            "+----+----+----+----+----+----+",
2965        ];
2966        // The output order is important as SMJ preserves sortedness
2967        assert_batches_eq!(expected, &batches);
2968        Ok(())
2969    }
2970
2971    #[tokio::test]
2972    async fn join_inner_output_two_batches() -> Result<()> {
2973        let left = build_table(
2974            ("a1", &vec![1, 2, 2]),
2975            ("b2", &vec![1, 2, 2]),
2976            ("c1", &vec![7, 8, 9]),
2977        );
2978        let right = build_table(
2979            ("a1", &vec![1, 2, 3]),
2980            ("b2", &vec![1, 2, 2]),
2981            ("c2", &vec![70, 80, 90]),
2982        );
2983        let on = vec![
2984            (
2985                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2986                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2987            ),
2988            (
2989                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2990                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2991            ),
2992        ];
2993
2994        let (_, batches) =
2995            join_collect_batch_size_equals_two(left, right, on, Inner).await?;
2996        let expected = [
2997            "+----+----+----+----+----+----+",
2998            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2999            "+----+----+----+----+----+----+",
3000            "| 1  | 1  | 7  | 1  | 1  | 70 |",
3001            "| 2  | 2  | 8  | 2  | 2  | 80 |",
3002            "| 2  | 2  | 9  | 2  | 2  | 80 |",
3003            "+----+----+----+----+----+----+",
3004        ];
3005        assert_eq!(batches.len(), 2);
3006        assert_eq!(batches[0].num_rows(), 2);
3007        assert_eq!(batches[1].num_rows(), 1);
3008        // The output order is important as SMJ preserves sortedness
3009        assert_batches_eq!(expected, &batches);
3010        Ok(())
3011    }
3012
3013    #[tokio::test]
3014    async fn join_left_one() -> Result<()> {
3015        let left = build_table(
3016            ("a1", &vec![1, 2, 3]),
3017            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
3018            ("c1", &vec![7, 8, 9]),
3019        );
3020        let right = build_table(
3021            ("a2", &vec![10, 20, 30]),
3022            ("b1", &vec![4, 5, 6]),
3023            ("c2", &vec![70, 80, 90]),
3024        );
3025        let on = vec![(
3026            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3027            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3028        )];
3029
3030        let (_, batches) = join_collect(left, right, on, Left).await?;
3031        let expected = [
3032            "+----+----+----+----+----+----+",
3033            "| a1 | b1 | c1 | a2 | b1 | c2 |",
3034            "+----+----+----+----+----+----+",
3035            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3036            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3037            "| 3  | 7  | 9  |    |    |    |",
3038            "+----+----+----+----+----+----+",
3039        ];
3040        // The output order is important as SMJ preserves sortedness
3041        assert_batches_eq!(expected, &batches);
3042        Ok(())
3043    }
3044
3045    #[tokio::test]
3046    async fn join_right_one() -> Result<()> {
3047        let left = build_table(
3048            ("a1", &vec![1, 2, 3]),
3049            ("b1", &vec![4, 5, 7]),
3050            ("c1", &vec![7, 8, 9]),
3051        );
3052        let right = build_table(
3053            ("a2", &vec![10, 20, 30]),
3054            ("b1", &vec![4, 5, 6]), // 6 does not exist on the left
3055            ("c2", &vec![70, 80, 90]),
3056        );
3057        let on = vec![(
3058            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3059            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3060        )];
3061
3062        let (_, batches) = join_collect(left, right, on, Right).await?;
3063        let expected = [
3064            "+----+----+----+----+----+----+",
3065            "| a1 | b1 | c1 | a2 | b1 | c2 |",
3066            "+----+----+----+----+----+----+",
3067            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3068            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3069            "|    |    |    | 30 | 6  | 90 |",
3070            "+----+----+----+----+----+----+",
3071        ];
3072        // The output order is important as SMJ preserves sortedness
3073        assert_batches_eq!(expected, &batches);
3074        Ok(())
3075    }
3076
3077    #[tokio::test]
3078    async fn join_full_one() -> Result<()> {
3079        let left = build_table(
3080            ("a1", &vec![1, 2, 3]),
3081            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
3082            ("c1", &vec![7, 8, 9]),
3083        );
3084        let right = build_table(
3085            ("a2", &vec![10, 20, 30]),
3086            ("b2", &vec![4, 5, 6]),
3087            ("c2", &vec![70, 80, 90]),
3088        );
3089        let on = vec![(
3090            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
3091            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
3092        )];
3093
3094        let (_, batches) = join_collect(left, right, on, Full).await?;
3095        let expected = [
3096            "+----+----+----+----+----+----+",
3097            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3098            "+----+----+----+----+----+----+",
3099            "|    |    |    | 30 | 6  | 90 |",
3100            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3101            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3102            "| 3  | 7  | 9  |    |    |    |",
3103            "+----+----+----+----+----+----+",
3104        ];
3105        assert_batches_sorted_eq!(expected, &batches);
3106        Ok(())
3107    }
3108
3109    #[tokio::test]
3110    async fn join_left_anti() -> Result<()> {
3111        let left = build_table(
3112            ("a1", &vec![1, 2, 2, 3, 5]),
3113            ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right
3114            ("c1", &vec![7, 8, 8, 9, 11]),
3115        );
3116        let right = build_table(
3117            ("a2", &vec![10, 20, 30]),
3118            ("b1", &vec![4, 5, 6]),
3119            ("c2", &vec![70, 80, 90]),
3120        );
3121        let on = vec![(
3122            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3123            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3124        )];
3125
3126        let (_, batches) = join_collect(left, right, on, LeftAnti).await?;
3127        let expected = [
3128            "+----+----+----+",
3129            "| a1 | b1 | c1 |",
3130            "+----+----+----+",
3131            "| 3  | 7  | 9  |",
3132            "| 5  | 7  | 11 |",
3133            "+----+----+----+",
3134        ];
3135        // The output order is important as SMJ preserves sortedness
3136        assert_batches_eq!(expected, &batches);
3137        Ok(())
3138    }
3139
3140    #[tokio::test]
3141    async fn join_right_anti_one_one() -> Result<()> {
3142        let left = build_table(
3143            ("a1", &vec![1, 2, 2]),
3144            ("b1", &vec![4, 5, 5]),
3145            ("c1", &vec![7, 8, 8]),
3146        );
3147        let right =
3148            build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]));
3149        let on = vec![(
3150            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3151            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3152        )];
3153
3154        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
3155        let expected = [
3156            "+----+----+",
3157            "| a2 | b1 |",
3158            "+----+----+",
3159            "| 30 | 6  |",
3160            "+----+----+",
3161        ];
3162        // The output order is important as SMJ preserves sortedness
3163        assert_batches_eq!(expected, &batches);
3164
3165        let left2 = build_table(
3166            ("a1", &vec![1, 2, 2]),
3167            ("b1", &vec![4, 5, 5]),
3168            ("c1", &vec![7, 8, 8]),
3169        );
3170        let right2 = build_table(
3171            ("a2", &vec![10, 20, 30]),
3172            ("b1", &vec![4, 5, 6]),
3173            ("c2", &vec![70, 80, 90]),
3174        );
3175
3176        let on = vec![(
3177            Arc::new(Column::new_with_schema("b1", &left2.schema())?) as _,
3178            Arc::new(Column::new_with_schema("b1", &right2.schema())?) as _,
3179        )];
3180
3181        let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?;
3182        let expected2 = [
3183            "+----+----+----+",
3184            "| a2 | b1 | c2 |",
3185            "+----+----+----+",
3186            "| 30 | 6  | 90 |",
3187            "+----+----+----+",
3188        ];
3189        // The output order is important as SMJ preserves sortedness
3190        assert_batches_eq!(expected2, &batches2);
3191
3192        Ok(())
3193    }
3194
3195    #[tokio::test]
3196    async fn join_right_anti_two_two() -> Result<()> {
3197        let left = build_table(
3198            ("a1", &vec![1, 2, 2]),
3199            ("b1", &vec![4, 5, 5]),
3200            ("c1", &vec![7, 8, 8]),
3201        );
3202        let right =
3203            build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]));
3204        let on = vec![
3205            (
3206                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3207                Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
3208            ),
3209            (
3210                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3211                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3212            ),
3213        ];
3214
3215        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
3216        let expected = [
3217            "+----+----+",
3218            "| a2 | b1 |",
3219            "+----+----+",
3220            "| 10 | 4  |",
3221            "| 20 | 5  |",
3222            "| 30 | 6  |",
3223            "+----+----+",
3224        ];
3225        // The output order is important as SMJ preserves sortedness
3226        assert_batches_eq!(expected, &batches);
3227
3228        let left = build_table(
3229            ("a1", &vec![1, 2, 2]),
3230            ("b1", &vec![4, 5, 5]),
3231            ("c1", &vec![7, 8, 8]),
3232        );
3233        let right = build_table(
3234            ("a2", &vec![10, 20, 30]),
3235            ("b1", &vec![4, 5, 6]),
3236            ("c2", &vec![70, 80, 90]),
3237        );
3238
3239        let on = vec![
3240            (
3241                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3242                Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
3243            ),
3244            (
3245                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3246                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3247            ),
3248        ];
3249
3250        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
3251        let expected = [
3252            "+----+----+----+",
3253            "| a2 | b1 | c2 |",
3254            "+----+----+----+",
3255            "| 10 | 4  | 70 |",
3256            "| 20 | 5  | 80 |",
3257            "| 30 | 6  | 90 |",
3258            "+----+----+----+",
3259        ];
3260        // The output order is important as SMJ preserves sortedness
3261        assert_batches_eq!(expected, &batches);
3262
3263        Ok(())
3264    }
3265
3266    #[tokio::test]
3267    async fn join_right_anti_two_with_filter() -> Result<()> {
3268        let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30]));
3269        let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20]));
3270        let on = vec![
3271            (
3272                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3273                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
3274            ),
3275            (
3276                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3277                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3278            ),
3279        ];
3280        let filter = JoinFilter::new(
3281            Arc::new(BinaryExpr::new(
3282                Arc::new(Column::new("c2", 1)),
3283                Operator::Gt,
3284                Arc::new(Column::new("c1", 0)),
3285            )),
3286            vec![
3287                ColumnIndex {
3288                    index: 2,
3289                    side: JoinSide::Left,
3290                },
3291                ColumnIndex {
3292                    index: 2,
3293                    side: JoinSide::Right,
3294                },
3295            ],
3296            Arc::new(Schema::new(vec![
3297                Field::new("c1", DataType::Int32, true),
3298                Field::new("c2", DataType::Int32, true),
3299            ])),
3300        );
3301        let (_, batches) =
3302            join_collect_with_filter(left, right, on, filter, RightAnti).await?;
3303        let expected = [
3304            "+----+----+----+",
3305            "| a1 | b1 | c2 |",
3306            "+----+----+----+",
3307            "| 1  | 10 | 20 |",
3308            "+----+----+----+",
3309        ];
3310        assert_batches_eq!(expected, &batches);
3311        Ok(())
3312    }
3313
3314    #[tokio::test]
3315    async fn join_right_anti_with_nulls() -> Result<()> {
3316        let left = build_table_i32_nullable(
3317            ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]),
3318            ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]),
3319            ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]),
3320        );
3321        let right = build_table_i32_nullable(
3322            ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]),
3323            ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field
3324            ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field
3325        );
3326        let on = vec![
3327            (
3328                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3329                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
3330            ),
3331            (
3332                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3333                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3334            ),
3335        ];
3336
3337        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
3338        let expected = [
3339            "+----+----+----+",
3340            "| a1 | b1 | c2 |",
3341            "+----+----+----+",
3342            "| 2  |    | 8  |",
3343            "+----+----+----+",
3344        ];
3345        // The output order is important as SMJ preserves sortedness
3346        assert_batches_eq!(expected, &batches);
3347        Ok(())
3348    }
3349
3350    #[tokio::test]
3351    async fn join_right_anti_with_nulls_with_options() -> Result<()> {
3352        let left = build_table_i32_nullable(
3353            ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]),
3354            ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]),
3355            ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]),
3356        );
3357        let right = build_table_i32_nullable(
3358            ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]),
3359            ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field
3360            ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field
3361        );
3362        let on = vec![
3363            (
3364                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3365                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
3366            ),
3367            (
3368                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3369                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3370            ),
3371        ];
3372
3373        let (_, batches) = join_collect_with_options(
3374            left,
3375            right,
3376            on,
3377            RightAnti,
3378            vec![
3379                SortOptions {
3380                    descending: true,
3381                    nulls_first: false,
3382                };
3383                2
3384            ],
3385            true,
3386        )
3387        .await?;
3388
3389        let expected = [
3390            "+----+----+----+",
3391            "| a1 | b1 | c2 |",
3392            "+----+----+----+",
3393            "| 3  |    | 9  |",
3394            "| 2  | 5  |    |",
3395            "| 2  | 5  | 8  |",
3396            "+----+----+----+",
3397        ];
3398        // The output order is important as SMJ preserves sortedness
3399        assert_batches_eq!(expected, &batches);
3400        Ok(())
3401    }
3402
3403    #[tokio::test]
3404    async fn join_right_anti_output_two_batches() -> Result<()> {
3405        let left = build_table(
3406            ("a1", &vec![1, 2, 2]),
3407            ("b1", &vec![4, 5, 5]),
3408            ("c1", &vec![7, 8, 8]),
3409        );
3410        let right = build_table(
3411            ("a2", &vec![10, 20, 30]),
3412            ("b1", &vec![4, 5, 6]),
3413            ("c2", &vec![70, 80, 90]),
3414        );
3415        let on = vec![
3416            (
3417                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
3418                Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
3419            ),
3420            (
3421                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3422                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3423            ),
3424        ];
3425
3426        let (_, batches) =
3427            join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?;
3428        let expected = [
3429            "+----+----+----+",
3430            "| a1 | b1 | c1 |",
3431            "+----+----+----+",
3432            "| 1  | 4  | 7  |",
3433            "| 2  | 5  | 8  |",
3434            "| 2  | 5  | 8  |",
3435            "+----+----+----+",
3436        ];
3437        assert_eq!(batches.len(), 2);
3438        assert_eq!(batches[0].num_rows(), 2);
3439        assert_eq!(batches[1].num_rows(), 1);
3440        assert_batches_eq!(expected, &batches);
3441        Ok(())
3442    }
3443
3444    #[tokio::test]
3445    async fn join_semi() -> Result<()> {
3446        let left = build_table(
3447            ("a1", &vec![1, 2, 2, 3]),
3448            ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
3449            ("c1", &vec![7, 8, 8, 9]),
3450        );
3451        let right = build_table(
3452            ("a2", &vec![10, 20, 30]),
3453            ("b1", &vec![4, 5, 6]), // 5 is double on the right
3454            ("c2", &vec![70, 80, 90]),
3455        );
3456        let on = vec![(
3457            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3458            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3459        )];
3460
3461        let (_, batches) = join_collect(left, right, on, LeftSemi).await?;
3462        let expected = [
3463            "+----+----+----+",
3464            "| a1 | b1 | c1 |",
3465            "+----+----+----+",
3466            "| 1  | 4  | 7  |",
3467            "| 2  | 5  | 8  |",
3468            "| 2  | 5  | 8  |",
3469            "+----+----+----+",
3470        ];
3471        // The output order is important as SMJ preserves sortedness
3472        assert_batches_eq!(expected, &batches);
3473        Ok(())
3474    }
3475
3476    #[tokio::test]
3477    async fn join_left_mark() -> Result<()> {
3478        let left = build_table(
3479            ("a1", &vec![1, 2, 2, 3]),
3480            ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
3481            ("c1", &vec![7, 8, 8, 9]),
3482        );
3483        let right = build_table(
3484            ("a2", &vec![10, 20, 30, 40]),
3485            ("b1", &vec![4, 4, 5, 6]), // 5 is double on the right
3486            ("c2", &vec![60, 70, 80, 90]),
3487        );
3488        let on = vec![(
3489            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3490            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3491        )];
3492
3493        let (_, batches) = join_collect(left, right, on, LeftMark).await?;
3494        let expected = [
3495            "+----+----+----+-------+",
3496            "| a1 | b1 | c1 | mark  |",
3497            "+----+----+----+-------+",
3498            "| 1  | 4  | 7  | true  |",
3499            "| 2  | 5  | 8  | true  |",
3500            "| 2  | 5  | 8  | true  |",
3501            "| 3  | 7  | 9  | false |",
3502            "+----+----+----+-------+",
3503        ];
3504        // The output order is important as SMJ preserves sortedness
3505        assert_batches_eq!(expected, &batches);
3506        Ok(())
3507    }
3508
3509    #[tokio::test]
3510    async fn join_with_duplicated_column_names() -> Result<()> {
3511        let left = build_table(
3512            ("a", &vec![1, 2, 3]),
3513            ("b", &vec![4, 5, 7]),
3514            ("c", &vec![7, 8, 9]),
3515        );
3516        let right = build_table(
3517            ("a", &vec![10, 20, 30]),
3518            ("b", &vec![1, 2, 7]),
3519            ("c", &vec![70, 80, 90]),
3520        );
3521        let on = vec![(
3522            // join on a=b so there are duplicate column names on unjoined columns
3523            Arc::new(Column::new_with_schema("a", &left.schema())?) as _,
3524            Arc::new(Column::new_with_schema("b", &right.schema())?) as _,
3525        )];
3526
3527        let (_, batches) = join_collect(left, right, on, Inner).await?;
3528        let expected = [
3529            "+---+---+---+----+---+----+",
3530            "| a | b | c | a  | b | c  |",
3531            "+---+---+---+----+---+----+",
3532            "| 1 | 4 | 7 | 10 | 1 | 70 |",
3533            "| 2 | 5 | 8 | 20 | 2 | 80 |",
3534            "+---+---+---+----+---+----+",
3535        ];
3536        // The output order is important as SMJ preserves sortedness
3537        assert_batches_eq!(expected, &batches);
3538        Ok(())
3539    }
3540
3541    #[tokio::test]
3542    async fn join_date32() -> Result<()> {
3543        let left = build_date_table(
3544            ("a1", &vec![1, 2, 3]),
3545            ("b1", &vec![19107, 19108, 19108]), // this has a repetition
3546            ("c1", &vec![7, 8, 9]),
3547        );
3548        let right = build_date_table(
3549            ("a2", &vec![10, 20, 30]),
3550            ("b1", &vec![19107, 19108, 19109]),
3551            ("c2", &vec![70, 80, 90]),
3552        );
3553
3554        let on = vec![(
3555            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3556            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3557        )];
3558
3559        let (_, batches) = join_collect(left, right, on, Inner).await?;
3560
3561        let expected = ["+------------+------------+------------+------------+------------+------------+",
3562            "| a1         | b1         | c1         | a2         | b1         | c2         |",
3563            "+------------+------------+------------+------------+------------+------------+",
3564            "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |",
3565            "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
3566            "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
3567            "+------------+------------+------------+------------+------------+------------+"];
3568        // The output order is important as SMJ preserves sortedness
3569        assert_batches_eq!(expected, &batches);
3570        Ok(())
3571    }
3572
3573    #[tokio::test]
3574    async fn join_date64() -> Result<()> {
3575        let left = build_date64_table(
3576            ("a1", &vec![1, 2, 3]),
3577            ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition
3578            ("c1", &vec![7, 8, 9]),
3579        );
3580        let right = build_date64_table(
3581            ("a2", &vec![10, 20, 30]),
3582            ("b1", &vec![1650703441000, 1650503441000, 1650903441000]),
3583            ("c2", &vec![70, 80, 90]),
3584        );
3585
3586        let on = vec![(
3587            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3588            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3589        )];
3590
3591        let (_, batches) = join_collect(left, right, on, Inner).await?;
3592
3593        let expected = ["+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
3594            "| a1                      | b1                  | c1                      | a2                      | b1                  | c2                      |",
3595            "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
3596            "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |",
3597            "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
3598            "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
3599            "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+"];
3600        // The output order is important as SMJ preserves sortedness
3601        assert_batches_eq!(expected, &batches);
3602        Ok(())
3603    }
3604
3605    #[tokio::test]
3606    async fn join_left_sort_order() -> Result<()> {
3607        let left = build_table(
3608            ("a1", &vec![0, 1, 2, 3, 4, 5]),
3609            ("b1", &vec![3, 4, 5, 6, 6, 7]),
3610            ("c1", &vec![4, 5, 6, 7, 8, 9]),
3611        );
3612        let right = build_table(
3613            ("a2", &vec![0, 10, 20, 30, 40]),
3614            ("b2", &vec![2, 4, 6, 6, 8]),
3615            ("c2", &vec![50, 60, 70, 80, 90]),
3616        );
3617        let on = vec![(
3618            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3619            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3620        )];
3621
3622        let (_, batches) = join_collect(left, right, on, Left).await?;
3623        let expected = [
3624            "+----+----+----+----+----+----+",
3625            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3626            "+----+----+----+----+----+----+",
3627            "| 0  | 3  | 4  |    |    |    |",
3628            "| 1  | 4  | 5  | 10 | 4  | 60 |",
3629            "| 2  | 5  | 6  |    |    |    |",
3630            "| 3  | 6  | 7  | 20 | 6  | 70 |",
3631            "| 3  | 6  | 7  | 30 | 6  | 80 |",
3632            "| 4  | 6  | 8  | 20 | 6  | 70 |",
3633            "| 4  | 6  | 8  | 30 | 6  | 80 |",
3634            "| 5  | 7  | 9  |    |    |    |",
3635            "+----+----+----+----+----+----+",
3636        ];
3637        assert_batches_eq!(expected, &batches);
3638        Ok(())
3639    }
3640
3641    #[tokio::test]
3642    async fn join_right_sort_order() -> Result<()> {
3643        let left = build_table(
3644            ("a1", &vec![0, 1, 2, 3]),
3645            ("b1", &vec![3, 4, 5, 7]),
3646            ("c1", &vec![6, 7, 8, 9]),
3647        );
3648        let right = build_table(
3649            ("a2", &vec![0, 10, 20, 30]),
3650            ("b2", &vec![2, 4, 5, 6]),
3651            ("c2", &vec![60, 70, 80, 90]),
3652        );
3653        let on = vec![(
3654            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3655            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3656        )];
3657
3658        let (_, batches) = join_collect(left, right, on, Right).await?;
3659        let expected = [
3660            "+----+----+----+----+----+----+",
3661            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3662            "+----+----+----+----+----+----+",
3663            "|    |    |    | 0  | 2  | 60 |",
3664            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3665            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3666            "|    |    |    | 30 | 6  | 90 |",
3667            "+----+----+----+----+----+----+",
3668        ];
3669        assert_batches_eq!(expected, &batches);
3670        Ok(())
3671    }
3672
3673    #[tokio::test]
3674    async fn join_left_multiple_batches() -> Result<()> {
3675        let left_batch_1 = build_table_i32(
3676            ("a1", &vec![0, 1, 2]),
3677            ("b1", &vec![3, 4, 5]),
3678            ("c1", &vec![4, 5, 6]),
3679        );
3680        let left_batch_2 = build_table_i32(
3681            ("a1", &vec![3, 4, 5, 6]),
3682            ("b1", &vec![6, 6, 7, 9]),
3683            ("c1", &vec![7, 8, 9, 9]),
3684        );
3685        let right_batch_1 = build_table_i32(
3686            ("a2", &vec![0, 10, 20]),
3687            ("b2", &vec![2, 4, 6]),
3688            ("c2", &vec![50, 60, 70]),
3689        );
3690        let right_batch_2 = build_table_i32(
3691            ("a2", &vec![30, 40]),
3692            ("b2", &vec![6, 8]),
3693            ("c2", &vec![80, 90]),
3694        );
3695        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
3696        let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
3697        let on = vec![(
3698            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3699            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3700        )];
3701
3702        let (_, batches) = join_collect(left, right, on, Left).await?;
3703        let expected = vec![
3704            "+----+----+----+----+----+----+",
3705            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3706            "+----+----+----+----+----+----+",
3707            "| 0  | 3  | 4  |    |    |    |",
3708            "| 1  | 4  | 5  | 10 | 4  | 60 |",
3709            "| 2  | 5  | 6  |    |    |    |",
3710            "| 3  | 6  | 7  | 20 | 6  | 70 |",
3711            "| 3  | 6  | 7  | 30 | 6  | 80 |",
3712            "| 4  | 6  | 8  | 20 | 6  | 70 |",
3713            "| 4  | 6  | 8  | 30 | 6  | 80 |",
3714            "| 5  | 7  | 9  |    |    |    |",
3715            "| 6  | 9  | 9  |    |    |    |",
3716            "+----+----+----+----+----+----+",
3717        ];
3718        assert_batches_eq!(expected, &batches);
3719        Ok(())
3720    }
3721
3722    #[tokio::test]
3723    async fn join_right_multiple_batches() -> Result<()> {
3724        let right_batch_1 = build_table_i32(
3725            ("a2", &vec![0, 1, 2]),
3726            ("b2", &vec![3, 4, 5]),
3727            ("c2", &vec![4, 5, 6]),
3728        );
3729        let right_batch_2 = build_table_i32(
3730            ("a2", &vec![3, 4, 5, 6]),
3731            ("b2", &vec![6, 6, 7, 9]),
3732            ("c2", &vec![7, 8, 9, 9]),
3733        );
3734        let left_batch_1 = build_table_i32(
3735            ("a1", &vec![0, 10, 20]),
3736            ("b1", &vec![2, 4, 6]),
3737            ("c1", &vec![50, 60, 70]),
3738        );
3739        let left_batch_2 = build_table_i32(
3740            ("a1", &vec![30, 40]),
3741            ("b1", &vec![6, 8]),
3742            ("c1", &vec![80, 90]),
3743        );
3744        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
3745        let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
3746        let on = vec![(
3747            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3748            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3749        )];
3750
3751        let (_, batches) = join_collect(left, right, on, Right).await?;
3752        let expected = vec![
3753            "+----+----+----+----+----+----+",
3754            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3755            "+----+----+----+----+----+----+",
3756            "|    |    |    | 0  | 3  | 4  |",
3757            "| 10 | 4  | 60 | 1  | 4  | 5  |",
3758            "|    |    |    | 2  | 5  | 6  |",
3759            "| 20 | 6  | 70 | 3  | 6  | 7  |",
3760            "| 30 | 6  | 80 | 3  | 6  | 7  |",
3761            "| 20 | 6  | 70 | 4  | 6  | 8  |",
3762            "| 30 | 6  | 80 | 4  | 6  | 8  |",
3763            "|    |    |    | 5  | 7  | 9  |",
3764            "|    |    |    | 6  | 9  | 9  |",
3765            "+----+----+----+----+----+----+",
3766        ];
3767        assert_batches_eq!(expected, &batches);
3768        Ok(())
3769    }
3770
3771    #[tokio::test]
3772    async fn join_full_multiple_batches() -> Result<()> {
3773        let left_batch_1 = build_table_i32(
3774            ("a1", &vec![0, 1, 2]),
3775            ("b1", &vec![3, 4, 5]),
3776            ("c1", &vec![4, 5, 6]),
3777        );
3778        let left_batch_2 = build_table_i32(
3779            ("a1", &vec![3, 4, 5, 6]),
3780            ("b1", &vec![6, 6, 7, 9]),
3781            ("c1", &vec![7, 8, 9, 9]),
3782        );
3783        let right_batch_1 = build_table_i32(
3784            ("a2", &vec![0, 10, 20]),
3785            ("b2", &vec![2, 4, 6]),
3786            ("c2", &vec![50, 60, 70]),
3787        );
3788        let right_batch_2 = build_table_i32(
3789            ("a2", &vec![30, 40]),
3790            ("b2", &vec![6, 8]),
3791            ("c2", &vec![80, 90]),
3792        );
3793        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
3794        let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
3795        let on = vec![(
3796            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3797            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3798        )];
3799
3800        let (_, batches) = join_collect(left, right, on, Full).await?;
3801        let expected = vec![
3802            "+----+----+----+----+----+----+",
3803            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3804            "+----+----+----+----+----+----+",
3805            "|    |    |    | 0  | 2  | 50 |",
3806            "|    |    |    | 40 | 8  | 90 |",
3807            "| 0  | 3  | 4  |    |    |    |",
3808            "| 1  | 4  | 5  | 10 | 4  | 60 |",
3809            "| 2  | 5  | 6  |    |    |    |",
3810            "| 3  | 6  | 7  | 20 | 6  | 70 |",
3811            "| 3  | 6  | 7  | 30 | 6  | 80 |",
3812            "| 4  | 6  | 8  | 20 | 6  | 70 |",
3813            "| 4  | 6  | 8  | 30 | 6  | 80 |",
3814            "| 5  | 7  | 9  |    |    |    |",
3815            "| 6  | 9  | 9  |    |    |    |",
3816            "+----+----+----+----+----+----+",
3817        ];
3818        assert_batches_sorted_eq!(expected, &batches);
3819        Ok(())
3820    }
3821
3822    #[tokio::test]
3823    async fn overallocation_single_batch_no_spill() -> Result<()> {
3824        let left = build_table(
3825            ("a1", &vec![0, 1, 2, 3, 4, 5]),
3826            ("b1", &vec![1, 2, 3, 4, 5, 6]),
3827            ("c1", &vec![4, 5, 6, 7, 8, 9]),
3828        );
3829        let right = build_table(
3830            ("a2", &vec![0, 10, 20, 30, 40]),
3831            ("b2", &vec![1, 3, 4, 6, 8]),
3832            ("c2", &vec![50, 60, 70, 80, 90]),
3833        );
3834        let on = vec![(
3835            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3836            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3837        )];
3838        let sort_options = vec![SortOptions::default(); on.len()];
3839
3840        let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark];
3841
3842        // Disable DiskManager to prevent spilling
3843        let runtime = RuntimeEnvBuilder::new()
3844            .with_memory_limit(100, 1.0)
3845            .with_disk_manager(DiskManagerConfig::Disabled)
3846            .build_arc()?;
3847        let session_config = SessionConfig::default().with_batch_size(50);
3848
3849        for join_type in join_types {
3850            let task_ctx = TaskContext::default()
3851                .with_session_config(session_config.clone())
3852                .with_runtime(Arc::clone(&runtime));
3853            let task_ctx = Arc::new(task_ctx);
3854
3855            let join = join_with_options(
3856                Arc::clone(&left),
3857                Arc::clone(&right),
3858                on.clone(),
3859                join_type,
3860                sort_options.clone(),
3861                false,
3862            )?;
3863
3864            let stream = join.execute(0, task_ctx)?;
3865            let err = common::collect(stream).await.unwrap_err();
3866
3867            assert_contains!(err.to_string(), "Failed to allocate additional");
3868            assert_contains!(err.to_string(), "SMJStream[0]");
3869            assert_contains!(err.to_string(), "Disk spilling disabled");
3870            assert!(join.metrics().is_some());
3871            assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
3872            assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
3873            assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
3874        }
3875
3876        Ok(())
3877    }
3878
3879    #[tokio::test]
3880    async fn overallocation_multi_batch_no_spill() -> Result<()> {
3881        let left_batch_1 = build_table_i32(
3882            ("a1", &vec![0, 1]),
3883            ("b1", &vec![1, 1]),
3884            ("c1", &vec![4, 5]),
3885        );
3886        let left_batch_2 = build_table_i32(
3887            ("a1", &vec![2, 3]),
3888            ("b1", &vec![1, 1]),
3889            ("c1", &vec![6, 7]),
3890        );
3891        let left_batch_3 = build_table_i32(
3892            ("a1", &vec![4, 5]),
3893            ("b1", &vec![1, 1]),
3894            ("c1", &vec![8, 9]),
3895        );
3896        let right_batch_1 = build_table_i32(
3897            ("a2", &vec![0, 10]),
3898            ("b2", &vec![1, 1]),
3899            ("c2", &vec![50, 60]),
3900        );
3901        let right_batch_2 = build_table_i32(
3902            ("a2", &vec![20, 30]),
3903            ("b2", &vec![1, 1]),
3904            ("c2", &vec![70, 80]),
3905        );
3906        let right_batch_3 =
3907            build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
3908        let left =
3909            build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
3910        let right =
3911            build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
3912        let on = vec![(
3913            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3914            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3915        )];
3916        let sort_options = vec![SortOptions::default(); on.len()];
3917
3918        let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark];
3919
3920        // Disable DiskManager to prevent spilling
3921        let runtime = RuntimeEnvBuilder::new()
3922            .with_memory_limit(100, 1.0)
3923            .with_disk_manager(DiskManagerConfig::Disabled)
3924            .build_arc()?;
3925        let session_config = SessionConfig::default().with_batch_size(50);
3926
3927        for join_type in join_types {
3928            let task_ctx = TaskContext::default()
3929                .with_session_config(session_config.clone())
3930                .with_runtime(Arc::clone(&runtime));
3931            let task_ctx = Arc::new(task_ctx);
3932            let join = join_with_options(
3933                Arc::clone(&left),
3934                Arc::clone(&right),
3935                on.clone(),
3936                join_type,
3937                sort_options.clone(),
3938                false,
3939            )?;
3940
3941            let stream = join.execute(0, task_ctx)?;
3942            let err = common::collect(stream).await.unwrap_err();
3943
3944            assert_contains!(err.to_string(), "Failed to allocate additional");
3945            assert_contains!(err.to_string(), "SMJStream[0]");
3946            assert_contains!(err.to_string(), "Disk spilling disabled");
3947            assert!(join.metrics().is_some());
3948            assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
3949            assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
3950            assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
3951        }
3952
3953        Ok(())
3954    }
3955
3956    #[tokio::test]
3957    async fn overallocation_single_batch_spill() -> Result<()> {
3958        let left = build_table(
3959            ("a1", &vec![0, 1, 2, 3, 4, 5]),
3960            ("b1", &vec![1, 2, 3, 4, 5, 6]),
3961            ("c1", &vec![4, 5, 6, 7, 8, 9]),
3962        );
3963        let right = build_table(
3964            ("a2", &vec![0, 10, 20, 30, 40]),
3965            ("b2", &vec![1, 3, 4, 6, 8]),
3966            ("c2", &vec![50, 60, 70, 80, 90]),
3967        );
3968        let on = vec![(
3969            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3970            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
3971        )];
3972        let sort_options = vec![SortOptions::default(); on.len()];
3973
3974        let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark];
3975
3976        // Enable DiskManager to allow spilling
3977        let runtime = RuntimeEnvBuilder::new()
3978            .with_memory_limit(100, 1.0)
3979            .with_disk_manager(DiskManagerConfig::NewOs)
3980            .build_arc()?;
3981
3982        for batch_size in [1, 50] {
3983            let session_config = SessionConfig::default().with_batch_size(batch_size);
3984
3985            for join_type in &join_types {
3986                let task_ctx = TaskContext::default()
3987                    .with_session_config(session_config.clone())
3988                    .with_runtime(Arc::clone(&runtime));
3989                let task_ctx = Arc::new(task_ctx);
3990
3991                let join = join_with_options(
3992                    Arc::clone(&left),
3993                    Arc::clone(&right),
3994                    on.clone(),
3995                    *join_type,
3996                    sort_options.clone(),
3997                    false,
3998                )?;
3999
4000                let stream = join.execute(0, task_ctx)?;
4001                let spilled_join_result = common::collect(stream).await.unwrap();
4002
4003                assert!(join.metrics().is_some());
4004                assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
4005                assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
4006                assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
4007
4008                // Run the test with no spill configuration as
4009                let task_ctx_no_spill =
4010                    TaskContext::default().with_session_config(session_config.clone());
4011                let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
4012
4013                let join = join_with_options(
4014                    Arc::clone(&left),
4015                    Arc::clone(&right),
4016                    on.clone(),
4017                    *join_type,
4018                    sort_options.clone(),
4019                    false,
4020                )?;
4021                let stream = join.execute(0, task_ctx_no_spill)?;
4022                let no_spilled_join_result = common::collect(stream).await.unwrap();
4023
4024                assert!(join.metrics().is_some());
4025                assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
4026                assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
4027                assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
4028                // Compare spilled and non spilled data to check spill logic doesn't corrupt the data
4029                assert_eq!(spilled_join_result, no_spilled_join_result);
4030            }
4031        }
4032
4033        Ok(())
4034    }
4035
4036    #[tokio::test]
4037    async fn overallocation_multi_batch_spill() -> Result<()> {
4038        let left_batch_1 = build_table_i32(
4039            ("a1", &vec![0, 1]),
4040            ("b1", &vec![1, 1]),
4041            ("c1", &vec![4, 5]),
4042        );
4043        let left_batch_2 = build_table_i32(
4044            ("a1", &vec![2, 3]),
4045            ("b1", &vec![1, 1]),
4046            ("c1", &vec![6, 7]),
4047        );
4048        let left_batch_3 = build_table_i32(
4049            ("a1", &vec![4, 5]),
4050            ("b1", &vec![1, 1]),
4051            ("c1", &vec![8, 9]),
4052        );
4053        let right_batch_1 = build_table_i32(
4054            ("a2", &vec![0, 10]),
4055            ("b2", &vec![1, 1]),
4056            ("c2", &vec![50, 60]),
4057        );
4058        let right_batch_2 = build_table_i32(
4059            ("a2", &vec![20, 30]),
4060            ("b2", &vec![1, 1]),
4061            ("c2", &vec![70, 80]),
4062        );
4063        let right_batch_3 =
4064            build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
4065        let left =
4066            build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
4067        let right =
4068            build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
4069        let on = vec![(
4070            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
4071            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
4072        )];
4073        let sort_options = vec![SortOptions::default(); on.len()];
4074
4075        let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark];
4076
4077        // Enable DiskManager to allow spilling
4078        let runtime = RuntimeEnvBuilder::new()
4079            .with_memory_limit(500, 1.0)
4080            .with_disk_manager(DiskManagerConfig::NewOs)
4081            .build_arc()?;
4082
4083        for batch_size in [1, 50] {
4084            let session_config = SessionConfig::default().with_batch_size(batch_size);
4085
4086            for join_type in &join_types {
4087                let task_ctx = TaskContext::default()
4088                    .with_session_config(session_config.clone())
4089                    .with_runtime(Arc::clone(&runtime));
4090                let task_ctx = Arc::new(task_ctx);
4091                let join = join_with_options(
4092                    Arc::clone(&left),
4093                    Arc::clone(&right),
4094                    on.clone(),
4095                    *join_type,
4096                    sort_options.clone(),
4097                    false,
4098                )?;
4099
4100                let stream = join.execute(0, task_ctx)?;
4101                let spilled_join_result = common::collect(stream).await.unwrap();
4102                assert!(join.metrics().is_some());
4103                assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
4104                assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
4105                assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
4106
4107                // Run the test with no spill configuration as
4108                let task_ctx_no_spill =
4109                    TaskContext::default().with_session_config(session_config.clone());
4110                let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
4111
4112                let join = join_with_options(
4113                    Arc::clone(&left),
4114                    Arc::clone(&right),
4115                    on.clone(),
4116                    *join_type,
4117                    sort_options.clone(),
4118                    false,
4119                )?;
4120                let stream = join.execute(0, task_ctx_no_spill)?;
4121                let no_spilled_join_result = common::collect(stream).await.unwrap();
4122
4123                assert!(join.metrics().is_some());
4124                assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
4125                assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
4126                assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
4127                // Compare spilled and non spilled data to check spill logic doesn't corrupt the data
4128                assert_eq!(spilled_join_result, no_spilled_join_result);
4129            }
4130        }
4131
4132        Ok(())
4133    }
4134
4135    fn build_joined_record_batches() -> Result<JoinedRecordBatches> {
4136        let schema = Arc::new(Schema::new(vec![
4137            Field::new("a", DataType::Int32, true),
4138            Field::new("b", DataType::Int32, true),
4139            Field::new("x", DataType::Int32, true),
4140            Field::new("y", DataType::Int32, true),
4141        ]));
4142
4143        let mut batches = JoinedRecordBatches {
4144            batches: vec![],
4145            filter_mask: BooleanBuilder::new(),
4146            row_indices: UInt64Builder::new(),
4147            batch_ids: vec![],
4148        };
4149
4150        // Insert already prejoined non-filtered rows
4151        batches.batches.push(RecordBatch::try_new(
4152            Arc::clone(&schema),
4153            vec![
4154                Arc::new(Int32Array::from(vec![1, 1])),
4155                Arc::new(Int32Array::from(vec![10, 10])),
4156                Arc::new(Int32Array::from(vec![1, 1])),
4157                Arc::new(Int32Array::from(vec![11, 9])),
4158            ],
4159        )?);
4160
4161        batches.batches.push(RecordBatch::try_new(
4162            Arc::clone(&schema),
4163            vec![
4164                Arc::new(Int32Array::from(vec![1])),
4165                Arc::new(Int32Array::from(vec![11])),
4166                Arc::new(Int32Array::from(vec![1])),
4167                Arc::new(Int32Array::from(vec![12])),
4168            ],
4169        )?);
4170
4171        batches.batches.push(RecordBatch::try_new(
4172            Arc::clone(&schema),
4173            vec![
4174                Arc::new(Int32Array::from(vec![1, 1])),
4175                Arc::new(Int32Array::from(vec![12, 12])),
4176                Arc::new(Int32Array::from(vec![1, 1])),
4177                Arc::new(Int32Array::from(vec![11, 13])),
4178            ],
4179        )?);
4180
4181        batches.batches.push(RecordBatch::try_new(
4182            Arc::clone(&schema),
4183            vec![
4184                Arc::new(Int32Array::from(vec![1])),
4185                Arc::new(Int32Array::from(vec![13])),
4186                Arc::new(Int32Array::from(vec![1])),
4187                Arc::new(Int32Array::from(vec![12])),
4188            ],
4189        )?);
4190
4191        batches.batches.push(RecordBatch::try_new(
4192            Arc::clone(&schema),
4193            vec![
4194                Arc::new(Int32Array::from(vec![1, 1])),
4195                Arc::new(Int32Array::from(vec![14, 14])),
4196                Arc::new(Int32Array::from(vec![1, 1])),
4197                Arc::new(Int32Array::from(vec![12, 11])),
4198            ],
4199        )?);
4200
4201        let streamed_indices = vec![0, 0];
4202        batches.batch_ids.extend(vec![0; streamed_indices.len()]);
4203        batches
4204            .row_indices
4205            .extend(&UInt64Array::from(streamed_indices));
4206
4207        let streamed_indices = vec![1];
4208        batches.batch_ids.extend(vec![0; streamed_indices.len()]);
4209        batches
4210            .row_indices
4211            .extend(&UInt64Array::from(streamed_indices));
4212
4213        let streamed_indices = vec![0, 0];
4214        batches.batch_ids.extend(vec![1; streamed_indices.len()]);
4215        batches
4216            .row_indices
4217            .extend(&UInt64Array::from(streamed_indices));
4218
4219        let streamed_indices = vec![0];
4220        batches.batch_ids.extend(vec![2; streamed_indices.len()]);
4221        batches
4222            .row_indices
4223            .extend(&UInt64Array::from(streamed_indices));
4224
4225        let streamed_indices = vec![0, 0];
4226        batches.batch_ids.extend(vec![3; streamed_indices.len()]);
4227        batches
4228            .row_indices
4229            .extend(&UInt64Array::from(streamed_indices));
4230
4231        batches
4232            .filter_mask
4233            .extend(&BooleanArray::from(vec![true, false]));
4234        batches.filter_mask.extend(&BooleanArray::from(vec![true]));
4235        batches
4236            .filter_mask
4237            .extend(&BooleanArray::from(vec![false, true]));
4238        batches.filter_mask.extend(&BooleanArray::from(vec![false]));
4239        batches
4240            .filter_mask
4241            .extend(&BooleanArray::from(vec![false, false]));
4242
4243        Ok(batches)
4244    }
4245
4246    #[tokio::test]
4247    async fn test_left_outer_join_filtered_mask() -> Result<()> {
4248        let mut joined_batches = build_joined_record_batches()?;
4249        let schema = joined_batches.batches.first().unwrap().schema();
4250
4251        let output = concat_batches(&schema, &joined_batches.batches)?;
4252        let out_mask = joined_batches.filter_mask.finish();
4253        let out_indices = joined_batches.row_indices.finish();
4254
4255        assert_eq!(
4256            get_corrected_filter_mask(
4257                Left,
4258                &UInt64Array::from(vec![0]),
4259                &[0usize],
4260                &BooleanArray::from(vec![true]),
4261                output.num_rows()
4262            )
4263            .unwrap(),
4264            BooleanArray::from(vec![
4265                true, false, false, false, false, false, false, false
4266            ])
4267        );
4268
4269        assert_eq!(
4270            get_corrected_filter_mask(
4271                Left,
4272                &UInt64Array::from(vec![0]),
4273                &[0usize],
4274                &BooleanArray::from(vec![false]),
4275                output.num_rows()
4276            )
4277            .unwrap(),
4278            BooleanArray::from(vec![
4279                false, false, false, false, false, false, false, false
4280            ])
4281        );
4282
4283        assert_eq!(
4284            get_corrected_filter_mask(
4285                Left,
4286                &UInt64Array::from(vec![0, 0]),
4287                &[0usize; 2],
4288                &BooleanArray::from(vec![true, true]),
4289                output.num_rows()
4290            )
4291            .unwrap(),
4292            BooleanArray::from(vec![
4293                true, true, false, false, false, false, false, false
4294            ])
4295        );
4296
4297        assert_eq!(
4298            get_corrected_filter_mask(
4299                Left,
4300                &UInt64Array::from(vec![0, 0, 0]),
4301                &[0usize; 3],
4302                &BooleanArray::from(vec![true, true, true]),
4303                output.num_rows()
4304            )
4305            .unwrap(),
4306            BooleanArray::from(vec![true, true, true, false, false, false, false, false])
4307        );
4308
4309        assert_eq!(
4310            get_corrected_filter_mask(
4311                Left,
4312                &UInt64Array::from(vec![0, 0, 0]),
4313                &[0usize; 3],
4314                &BooleanArray::from(vec![true, false, true]),
4315                output.num_rows()
4316            )
4317            .unwrap(),
4318            BooleanArray::from(vec![
4319                Some(true),
4320                None,
4321                Some(true),
4322                Some(false),
4323                Some(false),
4324                Some(false),
4325                Some(false),
4326                Some(false)
4327            ])
4328        );
4329
4330        assert_eq!(
4331            get_corrected_filter_mask(
4332                Left,
4333                &UInt64Array::from(vec![0, 0, 0]),
4334                &[0usize; 3],
4335                &BooleanArray::from(vec![false, false, true]),
4336                output.num_rows()
4337            )
4338            .unwrap(),
4339            BooleanArray::from(vec![
4340                None,
4341                None,
4342                Some(true),
4343                Some(false),
4344                Some(false),
4345                Some(false),
4346                Some(false),
4347                Some(false)
4348            ])
4349        );
4350
4351        assert_eq!(
4352            get_corrected_filter_mask(
4353                Left,
4354                &UInt64Array::from(vec![0, 0, 0]),
4355                &[0usize; 3],
4356                &BooleanArray::from(vec![false, true, true]),
4357                output.num_rows()
4358            )
4359            .unwrap(),
4360            BooleanArray::from(vec![
4361                None,
4362                Some(true),
4363                Some(true),
4364                Some(false),
4365                Some(false),
4366                Some(false),
4367                Some(false),
4368                Some(false)
4369            ])
4370        );
4371
4372        assert_eq!(
4373            get_corrected_filter_mask(
4374                Left,
4375                &UInt64Array::from(vec![0, 0, 0]),
4376                &[0usize; 3],
4377                &BooleanArray::from(vec![false, false, false]),
4378                output.num_rows()
4379            )
4380            .unwrap(),
4381            BooleanArray::from(vec![
4382                None,
4383                None,
4384                Some(false),
4385                Some(false),
4386                Some(false),
4387                Some(false),
4388                Some(false),
4389                Some(false)
4390            ])
4391        );
4392
4393        let corrected_mask = get_corrected_filter_mask(
4394            Left,
4395            &out_indices,
4396            &joined_batches.batch_ids,
4397            &out_mask,
4398            output.num_rows(),
4399        )
4400        .unwrap();
4401
4402        assert_eq!(
4403            corrected_mask,
4404            BooleanArray::from(vec![
4405                Some(true),
4406                None,
4407                Some(true),
4408                None,
4409                Some(true),
4410                Some(false),
4411                None,
4412                Some(false)
4413            ])
4414        );
4415
4416        let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
4417
4418        assert_batches_eq!(
4419            &[
4420                "+---+----+---+----+",
4421                "| a | b  | x | y  |",
4422                "+---+----+---+----+",
4423                "| 1 | 10 | 1 | 11 |",
4424                "| 1 | 11 | 1 | 12 |",
4425                "| 1 | 12 | 1 | 13 |",
4426                "+---+----+---+----+",
4427            ],
4428            &[filtered_rb]
4429        );
4430
4431        // output null rows
4432
4433        let null_mask = arrow::compute::not(&corrected_mask)?;
4434        assert_eq!(
4435            null_mask,
4436            BooleanArray::from(vec![
4437                Some(false),
4438                None,
4439                Some(false),
4440                None,
4441                Some(false),
4442                Some(true),
4443                None,
4444                Some(true)
4445            ])
4446        );
4447
4448        let null_joined_batch = filter_record_batch(&output, &null_mask)?;
4449
4450        assert_batches_eq!(
4451            &[
4452                "+---+----+---+----+",
4453                "| a | b  | x | y  |",
4454                "+---+----+---+----+",
4455                "| 1 | 13 | 1 | 12 |",
4456                "| 1 | 14 | 1 | 11 |",
4457                "+---+----+---+----+",
4458            ],
4459            &[null_joined_batch]
4460        );
4461        Ok(())
4462    }
4463
4464    #[tokio::test]
4465    async fn test_left_semi_join_filtered_mask() -> Result<()> {
4466        let mut joined_batches = build_joined_record_batches()?;
4467        let schema = joined_batches.batches.first().unwrap().schema();
4468
4469        let output = concat_batches(&schema, &joined_batches.batches)?;
4470        let out_mask = joined_batches.filter_mask.finish();
4471        let out_indices = joined_batches.row_indices.finish();
4472
4473        assert_eq!(
4474            get_corrected_filter_mask(
4475                LeftSemi,
4476                &UInt64Array::from(vec![0]),
4477                &[0usize],
4478                &BooleanArray::from(vec![true]),
4479                output.num_rows()
4480            )
4481            .unwrap(),
4482            BooleanArray::from(vec![true])
4483        );
4484
4485        assert_eq!(
4486            get_corrected_filter_mask(
4487                LeftSemi,
4488                &UInt64Array::from(vec![0]),
4489                &[0usize],
4490                &BooleanArray::from(vec![false]),
4491                output.num_rows()
4492            )
4493            .unwrap(),
4494            BooleanArray::from(vec![None])
4495        );
4496
4497        assert_eq!(
4498            get_corrected_filter_mask(
4499                LeftSemi,
4500                &UInt64Array::from(vec![0, 0]),
4501                &[0usize; 2],
4502                &BooleanArray::from(vec![true, true]),
4503                output.num_rows()
4504            )
4505            .unwrap(),
4506            BooleanArray::from(vec![Some(true), None])
4507        );
4508
4509        assert_eq!(
4510            get_corrected_filter_mask(
4511                LeftSemi,
4512                &UInt64Array::from(vec![0, 0, 0]),
4513                &[0usize; 3],
4514                &BooleanArray::from(vec![true, true, true]),
4515                output.num_rows()
4516            )
4517            .unwrap(),
4518            BooleanArray::from(vec![Some(true), None, None])
4519        );
4520
4521        assert_eq!(
4522            get_corrected_filter_mask(
4523                LeftSemi,
4524                &UInt64Array::from(vec![0, 0, 0]),
4525                &[0usize; 3],
4526                &BooleanArray::from(vec![true, false, true]),
4527                output.num_rows()
4528            )
4529            .unwrap(),
4530            BooleanArray::from(vec![Some(true), None, None])
4531        );
4532
4533        assert_eq!(
4534            get_corrected_filter_mask(
4535                LeftSemi,
4536                &UInt64Array::from(vec![0, 0, 0]),
4537                &[0usize; 3],
4538                &BooleanArray::from(vec![false, false, true]),
4539                output.num_rows()
4540            )
4541            .unwrap(),
4542            BooleanArray::from(vec![None, None, Some(true),])
4543        );
4544
4545        assert_eq!(
4546            get_corrected_filter_mask(
4547                LeftSemi,
4548                &UInt64Array::from(vec![0, 0, 0]),
4549                &[0usize; 3],
4550                &BooleanArray::from(vec![false, true, true]),
4551                output.num_rows()
4552            )
4553            .unwrap(),
4554            BooleanArray::from(vec![None, Some(true), None])
4555        );
4556
4557        assert_eq!(
4558            get_corrected_filter_mask(
4559                LeftSemi,
4560                &UInt64Array::from(vec![0, 0, 0]),
4561                &[0usize; 3],
4562                &BooleanArray::from(vec![false, false, false]),
4563                output.num_rows()
4564            )
4565            .unwrap(),
4566            BooleanArray::from(vec![None, None, None])
4567        );
4568
4569        let corrected_mask = get_corrected_filter_mask(
4570            LeftSemi,
4571            &out_indices,
4572            &joined_batches.batch_ids,
4573            &out_mask,
4574            output.num_rows(),
4575        )
4576        .unwrap();
4577
4578        assert_eq!(
4579            corrected_mask,
4580            BooleanArray::from(vec![
4581                Some(true),
4582                None,
4583                Some(true),
4584                None,
4585                Some(true),
4586                None,
4587                None,
4588                None
4589            ])
4590        );
4591
4592        let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
4593
4594        assert_batches_eq!(
4595            &[
4596                "+---+----+---+----+",
4597                "| a | b  | x | y  |",
4598                "+---+----+---+----+",
4599                "| 1 | 10 | 1 | 11 |",
4600                "| 1 | 11 | 1 | 12 |",
4601                "| 1 | 12 | 1 | 13 |",
4602                "+---+----+---+----+",
4603            ],
4604            &[filtered_rb]
4605        );
4606
4607        // output null rows
4608        let null_mask = arrow::compute::not(&corrected_mask)?;
4609        assert_eq!(
4610            null_mask,
4611            BooleanArray::from(vec![
4612                Some(false),
4613                None,
4614                Some(false),
4615                None,
4616                Some(false),
4617                None,
4618                None,
4619                None
4620            ])
4621        );
4622
4623        let null_joined_batch = filter_record_batch(&output, &null_mask)?;
4624
4625        assert_batches_eq!(
4626            &[
4627                "+---+---+---+---+",
4628                "| a | b | x | y |",
4629                "+---+---+---+---+",
4630                "+---+---+---+---+",
4631            ],
4632            &[null_joined_batch]
4633        );
4634        Ok(())
4635    }
4636
4637    #[tokio::test]
4638    async fn test_anti_join_filtered_mask() -> Result<()> {
4639        for join_type in [LeftAnti, RightAnti] {
4640            let mut joined_batches = build_joined_record_batches()?;
4641            let schema = joined_batches.batches.first().unwrap().schema();
4642
4643            let output = concat_batches(&schema, &joined_batches.batches)?;
4644            let out_mask = joined_batches.filter_mask.finish();
4645            let out_indices = joined_batches.row_indices.finish();
4646
4647            assert_eq!(
4648                get_corrected_filter_mask(
4649                    join_type,
4650                    &UInt64Array::from(vec![0]),
4651                    &[0usize],
4652                    &BooleanArray::from(vec![true]),
4653                    1
4654                )
4655                .unwrap(),
4656                BooleanArray::from(vec![None])
4657            );
4658
4659            assert_eq!(
4660                get_corrected_filter_mask(
4661                    join_type,
4662                    &UInt64Array::from(vec![0]),
4663                    &[0usize],
4664                    &BooleanArray::from(vec![false]),
4665                    1
4666                )
4667                .unwrap(),
4668                BooleanArray::from(vec![Some(true)])
4669            );
4670
4671            assert_eq!(
4672                get_corrected_filter_mask(
4673                    join_type,
4674                    &UInt64Array::from(vec![0, 0]),
4675                    &[0usize; 2],
4676                    &BooleanArray::from(vec![true, true]),
4677                    2
4678                )
4679                .unwrap(),
4680                BooleanArray::from(vec![None, None])
4681            );
4682
4683            assert_eq!(
4684                get_corrected_filter_mask(
4685                    join_type,
4686                    &UInt64Array::from(vec![0, 0, 0]),
4687                    &[0usize; 3],
4688                    &BooleanArray::from(vec![true, true, true]),
4689                    3
4690                )
4691                .unwrap(),
4692                BooleanArray::from(vec![None, None, None])
4693            );
4694
4695            assert_eq!(
4696                get_corrected_filter_mask(
4697                    join_type,
4698                    &UInt64Array::from(vec![0, 0, 0]),
4699                    &[0usize; 3],
4700                    &BooleanArray::from(vec![true, false, true]),
4701                    3
4702                )
4703                .unwrap(),
4704                BooleanArray::from(vec![None, None, None])
4705            );
4706
4707            assert_eq!(
4708                get_corrected_filter_mask(
4709                    join_type,
4710                    &UInt64Array::from(vec![0, 0, 0]),
4711                    &[0usize; 3],
4712                    &BooleanArray::from(vec![false, false, true]),
4713                    3
4714                )
4715                .unwrap(),
4716                BooleanArray::from(vec![None, None, None])
4717            );
4718
4719            assert_eq!(
4720                get_corrected_filter_mask(
4721                    join_type,
4722                    &UInt64Array::from(vec![0, 0, 0]),
4723                    &[0usize; 3],
4724                    &BooleanArray::from(vec![false, true, true]),
4725                    3
4726                )
4727                .unwrap(),
4728                BooleanArray::from(vec![None, None, None])
4729            );
4730
4731            assert_eq!(
4732                get_corrected_filter_mask(
4733                    join_type,
4734                    &UInt64Array::from(vec![0, 0, 0]),
4735                    &[0usize; 3],
4736                    &BooleanArray::from(vec![false, false, false]),
4737                    3
4738                )
4739                .unwrap(),
4740                BooleanArray::from(vec![None, None, Some(true)])
4741            );
4742
4743            let corrected_mask = get_corrected_filter_mask(
4744                join_type,
4745                &out_indices,
4746                &joined_batches.batch_ids,
4747                &out_mask,
4748                output.num_rows(),
4749            )
4750            .unwrap();
4751
4752            assert_eq!(
4753                corrected_mask,
4754                BooleanArray::from(vec![
4755                    None,
4756                    None,
4757                    None,
4758                    None,
4759                    None,
4760                    Some(true),
4761                    None,
4762                    Some(true)
4763                ])
4764            );
4765
4766            let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
4767
4768            assert_batches_eq!(
4769                &[
4770                    "+---+----+---+----+",
4771                    "| a | b  | x | y  |",
4772                    "+---+----+---+----+",
4773                    "| 1 | 13 | 1 | 12 |",
4774                    "| 1 | 14 | 1 | 11 |",
4775                    "+---+----+---+----+",
4776                ],
4777                &[filtered_rb]
4778            );
4779
4780            // output null rows
4781            let null_mask = arrow::compute::not(&corrected_mask)?;
4782            assert_eq!(
4783                null_mask,
4784                BooleanArray::from(vec![
4785                    None,
4786                    None,
4787                    None,
4788                    None,
4789                    None,
4790                    Some(false),
4791                    None,
4792                    Some(false),
4793                ])
4794            );
4795
4796            let null_joined_batch = filter_record_batch(&output, &null_mask)?;
4797
4798            assert_batches_eq!(
4799                &[
4800                    "+---+---+---+---+",
4801                    "| a | b | x | y |",
4802                    "+---+---+---+---+",
4803                    "+---+---+---+---+",
4804                ],
4805                &[null_joined_batch]
4806            );
4807        }
4808        Ok(())
4809    }
4810
4811    /// Returns the column names on the schema
4812    fn columns(schema: &Schema) -> Vec<String> {
4813        schema.fields().iter().map(|f| f.name().clone()).collect()
4814    }
4815}