datafusion_physical_plan/joins/
nested_loop_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`NestedLoopJoinExec`]: joins without equijoin (equality predicates).
19
20use std::any::Any;
21use std::fmt::Formatter;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::sync::Arc;
24use std::task::Poll;
25
26use super::utils::{
27    asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap,
28    need_produce_result_in_final, reorder_output_after_swap, swap_join_projection,
29    BatchSplitter, BatchTransformer, NoopBatchTransformer, StatefulStreamResult,
30};
31use crate::coalesce_partitions::CoalescePartitionsExec;
32use crate::common::can_project;
33use crate::execution_plan::{boundedness_from_children, EmissionType};
34use crate::joins::utils::{
35    adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices,
36    build_join_schema, check_join_is_valid, estimate_join_statistics,
37    BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut,
38};
39use crate::joins::SharedBitmapBuilder;
40use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
41use crate::projection::{
42    try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData,
43    ProjectionExec,
44};
45use crate::{
46    handle_state, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
47    ExecutionPlanProperties, PlanProperties, RecordBatchStream,
48    SendableRecordBatchStream,
49};
50
51use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array};
52use arrow::compute::concat_batches;
53use arrow::datatypes::{Schema, SchemaRef};
54use arrow::record_batch::RecordBatch;
55use datafusion_common::{
56    exec_datafusion_err, internal_err, project_schema, JoinSide, Result, Statistics,
57};
58use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
59use datafusion_execution::TaskContext;
60use datafusion_expr::JoinType;
61use datafusion_physical_expr::equivalence::{
62    join_equivalence_properties, ProjectionMapping,
63};
64
65use futures::{ready, Stream, StreamExt, TryStreamExt};
66use parking_lot::Mutex;
67
68/// Left (build-side) data
69struct JoinLeftData {
70    /// Build-side data collected to single batch
71    batch: RecordBatch,
72    /// Shared bitmap builder for visited left indices
73    bitmap: SharedBitmapBuilder,
74    /// Counter of running probe-threads, potentially able to update `bitmap`
75    probe_threads_counter: AtomicUsize,
76    /// Memory reservation for tracking batch and bitmap
77    /// Cleared on `JoinLeftData` drop
78    _reservation: MemoryReservation,
79}
80
81impl JoinLeftData {
82    fn new(
83        batch: RecordBatch,
84        bitmap: SharedBitmapBuilder,
85        probe_threads_counter: AtomicUsize,
86        _reservation: MemoryReservation,
87    ) -> Self {
88        Self {
89            batch,
90            bitmap,
91            probe_threads_counter,
92            _reservation,
93        }
94    }
95
96    fn batch(&self) -> &RecordBatch {
97        &self.batch
98    }
99
100    fn bitmap(&self) -> &SharedBitmapBuilder {
101        &self.bitmap
102    }
103
104    /// Decrements counter of running threads, and returns `true`
105    /// if caller is the last running thread
106    fn report_probe_completed(&self) -> bool {
107        self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
108    }
109}
110
111#[allow(rustdoc::private_intra_doc_links)]
112/// NestedLoopJoinExec is build-probe join operator, whose main task is to
113/// perform joins without any equijoin conditions in `ON` clause.
114///
115/// Execution consists of following phases:
116///
117/// #### 1. Build phase
118/// Collecting build-side data in memory, by polling all available data from build-side input.
119/// Due to the absence of equijoin conditions, it's not possible to partition build-side data
120/// across multiple threads of the operator, so build-side is always collected in a single
121/// batch shared across all threads.
122/// The operator always considers LEFT input as build-side input, so it's crucial to adjust
123/// smaller input to be the LEFT one. Normally this selection is handled by physical optimizer.
124///
125/// #### 2. Probe phase
126/// Sequentially polling batches from the probe-side input and processing them according to the
127/// following logic:
128/// - apply join filter (`ON` clause) to Cartesian product of probe batch and build side data
129///   -- filter evaluation is executed once per build-side data row
130/// - update shared bitmap of joined ("visited") build-side row indices, if required -- allows
131///   to produce unmatched build-side data in case of e.g. LEFT/FULL JOIN after probing phase
132///   completed
133/// - perform join index alignment is required -- depending on `JoinType`
134/// - produce output join batch
135///
136/// Probing phase is executed in parallel, according to probe-side input partitioning -- one
137/// thread per partition. After probe input is exhausted, each thread **ATTEMPTS** to produce
138/// unmatched build-side data.
139///
140/// #### 3. Producing unmatched build-side data
141/// Producing unmatched build-side data as an output batch, after probe input is exhausted.
142/// This step is also executed in parallel (once per probe input partition), and to avoid
143/// duplicate output of unmatched data (due to shared nature build-side data), each thread
144/// "reports" about probe phase completion (which means that "visited" bitmap won't be
145/// updated anymore), and only the last thread, reporting about completion, will return output.
146///
147/// # Clone / Shared State
148///
149/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
150/// loading of the left side with the processing in each output stream.
151/// Therefore it can not be [`Clone`]
152#[derive(Debug)]
153pub struct NestedLoopJoinExec {
154    /// left side
155    pub(crate) left: Arc<dyn ExecutionPlan>,
156    /// right side
157    pub(crate) right: Arc<dyn ExecutionPlan>,
158    /// Filters which are applied while finding matching rows
159    pub(crate) filter: Option<JoinFilter>,
160    /// How the join is performed
161    pub(crate) join_type: JoinType,
162    /// The schema once the join is applied
163    join_schema: SchemaRef,
164    /// Future that consumes left input and buffers it in memory
165    ///
166    /// This structure is *shared* across all output streams.
167    ///
168    /// Each output stream waits on the `OnceAsync` to signal the completion of
169    /// the hash table creation.
170    inner_table: OnceAsync<JoinLeftData>,
171    /// Information of index and left / right placement of columns
172    column_indices: Vec<ColumnIndex>,
173    /// Projection to apply to the output of the join
174    projection: Option<Vec<usize>>,
175
176    /// Execution metrics
177    metrics: ExecutionPlanMetricsSet,
178    /// Cache holding plan properties like equivalences, output partitioning etc.
179    cache: PlanProperties,
180}
181
182impl NestedLoopJoinExec {
183    /// Try to create a new [`NestedLoopJoinExec`]
184    pub fn try_new(
185        left: Arc<dyn ExecutionPlan>,
186        right: Arc<dyn ExecutionPlan>,
187        filter: Option<JoinFilter>,
188        join_type: &JoinType,
189        projection: Option<Vec<usize>>,
190    ) -> Result<Self> {
191        let left_schema = left.schema();
192        let right_schema = right.schema();
193        check_join_is_valid(&left_schema, &right_schema, &[])?;
194        let (join_schema, column_indices) =
195            build_join_schema(&left_schema, &right_schema, join_type);
196        let join_schema = Arc::new(join_schema);
197        let cache = Self::compute_properties(
198            &left,
199            &right,
200            Arc::clone(&join_schema),
201            *join_type,
202            projection.as_ref(),
203        )?;
204
205        Ok(NestedLoopJoinExec {
206            left,
207            right,
208            filter,
209            join_type: *join_type,
210            join_schema,
211            inner_table: Default::default(),
212            column_indices,
213            projection,
214            metrics: Default::default(),
215            cache,
216        })
217    }
218
219    /// left side
220    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
221        &self.left
222    }
223
224    /// right side
225    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
226        &self.right
227    }
228
229    /// Filters applied before join output
230    pub fn filter(&self) -> Option<&JoinFilter> {
231        self.filter.as_ref()
232    }
233
234    /// How the join is performed
235    pub fn join_type(&self) -> &JoinType {
236        &self.join_type
237    }
238
239    pub fn projection(&self) -> Option<&Vec<usize>> {
240        self.projection.as_ref()
241    }
242
243    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
244    fn compute_properties(
245        left: &Arc<dyn ExecutionPlan>,
246        right: &Arc<dyn ExecutionPlan>,
247        schema: SchemaRef,
248        join_type: JoinType,
249        projection: Option<&Vec<usize>>,
250    ) -> Result<PlanProperties> {
251        // Calculate equivalence properties:
252        let mut eq_properties = join_equivalence_properties(
253            left.equivalence_properties().clone(),
254            right.equivalence_properties().clone(),
255            &join_type,
256            Arc::clone(&schema),
257            &Self::maintains_input_order(join_type),
258            None,
259            // No on columns in nested loop join
260            &[],
261        );
262
263        let mut output_partitioning =
264            asymmetric_join_output_partitioning(left, right, &join_type);
265
266        let emission_type = if left.boundedness().is_unbounded() {
267            EmissionType::Final
268        } else if right.pipeline_behavior() == EmissionType::Incremental {
269            match join_type {
270                // If we only need to generate matched rows from the probe side,
271                // we can emit rows incrementally.
272                JoinType::Inner
273                | JoinType::LeftSemi
274                | JoinType::RightSemi
275                | JoinType::Right
276                | JoinType::RightAnti => EmissionType::Incremental,
277                // If we need to generate unmatched rows from the *build side*,
278                // we need to emit them at the end.
279                JoinType::Left
280                | JoinType::LeftAnti
281                | JoinType::LeftMark
282                | JoinType::Full => EmissionType::Both,
283            }
284        } else {
285            right.pipeline_behavior()
286        };
287
288        if let Some(projection) = projection {
289            // construct a map from the input expressions to the output expression of the Projection
290            let projection_mapping =
291                ProjectionMapping::from_indices(projection, &schema)?;
292            let out_schema = project_schema(&schema, Some(projection))?;
293            output_partitioning =
294                output_partitioning.project(&projection_mapping, &eq_properties);
295            eq_properties = eq_properties.project(&projection_mapping, out_schema);
296        }
297
298        Ok(PlanProperties::new(
299            eq_properties,
300            output_partitioning,
301            emission_type,
302            boundedness_from_children([left, right]),
303        ))
304    }
305
306    /// Returns a vector indicating whether the left and right inputs maintain their order.
307    /// The first element corresponds to the left input, and the second to the right.
308    ///
309    /// The left (build-side) input's order may change, but the right (probe-side) input's
310    /// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins.
311    ///
312    /// Maintaining the right input's order helps optimize the nodes down the pipeline
313    /// (See [`ExecutionPlan::maintains_input_order`]).
314    ///
315    /// This is a separate method because it is also called when computing properties, before
316    /// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as
317    /// opposed to `Self`, for the same reason.
318    fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
319        vec![
320            false,
321            matches!(
322                join_type,
323                JoinType::Inner
324                    | JoinType::Right
325                    | JoinType::RightAnti
326                    | JoinType::RightSemi
327            ),
328        ]
329    }
330
331    pub fn contains_projection(&self) -> bool {
332        self.projection.is_some()
333    }
334
335    pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
336        // check if the projection is valid
337        can_project(&self.schema(), projection.as_ref())?;
338        let projection = match projection {
339            Some(projection) => match &self.projection {
340                Some(p) => Some(projection.iter().map(|i| p[*i]).collect()),
341                None => Some(projection),
342            },
343            None => None,
344        };
345        Self::try_new(
346            Arc::clone(&self.left),
347            Arc::clone(&self.right),
348            self.filter.clone(),
349            &self.join_type,
350            projection,
351        )
352    }
353
354    /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left
355    /// and right inputs swapped.
356    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
357        let left = self.left();
358        let right = self.right();
359        let new_join = NestedLoopJoinExec::try_new(
360            Arc::clone(right),
361            Arc::clone(left),
362            self.filter().map(JoinFilter::swap),
363            &self.join_type().swap(),
364            swap_join_projection(
365                left.schema().fields().len(),
366                right.schema().fields().len(),
367                self.projection.as_ref(),
368                self.join_type(),
369            ),
370        )?;
371
372        // For Semi/Anti joins, swap result will produce same output schema,
373        // no need to wrap them into additional projection
374        let plan: Arc<dyn ExecutionPlan> = if matches!(
375            self.join_type(),
376            JoinType::LeftSemi
377                | JoinType::RightSemi
378                | JoinType::LeftAnti
379                | JoinType::RightAnti
380        ) || self.projection.is_some()
381        {
382            Arc::new(new_join)
383        } else {
384            reorder_output_after_swap(
385                Arc::new(new_join),
386                &self.left().schema(),
387                &self.right().schema(),
388            )?
389        };
390
391        Ok(plan)
392    }
393}
394
395impl DisplayAs for NestedLoopJoinExec {
396    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
397        match t {
398            DisplayFormatType::Default | DisplayFormatType::Verbose => {
399                let display_filter = self.filter.as_ref().map_or_else(
400                    || "".to_string(),
401                    |f| format!(", filter={}", f.expression()),
402                );
403                let display_projections = if self.contains_projection() {
404                    format!(
405                        ", projection=[{}]",
406                        self.projection
407                            .as_ref()
408                            .unwrap()
409                            .iter()
410                            .map(|index| format!(
411                                "{}@{}",
412                                self.join_schema.fields().get(*index).unwrap().name(),
413                                index
414                            ))
415                            .collect::<Vec<_>>()
416                            .join(", ")
417                    )
418                } else {
419                    "".to_string()
420                };
421                write!(
422                    f,
423                    "NestedLoopJoinExec: join_type={:?}{}{}",
424                    self.join_type, display_filter, display_projections
425                )
426            }
427        }
428    }
429}
430
431impl ExecutionPlan for NestedLoopJoinExec {
432    fn name(&self) -> &'static str {
433        "NestedLoopJoinExec"
434    }
435
436    fn as_any(&self) -> &dyn Any {
437        self
438    }
439
440    fn properties(&self) -> &PlanProperties {
441        &self.cache
442    }
443
444    fn required_input_distribution(&self) -> Vec<Distribution> {
445        vec![
446            Distribution::SinglePartition,
447            Distribution::UnspecifiedDistribution,
448        ]
449    }
450
451    fn maintains_input_order(&self) -> Vec<bool> {
452        Self::maintains_input_order(self.join_type)
453    }
454
455    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
456        vec![&self.left, &self.right]
457    }
458
459    fn with_new_children(
460        self: Arc<Self>,
461        children: Vec<Arc<dyn ExecutionPlan>>,
462    ) -> Result<Arc<dyn ExecutionPlan>> {
463        Ok(Arc::new(NestedLoopJoinExec::try_new(
464            Arc::clone(&children[0]),
465            Arc::clone(&children[1]),
466            self.filter.clone(),
467            &self.join_type,
468            self.projection.clone(),
469        )?))
470    }
471
472    fn execute(
473        &self,
474        partition: usize,
475        context: Arc<TaskContext>,
476    ) -> Result<SendableRecordBatchStream> {
477        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
478
479        // Initialization reservation for load of inner table
480        let load_reservation =
481            MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
482                .register(context.memory_pool());
483
484        let inner_table = self.inner_table.once(|| {
485            collect_left_input(
486                Arc::clone(&self.left),
487                Arc::clone(&context),
488                join_metrics.clone(),
489                load_reservation,
490                need_produce_result_in_final(self.join_type),
491                self.right().output_partitioning().partition_count(),
492            )
493        });
494
495        let batch_size = context.session_config().batch_size();
496        let enforce_batch_size_in_joins =
497            context.session_config().enforce_batch_size_in_joins();
498
499        let outer_table = self.right.execute(partition, context)?;
500
501        let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0));
502
503        // Right side has an order and it is maintained during operation.
504        let right_side_ordered =
505            self.maintains_input_order()[1] && self.right.output_ordering().is_some();
506
507        // update column indices to reflect the projection
508        let column_indices_after_projection = match &self.projection {
509            Some(projection) => projection
510                .iter()
511                .map(|i| self.column_indices[*i].clone())
512                .collect(),
513            None => self.column_indices.clone(),
514        };
515
516        if enforce_batch_size_in_joins {
517            Ok(Box::pin(NestedLoopJoinStream {
518                schema: self.schema(),
519                filter: self.filter.clone(),
520                join_type: self.join_type,
521                outer_table,
522                inner_table,
523                column_indices: column_indices_after_projection,
524                join_metrics,
525                indices_cache,
526                right_side_ordered,
527                state: NestedLoopJoinStreamState::WaitBuildSide,
528                batch_transformer: BatchSplitter::new(batch_size),
529                left_data: None,
530            }))
531        } else {
532            Ok(Box::pin(NestedLoopJoinStream {
533                schema: self.schema(),
534                filter: self.filter.clone(),
535                join_type: self.join_type,
536                outer_table,
537                inner_table,
538                column_indices: column_indices_after_projection,
539                join_metrics,
540                indices_cache,
541                right_side_ordered,
542                state: NestedLoopJoinStreamState::WaitBuildSide,
543                batch_transformer: NoopBatchTransformer::new(),
544                left_data: None,
545            }))
546        }
547    }
548
549    fn metrics(&self) -> Option<MetricsSet> {
550        Some(self.metrics.clone_inner())
551    }
552
553    fn statistics(&self) -> Result<Statistics> {
554        estimate_join_statistics(
555            Arc::clone(&self.left),
556            Arc::clone(&self.right),
557            vec![],
558            &self.join_type,
559            &self.join_schema,
560        )
561    }
562
563    /// Tries to push `projection` down through `nested_loop_join`. If possible, performs the
564    /// pushdown and returns a new [`NestedLoopJoinExec`] as the top plan which has projections
565    /// as its children. Otherwise, returns `None`.
566    fn try_swapping_with_projection(
567        &self,
568        projection: &ProjectionExec,
569    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
570        // TODO: currently if there is projection in NestedLoopJoinExec, we can't push down projection to left or right input. Maybe we can pushdown the mixed projection later.
571        if self.contains_projection() {
572            return Ok(None);
573        }
574
575        if let Some(JoinData {
576            projected_left_child,
577            projected_right_child,
578            join_filter,
579            ..
580        }) = try_pushdown_through_join(
581            projection,
582            self.left(),
583            self.right(),
584            &[],
585            self.schema(),
586            self.filter(),
587        )? {
588            Ok(Some(Arc::new(NestedLoopJoinExec::try_new(
589                Arc::new(projected_left_child),
590                Arc::new(projected_right_child),
591                join_filter,
592                self.join_type(),
593                // Returned early if projection is not None
594                None,
595            )?)))
596        } else {
597            try_embed_projection(projection, self)
598        }
599    }
600}
601
602/// Asynchronously collect input into a single batch, and creates `JoinLeftData` from it
603async fn collect_left_input(
604    input: Arc<dyn ExecutionPlan>,
605    context: Arc<TaskContext>,
606    join_metrics: BuildProbeJoinMetrics,
607    reservation: MemoryReservation,
608    with_visited_left_side: bool,
609    probe_threads_count: usize,
610) -> Result<JoinLeftData> {
611    let schema = input.schema();
612    let merge = if input.output_partitioning().partition_count() != 1 {
613        Arc::new(CoalescePartitionsExec::new(input))
614    } else {
615        input
616    };
617    let stream = merge.execute(0, context)?;
618
619    // Load all batches and count the rows
620    let (batches, metrics, mut reservation) = stream
621        .try_fold(
622            (Vec::new(), join_metrics, reservation),
623            |(mut batches, metrics, mut reservation), batch| async {
624                let batch_size = batch.get_array_memory_size();
625                // Reserve memory for incoming batch
626                reservation.try_grow(batch_size)?;
627                // Update metrics
628                metrics.build_mem_used.add(batch_size);
629                metrics.build_input_batches.add(1);
630                metrics.build_input_rows.add(batch.num_rows());
631                // Push batch to output
632                batches.push(batch);
633                Ok((batches, metrics, reservation))
634            },
635        )
636        .await?;
637
638    let merged_batch = concat_batches(&schema, &batches)?;
639
640    // Reserve memory for visited_left_side bitmap if required by join type
641    let visited_left_side = if with_visited_left_side {
642        let n_rows = merged_batch.num_rows();
643        let buffer_size = n_rows.div_ceil(8);
644        reservation.try_grow(buffer_size)?;
645        metrics.build_mem_used.add(buffer_size);
646
647        let mut buffer = BooleanBufferBuilder::new(n_rows);
648        buffer.append_n(n_rows, false);
649        buffer
650    } else {
651        BooleanBufferBuilder::new(0)
652    };
653
654    Ok(JoinLeftData::new(
655        merged_batch,
656        Mutex::new(visited_left_side),
657        AtomicUsize::new(probe_threads_count),
658        reservation,
659    ))
660}
661
662/// This enumeration represents various states of the nested loop join algorithm.
663#[derive(Debug, Clone)]
664enum NestedLoopJoinStreamState {
665    /// The initial state, indicating that build-side data not collected yet
666    WaitBuildSide,
667    /// Indicates that build-side has been collected, and stream is ready for
668    /// fetching probe-side
669    FetchProbeBatch,
670    /// Indicates that a non-empty batch has been fetched from probe-side, and
671    /// is ready to be processed
672    ProcessProbeBatch(RecordBatch),
673    /// Indicates that probe-side has been fully processed
674    ExhaustedProbeSide,
675    /// Indicates that NestedLoopJoinStream execution is completed
676    Completed,
677}
678
679impl NestedLoopJoinStreamState {
680    /// Tries to extract a `ProcessProbeBatchState` from the
681    /// `NestedLoopJoinStreamState` enum. Returns an error if state is not
682    /// `ProcessProbeBatchState`.
683    fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> {
684        match self {
685            NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state),
686            _ => internal_err!("Expected join stream in ProcessProbeBatch state"),
687        }
688    }
689}
690
691/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
692struct NestedLoopJoinStream<T> {
693    /// Input schema
694    schema: Arc<Schema>,
695    /// join filter
696    filter: Option<JoinFilter>,
697    /// type of the join
698    join_type: JoinType,
699    /// the outer table data of the nested loop join
700    outer_table: SendableRecordBatchStream,
701    /// the inner table data of the nested loop join
702    inner_table: OnceFut<JoinLeftData>,
703    /// Information of index and left / right placement of columns
704    column_indices: Vec<ColumnIndex>,
705    // TODO: support null aware equal
706    // null_equals_null: bool
707    /// Join execution metrics
708    join_metrics: BuildProbeJoinMetrics,
709    /// Cache for join indices calculations
710    indices_cache: (UInt64Array, UInt32Array),
711    /// Whether the right side is ordered
712    right_side_ordered: bool,
713    /// Current state of the stream
714    state: NestedLoopJoinStreamState,
715    /// Transforms the output batch before returning.
716    batch_transformer: T,
717    /// Result of the left data future
718    left_data: Option<Arc<JoinLeftData>>,
719}
720
721/// Creates a Cartesian product of two input batches, preserving the order of the right batch,
722/// and applying a join filter if provided.
723///
724/// # Example
725/// Input:
726/// left = [0, 1], right = [0, 1, 2]
727///
728/// Output:
729/// left_indices = [0, 1, 0, 1, 0, 1], right_indices = [0, 0, 1, 1, 2, 2]
730///
731/// Input:
732/// left = [0, 1, 2], right = [0, 1, 2, 3], filter = left.a != right.a
733///
734/// Output:
735/// left_indices = [1, 2, 0, 2, 0, 1, 0, 1, 2], right_indices = [0, 0, 1, 1, 2, 2, 3, 3, 3]
736fn build_join_indices(
737    left_batch: &RecordBatch,
738    right_batch: &RecordBatch,
739    filter: Option<&JoinFilter>,
740    indices_cache: &mut (UInt64Array, UInt32Array),
741) -> Result<(UInt64Array, UInt32Array)> {
742    let left_row_count = left_batch.num_rows();
743    let right_row_count = right_batch.num_rows();
744    let output_row_count = left_row_count * right_row_count;
745
746    // We always use the same indices before applying the filter, so we can cache them
747    let (left_indices_cache, right_indices_cache) = indices_cache;
748    let cached_output_row_count = left_indices_cache.len();
749
750    let (left_indices, right_indices) =
751        match output_row_count.cmp(&cached_output_row_count) {
752            std::cmp::Ordering::Equal => {
753                // Reuse the cached indices
754                (left_indices_cache.clone(), right_indices_cache.clone())
755            }
756            std::cmp::Ordering::Less => {
757                // Left_row_count never changes because it's the build side. The changes to the
758                // right_row_count can be handled trivially by taking the first output_row_count
759                // elements of the cache because of how the indices are generated.
760                // (See the Ordering::Greater match arm)
761                (
762                    left_indices_cache.slice(0, output_row_count),
763                    right_indices_cache.slice(0, output_row_count),
764                )
765            }
766            std::cmp::Ordering::Greater => {
767                // Rebuild the indices cache
768
769                // Produces 0, 1, 2, 0, 1, 2, 0, 1, 2, ...
770                *left_indices_cache = UInt64Array::from_iter_values(
771                    (0..output_row_count as u64).map(|i| i % left_row_count as u64),
772                );
773
774                // Produces 0, 0, 0, 1, 1, 1, 2, 2, 2, ...
775                *right_indices_cache = UInt32Array::from_iter_values(
776                    (0..output_row_count as u32).map(|i| i / left_row_count as u32),
777                );
778
779                (left_indices_cache.clone(), right_indices_cache.clone())
780            }
781        };
782
783    if let Some(filter) = filter {
784        apply_join_filter_to_indices(
785            left_batch,
786            right_batch,
787            left_indices,
788            right_indices,
789            filter,
790            JoinSide::Left,
791        )
792    } else {
793        Ok((left_indices, right_indices))
794    }
795}
796
797impl<T: BatchTransformer> NestedLoopJoinStream<T> {
798    fn poll_next_impl(
799        &mut self,
800        cx: &mut std::task::Context<'_>,
801    ) -> Poll<Option<Result<RecordBatch>>> {
802        loop {
803            return match self.state {
804                NestedLoopJoinStreamState::WaitBuildSide => {
805                    handle_state!(ready!(self.collect_build_side(cx)))
806                }
807                NestedLoopJoinStreamState::FetchProbeBatch => {
808                    handle_state!(ready!(self.fetch_probe_batch(cx)))
809                }
810                NestedLoopJoinStreamState::ProcessProbeBatch(_) => {
811                    handle_state!(self.process_probe_batch())
812                }
813                NestedLoopJoinStreamState::ExhaustedProbeSide => {
814                    handle_state!(self.process_unmatched_build_batch())
815                }
816                NestedLoopJoinStreamState::Completed => Poll::Ready(None),
817            };
818        }
819    }
820
821    fn collect_build_side(
822        &mut self,
823        cx: &mut std::task::Context<'_>,
824    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
825        let build_timer = self.join_metrics.build_time.timer();
826        // build hash table from left (build) side, if not yet done
827        self.left_data = Some(ready!(self.inner_table.get_shared(cx))?);
828        build_timer.done();
829
830        self.state = NestedLoopJoinStreamState::FetchProbeBatch;
831
832        Poll::Ready(Ok(StatefulStreamResult::Continue))
833    }
834
835    /// Fetches next batch from probe-side
836    ///
837    /// If a non-empty batch has been fetched, updates state to
838    /// `ProcessProbeBatchState`, otherwise updates state to `ExhaustedProbeSide`.
839    fn fetch_probe_batch(
840        &mut self,
841        cx: &mut std::task::Context<'_>,
842    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
843        match ready!(self.outer_table.poll_next_unpin(cx)) {
844            None => {
845                self.state = NestedLoopJoinStreamState::ExhaustedProbeSide;
846            }
847            Some(Ok(right_batch)) => {
848                self.state = NestedLoopJoinStreamState::ProcessProbeBatch(right_batch);
849            }
850            Some(Err(err)) => return Poll::Ready(Err(err)),
851        };
852
853        Poll::Ready(Ok(StatefulStreamResult::Continue))
854    }
855
856    /// Joins current probe batch with build-side data and produces batch with
857    /// matched output, updates state to `FetchProbeBatch`.
858    fn process_probe_batch(
859        &mut self,
860    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
861        let Some(left_data) = self.left_data.clone() else {
862            return internal_err!(
863                "Expected left_data to be Some in ProcessProbeBatch state"
864            );
865        };
866        let visited_left_side = left_data.bitmap();
867        let batch = self.state.try_as_process_probe_batch()?;
868
869        match self.batch_transformer.next() {
870            None => {
871                // Setting up timer & updating input metrics
872                self.join_metrics.input_batches.add(1);
873                self.join_metrics.input_rows.add(batch.num_rows());
874                let timer = self.join_metrics.join_time.timer();
875
876                let result = join_left_and_right_batch(
877                    left_data.batch(),
878                    batch,
879                    self.join_type,
880                    self.filter.as_ref(),
881                    &self.column_indices,
882                    &self.schema,
883                    visited_left_side,
884                    &mut self.indices_cache,
885                    self.right_side_ordered,
886                );
887                timer.done();
888
889                self.batch_transformer.set_batch(result?);
890                Ok(StatefulStreamResult::Continue)
891            }
892            Some((batch, last)) => {
893                if last {
894                    self.state = NestedLoopJoinStreamState::FetchProbeBatch;
895                }
896
897                self.join_metrics.output_batches.add(1);
898                self.join_metrics.output_rows.add(batch.num_rows());
899                Ok(StatefulStreamResult::Ready(Some(batch)))
900            }
901        }
902    }
903
904    /// Processes unmatched build-side rows for certain join types and produces
905    /// output batch, updates state to `Completed`.
906    fn process_unmatched_build_batch(
907        &mut self,
908    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
909        let Some(left_data) = self.left_data.clone() else {
910            return internal_err!(
911                "Expected left_data to be Some in ExhaustedProbeSide state"
912            );
913        };
914        let visited_left_side = left_data.bitmap();
915        if need_produce_result_in_final(self.join_type) {
916            // At this stage `visited_left_side` won't be updated, so it's
917            // safe to report about probe completion.
918            //
919            // Setting `is_exhausted` / returning None will prevent from
920            // multiple calls of `report_probe_completed()`
921            if !left_data.report_probe_completed() {
922                self.state = NestedLoopJoinStreamState::Completed;
923                return Ok(StatefulStreamResult::Ready(None));
924            };
925
926            // Only setting up timer, input is exhausted
927            let timer = self.join_metrics.join_time.timer();
928            // use the global left bitmap to produce the left indices and right indices
929            let (left_side, right_side) =
930                get_final_indices_from_shared_bitmap(visited_left_side, self.join_type);
931            let empty_right_batch = RecordBatch::new_empty(self.outer_table.schema());
932            // use the left and right indices to produce the batch result
933            let result = build_batch_from_indices(
934                &self.schema,
935                left_data.batch(),
936                &empty_right_batch,
937                &left_side,
938                &right_side,
939                &self.column_indices,
940                JoinSide::Left,
941            );
942            self.state = NestedLoopJoinStreamState::Completed;
943
944            // Recording time
945            if result.is_ok() {
946                timer.done();
947            }
948
949            Ok(StatefulStreamResult::Ready(Some(result?)))
950        } else {
951            // end of the join loop
952            self.state = NestedLoopJoinStreamState::Completed;
953            Ok(StatefulStreamResult::Ready(None))
954        }
955    }
956}
957
958#[allow(clippy::too_many_arguments)]
959fn join_left_and_right_batch(
960    left_batch: &RecordBatch,
961    right_batch: &RecordBatch,
962    join_type: JoinType,
963    filter: Option<&JoinFilter>,
964    column_indices: &[ColumnIndex],
965    schema: &Schema,
966    visited_left_side: &SharedBitmapBuilder,
967    indices_cache: &mut (UInt64Array, UInt32Array),
968    right_side_ordered: bool,
969) -> Result<RecordBatch> {
970    let (left_side, right_side) =
971        build_join_indices(left_batch, right_batch, filter, indices_cache).map_err(
972            |e| {
973                exec_datafusion_err!(
974                    "Fail to build join indices in NestedLoopJoinExec, error: {e}"
975                )
976            },
977        )?;
978
979    // set the left bitmap
980    // and only full join need the left bitmap
981    if need_produce_result_in_final(join_type) {
982        let mut bitmap = visited_left_side.lock();
983        left_side.values().iter().for_each(|x| {
984            bitmap.set_bit(*x as usize, true);
985        });
986    }
987    // adjust the two side indices base on the join type
988    let (left_side, right_side) = adjust_indices_by_join_type(
989        left_side,
990        right_side,
991        0..right_batch.num_rows(),
992        join_type,
993        right_side_ordered,
994    )?;
995
996    build_batch_from_indices(
997        schema,
998        left_batch,
999        right_batch,
1000        &left_side,
1001        &right_side,
1002        column_indices,
1003        JoinSide::Left,
1004    )
1005}
1006
1007impl<T: BatchTransformer + Unpin + Send> Stream for NestedLoopJoinStream<T> {
1008    type Item = Result<RecordBatch>;
1009
1010    fn poll_next(
1011        mut self: std::pin::Pin<&mut Self>,
1012        cx: &mut std::task::Context<'_>,
1013    ) -> Poll<Option<Self::Item>> {
1014        self.poll_next_impl(cx)
1015    }
1016}
1017
1018impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for NestedLoopJoinStream<T> {
1019    fn schema(&self) -> SchemaRef {
1020        Arc::clone(&self.schema)
1021    }
1022}
1023
1024impl EmbeddedProjection for NestedLoopJoinExec {
1025    fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
1026        self.with_projection(projection)
1027    }
1028}
1029
1030#[cfg(test)]
1031pub(crate) mod tests {
1032    use super::*;
1033    use crate::test::TestMemoryExec;
1034    use crate::{
1035        common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
1036    };
1037
1038    use arrow::array::Int32Array;
1039    use arrow::compute::SortOptions;
1040    use arrow::datatypes::{DataType, Field};
1041    use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue};
1042    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1043    use datafusion_expr::Operator;
1044    use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
1045    use datafusion_physical_expr::{Partitioning, PhysicalExpr};
1046    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1047
1048    use rstest::rstest;
1049
1050    fn build_table(
1051        a: (&str, &Vec<i32>),
1052        b: (&str, &Vec<i32>),
1053        c: (&str, &Vec<i32>),
1054        batch_size: Option<usize>,
1055        sorted_column_names: Vec<&str>,
1056    ) -> Arc<dyn ExecutionPlan> {
1057        let batch = build_table_i32(a, b, c);
1058        let schema = batch.schema();
1059
1060        let batches = if let Some(batch_size) = batch_size {
1061            let num_batches = batch.num_rows().div_ceil(batch_size);
1062            (0..num_batches)
1063                .map(|i| {
1064                    let start = i * batch_size;
1065                    let remaining_rows = batch.num_rows() - start;
1066                    batch.slice(start, batch_size.min(remaining_rows))
1067                })
1068                .collect::<Vec<_>>()
1069        } else {
1070            vec![batch]
1071        };
1072
1073        let mut source =
1074            TestMemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap();
1075        if !sorted_column_names.is_empty() {
1076            let mut sort_info = LexOrdering::default();
1077            for name in sorted_column_names {
1078                let index = schema.index_of(name).unwrap();
1079                let sort_expr = PhysicalSortExpr {
1080                    expr: Arc::new(Column::new(name, index)),
1081                    options: SortOptions {
1082                        descending: false,
1083                        nulls_first: false,
1084                    },
1085                };
1086                sort_info.push(sort_expr);
1087            }
1088            source = source.try_with_sort_information(vec![sort_info]).unwrap();
1089        }
1090
1091        Arc::new(TestMemoryExec::update_cache(Arc::new(source)))
1092    }
1093
1094    fn build_left_table() -> Arc<dyn ExecutionPlan> {
1095        build_table(
1096            ("a1", &vec![5, 9, 11]),
1097            ("b1", &vec![5, 8, 8]),
1098            ("c1", &vec![50, 90, 110]),
1099            None,
1100            Vec::new(),
1101        )
1102    }
1103
1104    fn build_right_table() -> Arc<dyn ExecutionPlan> {
1105        build_table(
1106            ("a2", &vec![12, 2, 10]),
1107            ("b2", &vec![10, 2, 10]),
1108            ("c2", &vec![40, 80, 100]),
1109            None,
1110            Vec::new(),
1111        )
1112    }
1113
1114    fn prepare_join_filter() -> JoinFilter {
1115        let column_indices = vec![
1116            ColumnIndex {
1117                index: 1,
1118                side: JoinSide::Left,
1119            },
1120            ColumnIndex {
1121                index: 1,
1122                side: JoinSide::Right,
1123            },
1124        ];
1125        let intermediate_schema = Schema::new(vec![
1126            Field::new("x", DataType::Int32, true),
1127            Field::new("x", DataType::Int32, true),
1128        ]);
1129        // left.b1!=8
1130        let left_filter = Arc::new(BinaryExpr::new(
1131            Arc::new(Column::new("x", 0)),
1132            Operator::NotEq,
1133            Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
1134        )) as Arc<dyn PhysicalExpr>;
1135        // right.b2!=10
1136        let right_filter = Arc::new(BinaryExpr::new(
1137            Arc::new(Column::new("x", 1)),
1138            Operator::NotEq,
1139            Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
1140        )) as Arc<dyn PhysicalExpr>;
1141        // filter = left.b1!=8 and right.b2!=10
1142        // after filter:
1143        // left table:
1144        // ("a1", &vec![5]),
1145        // ("b1", &vec![5]),
1146        // ("c1", &vec![50]),
1147        // right table:
1148        // ("a2", &vec![12, 2]),
1149        // ("b2", &vec![10, 2]),
1150        // ("c2", &vec![40, 80]),
1151        let filter_expression =
1152            Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
1153                as Arc<dyn PhysicalExpr>;
1154
1155        JoinFilter::new(
1156            filter_expression,
1157            column_indices,
1158            Arc::new(intermediate_schema),
1159        )
1160    }
1161
1162    pub(crate) async fn multi_partitioned_join_collect(
1163        left: Arc<dyn ExecutionPlan>,
1164        right: Arc<dyn ExecutionPlan>,
1165        join_type: &JoinType,
1166        join_filter: Option<JoinFilter>,
1167        context: Arc<TaskContext>,
1168    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
1169        let partition_count = 4;
1170
1171        // Redistributing right input
1172        let right = Arc::new(RepartitionExec::try_new(
1173            right,
1174            Partitioning::RoundRobinBatch(partition_count),
1175        )?) as Arc<dyn ExecutionPlan>;
1176
1177        // Use the required distribution for nested loop join to test partition data
1178        let nested_loop_join =
1179            NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?;
1180        let columns = columns(&nested_loop_join.schema());
1181        let mut batches = vec![];
1182        for i in 0..partition_count {
1183            let stream = nested_loop_join.execute(i, Arc::clone(&context))?;
1184            let more_batches = common::collect(stream).await?;
1185            batches.extend(
1186                more_batches
1187                    .into_iter()
1188                    .filter(|b| b.num_rows() > 0)
1189                    .collect::<Vec<_>>(),
1190            );
1191        }
1192        Ok((columns, batches))
1193    }
1194
1195    #[tokio::test]
1196    async fn join_inner_with_filter() -> Result<()> {
1197        let task_ctx = Arc::new(TaskContext::default());
1198        let left = build_left_table();
1199        let right = build_right_table();
1200        let filter = prepare_join_filter();
1201        let (columns, batches) = multi_partitioned_join_collect(
1202            left,
1203            right,
1204            &JoinType::Inner,
1205            Some(filter),
1206            task_ctx,
1207        )
1208        .await?;
1209        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
1210        let expected = [
1211            "+----+----+----+----+----+----+",
1212            "| a1 | b1 | c1 | a2 | b2 | c2 |",
1213            "+----+----+----+----+----+----+",
1214            "| 5  | 5  | 50 | 2  | 2  | 80 |",
1215            "+----+----+----+----+----+----+",
1216        ];
1217
1218        assert_batches_sorted_eq!(expected, &batches);
1219
1220        Ok(())
1221    }
1222
1223    #[tokio::test]
1224    async fn join_left_with_filter() -> Result<()> {
1225        let task_ctx = Arc::new(TaskContext::default());
1226        let left = build_left_table();
1227        let right = build_right_table();
1228
1229        let filter = prepare_join_filter();
1230        let (columns, batches) = multi_partitioned_join_collect(
1231            left,
1232            right,
1233            &JoinType::Left,
1234            Some(filter),
1235            task_ctx,
1236        )
1237        .await?;
1238        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
1239        let expected = [
1240            "+----+----+-----+----+----+----+",
1241            "| a1 | b1 | c1  | a2 | b2 | c2 |",
1242            "+----+----+-----+----+----+----+",
1243            "| 11 | 8  | 110 |    |    |    |",
1244            "| 5  | 5  | 50  | 2  | 2  | 80 |",
1245            "| 9  | 8  | 90  |    |    |    |",
1246            "+----+----+-----+----+----+----+",
1247        ];
1248
1249        assert_batches_sorted_eq!(expected, &batches);
1250
1251        Ok(())
1252    }
1253
1254    #[tokio::test]
1255    async fn join_right_with_filter() -> Result<()> {
1256        let task_ctx = Arc::new(TaskContext::default());
1257        let left = build_left_table();
1258        let right = build_right_table();
1259
1260        let filter = prepare_join_filter();
1261        let (columns, batches) = multi_partitioned_join_collect(
1262            left,
1263            right,
1264            &JoinType::Right,
1265            Some(filter),
1266            task_ctx,
1267        )
1268        .await?;
1269        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
1270        let expected = [
1271            "+----+----+----+----+----+-----+",
1272            "| a1 | b1 | c1 | a2 | b2 | c2  |",
1273            "+----+----+----+----+----+-----+",
1274            "|    |    |    | 10 | 10 | 100 |",
1275            "|    |    |    | 12 | 10 | 40  |",
1276            "| 5  | 5  | 50 | 2  | 2  | 80  |",
1277            "+----+----+----+----+----+-----+",
1278        ];
1279
1280        assert_batches_sorted_eq!(expected, &batches);
1281
1282        Ok(())
1283    }
1284
1285    #[tokio::test]
1286    async fn join_full_with_filter() -> Result<()> {
1287        let task_ctx = Arc::new(TaskContext::default());
1288        let left = build_left_table();
1289        let right = build_right_table();
1290
1291        let filter = prepare_join_filter();
1292        let (columns, batches) = multi_partitioned_join_collect(
1293            left,
1294            right,
1295            &JoinType::Full,
1296            Some(filter),
1297            task_ctx,
1298        )
1299        .await?;
1300        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
1301        let expected = [
1302            "+----+----+-----+----+----+-----+",
1303            "| a1 | b1 | c1  | a2 | b2 | c2  |",
1304            "+----+----+-----+----+----+-----+",
1305            "|    |    |     | 10 | 10 | 100 |",
1306            "|    |    |     | 12 | 10 | 40  |",
1307            "| 11 | 8  | 110 |    |    |     |",
1308            "| 5  | 5  | 50  | 2  | 2  | 80  |",
1309            "| 9  | 8  | 90  |    |    |     |",
1310            "+----+----+-----+----+----+-----+",
1311        ];
1312
1313        assert_batches_sorted_eq!(expected, &batches);
1314
1315        Ok(())
1316    }
1317
1318    #[tokio::test]
1319    async fn join_left_semi_with_filter() -> Result<()> {
1320        let task_ctx = Arc::new(TaskContext::default());
1321        let left = build_left_table();
1322        let right = build_right_table();
1323
1324        let filter = prepare_join_filter();
1325        let (columns, batches) = multi_partitioned_join_collect(
1326            left,
1327            right,
1328            &JoinType::LeftSemi,
1329            Some(filter),
1330            task_ctx,
1331        )
1332        .await?;
1333        assert_eq!(columns, vec!["a1", "b1", "c1"]);
1334        let expected = [
1335            "+----+----+----+",
1336            "| a1 | b1 | c1 |",
1337            "+----+----+----+",
1338            "| 5  | 5  | 50 |",
1339            "+----+----+----+",
1340        ];
1341
1342        assert_batches_sorted_eq!(expected, &batches);
1343
1344        Ok(())
1345    }
1346
1347    #[tokio::test]
1348    async fn join_left_anti_with_filter() -> Result<()> {
1349        let task_ctx = Arc::new(TaskContext::default());
1350        let left = build_left_table();
1351        let right = build_right_table();
1352
1353        let filter = prepare_join_filter();
1354        let (columns, batches) = multi_partitioned_join_collect(
1355            left,
1356            right,
1357            &JoinType::LeftAnti,
1358            Some(filter),
1359            task_ctx,
1360        )
1361        .await?;
1362        assert_eq!(columns, vec!["a1", "b1", "c1"]);
1363        let expected = [
1364            "+----+----+-----+",
1365            "| a1 | b1 | c1  |",
1366            "+----+----+-----+",
1367            "| 11 | 8  | 110 |",
1368            "| 9  | 8  | 90  |",
1369            "+----+----+-----+",
1370        ];
1371
1372        assert_batches_sorted_eq!(expected, &batches);
1373
1374        Ok(())
1375    }
1376
1377    #[tokio::test]
1378    async fn join_right_semi_with_filter() -> Result<()> {
1379        let task_ctx = Arc::new(TaskContext::default());
1380        let left = build_left_table();
1381        let right = build_right_table();
1382
1383        let filter = prepare_join_filter();
1384        let (columns, batches) = multi_partitioned_join_collect(
1385            left,
1386            right,
1387            &JoinType::RightSemi,
1388            Some(filter),
1389            task_ctx,
1390        )
1391        .await?;
1392        assert_eq!(columns, vec!["a2", "b2", "c2"]);
1393        let expected = [
1394            "+----+----+----+",
1395            "| a2 | b2 | c2 |",
1396            "+----+----+----+",
1397            "| 2  | 2  | 80 |",
1398            "+----+----+----+",
1399        ];
1400
1401        assert_batches_sorted_eq!(expected, &batches);
1402
1403        Ok(())
1404    }
1405
1406    #[tokio::test]
1407    async fn join_right_anti_with_filter() -> Result<()> {
1408        let task_ctx = Arc::new(TaskContext::default());
1409        let left = build_left_table();
1410        let right = build_right_table();
1411
1412        let filter = prepare_join_filter();
1413        let (columns, batches) = multi_partitioned_join_collect(
1414            left,
1415            right,
1416            &JoinType::RightAnti,
1417            Some(filter),
1418            task_ctx,
1419        )
1420        .await?;
1421        assert_eq!(columns, vec!["a2", "b2", "c2"]);
1422        let expected = [
1423            "+----+----+-----+",
1424            "| a2 | b2 | c2  |",
1425            "+----+----+-----+",
1426            "| 10 | 10 | 100 |",
1427            "| 12 | 10 | 40  |",
1428            "+----+----+-----+",
1429        ];
1430
1431        assert_batches_sorted_eq!(expected, &batches);
1432
1433        Ok(())
1434    }
1435
1436    #[tokio::test]
1437    async fn join_left_mark_with_filter() -> Result<()> {
1438        let task_ctx = Arc::new(TaskContext::default());
1439        let left = build_left_table();
1440        let right = build_right_table();
1441
1442        let filter = prepare_join_filter();
1443        let (columns, batches) = multi_partitioned_join_collect(
1444            left,
1445            right,
1446            &JoinType::LeftMark,
1447            Some(filter),
1448            task_ctx,
1449        )
1450        .await?;
1451        assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
1452        let expected = [
1453            "+----+----+-----+-------+",
1454            "| a1 | b1 | c1  | mark  |",
1455            "+----+----+-----+-------+",
1456            "| 11 | 8  | 110 | false |",
1457            "| 5  | 5  | 50  | true  |",
1458            "| 9  | 8  | 90  | false |",
1459            "+----+----+-----+-------+",
1460        ];
1461
1462        assert_batches_sorted_eq!(expected, &batches);
1463
1464        Ok(())
1465    }
1466
1467    #[tokio::test]
1468    async fn test_overallocation() -> Result<()> {
1469        let left = build_table(
1470            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
1471            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
1472            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
1473            None,
1474            Vec::new(),
1475        );
1476        let right = build_table(
1477            ("a2", &vec![10, 11]),
1478            ("b2", &vec![12, 13]),
1479            ("c2", &vec![14, 15]),
1480            None,
1481            Vec::new(),
1482        );
1483        let filter = prepare_join_filter();
1484
1485        let join_types = vec![
1486            JoinType::Inner,
1487            JoinType::Left,
1488            JoinType::Right,
1489            JoinType::Full,
1490            JoinType::LeftSemi,
1491            JoinType::LeftAnti,
1492            JoinType::LeftMark,
1493            JoinType::RightSemi,
1494            JoinType::RightAnti,
1495        ];
1496
1497        for join_type in join_types {
1498            let runtime = RuntimeEnvBuilder::new()
1499                .with_memory_limit(100, 1.0)
1500                .build_arc()?;
1501            let task_ctx = TaskContext::default().with_runtime(runtime);
1502            let task_ctx = Arc::new(task_ctx);
1503
1504            let err = multi_partitioned_join_collect(
1505                Arc::clone(&left),
1506                Arc::clone(&right),
1507                &join_type,
1508                Some(filter.clone()),
1509                task_ctx,
1510            )
1511            .await
1512            .unwrap_err();
1513
1514            assert_contains!(
1515                err.to_string(),
1516                "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]"
1517            );
1518        }
1519
1520        Ok(())
1521    }
1522
1523    fn prepare_mod_join_filter() -> JoinFilter {
1524        let column_indices = vec![
1525            ColumnIndex {
1526                index: 1,
1527                side: JoinSide::Left,
1528            },
1529            ColumnIndex {
1530                index: 1,
1531                side: JoinSide::Right,
1532            },
1533        ];
1534        let intermediate_schema = Schema::new(vec![
1535            Field::new("x", DataType::Int32, true),
1536            Field::new("x", DataType::Int32, true),
1537        ]);
1538
1539        // left.b1 % 3
1540        let left_mod = Arc::new(BinaryExpr::new(
1541            Arc::new(Column::new("x", 0)),
1542            Operator::Modulo,
1543            Arc::new(Literal::new(ScalarValue::Int32(Some(3)))),
1544        )) as Arc<dyn PhysicalExpr>;
1545        // left.b1 % 3 != 0
1546        let left_filter = Arc::new(BinaryExpr::new(
1547            left_mod,
1548            Operator::NotEq,
1549            Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
1550        )) as Arc<dyn PhysicalExpr>;
1551
1552        // right.b2 % 5
1553        let right_mod = Arc::new(BinaryExpr::new(
1554            Arc::new(Column::new("x", 1)),
1555            Operator::Modulo,
1556            Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
1557        )) as Arc<dyn PhysicalExpr>;
1558        // right.b2 % 5 != 0
1559        let right_filter = Arc::new(BinaryExpr::new(
1560            right_mod,
1561            Operator::NotEq,
1562            Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
1563        )) as Arc<dyn PhysicalExpr>;
1564        // filter = left.b1 % 3 != 0 and right.b2 % 5 != 0
1565        let filter_expression =
1566            Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
1567                as Arc<dyn PhysicalExpr>;
1568
1569        JoinFilter::new(
1570            filter_expression,
1571            column_indices,
1572            Arc::new(intermediate_schema),
1573        )
1574    }
1575
1576    fn generate_columns(num_columns: usize, num_rows: usize) -> Vec<Vec<i32>> {
1577        let column = (1..=num_rows).map(|x| x as i32).collect();
1578        vec![column; num_columns]
1579    }
1580
1581    #[rstest]
1582    #[tokio::test]
1583    async fn join_maintains_right_order(
1584        #[values(
1585            JoinType::Inner,
1586            JoinType::Right,
1587            JoinType::RightAnti,
1588            JoinType::RightSemi
1589        )]
1590        join_type: JoinType,
1591        #[values(1, 100, 1000)] left_batch_size: usize,
1592        #[values(1, 100, 1000)] right_batch_size: usize,
1593    ) -> Result<()> {
1594        let left_columns = generate_columns(3, 1000);
1595        let left = build_table(
1596            ("a1", &left_columns[0]),
1597            ("b1", &left_columns[1]),
1598            ("c1", &left_columns[2]),
1599            Some(left_batch_size),
1600            Vec::new(),
1601        );
1602
1603        let right_columns = generate_columns(3, 1000);
1604        let right = build_table(
1605            ("a2", &right_columns[0]),
1606            ("b2", &right_columns[1]),
1607            ("c2", &right_columns[2]),
1608            Some(right_batch_size),
1609            vec!["a2", "b2", "c2"],
1610        );
1611
1612        let filter = prepare_mod_join_filter();
1613
1614        let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new(
1615            left,
1616            Arc::clone(&right),
1617            Some(filter),
1618            &join_type,
1619            None,
1620        )?) as Arc<dyn ExecutionPlan>;
1621        assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]);
1622
1623        let right_column_indices = match join_type {
1624            JoinType::Inner | JoinType::Right => vec![3, 4, 5],
1625            JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2],
1626            _ => unreachable!(),
1627        };
1628
1629        let right_ordering = right.output_ordering().unwrap();
1630        let join_ordering = nested_loop_join.output_ordering().unwrap();
1631        for (right, join) in right_ordering.iter().zip(join_ordering.iter()) {
1632            let right_column = right.expr.as_any().downcast_ref::<Column>().unwrap();
1633            let join_column = join.expr.as_any().downcast_ref::<Column>().unwrap();
1634            assert_eq!(join_column.name(), join_column.name());
1635            assert_eq!(
1636                right_column_indices[right_column.index()],
1637                join_column.index()
1638            );
1639            assert_eq!(right.options, join.options);
1640        }
1641
1642        let batches = nested_loop_join
1643            .execute(0, Arc::new(TaskContext::default()))?
1644            .try_collect::<Vec<_>>()
1645            .await?;
1646
1647        // Make sure that the order of the right side is maintained
1648        let mut prev_values = [i32::MIN, i32::MIN, i32::MIN];
1649
1650        for (batch_index, batch) in batches.iter().enumerate() {
1651            let columns: Vec<_> = right_column_indices
1652                .iter()
1653                .map(|&i| {
1654                    batch
1655                        .column(i)
1656                        .as_any()
1657                        .downcast_ref::<Int32Array>()
1658                        .unwrap()
1659                })
1660                .collect();
1661
1662            for row in 0..batch.num_rows() {
1663                let current_values = [
1664                    columns[0].value(row),
1665                    columns[1].value(row),
1666                    columns[2].value(row),
1667                ];
1668                assert!(
1669                    current_values
1670                        .into_iter()
1671                        .zip(prev_values)
1672                        .all(|(current, prev)| current >= prev),
1673                    "batch_index: {} row: {} current: {:?}, prev: {:?}",
1674                    batch_index,
1675                    row,
1676                    current_values,
1677                    prev_values
1678                );
1679                prev_values = current_values;
1680            }
1681        }
1682
1683        Ok(())
1684    }
1685
1686    /// Returns the column names on the schema
1687    fn columns(schema: &Schema) -> Vec<String> {
1688        schema.fields().iter().map(|f| f.name().clone()).collect()
1689    }
1690}