datafusion_physical_plan/joins/
hash_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`HashJoinExec`] Partitioned Hash Join Operator
19
20use std::fmt;
21use std::mem::size_of;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::sync::Arc;
24use std::task::Poll;
25use std::{any::Any, vec};
26
27use super::utils::{
28    asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap,
29    reorder_output_after_swap, swap_join_projection,
30};
31use super::{
32    utils::{OnceAsync, OnceFut},
33    PartitionMode, SharedBitmapBuilder,
34};
35use crate::execution_plan::{boundedness_from_children, EmissionType};
36use crate::projection::{
37    try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData,
38    ProjectionExec,
39};
40use crate::spill::get_record_batch_memory_size;
41use crate::ExecutionPlanProperties;
42use crate::{
43    coalesce_partitions::CoalescePartitionsExec,
44    common::can_project,
45    handle_state,
46    hash_utils::create_hashes,
47    joins::utils::{
48        adjust_indices_by_join_type, apply_join_filter_to_indices,
49        build_batch_from_indices, build_join_schema, check_join_is_valid,
50        estimate_join_statistics, need_produce_result_in_final,
51        symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex,
52        JoinFilter, JoinHashMap, JoinHashMapOffset, JoinHashMapType, JoinOn, JoinOnRef,
53        StatefulStreamResult,
54    },
55    metrics::{ExecutionPlanMetricsSet, MetricsSet},
56    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
57    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
58};
59
60use arrow::array::{
61    cast::downcast_array, Array, ArrayRef, BooleanArray, BooleanBufferBuilder,
62    UInt32Array, UInt64Array,
63};
64use arrow::compute::kernels::cmp::{eq, not_distinct};
65use arrow::compute::{and, concat_batches, take, FilterBuilder};
66use arrow::datatypes::{Schema, SchemaRef};
67use arrow::error::ArrowError;
68use arrow::record_batch::RecordBatch;
69use arrow::util::bit_util;
70use datafusion_common::utils::memory::estimate_memory_size;
71use datafusion_common::{
72    internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError,
73    JoinSide, JoinType, Result,
74};
75use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
76use datafusion_execution::TaskContext;
77use datafusion_expr::Operator;
78use datafusion_physical_expr::equivalence::{
79    join_equivalence_properties, ProjectionMapping,
80};
81use datafusion_physical_expr::PhysicalExprRef;
82use datafusion_physical_expr_common::datum::compare_op_for_nested;
83
84use ahash::RandomState;
85use futures::{ready, Stream, StreamExt, TryStreamExt};
86use parking_lot::Mutex;
87
88/// HashTable and input data for the left (build side) of a join
89struct JoinLeftData {
90    /// The hash table with indices into `batch`
91    hash_map: JoinHashMap,
92    /// The input rows for the build side
93    batch: RecordBatch,
94    /// The build side on expressions values
95    values: Vec<ArrayRef>,
96    /// Shared bitmap builder for visited left indices
97    visited_indices_bitmap: SharedBitmapBuilder,
98    /// Counter of running probe-threads, potentially
99    /// able to update `visited_indices_bitmap`
100    probe_threads_counter: AtomicUsize,
101    /// We need to keep this field to maintain accurate memory accounting, even though we don't directly use it.
102    /// Without holding onto this reservation, the recorded memory usage would become inconsistent with actual usage.
103    /// This could hide potential out-of-memory issues, especially when upstream operators increase their memory consumption.
104    /// The MemoryReservation ensures proper tracking of memory resources throughout the join operation's lifecycle.
105    _reservation: MemoryReservation,
106}
107
108impl JoinLeftData {
109    /// Create a new `JoinLeftData` from its parts
110    fn new(
111        hash_map: JoinHashMap,
112        batch: RecordBatch,
113        values: Vec<ArrayRef>,
114        visited_indices_bitmap: SharedBitmapBuilder,
115        probe_threads_counter: AtomicUsize,
116        reservation: MemoryReservation,
117    ) -> Self {
118        Self {
119            hash_map,
120            batch,
121            values,
122            visited_indices_bitmap,
123            probe_threads_counter,
124            _reservation: reservation,
125        }
126    }
127
128    /// return a reference to the hash map
129    fn hash_map(&self) -> &JoinHashMap {
130        &self.hash_map
131    }
132
133    /// returns a reference to the build side batch
134    fn batch(&self) -> &RecordBatch {
135        &self.batch
136    }
137
138    /// returns a reference to the build side expressions values
139    fn values(&self) -> &[ArrayRef] {
140        &self.values
141    }
142
143    /// returns a reference to the visited indices bitmap
144    fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder {
145        &self.visited_indices_bitmap
146    }
147
148    /// Decrements the counter of running threads, and returns `true`
149    /// if caller is the last running thread
150    fn report_probe_completed(&self) -> bool {
151        self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
152    }
153}
154
155#[allow(rustdoc::private_intra_doc_links)]
156/// Join execution plan: Evaluates equijoin predicates in parallel on multiple
157/// partitions using a hash table and an optional filter list to apply post
158/// join.
159///
160/// # Join Expressions
161///
162/// This implementation is optimized for evaluating equijoin predicates  (
163/// `<col1> = <col2>`) expressions, which are represented as a list of `Columns`
164/// in [`Self::on`].
165///
166/// Non-equality predicates, which can not pushed down to a join inputs (e.g.
167/// `<col1> != <col2>`) are known as "filter expressions" and are evaluated
168/// after the equijoin predicates.
169///
170/// # "Build Side" vs "Probe Side"
171///
172/// HashJoin takes two inputs, which are referred to as the "build" and the
173/// "probe". The build side is the first child, and the probe side is the second
174/// child.
175///
176/// The two inputs are treated differently and it is VERY important that the
177/// *smaller* input is placed on the build side to minimize the work of creating
178/// the hash table.
179///
180/// ```text
181///          ┌───────────┐
182///          │ HashJoin  │
183///          │           │
184///          └───────────┘
185///              │   │
186///        ┌─────┘   └─────┐
187///        ▼               ▼
188/// ┌────────────┐  ┌─────────────┐
189/// │   Input    │  │    Input    │
190/// │    [0]     │  │     [1]     │
191/// └────────────┘  └─────────────┘
192///
193///  "build side"    "probe side"
194/// ```
195///
196/// Execution proceeds in 2 stages:
197///
198/// 1. the **build phase** creates a hash table from the tuples of the build side,
199///    and single concatenated batch containing data from all fetched record batches.
200///    Resulting hash table stores hashed join-key fields for each row as a key, and
201///    indices of corresponding rows in concatenated batch.
202///
203/// Hash join uses LIFO data structure as a hash table, and in order to retain
204/// original build-side input order while obtaining data during probe phase, hash
205/// table is updated by iterating batch sequence in reverse order -- it allows to
206/// keep rows with smaller indices "on the top" of hash table, and still maintain
207/// correct indexing for concatenated build-side data batch.
208///
209/// Example of build phase for 3 record batches:
210///
211///
212/// ```text
213///
214///  Original build-side data   Inserting build-side values into hashmap    Concatenated build-side batch
215///                                                                         ┌───────────────────────────┐
216///                             hashmap.insert(row-hash, row-idx + offset)  │                      idx  │
217///            ┌───────┐                                                    │          ┌───────┐        │
218///            │ Row 1 │        1) update_hash for batch 3 with offset 0    │          │ Row 6 │    0   │
219///   Batch 1  │       │           - hashmap.insert(Row 7, idx 1)           │ Batch 3  │       │        │
220///            │ Row 2 │           - hashmap.insert(Row 6, idx 0)           │          │ Row 7 │    1   │
221///            └───────┘                                                    │          └───────┘        │
222///                                                                         │                           │
223///            ┌───────┐                                                    │          ┌───────┐        │
224///            │ Row 3 │        2) update_hash for batch 2 with offset 2    │          │ Row 3 │    2   │
225///            │       │           - hashmap.insert(Row 5, idx 4)           │          │       │        │
226///   Batch 2  │ Row 4 │           - hashmap.insert(Row 4, idx 3)           │ Batch 2  │ Row 4 │    3   │
227///            │       │           - hashmap.insert(Row 3, idx 2)           │          │       │        │
228///            │ Row 5 │                                                    │          │ Row 5 │    4   │
229///            └───────┘                                                    │          └───────┘        │
230///                                                                         │                           │
231///            ┌───────┐                                                    │          ┌───────┐        │
232///            │ Row 6 │        3) update_hash for batch 1 with offset 5    │          │ Row 1 │    5   │
233///   Batch 3  │       │           - hashmap.insert(Row 2, idx 6)           │ Batch 1  │       │        │
234///            │ Row 7 │           - hashmap.insert(Row 1, idx 5)           │          │ Row 2 │    6   │
235///            └───────┘                                                    │          └───────┘        │
236///                                                                         │                           │
237///                                                                         └───────────────────────────┘
238///
239/// ```
240///
241/// 2. the **probe phase** where the tuples of the probe side are streamed
242///    through, checking for matches of the join keys in the hash table.
243///
244/// ```text
245///                 ┌────────────────┐          ┌────────────────┐
246///                 │ ┌─────────┐    │          │ ┌─────────┐    │
247///                 │ │  Hash   │    │          │ │  Hash   │    │
248///                 │ │  Table  │    │          │ │  Table  │    │
249///                 │ │(keys are│    │          │ │(keys are│    │
250///                 │ │equi join│    │          │ │equi join│    │  Stage 2: batches from
251///  Stage 1: the   │ │columns) │    │          │ │columns) │    │    the probe side are
252/// *entire* build  │ │         │    │          │ │         │    │  streamed through, and
253///  side is read   │ └─────────┘    │          │ └─────────┘    │   checked against the
254/// into the hash   │      ▲         │          │          ▲     │   contents of the hash
255///     table       │       HashJoin │          │  HashJoin      │          table
256///                 └──────┼─────────┘          └──────────┼─────┘
257///             ─ ─ ─ ─ ─ ─                                 ─ ─ ─ ─ ─ ─ ─
258///            │                                                         │
259///
260///            │                                                         │
261///     ┌────────────┐                                            ┌────────────┐
262///     │RecordBatch │                                            │RecordBatch │
263///     └────────────┘                                            └────────────┘
264///     ┌────────────┐                                            ┌────────────┐
265///     │RecordBatch │                                            │RecordBatch │
266///     └────────────┘                                            └────────────┘
267///           ...                                                       ...
268///     ┌────────────┐                                            ┌────────────┐
269///     │RecordBatch │                                            │RecordBatch │
270///     └────────────┘                                            └────────────┘
271///
272///        build side                                                probe side
273///
274/// ```
275///
276/// # Example "Optimal" Plans
277///
278/// The differences in the inputs means that for classic "Star Schema Query",
279/// the optimal plan will be a **"Right Deep Tree"** . A Star Schema Query is
280/// one where there is one large table and several smaller "dimension" tables,
281/// joined on `Foreign Key = Primary Key` predicates.
282///
283/// A "Right Deep Tree" looks like this large table as the probe side on the
284/// lowest join:
285///
286/// ```text
287///             ┌───────────┐
288///             │ HashJoin  │
289///             │           │
290///             └───────────┘
291///                 │   │
292///         ┌───────┘   └──────────┐
293///         ▼                      ▼
294/// ┌───────────────┐        ┌───────────┐
295/// │ small table 1 │        │ HashJoin  │
296/// │  "dimension"  │        │           │
297/// └───────────────┘        └───┬───┬───┘
298///                   ┌──────────┘   └───────┐
299///                   │                      │
300///                   ▼                      ▼
301///           ┌───────────────┐        ┌───────────┐
302///           │ small table 2 │        │ HashJoin  │
303///           │  "dimension"  │        │           │
304///           └───────────────┘        └───┬───┬───┘
305///                               ┌────────┘   └────────┐
306///                               │                     │
307///                               ▼                     ▼
308///                       ┌───────────────┐     ┌───────────────┐
309///                       │ small table 3 │     │  large table  │
310///                       │  "dimension"  │     │    "fact"     │
311///                       └───────────────┘     └───────────────┘
312/// ```
313///
314/// # Clone / Shared State
315///
316/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
317/// loading of the left side with the processing in each output stream.
318/// Therefore it can not be [`Clone`]
319#[derive(Debug)]
320pub struct HashJoinExec {
321    /// left (build) side which gets hashed
322    pub left: Arc<dyn ExecutionPlan>,
323    /// right (probe) side which are filtered by the hash table
324    pub right: Arc<dyn ExecutionPlan>,
325    /// Set of equijoin columns from the relations: `(left_col, right_col)`
326    pub on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
327    /// Filters which are applied while finding matching rows
328    pub filter: Option<JoinFilter>,
329    /// How the join is performed (`OUTER`, `INNER`, etc)
330    pub join_type: JoinType,
331    /// The schema after join. Please be careful when using this schema,
332    /// if there is a projection, the schema isn't the same as the output schema.
333    join_schema: SchemaRef,
334    /// Future that consumes left input and builds the hash table
335    ///
336    /// For CollectLeft partition mode, this structure is *shared* across all output streams.
337    ///
338    /// Each output stream waits on the `OnceAsync` to signal the completion of
339    /// the hash table creation.
340    left_fut: OnceAsync<JoinLeftData>,
341    /// Shared the `RandomState` for the hashing algorithm
342    random_state: RandomState,
343    /// Partitioning mode to use
344    pub mode: PartitionMode,
345    /// Execution metrics
346    metrics: ExecutionPlanMetricsSet,
347    /// The projection indices of the columns in the output schema of join
348    pub projection: Option<Vec<usize>>,
349    /// Information of index and left / right placement of columns
350    column_indices: Vec<ColumnIndex>,
351    /// Null matching behavior: If `null_equals_null` is true, rows that have
352    /// `null`s in both left and right equijoin columns will be matched.
353    /// Otherwise, rows that have `null`s in the join columns will not be
354    /// matched and thus will not appear in the output.
355    pub null_equals_null: bool,
356    /// Cache holding plan properties like equivalences, output partitioning etc.
357    cache: PlanProperties,
358}
359
360impl HashJoinExec {
361    /// Tries to create a new [HashJoinExec].
362    ///
363    /// # Error
364    /// This function errors when it is not possible to join the left and right sides on keys `on`.
365    #[allow(clippy::too_many_arguments)]
366    pub fn try_new(
367        left: Arc<dyn ExecutionPlan>,
368        right: Arc<dyn ExecutionPlan>,
369        on: JoinOn,
370        filter: Option<JoinFilter>,
371        join_type: &JoinType,
372        projection: Option<Vec<usize>>,
373        partition_mode: PartitionMode,
374        null_equals_null: bool,
375    ) -> Result<Self> {
376        let left_schema = left.schema();
377        let right_schema = right.schema();
378        if on.is_empty() {
379            return plan_err!("On constraints in HashJoinExec should be non-empty");
380        }
381
382        check_join_is_valid(&left_schema, &right_schema, &on)?;
383
384        let (join_schema, column_indices) =
385            build_join_schema(&left_schema, &right_schema, join_type);
386
387        let random_state = RandomState::with_seeds(0, 0, 0, 0);
388
389        let join_schema = Arc::new(join_schema);
390
391        //  check if the projection is valid
392        can_project(&join_schema, projection.as_ref())?;
393
394        let cache = Self::compute_properties(
395            &left,
396            &right,
397            Arc::clone(&join_schema),
398            *join_type,
399            &on,
400            partition_mode,
401            projection.as_ref(),
402        )?;
403
404        Ok(HashJoinExec {
405            left,
406            right,
407            on,
408            filter,
409            join_type: *join_type,
410            join_schema,
411            left_fut: Default::default(),
412            random_state,
413            mode: partition_mode,
414            metrics: ExecutionPlanMetricsSet::new(),
415            projection,
416            column_indices,
417            null_equals_null,
418            cache,
419        })
420    }
421
422    /// left (build) side which gets hashed
423    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
424        &self.left
425    }
426
427    /// right (probe) side which are filtered by the hash table
428    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
429        &self.right
430    }
431
432    /// Set of common columns used to join on
433    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
434        &self.on
435    }
436
437    /// Filters applied before join output
438    pub fn filter(&self) -> Option<&JoinFilter> {
439        self.filter.as_ref()
440    }
441
442    /// How the join is performed
443    pub fn join_type(&self) -> &JoinType {
444        &self.join_type
445    }
446
447    /// The schema after join. Please be careful when using this schema,
448    /// if there is a projection, the schema isn't the same as the output schema.
449    pub fn join_schema(&self) -> &SchemaRef {
450        &self.join_schema
451    }
452
453    /// The partitioning mode of this hash join
454    pub fn partition_mode(&self) -> &PartitionMode {
455        &self.mode
456    }
457
458    /// Get null_equals_null
459    pub fn null_equals_null(&self) -> bool {
460        self.null_equals_null
461    }
462
463    /// Calculate order preservation flags for this hash join.
464    fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
465        vec![
466            false,
467            matches!(
468                join_type,
469                JoinType::Inner
470                    | JoinType::Right
471                    | JoinType::RightAnti
472                    | JoinType::RightSemi
473            ),
474        ]
475    }
476
477    /// Get probe side information for the hash join.
478    pub fn probe_side() -> JoinSide {
479        // In current implementation right side is always probe side.
480        JoinSide::Right
481    }
482
483    /// Return whether the join contains a projection
484    pub fn contains_projection(&self) -> bool {
485        self.projection.is_some()
486    }
487
488    /// Return new instance of [HashJoinExec] with the given projection.
489    pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
490        //  check if the projection is valid
491        can_project(&self.schema(), projection.as_ref())?;
492        let projection = match projection {
493            Some(projection) => match &self.projection {
494                Some(p) => Some(projection.iter().map(|i| p[*i]).collect()),
495                None => Some(projection),
496            },
497            None => None,
498        };
499        Self::try_new(
500            Arc::clone(&self.left),
501            Arc::clone(&self.right),
502            self.on.clone(),
503            self.filter.clone(),
504            &self.join_type,
505            projection,
506            self.mode,
507            self.null_equals_null,
508        )
509    }
510
511    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
512    fn compute_properties(
513        left: &Arc<dyn ExecutionPlan>,
514        right: &Arc<dyn ExecutionPlan>,
515        schema: SchemaRef,
516        join_type: JoinType,
517        on: JoinOnRef,
518        mode: PartitionMode,
519        projection: Option<&Vec<usize>>,
520    ) -> Result<PlanProperties> {
521        // Calculate equivalence properties:
522        let mut eq_properties = join_equivalence_properties(
523            left.equivalence_properties().clone(),
524            right.equivalence_properties().clone(),
525            &join_type,
526            Arc::clone(&schema),
527            &Self::maintains_input_order(join_type),
528            Some(Self::probe_side()),
529            on,
530        );
531
532        let mut output_partitioning = match mode {
533            PartitionMode::CollectLeft => {
534                asymmetric_join_output_partitioning(left, right, &join_type)
535            }
536            PartitionMode::Auto => Partitioning::UnknownPartitioning(
537                right.output_partitioning().partition_count(),
538            ),
539            PartitionMode::Partitioned => {
540                symmetric_join_output_partitioning(left, right, &join_type)
541            }
542        };
543
544        let emission_type = if left.boundedness().is_unbounded() {
545            EmissionType::Final
546        } else if right.pipeline_behavior() == EmissionType::Incremental {
547            match join_type {
548                // If we only need to generate matched rows from the probe side,
549                // we can emit rows incrementally.
550                JoinType::Inner
551                | JoinType::LeftSemi
552                | JoinType::RightSemi
553                | JoinType::Right
554                | JoinType::RightAnti => EmissionType::Incremental,
555                // If we need to generate unmatched rows from the *build side*,
556                // we need to emit them at the end.
557                JoinType::Left
558                | JoinType::LeftAnti
559                | JoinType::LeftMark
560                | JoinType::Full => EmissionType::Both,
561            }
562        } else {
563            right.pipeline_behavior()
564        };
565
566        // If contains projection, update the PlanProperties.
567        if let Some(projection) = projection {
568            // construct a map from the input expressions to the output expression of the Projection
569            let projection_mapping =
570                ProjectionMapping::from_indices(projection, &schema)?;
571            let out_schema = project_schema(&schema, Some(projection))?;
572            output_partitioning =
573                output_partitioning.project(&projection_mapping, &eq_properties);
574            eq_properties = eq_properties.project(&projection_mapping, out_schema);
575        }
576
577        Ok(PlanProperties::new(
578            eq_properties,
579            output_partitioning,
580            emission_type,
581            boundedness_from_children([left, right]),
582        ))
583    }
584
585    /// Returns a new `ExecutionPlan` that computes the same join as this one,
586    /// with the left and right inputs swapped using the  specified
587    /// `partition_mode`.
588    ///
589    /// # Notes:
590    ///
591    /// This function is public so other downstream projects can use it to
592    /// construct `HashJoinExec` with right side as the build side.
593    pub fn swap_inputs(
594        &self,
595        partition_mode: PartitionMode,
596    ) -> Result<Arc<dyn ExecutionPlan>> {
597        let left = self.left();
598        let right = self.right();
599        let new_join = HashJoinExec::try_new(
600            Arc::clone(right),
601            Arc::clone(left),
602            self.on()
603                .iter()
604                .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
605                .collect(),
606            self.filter().map(JoinFilter::swap),
607            &self.join_type().swap(),
608            swap_join_projection(
609                left.schema().fields().len(),
610                right.schema().fields().len(),
611                self.projection.as_ref(),
612                self.join_type(),
613            ),
614            partition_mode,
615            self.null_equals_null(),
616        )?;
617        // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again
618        if matches!(
619            self.join_type(),
620            JoinType::LeftSemi
621                | JoinType::RightSemi
622                | JoinType::LeftAnti
623                | JoinType::RightAnti
624        ) || self.projection.is_some()
625        {
626            Ok(Arc::new(new_join))
627        } else {
628            reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
629        }
630    }
631}
632
633impl DisplayAs for HashJoinExec {
634    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
635        match t {
636            DisplayFormatType::Default | DisplayFormatType::Verbose => {
637                let display_filter = self.filter.as_ref().map_or_else(
638                    || "".to_string(),
639                    |f| format!(", filter={}", f.expression()),
640                );
641                let display_projections = if self.contains_projection() {
642                    format!(
643                        ", projection=[{}]",
644                        self.projection
645                            .as_ref()
646                            .unwrap()
647                            .iter()
648                            .map(|index| format!(
649                                "{}@{}",
650                                self.join_schema.fields().get(*index).unwrap().name(),
651                                index
652                            ))
653                            .collect::<Vec<_>>()
654                            .join(", ")
655                    )
656                } else {
657                    "".to_string()
658                };
659                let on = self
660                    .on
661                    .iter()
662                    .map(|(c1, c2)| format!("({}, {})", c1, c2))
663                    .collect::<Vec<String>>()
664                    .join(", ");
665                write!(
666                    f,
667                    "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}",
668                    self.mode, self.join_type, on, display_filter, display_projections
669                )
670            }
671        }
672    }
673}
674
675impl ExecutionPlan for HashJoinExec {
676    fn name(&self) -> &'static str {
677        "HashJoinExec"
678    }
679
680    fn as_any(&self) -> &dyn Any {
681        self
682    }
683
684    fn properties(&self) -> &PlanProperties {
685        &self.cache
686    }
687
688    fn required_input_distribution(&self) -> Vec<Distribution> {
689        match self.mode {
690            PartitionMode::CollectLeft => vec![
691                Distribution::SinglePartition,
692                Distribution::UnspecifiedDistribution,
693            ],
694            PartitionMode::Partitioned => {
695                let (left_expr, right_expr) = self
696                    .on
697                    .iter()
698                    .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
699                    .unzip();
700                vec![
701                    Distribution::HashPartitioned(left_expr),
702                    Distribution::HashPartitioned(right_expr),
703                ]
704            }
705            PartitionMode::Auto => vec![
706                Distribution::UnspecifiedDistribution,
707                Distribution::UnspecifiedDistribution,
708            ],
709        }
710    }
711
712    // For [JoinType::Inner] and [JoinType::RightSemi] in hash joins, the probe phase initiates by
713    // applying the hash function to convert the join key(s) in each row into a hash value from the
714    // probe side table in the order they're arranged. The hash value is used to look up corresponding
715    // entries in the hash table that was constructed from the build side table during the build phase.
716    //
717    // Because of the immediate generation of result rows once a match is found,
718    // the output of the join tends to follow the order in which the rows were read from
719    // the probe side table. This is simply due to the sequence in which the rows were processed.
720    // Hence, it appears that the hash join is preserving the order of the probe side.
721    //
722    // Meanwhile, in the case of a [JoinType::RightAnti] hash join,
723    // the unmatched rows from the probe side are also kept in order.
724    // This is because the **`RightAnti`** join is designed to return rows from the right
725    // (probe side) table that have no match in the left (build side) table. Because the rows
726    // are processed sequentially in the probe phase, and unmatched rows are directly output
727    // as results, these results tend to retain the order of the probe side table.
728    fn maintains_input_order(&self) -> Vec<bool> {
729        Self::maintains_input_order(self.join_type)
730    }
731
732    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
733        vec![&self.left, &self.right]
734    }
735
736    fn with_new_children(
737        self: Arc<Self>,
738        children: Vec<Arc<dyn ExecutionPlan>>,
739    ) -> Result<Arc<dyn ExecutionPlan>> {
740        Ok(Arc::new(HashJoinExec::try_new(
741            Arc::clone(&children[0]),
742            Arc::clone(&children[1]),
743            self.on.clone(),
744            self.filter.clone(),
745            &self.join_type,
746            self.projection.clone(),
747            self.mode,
748            self.null_equals_null,
749        )?))
750    }
751
752    fn execute(
753        &self,
754        partition: usize,
755        context: Arc<TaskContext>,
756    ) -> Result<SendableRecordBatchStream> {
757        let on_left = self
758            .on
759            .iter()
760            .map(|on| Arc::clone(&on.0))
761            .collect::<Vec<_>>();
762        let on_right = self
763            .on
764            .iter()
765            .map(|on| Arc::clone(&on.1))
766            .collect::<Vec<_>>();
767        let left_partitions = self.left.output_partitioning().partition_count();
768        let right_partitions = self.right.output_partitioning().partition_count();
769
770        if self.mode == PartitionMode::Partitioned && left_partitions != right_partitions
771        {
772            return internal_err!(
773                "Invalid HashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
774                 consider using RepartitionExec"
775            );
776        }
777
778        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
779        let left_fut = match self.mode {
780            PartitionMode::CollectLeft => self.left_fut.once(|| {
781                let reservation =
782                    MemoryConsumer::new("HashJoinInput").register(context.memory_pool());
783                collect_left_input(
784                    None,
785                    self.random_state.clone(),
786                    Arc::clone(&self.left),
787                    on_left.clone(),
788                    Arc::clone(&context),
789                    join_metrics.clone(),
790                    reservation,
791                    need_produce_result_in_final(self.join_type),
792                    self.right().output_partitioning().partition_count(),
793                )
794            }),
795            PartitionMode::Partitioned => {
796                let reservation =
797                    MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
798                        .register(context.memory_pool());
799
800                OnceFut::new(collect_left_input(
801                    Some(partition),
802                    self.random_state.clone(),
803                    Arc::clone(&self.left),
804                    on_left.clone(),
805                    Arc::clone(&context),
806                    join_metrics.clone(),
807                    reservation,
808                    need_produce_result_in_final(self.join_type),
809                    1,
810                ))
811            }
812            PartitionMode::Auto => {
813                return plan_err!(
814                    "Invalid HashJoinExec, unsupported PartitionMode {:?} in execute()",
815                    PartitionMode::Auto
816                );
817            }
818        };
819
820        let batch_size = context.session_config().batch_size();
821
822        // we have the batches and the hash map with their keys. We can how create a stream
823        // over the right that uses this information to issue new batches.
824        let right_stream = self.right.execute(partition, context)?;
825
826        // update column indices to reflect the projection
827        let column_indices_after_projection = match &self.projection {
828            Some(projection) => projection
829                .iter()
830                .map(|i| self.column_indices[*i].clone())
831                .collect(),
832            None => self.column_indices.clone(),
833        };
834
835        Ok(Box::pin(HashJoinStream {
836            schema: self.schema(),
837            on_right,
838            filter: self.filter.clone(),
839            join_type: self.join_type,
840            right: right_stream,
841            column_indices: column_indices_after_projection,
842            random_state: self.random_state.clone(),
843            join_metrics,
844            null_equals_null: self.null_equals_null,
845            state: HashJoinStreamState::WaitBuildSide,
846            build_side: BuildSide::Initial(BuildSideInitialState { left_fut }),
847            batch_size,
848            hashes_buffer: vec![],
849            right_side_ordered: self.right.output_ordering().is_some(),
850        }))
851    }
852
853    fn metrics(&self) -> Option<MetricsSet> {
854        Some(self.metrics.clone_inner())
855    }
856
857    fn statistics(&self) -> Result<Statistics> {
858        // TODO stats: it is not possible in general to know the output size of joins
859        // There are some special cases though, for example:
860        // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
861        let stats = estimate_join_statistics(
862            Arc::clone(&self.left),
863            Arc::clone(&self.right),
864            self.on.clone(),
865            &self.join_type,
866            &self.join_schema,
867        )?;
868        // Project statistics if there is a projection
869        Ok(stats.project(self.projection.as_ref()))
870    }
871
872    /// Tries to push `projection` down through `hash_join`. If possible, performs the
873    /// pushdown and returns a new [`HashJoinExec`] as the top plan which has projections
874    /// as its children. Otherwise, returns `None`.
875    fn try_swapping_with_projection(
876        &self,
877        projection: &ProjectionExec,
878    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
879        // TODO: currently if there is projection in HashJoinExec, we can't push down projection to left or right input. Maybe we can pushdown the mixed projection later.
880        if self.contains_projection() {
881            return Ok(None);
882        }
883
884        if let Some(JoinData {
885            projected_left_child,
886            projected_right_child,
887            join_filter,
888            join_on,
889        }) = try_pushdown_through_join(
890            projection,
891            self.left(),
892            self.right(),
893            self.on(),
894            self.schema(),
895            self.filter(),
896        )? {
897            Ok(Some(Arc::new(HashJoinExec::try_new(
898                Arc::new(projected_left_child),
899                Arc::new(projected_right_child),
900                join_on,
901                join_filter,
902                self.join_type(),
903                // Returned early if projection is not None
904                None,
905                *self.partition_mode(),
906                self.null_equals_null,
907            )?)))
908        } else {
909            try_embed_projection(projection, self)
910        }
911    }
912}
913
914/// Reads the left (build) side of the input, buffering it in memory, to build a
915/// hash table (`LeftJoinData`)
916#[allow(clippy::too_many_arguments)]
917async fn collect_left_input(
918    partition: Option<usize>,
919    random_state: RandomState,
920    left: Arc<dyn ExecutionPlan>,
921    on_left: Vec<PhysicalExprRef>,
922    context: Arc<TaskContext>,
923    metrics: BuildProbeJoinMetrics,
924    reservation: MemoryReservation,
925    with_visited_indices_bitmap: bool,
926    probe_threads_count: usize,
927) -> Result<JoinLeftData> {
928    let schema = left.schema();
929
930    let (left_input, left_input_partition) = if let Some(partition) = partition {
931        (left, partition)
932    } else if left.output_partitioning().partition_count() != 1 {
933        (Arc::new(CoalescePartitionsExec::new(left)) as _, 0)
934    } else {
935        (left, 0)
936    };
937
938    // Depending on partition argument load single partition or whole left side in memory
939    let stream = left_input.execute(left_input_partition, Arc::clone(&context))?;
940
941    // This operation performs 2 steps at once:
942    // 1. creates a [JoinHashMap] of all batches from the stream
943    // 2. stores the batches in a vector.
944    let initial = (Vec::new(), 0, metrics, reservation);
945    let (batches, num_rows, metrics, mut reservation) = stream
946        .try_fold(initial, |mut acc, batch| async {
947            let batch_size = get_record_batch_memory_size(&batch);
948            // Reserve memory for incoming batch
949            acc.3.try_grow(batch_size)?;
950            // Update metrics
951            acc.2.build_mem_used.add(batch_size);
952            acc.2.build_input_batches.add(1);
953            acc.2.build_input_rows.add(batch.num_rows());
954            // Update row count
955            acc.1 += batch.num_rows();
956            // Push batch to output
957            acc.0.push(batch);
958            Ok(acc)
959        })
960        .await?;
961
962    // Estimation of memory size, required for hashtable, prior to allocation.
963    // Final result can be verified using `RawTable.allocation_info()`
964    let fixed_size = size_of::<JoinHashMap>();
965    let estimated_hashtable_size =
966        estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)?;
967
968    reservation.try_grow(estimated_hashtable_size)?;
969    metrics.build_mem_used.add(estimated_hashtable_size);
970
971    let mut hashmap = JoinHashMap::with_capacity(num_rows);
972    let mut hashes_buffer = Vec::new();
973    let mut offset = 0;
974
975    // Updating hashmap starting from the last batch
976    let batches_iter = batches.iter().rev();
977    for batch in batches_iter.clone() {
978        hashes_buffer.clear();
979        hashes_buffer.resize(batch.num_rows(), 0);
980        update_hash(
981            &on_left,
982            batch,
983            &mut hashmap,
984            offset,
985            &random_state,
986            &mut hashes_buffer,
987            0,
988            true,
989        )?;
990        offset += batch.num_rows();
991    }
992    // Merge all batches into a single batch, so we can directly index into the arrays
993    let single_batch = concat_batches(&schema, batches_iter)?;
994
995    // Reserve additional memory for visited indices bitmap and create shared builder
996    let visited_indices_bitmap = if with_visited_indices_bitmap {
997        let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8);
998        reservation.try_grow(bitmap_size)?;
999        metrics.build_mem_used.add(bitmap_size);
1000
1001        let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows());
1002        bitmap_buffer.append_n(num_rows, false);
1003        bitmap_buffer
1004    } else {
1005        BooleanBufferBuilder::new(0)
1006    };
1007
1008    let left_values = on_left
1009        .iter()
1010        .map(|c| {
1011            c.evaluate(&single_batch)?
1012                .into_array(single_batch.num_rows())
1013        })
1014        .collect::<Result<Vec<_>>>()?;
1015
1016    let data = JoinLeftData::new(
1017        hashmap,
1018        single_batch,
1019        left_values,
1020        Mutex::new(visited_indices_bitmap),
1021        AtomicUsize::new(probe_threads_count),
1022        reservation,
1023    );
1024
1025    Ok(data)
1026}
1027
1028/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on`
1029/// using `offset` as a start value for `batch` row indices.
1030///
1031/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap,
1032/// which allows to keep either first (if set to true) or last (if set to false) row index
1033/// as a chain head for rows with equal hash values.
1034#[allow(clippy::too_many_arguments)]
1035pub fn update_hash<T>(
1036    on: &[PhysicalExprRef],
1037    batch: &RecordBatch,
1038    hash_map: &mut T,
1039    offset: usize,
1040    random_state: &RandomState,
1041    hashes_buffer: &mut Vec<u64>,
1042    deleted_offset: usize,
1043    fifo_hashmap: bool,
1044) -> Result<()>
1045where
1046    T: JoinHashMapType,
1047{
1048    // evaluate the keys
1049    let keys_values = on
1050        .iter()
1051        .map(|c| c.evaluate(batch)?.into_array(batch.num_rows()))
1052        .collect::<Result<Vec<_>>>()?;
1053
1054    // calculate the hash values
1055    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1056
1057    // For usual JoinHashmap, the implementation is void.
1058    hash_map.extend_zero(batch.num_rows());
1059
1060    // Updating JoinHashMap from hash values iterator
1061    let hash_values_iter = hash_values
1062        .iter()
1063        .enumerate()
1064        .map(|(i, val)| (i + offset, val));
1065
1066    if fifo_hashmap {
1067        hash_map.update_from_iter(hash_values_iter.rev(), deleted_offset);
1068    } else {
1069        hash_map.update_from_iter(hash_values_iter, deleted_offset);
1070    }
1071
1072    Ok(())
1073}
1074
1075/// Represents build-side of hash join.
1076enum BuildSide {
1077    /// Indicates that build-side not collected yet
1078    Initial(BuildSideInitialState),
1079    /// Indicates that build-side data has been collected
1080    Ready(BuildSideReadyState),
1081}
1082
1083/// Container for BuildSide::Initial related data
1084struct BuildSideInitialState {
1085    /// Future for building hash table from build-side input
1086    left_fut: OnceFut<JoinLeftData>,
1087}
1088
1089/// Container for BuildSide::Ready related data
1090struct BuildSideReadyState {
1091    /// Collected build-side data
1092    left_data: Arc<JoinLeftData>,
1093}
1094
1095impl BuildSide {
1096    /// Tries to extract BuildSideInitialState from BuildSide enum.
1097    /// Returns an error if state is not Initial.
1098    fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> {
1099        match self {
1100            BuildSide::Initial(state) => Ok(state),
1101            _ => internal_err!("Expected build side in initial state"),
1102        }
1103    }
1104
1105    /// Tries to extract BuildSideReadyState from BuildSide enum.
1106    /// Returns an error if state is not Ready.
1107    fn try_as_ready(&self) -> Result<&BuildSideReadyState> {
1108        match self {
1109            BuildSide::Ready(state) => Ok(state),
1110            _ => internal_err!("Expected build side in ready state"),
1111        }
1112    }
1113
1114    /// Tries to extract BuildSideReadyState from BuildSide enum.
1115    /// Returns an error if state is not Ready.
1116    fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> {
1117        match self {
1118            BuildSide::Ready(state) => Ok(state),
1119            _ => internal_err!("Expected build side in ready state"),
1120        }
1121    }
1122}
1123
1124/// Represents state of HashJoinStream
1125///
1126/// Expected state transitions performed by HashJoinStream are:
1127///
1128/// ```text
1129///
1130///       WaitBuildSide
1131///             │
1132///             ▼
1133///  ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed
1134///  │          │
1135///  │          ▼
1136///  └─ ProcessProbeBatch
1137///
1138/// ```
1139#[derive(Debug, Clone)]
1140enum HashJoinStreamState {
1141    /// Initial state for HashJoinStream indicating that build-side data not collected yet
1142    WaitBuildSide,
1143    /// Indicates that build-side has been collected, and stream is ready for fetching probe-side
1144    FetchProbeBatch,
1145    /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed
1146    ProcessProbeBatch(ProcessProbeBatchState),
1147    /// Indicates that probe-side has been fully processed
1148    ExhaustedProbeSide,
1149    /// Indicates that HashJoinStream execution is completed
1150    Completed,
1151}
1152
1153impl HashJoinStreamState {
1154    /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum.
1155    /// Returns an error if state is not ProcessProbeBatchState.
1156    fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> {
1157        match self {
1158            HashJoinStreamState::ProcessProbeBatch(state) => Ok(state),
1159            _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"),
1160        }
1161    }
1162}
1163
1164/// Container for HashJoinStreamState::ProcessProbeBatch related data
1165#[derive(Debug, Clone)]
1166struct ProcessProbeBatchState {
1167    /// Current probe-side batch
1168    batch: RecordBatch,
1169    /// Probe-side on expressions values
1170    values: Vec<ArrayRef>,
1171    /// Starting offset for JoinHashMap lookups
1172    offset: JoinHashMapOffset,
1173    /// Max joined probe-side index from current batch
1174    joined_probe_idx: Option<usize>,
1175}
1176
1177impl ProcessProbeBatchState {
1178    fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option<usize>) {
1179        self.offset = offset;
1180        if joined_probe_idx.is_some() {
1181            self.joined_probe_idx = joined_probe_idx;
1182        }
1183    }
1184}
1185
1186/// [`Stream`] for [`HashJoinExec`] that does the actual join.
1187///
1188/// This stream:
1189///
1190/// 1. Reads the entire left input (build) and constructs a hash table
1191///
1192/// 2. Streams [RecordBatch]es as they arrive from the right input (probe) and joins
1193///    them with the contents of the hash table
1194struct HashJoinStream {
1195    /// Input schema
1196    schema: Arc<Schema>,
1197    /// equijoin columns from the right (probe side)
1198    on_right: Vec<PhysicalExprRef>,
1199    /// optional join filter
1200    filter: Option<JoinFilter>,
1201    /// type of the join (left, right, semi, etc)
1202    join_type: JoinType,
1203    /// right (probe) input
1204    right: SendableRecordBatchStream,
1205    /// Random state used for hashing initialization
1206    random_state: RandomState,
1207    /// Metrics
1208    join_metrics: BuildProbeJoinMetrics,
1209    /// Information of index and left / right placement of columns
1210    column_indices: Vec<ColumnIndex>,
1211    /// If null_equals_null is true, null == null else null != null
1212    null_equals_null: bool,
1213    /// State of the stream
1214    state: HashJoinStreamState,
1215    /// Build side
1216    build_side: BuildSide,
1217    /// Maximum output batch size
1218    batch_size: usize,
1219    /// Scratch space for computing hashes
1220    hashes_buffer: Vec<u64>,
1221    /// Specifies whether the right side has an ordering to potentially preserve
1222    right_side_ordered: bool,
1223}
1224
1225impl RecordBatchStream for HashJoinStream {
1226    fn schema(&self) -> SchemaRef {
1227        Arc::clone(&self.schema)
1228    }
1229}
1230
1231/// Executes lookups by hash against JoinHashMap and resolves potential
1232/// hash collisions.
1233/// Returns build/probe indices satisfying the equality condition, along with
1234/// (optional) starting point for next iteration.
1235///
1236/// # Example
1237///
1238/// For `LEFT.b1 = RIGHT.b2`:
1239/// LEFT (build) Table:
1240/// ```text
1241///  a1  b1  c1
1242///  1   1   10
1243///  3   3   30
1244///  5   5   50
1245///  7   7   70
1246///  9   8   90
1247///  11  8   110
1248///  13   10  130
1249/// ```
1250///
1251/// RIGHT (probe) Table:
1252/// ```text
1253///  a2   b2  c2
1254///  2    2   20
1255///  4    4   40
1256///  6    6   60
1257///  8    8   80
1258/// 10   10  100
1259/// 12   10  120
1260/// ```
1261///
1262/// The result is
1263/// ```text
1264/// "+----+----+-----+----+----+-----+",
1265/// "| a1 | b1 | c1  | a2 | b2 | c2  |",
1266/// "+----+----+-----+----+----+-----+",
1267/// "| 9  | 8  | 90  | 8  | 8  | 80  |",
1268/// "| 11 | 8  | 110 | 8  | 8  | 80  |",
1269/// "| 13 | 10 | 130 | 10 | 10 | 100 |",
1270/// "| 13 | 10 | 130 | 12 | 10 | 120 |",
1271/// "+----+----+-----+----+----+-----+"
1272/// ```
1273///
1274/// And the result of build and probe indices are:
1275/// ```text
1276/// Build indices: 4, 5, 6, 6
1277/// Probe indices: 3, 3, 4, 5
1278/// ```
1279#[allow(clippy::too_many_arguments)]
1280fn lookup_join_hashmap(
1281    build_hashmap: &JoinHashMap,
1282    build_side_values: &[ArrayRef],
1283    probe_side_values: &[ArrayRef],
1284    null_equals_null: bool,
1285    hashes_buffer: &[u64],
1286    limit: usize,
1287    offset: JoinHashMapOffset,
1288) -> Result<(UInt64Array, UInt32Array, Option<JoinHashMapOffset>)> {
1289    let (probe_indices, build_indices, next_offset) = build_hashmap
1290        .get_matched_indices_with_limit_offset(hashes_buffer, None, limit, offset);
1291
1292    let build_indices: UInt64Array = build_indices.into();
1293    let probe_indices: UInt32Array = probe_indices.into();
1294
1295    let (build_indices, probe_indices) = equal_rows_arr(
1296        &build_indices,
1297        &probe_indices,
1298        build_side_values,
1299        probe_side_values,
1300        null_equals_null,
1301    )?;
1302
1303    Ok((build_indices, probe_indices, next_offset))
1304}
1305
1306// version of eq_dyn supporting equality on null arrays
1307fn eq_dyn_null(
1308    left: &dyn Array,
1309    right: &dyn Array,
1310    null_equals_null: bool,
1311) -> Result<BooleanArray, ArrowError> {
1312    // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special
1313    // implementation
1314    // <https://github.com/apache/datafusion/issues/10749>
1315    if left.data_type().is_nested() {
1316        let op = if null_equals_null {
1317            Operator::IsNotDistinctFrom
1318        } else {
1319            Operator::Eq
1320        };
1321        return Ok(compare_op_for_nested(op, &left, &right)?);
1322    }
1323    match (left.data_type(), right.data_type()) {
1324        _ if null_equals_null => not_distinct(&left, &right),
1325        _ => eq(&left, &right),
1326    }
1327}
1328
1329pub fn equal_rows_arr(
1330    indices_left: &UInt64Array,
1331    indices_right: &UInt32Array,
1332    left_arrays: &[ArrayRef],
1333    right_arrays: &[ArrayRef],
1334    null_equals_null: bool,
1335) -> Result<(UInt64Array, UInt32Array)> {
1336    let mut iter = left_arrays.iter().zip(right_arrays.iter());
1337
1338    let (first_left, first_right) = iter.next().ok_or_else(|| {
1339        DataFusionError::Internal(
1340            "At least one array should be provided for both left and right".to_string(),
1341        )
1342    })?;
1343
1344    let arr_left = take(first_left.as_ref(), indices_left, None)?;
1345    let arr_right = take(first_right.as_ref(), indices_right, None)?;
1346
1347    let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equals_null)?;
1348
1349    // Use map and try_fold to iterate over the remaining pairs of arrays.
1350    // In each iteration, take is used on the pair of arrays and their equality is determined.
1351    // The results are then folded (combined) using the and function to get a final equality result.
1352    equal = iter
1353        .map(|(left, right)| {
1354            let arr_left = take(left.as_ref(), indices_left, None)?;
1355            let arr_right = take(right.as_ref(), indices_right, None)?;
1356            eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equals_null)
1357        })
1358        .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?;
1359
1360    let filter_builder = FilterBuilder::new(&equal).optimize().build();
1361
1362    let left_filtered = filter_builder.filter(indices_left)?;
1363    let right_filtered = filter_builder.filter(indices_right)?;
1364
1365    Ok((
1366        downcast_array(left_filtered.as_ref()),
1367        downcast_array(right_filtered.as_ref()),
1368    ))
1369}
1370
1371impl HashJoinStream {
1372    /// Separate implementation function that unpins the [`HashJoinStream`] so
1373    /// that partial borrows work correctly
1374    fn poll_next_impl(
1375        &mut self,
1376        cx: &mut std::task::Context<'_>,
1377    ) -> Poll<Option<Result<RecordBatch>>> {
1378        loop {
1379            return match self.state {
1380                HashJoinStreamState::WaitBuildSide => {
1381                    handle_state!(ready!(self.collect_build_side(cx)))
1382                }
1383                HashJoinStreamState::FetchProbeBatch => {
1384                    handle_state!(ready!(self.fetch_probe_batch(cx)))
1385                }
1386                HashJoinStreamState::ProcessProbeBatch(_) => {
1387                    handle_state!(self.process_probe_batch())
1388                }
1389                HashJoinStreamState::ExhaustedProbeSide => {
1390                    handle_state!(self.process_unmatched_build_batch())
1391                }
1392                HashJoinStreamState::Completed => Poll::Ready(None),
1393            };
1394        }
1395    }
1396
1397    /// Collects build-side data by polling `OnceFut` future from initialized build-side
1398    ///
1399    /// Updates build-side to `Ready`, and state to `FetchProbeSide`
1400    fn collect_build_side(
1401        &mut self,
1402        cx: &mut std::task::Context<'_>,
1403    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1404        let build_timer = self.join_metrics.build_time.timer();
1405        // build hash table from left (build) side, if not yet done
1406        let left_data = ready!(self
1407            .build_side
1408            .try_as_initial_mut()?
1409            .left_fut
1410            .get_shared(cx))?;
1411        build_timer.done();
1412
1413        self.state = HashJoinStreamState::FetchProbeBatch;
1414        self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });
1415
1416        Poll::Ready(Ok(StatefulStreamResult::Continue))
1417    }
1418
1419    /// Fetches next batch from probe-side
1420    ///
1421    /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`,
1422    /// otherwise updates state to `ExhaustedProbeSide`
1423    fn fetch_probe_batch(
1424        &mut self,
1425        cx: &mut std::task::Context<'_>,
1426    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1427        match ready!(self.right.poll_next_unpin(cx)) {
1428            None => {
1429                self.state = HashJoinStreamState::ExhaustedProbeSide;
1430            }
1431            Some(Ok(batch)) => {
1432                // Precalculate hash values for fetched batch
1433                let keys_values = self
1434                    .on_right
1435                    .iter()
1436                    .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows()))
1437                    .collect::<Result<Vec<_>>>()?;
1438
1439                self.hashes_buffer.clear();
1440                self.hashes_buffer.resize(batch.num_rows(), 0);
1441                create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?;
1442
1443                self.join_metrics.input_batches.add(1);
1444                self.join_metrics.input_rows.add(batch.num_rows());
1445
1446                self.state =
1447                    HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
1448                        batch,
1449                        values: keys_values,
1450                        offset: (0, None),
1451                        joined_probe_idx: None,
1452                    });
1453            }
1454            Some(Err(err)) => return Poll::Ready(Err(err)),
1455        };
1456
1457        Poll::Ready(Ok(StatefulStreamResult::Continue))
1458    }
1459
1460    /// Joins current probe batch with build-side data and produces batch with matched output
1461    ///
1462    /// Updates state to `FetchProbeBatch`
1463    fn process_probe_batch(
1464        &mut self,
1465    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1466        let state = self.state.try_as_process_probe_batch_mut()?;
1467        let build_side = self.build_side.try_as_ready_mut()?;
1468
1469        let timer = self.join_metrics.join_time.timer();
1470
1471        // get the matched by join keys indices
1472        let (left_indices, right_indices, next_offset) = lookup_join_hashmap(
1473            build_side.left_data.hash_map(),
1474            build_side.left_data.values(),
1475            &state.values,
1476            self.null_equals_null,
1477            &self.hashes_buffer,
1478            self.batch_size,
1479            state.offset,
1480        )?;
1481
1482        // apply join filter if exists
1483        let (left_indices, right_indices) = if let Some(filter) = &self.filter {
1484            apply_join_filter_to_indices(
1485                build_side.left_data.batch(),
1486                &state.batch,
1487                left_indices,
1488                right_indices,
1489                filter,
1490                JoinSide::Left,
1491            )?
1492        } else {
1493            (left_indices, right_indices)
1494        };
1495
1496        // mark joined left-side indices as visited, if required by join type
1497        if need_produce_result_in_final(self.join_type) {
1498            let mut bitmap = build_side.left_data.visited_indices_bitmap().lock();
1499            left_indices.iter().flatten().for_each(|x| {
1500                bitmap.set_bit(x as usize, true);
1501            });
1502        }
1503
1504        // The goals of index alignment for different join types are:
1505        //
1506        // 1) Right & FullJoin -- to append all missing probe-side indices between
1507        //    previous (excluding) and current joined indices.
1508        // 2) SemiJoin -- deduplicate probe indices in range between previous
1509        //    (excluding) and current joined indices.
1510        // 3) AntiJoin -- return only missing indices in range between
1511        //    previous and current joined indices.
1512        //    Inclusion/exclusion of the indices themselves don't matter
1513        //
1514        // As a summary -- alignment range can be produced based only on
1515        // joined (matched with filters applied) probe side indices, excluding starting one
1516        // (left from previous iteration).
1517
1518        // if any rows have been joined -- get last joined probe-side (right) row
1519        // it's important that index counts as "joined" after hash collisions checks
1520        // and join filters applied.
1521        let last_joined_right_idx = match right_indices.len() {
1522            0 => None,
1523            n => Some(right_indices.value(n - 1) as usize),
1524        };
1525
1526        // Calculate range and perform alignment.
1527        // In case probe batch has been processed -- align all remaining rows.
1528        let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1);
1529        let index_alignment_range_end = if next_offset.is_none() {
1530            state.batch.num_rows()
1531        } else {
1532            last_joined_right_idx.map_or(0, |v| v + 1)
1533        };
1534
1535        let (left_indices, right_indices) = adjust_indices_by_join_type(
1536            left_indices,
1537            right_indices,
1538            index_alignment_range_start..index_alignment_range_end,
1539            self.join_type,
1540            self.right_side_ordered,
1541        )?;
1542
1543        let result = build_batch_from_indices(
1544            &self.schema,
1545            build_side.left_data.batch(),
1546            &state.batch,
1547            &left_indices,
1548            &right_indices,
1549            &self.column_indices,
1550            JoinSide::Left,
1551        )?;
1552
1553        self.join_metrics.output_batches.add(1);
1554        self.join_metrics.output_rows.add(result.num_rows());
1555        timer.done();
1556
1557        if next_offset.is_none() {
1558            self.state = HashJoinStreamState::FetchProbeBatch;
1559        } else {
1560            state.advance(
1561                next_offset
1562                    .ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?,
1563                last_joined_right_idx,
1564            )
1565        };
1566
1567        Ok(StatefulStreamResult::Ready(Some(result)))
1568    }
1569
1570    /// Processes unmatched build-side rows for certain join types and produces output batch
1571    ///
1572    /// Updates state to `Completed`
1573    fn process_unmatched_build_batch(
1574        &mut self,
1575    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1576        let timer = self.join_metrics.join_time.timer();
1577
1578        if !need_produce_result_in_final(self.join_type) {
1579            self.state = HashJoinStreamState::Completed;
1580            return Ok(StatefulStreamResult::Continue);
1581        }
1582
1583        let build_side = self.build_side.try_as_ready()?;
1584        if !build_side.left_data.report_probe_completed() {
1585            self.state = HashJoinStreamState::Completed;
1586            return Ok(StatefulStreamResult::Continue);
1587        }
1588
1589        // use the global left bitmap to produce the left indices and right indices
1590        let (left_side, right_side) = get_final_indices_from_shared_bitmap(
1591            build_side.left_data.visited_indices_bitmap(),
1592            self.join_type,
1593        );
1594        let empty_right_batch = RecordBatch::new_empty(self.right.schema());
1595        // use the left and right indices to produce the batch result
1596        let result = build_batch_from_indices(
1597            &self.schema,
1598            build_side.left_data.batch(),
1599            &empty_right_batch,
1600            &left_side,
1601            &right_side,
1602            &self.column_indices,
1603            JoinSide::Left,
1604        );
1605
1606        if let Ok(ref batch) = result {
1607            self.join_metrics.input_batches.add(1);
1608            self.join_metrics.input_rows.add(batch.num_rows());
1609
1610            self.join_metrics.output_batches.add(1);
1611            self.join_metrics.output_rows.add(batch.num_rows());
1612        }
1613        timer.done();
1614
1615        self.state = HashJoinStreamState::Completed;
1616
1617        Ok(StatefulStreamResult::Ready(Some(result?)))
1618    }
1619}
1620
1621impl Stream for HashJoinStream {
1622    type Item = Result<RecordBatch>;
1623
1624    fn poll_next(
1625        mut self: std::pin::Pin<&mut Self>,
1626        cx: &mut std::task::Context<'_>,
1627    ) -> Poll<Option<Self::Item>> {
1628        self.poll_next_impl(cx)
1629    }
1630}
1631
1632impl EmbeddedProjection for HashJoinExec {
1633    fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
1634        self.with_projection(projection)
1635    }
1636}
1637
1638#[cfg(test)]
1639mod tests {
1640    use super::*;
1641    use crate::test::TestMemoryExec;
1642    use crate::{
1643        common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
1644        test::exec::MockExec,
1645    };
1646
1647    use arrow::array::{Date32Array, Int32Array, StructArray};
1648    use arrow::buffer::NullBuffer;
1649    use arrow::datatypes::{DataType, Field};
1650    use datafusion_common::{
1651        assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err,
1652        ScalarValue,
1653    };
1654    use datafusion_execution::config::SessionConfig;
1655    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1656    use datafusion_expr::Operator;
1657    use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
1658    use datafusion_physical_expr::PhysicalExpr;
1659    use hashbrown::HashTable;
1660    use rstest::*;
1661    use rstest_reuse::*;
1662
1663    fn div_ceil(a: usize, b: usize) -> usize {
1664        a.div_ceil(b)
1665    }
1666
1667    #[template]
1668    #[rstest]
1669    fn batch_sizes(#[values(8192, 10, 5, 2, 1)] batch_size: usize) {}
1670
1671    fn prepare_task_ctx(batch_size: usize) -> Arc<TaskContext> {
1672        let session_config = SessionConfig::default().with_batch_size(batch_size);
1673        Arc::new(TaskContext::default().with_session_config(session_config))
1674    }
1675
1676    fn build_table(
1677        a: (&str, &Vec<i32>),
1678        b: (&str, &Vec<i32>),
1679        c: (&str, &Vec<i32>),
1680    ) -> Arc<dyn ExecutionPlan> {
1681        let batch = build_table_i32(a, b, c);
1682        let schema = batch.schema();
1683        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
1684    }
1685
1686    fn join(
1687        left: Arc<dyn ExecutionPlan>,
1688        right: Arc<dyn ExecutionPlan>,
1689        on: JoinOn,
1690        join_type: &JoinType,
1691        null_equals_null: bool,
1692    ) -> Result<HashJoinExec> {
1693        HashJoinExec::try_new(
1694            left,
1695            right,
1696            on,
1697            None,
1698            join_type,
1699            None,
1700            PartitionMode::CollectLeft,
1701            null_equals_null,
1702        )
1703    }
1704
1705    fn join_with_filter(
1706        left: Arc<dyn ExecutionPlan>,
1707        right: Arc<dyn ExecutionPlan>,
1708        on: JoinOn,
1709        filter: JoinFilter,
1710        join_type: &JoinType,
1711        null_equals_null: bool,
1712    ) -> Result<HashJoinExec> {
1713        HashJoinExec::try_new(
1714            left,
1715            right,
1716            on,
1717            Some(filter),
1718            join_type,
1719            None,
1720            PartitionMode::CollectLeft,
1721            null_equals_null,
1722        )
1723    }
1724
1725    async fn join_collect(
1726        left: Arc<dyn ExecutionPlan>,
1727        right: Arc<dyn ExecutionPlan>,
1728        on: JoinOn,
1729        join_type: &JoinType,
1730        null_equals_null: bool,
1731        context: Arc<TaskContext>,
1732    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
1733        let join = join(left, right, on, join_type, null_equals_null)?;
1734        let columns_header = columns(&join.schema());
1735
1736        let stream = join.execute(0, context)?;
1737        let batches = common::collect(stream).await?;
1738
1739        Ok((columns_header, batches))
1740    }
1741
1742    async fn partitioned_join_collect(
1743        left: Arc<dyn ExecutionPlan>,
1744        right: Arc<dyn ExecutionPlan>,
1745        on: JoinOn,
1746        join_type: &JoinType,
1747        null_equals_null: bool,
1748        context: Arc<TaskContext>,
1749    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
1750        join_collect_with_partition_mode(
1751            left,
1752            right,
1753            on,
1754            join_type,
1755            PartitionMode::Partitioned,
1756            null_equals_null,
1757            context,
1758        )
1759        .await
1760    }
1761
1762    async fn join_collect_with_partition_mode(
1763        left: Arc<dyn ExecutionPlan>,
1764        right: Arc<dyn ExecutionPlan>,
1765        on: JoinOn,
1766        join_type: &JoinType,
1767        partition_mode: PartitionMode,
1768        null_equals_null: bool,
1769        context: Arc<TaskContext>,
1770    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
1771        let partition_count = 4;
1772
1773        let (left_expr, right_expr) = on
1774            .iter()
1775            .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
1776            .unzip();
1777
1778        let left_repartitioned: Arc<dyn ExecutionPlan> = match partition_mode {
1779            PartitionMode::CollectLeft => Arc::new(CoalescePartitionsExec::new(left)),
1780            PartitionMode::Partitioned => Arc::new(RepartitionExec::try_new(
1781                left,
1782                Partitioning::Hash(left_expr, partition_count),
1783            )?),
1784            PartitionMode::Auto => {
1785                return internal_err!("Unexpected PartitionMode::Auto in join tests")
1786            }
1787        };
1788
1789        let right_repartitioned: Arc<dyn ExecutionPlan> = match partition_mode {
1790            PartitionMode::CollectLeft => {
1791                let partition_column_name = right.schema().field(0).name().clone();
1792                let partition_expr = vec![Arc::new(Column::new_with_schema(
1793                    &partition_column_name,
1794                    &right.schema(),
1795                )?) as _];
1796                Arc::new(RepartitionExec::try_new(
1797                    right,
1798                    Partitioning::Hash(partition_expr, partition_count),
1799                )?) as _
1800            }
1801            PartitionMode::Partitioned => Arc::new(RepartitionExec::try_new(
1802                right,
1803                Partitioning::Hash(right_expr, partition_count),
1804            )?),
1805            PartitionMode::Auto => {
1806                return internal_err!("Unexpected PartitionMode::Auto in join tests")
1807            }
1808        };
1809
1810        let join = HashJoinExec::try_new(
1811            left_repartitioned,
1812            right_repartitioned,
1813            on,
1814            None,
1815            join_type,
1816            None,
1817            partition_mode,
1818            null_equals_null,
1819        )?;
1820
1821        let columns = columns(&join.schema());
1822
1823        let mut batches = vec![];
1824        for i in 0..partition_count {
1825            let stream = join.execute(i, Arc::clone(&context))?;
1826            let more_batches = common::collect(stream).await?;
1827            batches.extend(
1828                more_batches
1829                    .into_iter()
1830                    .filter(|b| b.num_rows() > 0)
1831                    .collect::<Vec<_>>(),
1832            );
1833        }
1834
1835        Ok((columns, batches))
1836    }
1837
1838    #[apply(batch_sizes)]
1839    #[tokio::test]
1840    async fn join_inner_one(batch_size: usize) -> Result<()> {
1841        let task_ctx = prepare_task_ctx(batch_size);
1842        let left = build_table(
1843            ("a1", &vec![1, 2, 3]),
1844            ("b1", &vec![4, 5, 5]), // this has a repetition
1845            ("c1", &vec![7, 8, 9]),
1846        );
1847        let right = build_table(
1848            ("a2", &vec![10, 20, 30]),
1849            ("b1", &vec![4, 5, 6]),
1850            ("c2", &vec![70, 80, 90]),
1851        );
1852
1853        let on = vec![(
1854            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1855            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1856        )];
1857
1858        let (columns, batches) = join_collect(
1859            Arc::clone(&left),
1860            Arc::clone(&right),
1861            on.clone(),
1862            &JoinType::Inner,
1863            false,
1864            task_ctx,
1865        )
1866        .await?;
1867
1868        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
1869
1870        let expected = [
1871            "+----+----+----+----+----+----+",
1872            "| a1 | b1 | c1 | a2 | b1 | c2 |",
1873            "+----+----+----+----+----+----+",
1874            "| 1  | 4  | 7  | 10 | 4  | 70 |",
1875            "| 2  | 5  | 8  | 20 | 5  | 80 |",
1876            "| 3  | 5  | 9  | 20 | 5  | 80 |",
1877            "+----+----+----+----+----+----+",
1878        ];
1879
1880        // Inner join output is expected to preserve both inputs order
1881        assert_batches_eq!(expected, &batches);
1882
1883        Ok(())
1884    }
1885
1886    #[apply(batch_sizes)]
1887    #[tokio::test]
1888    async fn partitioned_join_inner_one(batch_size: usize) -> Result<()> {
1889        let task_ctx = prepare_task_ctx(batch_size);
1890        let left = build_table(
1891            ("a1", &vec![1, 2, 3]),
1892            ("b1", &vec![4, 5, 5]), // this has a repetition
1893            ("c1", &vec![7, 8, 9]),
1894        );
1895        let right = build_table(
1896            ("a2", &vec![10, 20, 30]),
1897            ("b1", &vec![4, 5, 6]),
1898            ("c2", &vec![70, 80, 90]),
1899        );
1900        let on = vec![(
1901            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1902            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1903        )];
1904
1905        let (columns, batches) = partitioned_join_collect(
1906            Arc::clone(&left),
1907            Arc::clone(&right),
1908            on.clone(),
1909            &JoinType::Inner,
1910            false,
1911            task_ctx,
1912        )
1913        .await?;
1914
1915        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
1916
1917        let expected = [
1918            "+----+----+----+----+----+----+",
1919            "| a1 | b1 | c1 | a2 | b1 | c2 |",
1920            "+----+----+----+----+----+----+",
1921            "| 1  | 4  | 7  | 10 | 4  | 70 |",
1922            "| 2  | 5  | 8  | 20 | 5  | 80 |",
1923            "| 3  | 5  | 9  | 20 | 5  | 80 |",
1924            "+----+----+----+----+----+----+",
1925        ];
1926        assert_batches_sorted_eq!(expected, &batches);
1927
1928        Ok(())
1929    }
1930
1931    #[tokio::test]
1932    async fn join_inner_one_no_shared_column_names() -> Result<()> {
1933        let task_ctx = Arc::new(TaskContext::default());
1934        let left = build_table(
1935            ("a1", &vec![1, 2, 3]),
1936            ("b1", &vec![4, 5, 5]), // this has a repetition
1937            ("c1", &vec![7, 8, 9]),
1938        );
1939        let right = build_table(
1940            ("a2", &vec![10, 20, 30]),
1941            ("b2", &vec![4, 5, 6]),
1942            ("c2", &vec![70, 80, 90]),
1943        );
1944        let on = vec![(
1945            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1946            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
1947        )];
1948
1949        let (columns, batches) =
1950            join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;
1951
1952        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
1953
1954        let expected = [
1955            "+----+----+----+----+----+----+",
1956            "| a1 | b1 | c1 | a2 | b2 | c2 |",
1957            "+----+----+----+----+----+----+",
1958            "| 1  | 4  | 7  | 10 | 4  | 70 |",
1959            "| 2  | 5  | 8  | 20 | 5  | 80 |",
1960            "| 3  | 5  | 9  | 20 | 5  | 80 |",
1961            "+----+----+----+----+----+----+",
1962        ];
1963
1964        // Inner join output is expected to preserve both inputs order
1965        assert_batches_eq!(expected, &batches);
1966
1967        Ok(())
1968    }
1969
1970    #[tokio::test]
1971    async fn join_inner_one_randomly_ordered() -> Result<()> {
1972        let task_ctx = Arc::new(TaskContext::default());
1973        let left = build_table(
1974            ("a1", &vec![0, 3, 2, 1]),
1975            ("b1", &vec![4, 5, 5, 4]),
1976            ("c1", &vec![6, 9, 8, 7]),
1977        );
1978        let right = build_table(
1979            ("a2", &vec![20, 30, 10]),
1980            ("b2", &vec![5, 6, 4]),
1981            ("c2", &vec![80, 90, 70]),
1982        );
1983        let on = vec![(
1984            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1985            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
1986        )];
1987
1988        let (columns, batches) =
1989            join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;
1990
1991        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
1992
1993        let expected = [
1994            "+----+----+----+----+----+----+",
1995            "| a1 | b1 | c1 | a2 | b2 | c2 |",
1996            "+----+----+----+----+----+----+",
1997            "| 3  | 5  | 9  | 20 | 5  | 80 |",
1998            "| 2  | 5  | 8  | 20 | 5  | 80 |",
1999            "| 0  | 4  | 6  | 10 | 4  | 70 |",
2000            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2001            "+----+----+----+----+----+----+",
2002        ];
2003
2004        // Inner join output is expected to preserve both inputs order
2005        assert_batches_eq!(expected, &batches);
2006
2007        Ok(())
2008    }
2009
2010    #[apply(batch_sizes)]
2011    #[tokio::test]
2012    async fn join_inner_two(batch_size: usize) -> Result<()> {
2013        let task_ctx = prepare_task_ctx(batch_size);
2014        let left = build_table(
2015            ("a1", &vec![1, 2, 2]),
2016            ("b2", &vec![1, 2, 2]),
2017            ("c1", &vec![7, 8, 9]),
2018        );
2019        let right = build_table(
2020            ("a1", &vec![1, 2, 3]),
2021            ("b2", &vec![1, 2, 2]),
2022            ("c2", &vec![70, 80, 90]),
2023        );
2024        let on = vec![
2025            (
2026                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2027                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2028            ),
2029            (
2030                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2031                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2032            ),
2033        ];
2034
2035        let (columns, batches) =
2036            join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;
2037
2038        assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
2039
2040        let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) {
2041            // Expected number of hash table matches = 3
2042            // in case batch_size is 1 - additional empty batch for remaining 3-2 row
2043            let mut expected_batch_count = div_ceil(3, batch_size);
2044            if batch_size == 1 {
2045                expected_batch_count += 1;
2046            }
2047            expected_batch_count
2048        } else {
2049            // With hash collisions enabled, all records will match each other
2050            // and filtered later.
2051            div_ceil(9, batch_size)
2052        };
2053
2054        assert_eq!(batches.len(), expected_batch_count);
2055
2056        let expected = [
2057            "+----+----+----+----+----+----+",
2058            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2059            "+----+----+----+----+----+----+",
2060            "| 1  | 1  | 7  | 1  | 1  | 70 |",
2061            "| 2  | 2  | 8  | 2  | 2  | 80 |",
2062            "| 2  | 2  | 9  | 2  | 2  | 80 |",
2063            "+----+----+----+----+----+----+",
2064        ];
2065
2066        // Inner join output is expected to preserve both inputs order
2067        assert_batches_eq!(expected, &batches);
2068
2069        Ok(())
2070    }
2071
2072    /// Test where the left has 2 parts, the right with 1 part => 1 part
2073    #[apply(batch_sizes)]
2074    #[tokio::test]
2075    async fn join_inner_one_two_parts_left(batch_size: usize) -> Result<()> {
2076        let task_ctx = prepare_task_ctx(batch_size);
2077        let batch1 = build_table_i32(
2078            ("a1", &vec![1, 2]),
2079            ("b2", &vec![1, 2]),
2080            ("c1", &vec![7, 8]),
2081        );
2082        let batch2 =
2083            build_table_i32(("a1", &vec![2]), ("b2", &vec![2]), ("c1", &vec![9]));
2084        let schema = batch1.schema();
2085        let left =
2086            TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None)
2087                .unwrap();
2088
2089        let right = build_table(
2090            ("a1", &vec![1, 2, 3]),
2091            ("b2", &vec![1, 2, 2]),
2092            ("c2", &vec![70, 80, 90]),
2093        );
2094        let on = vec![
2095            (
2096                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2097                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2098            ),
2099            (
2100                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
2101                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2102            ),
2103        ];
2104
2105        let (columns, batches) =
2106            join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;
2107
2108        assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
2109
2110        let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) {
2111            // Expected number of hash table matches = 3
2112            // in case batch_size is 1 - additional empty batch for remaining 3-2 row
2113            let mut expected_batch_count = div_ceil(3, batch_size);
2114            if batch_size == 1 {
2115                expected_batch_count += 1;
2116            }
2117            expected_batch_count
2118        } else {
2119            // With hash collisions enabled, all records will match each other
2120            // and filtered later.
2121            div_ceil(9, batch_size)
2122        };
2123
2124        assert_eq!(batches.len(), expected_batch_count);
2125
2126        let expected = [
2127            "+----+----+----+----+----+----+",
2128            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2129            "+----+----+----+----+----+----+",
2130            "| 1  | 1  | 7  | 1  | 1  | 70 |",
2131            "| 2  | 2  | 8  | 2  | 2  | 80 |",
2132            "| 2  | 2  | 9  | 2  | 2  | 80 |",
2133            "+----+----+----+----+----+----+",
2134        ];
2135
2136        // Inner join output is expected to preserve both inputs order
2137        assert_batches_eq!(expected, &batches);
2138
2139        Ok(())
2140    }
2141
2142    #[tokio::test]
2143    async fn join_inner_one_two_parts_left_randomly_ordered() -> Result<()> {
2144        let task_ctx = Arc::new(TaskContext::default());
2145        let batch1 = build_table_i32(
2146            ("a1", &vec![0, 3]),
2147            ("b1", &vec![4, 5]),
2148            ("c1", &vec![6, 9]),
2149        );
2150        let batch2 = build_table_i32(
2151            ("a1", &vec![2, 1]),
2152            ("b1", &vec![5, 4]),
2153            ("c1", &vec![8, 7]),
2154        );
2155        let schema = batch1.schema();
2156
2157        let left =
2158            TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None)
2159                .unwrap();
2160        let right = build_table(
2161            ("a2", &vec![20, 30, 10]),
2162            ("b2", &vec![5, 6, 4]),
2163            ("c2", &vec![80, 90, 70]),
2164        );
2165        let on = vec![(
2166            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2167            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2168        )];
2169
2170        let (columns, batches) =
2171            join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;
2172
2173        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2174
2175        let expected = [
2176            "+----+----+----+----+----+----+",
2177            "| a1 | b1 | c1 | a2 | b2 | c2 |",
2178            "+----+----+----+----+----+----+",
2179            "| 3  | 5  | 9  | 20 | 5  | 80 |",
2180            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2181            "| 0  | 4  | 6  | 10 | 4  | 70 |",
2182            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2183            "+----+----+----+----+----+----+",
2184        ];
2185
2186        // Inner join output is expected to preserve both inputs order
2187        assert_batches_eq!(expected, &batches);
2188
2189        Ok(())
2190    }
2191
2192    /// Test where the left has 1 part, the right has 2 parts => 2 parts
2193    #[apply(batch_sizes)]
2194    #[tokio::test]
2195    async fn join_inner_one_two_parts_right(batch_size: usize) -> Result<()> {
2196        let task_ctx = prepare_task_ctx(batch_size);
2197        let left = build_table(
2198            ("a1", &vec![1, 2, 3]),
2199            ("b1", &vec![4, 5, 5]), // this has a repetition
2200            ("c1", &vec![7, 8, 9]),
2201        );
2202
2203        let batch1 = build_table_i32(
2204            ("a2", &vec![10, 20]),
2205            ("b1", &vec![4, 6]),
2206            ("c2", &vec![70, 80]),
2207        );
2208        let batch2 =
2209            build_table_i32(("a2", &vec![30]), ("b1", &vec![5]), ("c2", &vec![90]));
2210        let schema = batch1.schema();
2211        let right =
2212            TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None)
2213                .unwrap();
2214
2215        let on = vec![(
2216            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2217            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
2218        )];
2219
2220        let join = join(left, right, on, &JoinType::Inner, false)?;
2221
2222        let columns = columns(&join.schema());
2223        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
2224
2225        // first part
2226        let stream = join.execute(0, Arc::clone(&task_ctx))?;
2227        let batches = common::collect(stream).await?;
2228
2229        let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) {
2230            // Expected number of hash table matches for first right batch = 1
2231            // and additional empty batch for non-joined 20-6-80
2232            let mut expected_batch_count = div_ceil(1, batch_size);
2233            if batch_size == 1 {
2234                expected_batch_count += 1;
2235            }
2236            expected_batch_count
2237        } else {
2238            // With hash collisions enabled, all records will match each other
2239            // and filtered later.
2240            div_ceil(6, batch_size)
2241        };
2242        assert_eq!(batches.len(), expected_batch_count);
2243
2244        let expected = [
2245            "+----+----+----+----+----+----+",
2246            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2247            "+----+----+----+----+----+----+",
2248            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2249            "+----+----+----+----+----+----+",
2250        ];
2251
2252        // Inner join output is expected to preserve both inputs order
2253        assert_batches_eq!(expected, &batches);
2254
2255        // second part
2256        let stream = join.execute(1, Arc::clone(&task_ctx))?;
2257        let batches = common::collect(stream).await?;
2258
2259        let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) {
2260            // Expected number of hash table matches for second right batch = 2
2261            div_ceil(2, batch_size)
2262        } else {
2263            // With hash collisions enabled, all records will match each other
2264            // and filtered later.
2265            div_ceil(3, batch_size)
2266        };
2267        assert_eq!(batches.len(), expected_batch_count);
2268
2269        let expected = [
2270            "+----+----+----+----+----+----+",
2271            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2272            "+----+----+----+----+----+----+",
2273            "| 2  | 5  | 8  | 30 | 5  | 90 |",
2274            "| 3  | 5  | 9  | 30 | 5  | 90 |",
2275            "+----+----+----+----+----+----+",
2276        ];
2277
2278        // Inner join output is expected to preserve both inputs order
2279        assert_batches_eq!(expected, &batches);
2280
2281        Ok(())
2282    }
2283
2284    fn build_table_two_batches(
2285        a: (&str, &Vec<i32>),
2286        b: (&str, &Vec<i32>),
2287        c: (&str, &Vec<i32>),
2288    ) -> Arc<dyn ExecutionPlan> {
2289        let batch = build_table_i32(a, b, c);
2290        let schema = batch.schema();
2291        TestMemoryExec::try_new_exec(&[vec![batch.clone(), batch]], schema, None).unwrap()
2292    }
2293
2294    #[apply(batch_sizes)]
2295    #[tokio::test]
2296    async fn join_left_multi_batch(batch_size: usize) {
2297        let task_ctx = prepare_task_ctx(batch_size);
2298        let left = build_table(
2299            ("a1", &vec![1, 2, 3]),
2300            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
2301            ("c1", &vec![7, 8, 9]),
2302        );
2303        let right = build_table_two_batches(
2304            ("a2", &vec![10, 20, 30]),
2305            ("b1", &vec![4, 5, 6]),
2306            ("c2", &vec![70, 80, 90]),
2307        );
2308        let on = vec![(
2309            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
2310            Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _,
2311        )];
2312
2313        let join = join(left, right, on, &JoinType::Left, false).unwrap();
2314
2315        let columns = columns(&join.schema());
2316        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
2317
2318        let stream = join.execute(0, task_ctx).unwrap();
2319        let batches = common::collect(stream).await.unwrap();
2320
2321        let expected = [
2322            "+----+----+----+----+----+----+",
2323            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2324            "+----+----+----+----+----+----+",
2325            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2326            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2327            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2328            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2329            "| 3  | 7  | 9  |    |    |    |",
2330            "+----+----+----+----+----+----+",
2331        ];
2332
2333        assert_batches_sorted_eq!(expected, &batches);
2334    }
2335
2336    #[apply(batch_sizes)]
2337    #[tokio::test]
2338    async fn join_full_multi_batch(batch_size: usize) {
2339        let task_ctx = prepare_task_ctx(batch_size);
2340        let left = build_table(
2341            ("a1", &vec![1, 2, 3]),
2342            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
2343            ("c1", &vec![7, 8, 9]),
2344        );
2345        // create two identical batches for the right side
2346        let right = build_table_two_batches(
2347            ("a2", &vec![10, 20, 30]),
2348            ("b2", &vec![4, 5, 6]),
2349            ("c2", &vec![70, 80, 90]),
2350        );
2351        let on = vec![(
2352            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
2353            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
2354        )];
2355
2356        let join = join(left, right, on, &JoinType::Full, false).unwrap();
2357
2358        let columns = columns(&join.schema());
2359        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2360
2361        let stream = join.execute(0, task_ctx).unwrap();
2362        let batches = common::collect(stream).await.unwrap();
2363
2364        let expected = [
2365            "+----+----+----+----+----+----+",
2366            "| a1 | b1 | c1 | a2 | b2 | c2 |",
2367            "+----+----+----+----+----+----+",
2368            "|    |    |    | 30 | 6  | 90 |",
2369            "|    |    |    | 30 | 6  | 90 |",
2370            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2371            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2372            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2373            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2374            "| 3  | 7  | 9  |    |    |    |",
2375            "+----+----+----+----+----+----+",
2376        ];
2377
2378        assert_batches_sorted_eq!(expected, &batches);
2379    }
2380
2381    #[apply(batch_sizes)]
2382    #[tokio::test]
2383    async fn join_left_empty_right(batch_size: usize) {
2384        let task_ctx = prepare_task_ctx(batch_size);
2385        let left = build_table(
2386            ("a1", &vec![1, 2, 3]),
2387            ("b1", &vec![4, 5, 7]),
2388            ("c1", &vec![7, 8, 9]),
2389        );
2390        let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![]));
2391        let on = vec![(
2392            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
2393            Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _,
2394        )];
2395        let schema = right.schema();
2396        let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, None).unwrap();
2397        let join = join(left, right, on, &JoinType::Left, false).unwrap();
2398
2399        let columns = columns(&join.schema());
2400        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
2401
2402        let stream = join.execute(0, task_ctx).unwrap();
2403        let batches = common::collect(stream).await.unwrap();
2404
2405        let expected = [
2406            "+----+----+----+----+----+----+",
2407            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2408            "+----+----+----+----+----+----+",
2409            "| 1  | 4  | 7  |    |    |    |",
2410            "| 2  | 5  | 8  |    |    |    |",
2411            "| 3  | 7  | 9  |    |    |    |",
2412            "+----+----+----+----+----+----+",
2413        ];
2414
2415        assert_batches_sorted_eq!(expected, &batches);
2416    }
2417
2418    #[apply(batch_sizes)]
2419    #[tokio::test]
2420    async fn join_full_empty_right(batch_size: usize) {
2421        let task_ctx = prepare_task_ctx(batch_size);
2422        let left = build_table(
2423            ("a1", &vec![1, 2, 3]),
2424            ("b1", &vec![4, 5, 7]),
2425            ("c1", &vec![7, 8, 9]),
2426        );
2427        let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![]));
2428        let on = vec![(
2429            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
2430            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
2431        )];
2432        let schema = right.schema();
2433        let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, None).unwrap();
2434        let join = join(left, right, on, &JoinType::Full, false).unwrap();
2435
2436        let columns = columns(&join.schema());
2437        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2438
2439        let stream = join.execute(0, task_ctx).unwrap();
2440        let batches = common::collect(stream).await.unwrap();
2441
2442        let expected = [
2443            "+----+----+----+----+----+----+",
2444            "| a1 | b1 | c1 | a2 | b2 | c2 |",
2445            "+----+----+----+----+----+----+",
2446            "| 1  | 4  | 7  |    |    |    |",
2447            "| 2  | 5  | 8  |    |    |    |",
2448            "| 3  | 7  | 9  |    |    |    |",
2449            "+----+----+----+----+----+----+",
2450        ];
2451
2452        assert_batches_sorted_eq!(expected, &batches);
2453    }
2454
2455    #[apply(batch_sizes)]
2456    #[tokio::test]
2457    async fn join_left_one(batch_size: usize) -> Result<()> {
2458        let task_ctx = prepare_task_ctx(batch_size);
2459        let left = build_table(
2460            ("a1", &vec![1, 2, 3]),
2461            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
2462            ("c1", &vec![7, 8, 9]),
2463        );
2464        let right = build_table(
2465            ("a2", &vec![10, 20, 30]),
2466            ("b1", &vec![4, 5, 6]),
2467            ("c2", &vec![70, 80, 90]),
2468        );
2469        let on = vec![(
2470            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2471            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
2472        )];
2473
2474        let (columns, batches) = join_collect(
2475            Arc::clone(&left),
2476            Arc::clone(&right),
2477            on.clone(),
2478            &JoinType::Left,
2479            false,
2480            task_ctx,
2481        )
2482        .await?;
2483        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
2484
2485        let expected = [
2486            "+----+----+----+----+----+----+",
2487            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2488            "+----+----+----+----+----+----+",
2489            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2490            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2491            "| 3  | 7  | 9  |    |    |    |",
2492            "+----+----+----+----+----+----+",
2493        ];
2494        assert_batches_sorted_eq!(expected, &batches);
2495
2496        Ok(())
2497    }
2498
2499    #[apply(batch_sizes)]
2500    #[tokio::test]
2501    async fn partitioned_join_left_one(batch_size: usize) -> Result<()> {
2502        let task_ctx = prepare_task_ctx(batch_size);
2503        let left = build_table(
2504            ("a1", &vec![1, 2, 3]),
2505            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
2506            ("c1", &vec![7, 8, 9]),
2507        );
2508        let right = build_table(
2509            ("a2", &vec![10, 20, 30]),
2510            ("b1", &vec![4, 5, 6]),
2511            ("c2", &vec![70, 80, 90]),
2512        );
2513        let on = vec![(
2514            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2515            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
2516        )];
2517
2518        let (columns, batches) = partitioned_join_collect(
2519            Arc::clone(&left),
2520            Arc::clone(&right),
2521            on.clone(),
2522            &JoinType::Left,
2523            false,
2524            task_ctx,
2525        )
2526        .await?;
2527        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
2528
2529        let expected = [
2530            "+----+----+----+----+----+----+",
2531            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2532            "+----+----+----+----+----+----+",
2533            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2534            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2535            "| 3  | 7  | 9  |    |    |    |",
2536            "+----+----+----+----+----+----+",
2537        ];
2538        assert_batches_sorted_eq!(expected, &batches);
2539
2540        Ok(())
2541    }
2542
2543    fn build_semi_anti_left_table() -> Arc<dyn ExecutionPlan> {
2544        // just two line match
2545        // b1 = 10
2546        build_table(
2547            ("a1", &vec![1, 3, 5, 7, 9, 11, 13]),
2548            ("b1", &vec![1, 3, 5, 7, 8, 8, 10]),
2549            ("c1", &vec![10, 30, 50, 70, 90, 110, 130]),
2550        )
2551    }
2552
2553    fn build_semi_anti_right_table() -> Arc<dyn ExecutionPlan> {
2554        // just two line match
2555        // b2 = 10
2556        build_table(
2557            ("a2", &vec![8, 12, 6, 2, 10, 4]),
2558            ("b2", &vec![8, 10, 6, 2, 10, 4]),
2559            ("c2", &vec![20, 40, 60, 80, 100, 120]),
2560        )
2561    }
2562
2563    #[apply(batch_sizes)]
2564    #[tokio::test]
2565    async fn join_left_semi(batch_size: usize) -> Result<()> {
2566        let task_ctx = prepare_task_ctx(batch_size);
2567        let left = build_semi_anti_left_table();
2568        let right = build_semi_anti_right_table();
2569        // left_table left semi join right_table on left_table.b1 = right_table.b2
2570        let on = vec![(
2571            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2572            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2573        )];
2574
2575        let join = join(left, right, on, &JoinType::LeftSemi, false)?;
2576
2577        let columns = columns(&join.schema());
2578        assert_eq!(columns, vec!["a1", "b1", "c1"]);
2579
2580        let stream = join.execute(0, task_ctx)?;
2581        let batches = common::collect(stream).await?;
2582
2583        // ignore the order
2584        let expected = [
2585            "+----+----+-----+",
2586            "| a1 | b1 | c1  |",
2587            "+----+----+-----+",
2588            "| 11 | 8  | 110 |",
2589            "| 13 | 10 | 130 |",
2590            "| 9  | 8  | 90  |",
2591            "+----+----+-----+",
2592        ];
2593        assert_batches_sorted_eq!(expected, &batches);
2594
2595        Ok(())
2596    }
2597
2598    #[apply(batch_sizes)]
2599    #[tokio::test]
2600    async fn join_left_semi_with_filter(batch_size: usize) -> Result<()> {
2601        let task_ctx = prepare_task_ctx(batch_size);
2602        let left = build_semi_anti_left_table();
2603        let right = build_semi_anti_right_table();
2604
2605        // left_table left semi join right_table on left_table.b1 = right_table.b2 and right_table.a2 != 10
2606        let on = vec![(
2607            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2608            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2609        )];
2610
2611        let column_indices = vec![ColumnIndex {
2612            index: 0,
2613            side: JoinSide::Right,
2614        }];
2615        let intermediate_schema =
2616            Schema::new(vec![Field::new("x", DataType::Int32, true)]);
2617
2618        let filter_expression = Arc::new(BinaryExpr::new(
2619            Arc::new(Column::new("x", 0)),
2620            Operator::NotEq,
2621            Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2622        )) as Arc<dyn PhysicalExpr>;
2623
2624        let filter = JoinFilter::new(
2625            filter_expression,
2626            column_indices.clone(),
2627            Arc::new(intermediate_schema.clone()),
2628        );
2629
2630        let join = join_with_filter(
2631            Arc::clone(&left),
2632            Arc::clone(&right),
2633            on.clone(),
2634            filter,
2635            &JoinType::LeftSemi,
2636            false,
2637        )?;
2638
2639        let columns_header = columns(&join.schema());
2640        assert_eq!(columns_header.clone(), vec!["a1", "b1", "c1"]);
2641
2642        let stream = join.execute(0, Arc::clone(&task_ctx))?;
2643        let batches = common::collect(stream).await?;
2644
2645        let expected = [
2646            "+----+----+-----+",
2647            "| a1 | b1 | c1  |",
2648            "+----+----+-----+",
2649            "| 11 | 8  | 110 |",
2650            "| 13 | 10 | 130 |",
2651            "| 9  | 8  | 90  |",
2652            "+----+----+-----+",
2653        ];
2654        assert_batches_sorted_eq!(expected, &batches);
2655
2656        // left_table left semi join right_table on left_table.b1 = right_table.b2 and right_table.a2 > 10
2657        let filter_expression = Arc::new(BinaryExpr::new(
2658            Arc::new(Column::new("x", 0)),
2659            Operator::Gt,
2660            Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2661        )) as Arc<dyn PhysicalExpr>;
2662        let filter = JoinFilter::new(
2663            filter_expression,
2664            column_indices,
2665            Arc::new(intermediate_schema),
2666        );
2667
2668        let join = join_with_filter(left, right, on, filter, &JoinType::LeftSemi, false)?;
2669
2670        let columns_header = columns(&join.schema());
2671        assert_eq!(columns_header, vec!["a1", "b1", "c1"]);
2672
2673        let stream = join.execute(0, task_ctx)?;
2674        let batches = common::collect(stream).await?;
2675
2676        let expected = [
2677            "+----+----+-----+",
2678            "| a1 | b1 | c1  |",
2679            "+----+----+-----+",
2680            "| 13 | 10 | 130 |",
2681            "+----+----+-----+",
2682        ];
2683        assert_batches_sorted_eq!(expected, &batches);
2684
2685        Ok(())
2686    }
2687
2688    #[apply(batch_sizes)]
2689    #[tokio::test]
2690    async fn join_right_semi(batch_size: usize) -> Result<()> {
2691        let task_ctx = prepare_task_ctx(batch_size);
2692        let left = build_semi_anti_left_table();
2693        let right = build_semi_anti_right_table();
2694
2695        // left_table right semi join right_table on left_table.b1 = right_table.b2
2696        let on = vec![(
2697            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2698            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2699        )];
2700
2701        let join = join(left, right, on, &JoinType::RightSemi, false)?;
2702
2703        let columns = columns(&join.schema());
2704        assert_eq!(columns, vec!["a2", "b2", "c2"]);
2705
2706        let stream = join.execute(0, task_ctx)?;
2707        let batches = common::collect(stream).await?;
2708
2709        let expected = [
2710            "+----+----+-----+",
2711            "| a2 | b2 | c2  |",
2712            "+----+----+-----+",
2713            "| 8  | 8  | 20  |",
2714            "| 12 | 10 | 40  |",
2715            "| 10 | 10 | 100 |",
2716            "+----+----+-----+",
2717        ];
2718
2719        // RightSemi join output is expected to preserve right input order
2720        assert_batches_eq!(expected, &batches);
2721
2722        Ok(())
2723    }
2724
2725    #[apply(batch_sizes)]
2726    #[tokio::test]
2727    async fn join_right_semi_with_filter(batch_size: usize) -> Result<()> {
2728        let task_ctx = prepare_task_ctx(batch_size);
2729        let left = build_semi_anti_left_table();
2730        let right = build_semi_anti_right_table();
2731
2732        // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9
2733        let on = vec![(
2734            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2735            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2736        )];
2737
2738        let column_indices = vec![ColumnIndex {
2739            index: 0,
2740            side: JoinSide::Left,
2741        }];
2742        let intermediate_schema =
2743            Schema::new(vec![Field::new("x", DataType::Int32, true)]);
2744
2745        let filter_expression = Arc::new(BinaryExpr::new(
2746            Arc::new(Column::new("x", 0)),
2747            Operator::NotEq,
2748            Arc::new(Literal::new(ScalarValue::Int32(Some(9)))),
2749        )) as Arc<dyn PhysicalExpr>;
2750
2751        let filter = JoinFilter::new(
2752            filter_expression,
2753            column_indices.clone(),
2754            Arc::new(intermediate_schema.clone()),
2755        );
2756
2757        let join = join_with_filter(
2758            Arc::clone(&left),
2759            Arc::clone(&right),
2760            on.clone(),
2761            filter,
2762            &JoinType::RightSemi,
2763            false,
2764        )?;
2765
2766        let columns = columns(&join.schema());
2767        assert_eq!(columns, vec!["a2", "b2", "c2"]);
2768
2769        let stream = join.execute(0, Arc::clone(&task_ctx))?;
2770        let batches = common::collect(stream).await?;
2771
2772        let expected = [
2773            "+----+----+-----+",
2774            "| a2 | b2 | c2  |",
2775            "+----+----+-----+",
2776            "| 8  | 8  | 20  |",
2777            "| 12 | 10 | 40  |",
2778            "| 10 | 10 | 100 |",
2779            "+----+----+-----+",
2780        ];
2781
2782        // RightSemi join output is expected to preserve right input order
2783        assert_batches_eq!(expected, &batches);
2784
2785        // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9
2786        let filter_expression = Arc::new(BinaryExpr::new(
2787            Arc::new(Column::new("x", 0)),
2788            Operator::Gt,
2789            Arc::new(Literal::new(ScalarValue::Int32(Some(11)))),
2790        )) as Arc<dyn PhysicalExpr>;
2791
2792        let filter = JoinFilter::new(
2793            filter_expression,
2794            column_indices,
2795            Arc::new(intermediate_schema.clone()),
2796        );
2797
2798        let join =
2799            join_with_filter(left, right, on, filter, &JoinType::RightSemi, false)?;
2800        let stream = join.execute(0, task_ctx)?;
2801        let batches = common::collect(stream).await?;
2802
2803        let expected = [
2804            "+----+----+-----+",
2805            "| a2 | b2 | c2  |",
2806            "+----+----+-----+",
2807            "| 12 | 10 | 40  |",
2808            "| 10 | 10 | 100 |",
2809            "+----+----+-----+",
2810        ];
2811
2812        // RightSemi join output is expected to preserve right input order
2813        assert_batches_eq!(expected, &batches);
2814
2815        Ok(())
2816    }
2817
2818    #[apply(batch_sizes)]
2819    #[tokio::test]
2820    async fn join_left_anti(batch_size: usize) -> Result<()> {
2821        let task_ctx = prepare_task_ctx(batch_size);
2822        let left = build_semi_anti_left_table();
2823        let right = build_semi_anti_right_table();
2824        // left_table left anti join right_table on left_table.b1 = right_table.b2
2825        let on = vec![(
2826            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2827            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2828        )];
2829
2830        let join = join(left, right, on, &JoinType::LeftAnti, false)?;
2831
2832        let columns = columns(&join.schema());
2833        assert_eq!(columns, vec!["a1", "b1", "c1"]);
2834
2835        let stream = join.execute(0, task_ctx)?;
2836        let batches = common::collect(stream).await?;
2837
2838        let expected = [
2839            "+----+----+----+",
2840            "| a1 | b1 | c1 |",
2841            "+----+----+----+",
2842            "| 1  | 1  | 10 |",
2843            "| 3  | 3  | 30 |",
2844            "| 5  | 5  | 50 |",
2845            "| 7  | 7  | 70 |",
2846            "+----+----+----+",
2847        ];
2848        assert_batches_sorted_eq!(expected, &batches);
2849        Ok(())
2850    }
2851
2852    #[apply(batch_sizes)]
2853    #[tokio::test]
2854    async fn join_left_anti_with_filter(batch_size: usize) -> Result<()> {
2855        let task_ctx = prepare_task_ctx(batch_size);
2856        let left = build_semi_anti_left_table();
2857        let right = build_semi_anti_right_table();
2858        // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2!=8
2859        let on = vec![(
2860            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2861            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2862        )];
2863
2864        let column_indices = vec![ColumnIndex {
2865            index: 0,
2866            side: JoinSide::Right,
2867        }];
2868        let intermediate_schema =
2869            Schema::new(vec![Field::new("x", DataType::Int32, true)]);
2870        let filter_expression = Arc::new(BinaryExpr::new(
2871            Arc::new(Column::new("x", 0)),
2872            Operator::NotEq,
2873            Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
2874        )) as Arc<dyn PhysicalExpr>;
2875
2876        let filter = JoinFilter::new(
2877            filter_expression,
2878            column_indices.clone(),
2879            Arc::new(intermediate_schema.clone()),
2880        );
2881
2882        let join = join_with_filter(
2883            Arc::clone(&left),
2884            Arc::clone(&right),
2885            on.clone(),
2886            filter,
2887            &JoinType::LeftAnti,
2888            false,
2889        )?;
2890
2891        let columns_header = columns(&join.schema());
2892        assert_eq!(columns_header, vec!["a1", "b1", "c1"]);
2893
2894        let stream = join.execute(0, Arc::clone(&task_ctx))?;
2895        let batches = common::collect(stream).await?;
2896
2897        let expected = [
2898            "+----+----+-----+",
2899            "| a1 | b1 | c1  |",
2900            "+----+----+-----+",
2901            "| 1  | 1  | 10  |",
2902            "| 11 | 8  | 110 |",
2903            "| 3  | 3  | 30  |",
2904            "| 5  | 5  | 50  |",
2905            "| 7  | 7  | 70  |",
2906            "| 9  | 8  | 90  |",
2907            "+----+----+-----+",
2908        ];
2909        assert_batches_sorted_eq!(expected, &batches);
2910
2911        // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2 != 13
2912        let filter_expression = Arc::new(BinaryExpr::new(
2913            Arc::new(Column::new("x", 0)),
2914            Operator::NotEq,
2915            Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
2916        )) as Arc<dyn PhysicalExpr>;
2917
2918        let filter = JoinFilter::new(
2919            filter_expression,
2920            column_indices,
2921            Arc::new(intermediate_schema),
2922        );
2923
2924        let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?;
2925
2926        let columns_header = columns(&join.schema());
2927        assert_eq!(columns_header, vec!["a1", "b1", "c1"]);
2928
2929        let stream = join.execute(0, task_ctx)?;
2930        let batches = common::collect(stream).await?;
2931
2932        let expected = [
2933            "+----+----+-----+",
2934            "| a1 | b1 | c1  |",
2935            "+----+----+-----+",
2936            "| 1  | 1  | 10  |",
2937            "| 11 | 8  | 110 |",
2938            "| 3  | 3  | 30  |",
2939            "| 5  | 5  | 50  |",
2940            "| 7  | 7  | 70  |",
2941            "| 9  | 8  | 90  |",
2942            "+----+----+-----+",
2943        ];
2944        assert_batches_sorted_eq!(expected, &batches);
2945
2946        Ok(())
2947    }
2948
2949    #[apply(batch_sizes)]
2950    #[tokio::test]
2951    async fn join_right_anti(batch_size: usize) -> Result<()> {
2952        let task_ctx = prepare_task_ctx(batch_size);
2953        let left = build_semi_anti_left_table();
2954        let right = build_semi_anti_right_table();
2955        let on = vec![(
2956            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2957            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2958        )];
2959
2960        let join = join(left, right, on, &JoinType::RightAnti, false)?;
2961
2962        let columns = columns(&join.schema());
2963        assert_eq!(columns, vec!["a2", "b2", "c2"]);
2964
2965        let stream = join.execute(0, task_ctx)?;
2966        let batches = common::collect(stream).await?;
2967
2968        let expected = [
2969            "+----+----+-----+",
2970            "| a2 | b2 | c2  |",
2971            "+----+----+-----+",
2972            "| 6  | 6  | 60  |",
2973            "| 2  | 2  | 80  |",
2974            "| 4  | 4  | 120 |",
2975            "+----+----+-----+",
2976        ];
2977
2978        // RightAnti join output is expected to preserve right input order
2979        assert_batches_eq!(expected, &batches);
2980        Ok(())
2981    }
2982
2983    #[apply(batch_sizes)]
2984    #[tokio::test]
2985    async fn join_right_anti_with_filter(batch_size: usize) -> Result<()> {
2986        let task_ctx = prepare_task_ctx(batch_size);
2987        let left = build_semi_anti_left_table();
2988        let right = build_semi_anti_right_table();
2989        // left_table right anti join right_table on left_table.b1 = right_table.b2 and left_table.a1!=13
2990        let on = vec![(
2991            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2992            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2993        )];
2994
2995        let column_indices = vec![ColumnIndex {
2996            index: 0,
2997            side: JoinSide::Left,
2998        }];
2999        let intermediate_schema =
3000            Schema::new(vec![Field::new("x", DataType::Int32, true)]);
3001
3002        let filter_expression = Arc::new(BinaryExpr::new(
3003            Arc::new(Column::new("x", 0)),
3004            Operator::NotEq,
3005            Arc::new(Literal::new(ScalarValue::Int32(Some(13)))),
3006        )) as Arc<dyn PhysicalExpr>;
3007
3008        let filter = JoinFilter::new(
3009            filter_expression,
3010            column_indices,
3011            Arc::new(intermediate_schema.clone()),
3012        );
3013
3014        let join = join_with_filter(
3015            Arc::clone(&left),
3016            Arc::clone(&right),
3017            on.clone(),
3018            filter,
3019            &JoinType::RightAnti,
3020            false,
3021        )?;
3022
3023        let columns_header = columns(&join.schema());
3024        assert_eq!(columns_header, vec!["a2", "b2", "c2"]);
3025
3026        let stream = join.execute(0, Arc::clone(&task_ctx))?;
3027        let batches = common::collect(stream).await?;
3028
3029        let expected = [
3030            "+----+----+-----+",
3031            "| a2 | b2 | c2  |",
3032            "+----+----+-----+",
3033            "| 12 | 10 | 40  |",
3034            "| 6  | 6  | 60  |",
3035            "| 2  | 2  | 80  |",
3036            "| 10 | 10 | 100 |",
3037            "| 4  | 4  | 120 |",
3038            "+----+----+-----+",
3039        ];
3040
3041        // RightAnti join output is expected to preserve right input order
3042        assert_batches_eq!(expected, &batches);
3043
3044        // left_table right anti join right_table on left_table.b1 = right_table.b2 and right_table.b2!=8
3045        let column_indices = vec![ColumnIndex {
3046            index: 1,
3047            side: JoinSide::Right,
3048        }];
3049        let filter_expression = Arc::new(BinaryExpr::new(
3050            Arc::new(Column::new("x", 0)),
3051            Operator::NotEq,
3052            Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
3053        )) as Arc<dyn PhysicalExpr>;
3054
3055        let filter = JoinFilter::new(
3056            filter_expression,
3057            column_indices,
3058            Arc::new(intermediate_schema),
3059        );
3060
3061        let join =
3062            join_with_filter(left, right, on, filter, &JoinType::RightAnti, false)?;
3063
3064        let columns_header = columns(&join.schema());
3065        assert_eq!(columns_header, vec!["a2", "b2", "c2"]);
3066
3067        let stream = join.execute(0, task_ctx)?;
3068        let batches = common::collect(stream).await?;
3069
3070        let expected = [
3071            "+----+----+-----+",
3072            "| a2 | b2 | c2  |",
3073            "+----+----+-----+",
3074            "| 8  | 8  | 20  |",
3075            "| 6  | 6  | 60  |",
3076            "| 2  | 2  | 80  |",
3077            "| 4  | 4  | 120 |",
3078            "+----+----+-----+",
3079        ];
3080
3081        // RightAnti join output is expected to preserve right input order
3082        assert_batches_eq!(expected, &batches);
3083
3084        Ok(())
3085    }
3086
3087    #[apply(batch_sizes)]
3088    #[tokio::test]
3089    async fn join_right_one(batch_size: usize) -> Result<()> {
3090        let task_ctx = prepare_task_ctx(batch_size);
3091        let left = build_table(
3092            ("a1", &vec![1, 2, 3]),
3093            ("b1", &vec![4, 5, 7]),
3094            ("c1", &vec![7, 8, 9]),
3095        );
3096        let right = build_table(
3097            ("a2", &vec![10, 20, 30]),
3098            ("b1", &vec![4, 5, 6]), // 6 does not exist on the left
3099            ("c2", &vec![70, 80, 90]),
3100        );
3101        let on = vec![(
3102            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3103            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3104        )];
3105
3106        let (columns, batches) =
3107            join_collect(left, right, on, &JoinType::Right, false, task_ctx).await?;
3108
3109        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
3110
3111        let expected = [
3112            "+----+----+----+----+----+----+",
3113            "| a1 | b1 | c1 | a2 | b1 | c2 |",
3114            "+----+----+----+----+----+----+",
3115            "|    |    |    | 30 | 6  | 90 |",
3116            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3117            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3118            "+----+----+----+----+----+----+",
3119        ];
3120
3121        assert_batches_sorted_eq!(expected, &batches);
3122
3123        Ok(())
3124    }
3125
3126    #[apply(batch_sizes)]
3127    #[tokio::test]
3128    async fn partitioned_join_right_one(batch_size: usize) -> Result<()> {
3129        let task_ctx = prepare_task_ctx(batch_size);
3130        let left = build_table(
3131            ("a1", &vec![1, 2, 3]),
3132            ("b1", &vec![4, 5, 7]),
3133            ("c1", &vec![7, 8, 9]),
3134        );
3135        let right = build_table(
3136            ("a2", &vec![10, 20, 30]),
3137            ("b1", &vec![4, 5, 6]), // 6 does not exist on the left
3138            ("c2", &vec![70, 80, 90]),
3139        );
3140        let on = vec![(
3141            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3142            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3143        )];
3144
3145        let (columns, batches) =
3146            partitioned_join_collect(left, right, on, &JoinType::Right, false, task_ctx)
3147                .await?;
3148
3149        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
3150
3151        let expected = [
3152            "+----+----+----+----+----+----+",
3153            "| a1 | b1 | c1 | a2 | b1 | c2 |",
3154            "+----+----+----+----+----+----+",
3155            "|    |    |    | 30 | 6  | 90 |",
3156            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3157            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3158            "+----+----+----+----+----+----+",
3159        ];
3160
3161        assert_batches_sorted_eq!(expected, &batches);
3162
3163        Ok(())
3164    }
3165
3166    #[apply(batch_sizes)]
3167    #[tokio::test]
3168    async fn join_full_one(batch_size: usize) -> Result<()> {
3169        let task_ctx = prepare_task_ctx(batch_size);
3170        let left = build_table(
3171            ("a1", &vec![1, 2, 3]),
3172            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
3173            ("c1", &vec![7, 8, 9]),
3174        );
3175        let right = build_table(
3176            ("a2", &vec![10, 20, 30]),
3177            ("b2", &vec![4, 5, 6]),
3178            ("c2", &vec![70, 80, 90]),
3179        );
3180        let on = vec![(
3181            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
3182            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
3183        )];
3184
3185        let join = join(left, right, on, &JoinType::Full, false)?;
3186
3187        let columns = columns(&join.schema());
3188        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3189
3190        let stream = join.execute(0, task_ctx)?;
3191        let batches = common::collect(stream).await?;
3192
3193        let expected = [
3194            "+----+----+----+----+----+----+",
3195            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3196            "+----+----+----+----+----+----+",
3197            "|    |    |    | 30 | 6  | 90 |",
3198            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3199            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3200            "| 3  | 7  | 9  |    |    |    |",
3201            "+----+----+----+----+----+----+",
3202        ];
3203        assert_batches_sorted_eq!(expected, &batches);
3204
3205        Ok(())
3206    }
3207
3208    #[apply(batch_sizes)]
3209    #[tokio::test]
3210    async fn join_left_mark(batch_size: usize) -> Result<()> {
3211        let task_ctx = prepare_task_ctx(batch_size);
3212        let left = build_table(
3213            ("a1", &vec![1, 2, 3]),
3214            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
3215            ("c1", &vec![7, 8, 9]),
3216        );
3217        let right = build_table(
3218            ("a2", &vec![10, 20, 30]),
3219            ("b1", &vec![4, 5, 6]),
3220            ("c2", &vec![70, 80, 90]),
3221        );
3222        let on = vec![(
3223            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3224            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3225        )];
3226
3227        let (columns, batches) = join_collect(
3228            Arc::clone(&left),
3229            Arc::clone(&right),
3230            on.clone(),
3231            &JoinType::LeftMark,
3232            false,
3233            task_ctx,
3234        )
3235        .await?;
3236        assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
3237
3238        let expected = [
3239            "+----+----+----+-------+",
3240            "| a1 | b1 | c1 | mark  |",
3241            "+----+----+----+-------+",
3242            "| 1  | 4  | 7  | true  |",
3243            "| 2  | 5  | 8  | true  |",
3244            "| 3  | 7  | 9  | false |",
3245            "+----+----+----+-------+",
3246        ];
3247        assert_batches_sorted_eq!(expected, &batches);
3248
3249        Ok(())
3250    }
3251
3252    #[apply(batch_sizes)]
3253    #[tokio::test]
3254    async fn partitioned_join_left_mark(batch_size: usize) -> Result<()> {
3255        let task_ctx = prepare_task_ctx(batch_size);
3256        let left = build_table(
3257            ("a1", &vec![1, 2, 3]),
3258            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
3259            ("c1", &vec![7, 8, 9]),
3260        );
3261        let right = build_table(
3262            ("a2", &vec![10, 20, 30, 40]),
3263            ("b1", &vec![4, 4, 5, 6]),
3264            ("c2", &vec![60, 70, 80, 90]),
3265        );
3266        let on = vec![(
3267            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
3268            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
3269        )];
3270
3271        let (columns, batches) = partitioned_join_collect(
3272            Arc::clone(&left),
3273            Arc::clone(&right),
3274            on.clone(),
3275            &JoinType::LeftMark,
3276            false,
3277            task_ctx,
3278        )
3279        .await?;
3280        assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
3281
3282        let expected = [
3283            "+----+----+----+-------+",
3284            "| a1 | b1 | c1 | mark  |",
3285            "+----+----+----+-------+",
3286            "| 1  | 4  | 7  | true  |",
3287            "| 2  | 5  | 8  | true  |",
3288            "| 3  | 7  | 9  | false |",
3289            "+----+----+----+-------+",
3290        ];
3291        assert_batches_sorted_eq!(expected, &batches);
3292
3293        Ok(())
3294    }
3295
3296    #[test]
3297    fn join_with_hash_collision() -> Result<()> {
3298        let mut hashmap_left = HashTable::with_capacity(2);
3299        let left = build_table_i32(
3300            ("a", &vec![10, 20]),
3301            ("x", &vec![100, 200]),
3302            ("y", &vec![200, 300]),
3303        );
3304
3305        let random_state = RandomState::with_seeds(0, 0, 0, 0);
3306        let hashes_buff = &mut vec![0; left.num_rows()];
3307        let hashes = create_hashes(
3308            &[Arc::clone(&left.columns()[0])],
3309            &random_state,
3310            hashes_buff,
3311        )?;
3312
3313        // Create hash collisions (same hashes)
3314        hashmap_left.insert_unique(hashes[0], (hashes[0], 1), |(h, _)| *h);
3315        hashmap_left.insert_unique(hashes[1], (hashes[1], 1), |(h, _)| *h);
3316
3317        let next = vec![2, 0];
3318
3319        let right = build_table_i32(
3320            ("a", &vec![10, 20]),
3321            ("b", &vec![0, 0]),
3322            ("c", &vec![30, 40]),
3323        );
3324
3325        // Join key column for both join sides
3326        let key_column: PhysicalExprRef = Arc::new(Column::new("a", 0)) as _;
3327
3328        let join_hash_map = JoinHashMap::new(hashmap_left, next);
3329
3330        let left_keys_values = key_column.evaluate(&left)?.into_array(left.num_rows())?;
3331        let right_keys_values =
3332            key_column.evaluate(&right)?.into_array(right.num_rows())?;
3333        let mut hashes_buffer = vec![0; right.num_rows()];
3334        create_hashes(
3335            &[Arc::clone(&right_keys_values)],
3336            &random_state,
3337            &mut hashes_buffer,
3338        )?;
3339
3340        let (l, r, _) = lookup_join_hashmap(
3341            &join_hash_map,
3342            &[left_keys_values],
3343            &[right_keys_values],
3344            false,
3345            &hashes_buffer,
3346            8192,
3347            (0, None),
3348        )?;
3349
3350        let left_ids: UInt64Array = vec![0, 1].into();
3351
3352        let right_ids: UInt32Array = vec![0, 1].into();
3353
3354        assert_eq!(left_ids, l);
3355
3356        assert_eq!(right_ids, r);
3357
3358        Ok(())
3359    }
3360
3361    #[tokio::test]
3362    async fn join_with_duplicated_column_names() -> Result<()> {
3363        let task_ctx = Arc::new(TaskContext::default());
3364        let left = build_table(
3365            ("a", &vec![1, 2, 3]),
3366            ("b", &vec![4, 5, 7]),
3367            ("c", &vec![7, 8, 9]),
3368        );
3369        let right = build_table(
3370            ("a", &vec![10, 20, 30]),
3371            ("b", &vec![1, 2, 7]),
3372            ("c", &vec![70, 80, 90]),
3373        );
3374        let on = vec![(
3375            // join on a=b so there are duplicate column names on unjoined columns
3376            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _,
3377            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _,
3378        )];
3379
3380        let join = join(left, right, on, &JoinType::Inner, false)?;
3381
3382        let columns = columns(&join.schema());
3383        assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
3384
3385        let stream = join.execute(0, task_ctx)?;
3386        let batches = common::collect(stream).await?;
3387
3388        let expected = [
3389            "+---+---+---+----+---+----+",
3390            "| a | b | c | a  | b | c  |",
3391            "+---+---+---+----+---+----+",
3392            "| 1 | 4 | 7 | 10 | 1 | 70 |",
3393            "| 2 | 5 | 8 | 20 | 2 | 80 |",
3394            "+---+---+---+----+---+----+",
3395        ];
3396        assert_batches_sorted_eq!(expected, &batches);
3397
3398        Ok(())
3399    }
3400
3401    fn prepare_join_filter() -> JoinFilter {
3402        let column_indices = vec![
3403            ColumnIndex {
3404                index: 2,
3405                side: JoinSide::Left,
3406            },
3407            ColumnIndex {
3408                index: 2,
3409                side: JoinSide::Right,
3410            },
3411        ];
3412        let intermediate_schema = Schema::new(vec![
3413            Field::new("c", DataType::Int32, true),
3414            Field::new("c", DataType::Int32, true),
3415        ]);
3416        let filter_expression = Arc::new(BinaryExpr::new(
3417            Arc::new(Column::new("c", 0)),
3418            Operator::Gt,
3419            Arc::new(Column::new("c", 1)),
3420        )) as Arc<dyn PhysicalExpr>;
3421
3422        JoinFilter::new(
3423            filter_expression,
3424            column_indices,
3425            Arc::new(intermediate_schema),
3426        )
3427    }
3428
3429    #[apply(batch_sizes)]
3430    #[tokio::test]
3431    async fn join_inner_with_filter(batch_size: usize) -> Result<()> {
3432        let task_ctx = prepare_task_ctx(batch_size);
3433        let left = build_table(
3434            ("a", &vec![0, 1, 2, 2]),
3435            ("b", &vec![4, 5, 7, 8]),
3436            ("c", &vec![7, 8, 9, 1]),
3437        );
3438        let right = build_table(
3439            ("a", &vec![10, 20, 30, 40]),
3440            ("b", &vec![2, 2, 3, 4]),
3441            ("c", &vec![7, 5, 6, 4]),
3442        );
3443        let on = vec![(
3444            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _,
3445            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _,
3446        )];
3447        let filter = prepare_join_filter();
3448
3449        let join = join_with_filter(left, right, on, filter, &JoinType::Inner, false)?;
3450
3451        let columns = columns(&join.schema());
3452        assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
3453
3454        let stream = join.execute(0, task_ctx)?;
3455        let batches = common::collect(stream).await?;
3456
3457        let expected = [
3458            "+---+---+---+----+---+---+",
3459            "| a | b | c | a  | b | c |",
3460            "+---+---+---+----+---+---+",
3461            "| 2 | 7 | 9 | 10 | 2 | 7 |",
3462            "| 2 | 7 | 9 | 20 | 2 | 5 |",
3463            "+---+---+---+----+---+---+",
3464        ];
3465        assert_batches_sorted_eq!(expected, &batches);
3466
3467        Ok(())
3468    }
3469
3470    #[apply(batch_sizes)]
3471    #[tokio::test]
3472    async fn join_left_with_filter(batch_size: usize) -> Result<()> {
3473        let task_ctx = prepare_task_ctx(batch_size);
3474        let left = build_table(
3475            ("a", &vec![0, 1, 2, 2]),
3476            ("b", &vec![4, 5, 7, 8]),
3477            ("c", &vec![7, 8, 9, 1]),
3478        );
3479        let right = build_table(
3480            ("a", &vec![10, 20, 30, 40]),
3481            ("b", &vec![2, 2, 3, 4]),
3482            ("c", &vec![7, 5, 6, 4]),
3483        );
3484        let on = vec![(
3485            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _,
3486            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _,
3487        )];
3488        let filter = prepare_join_filter();
3489
3490        let join = join_with_filter(left, right, on, filter, &JoinType::Left, false)?;
3491
3492        let columns = columns(&join.schema());
3493        assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
3494
3495        let stream = join.execute(0, task_ctx)?;
3496        let batches = common::collect(stream).await?;
3497
3498        let expected = [
3499            "+---+---+---+----+---+---+",
3500            "| a | b | c | a  | b | c |",
3501            "+---+---+---+----+---+---+",
3502            "| 0 | 4 | 7 |    |   |   |",
3503            "| 1 | 5 | 8 |    |   |   |",
3504            "| 2 | 7 | 9 | 10 | 2 | 7 |",
3505            "| 2 | 7 | 9 | 20 | 2 | 5 |",
3506            "| 2 | 8 | 1 |    |   |   |",
3507            "+---+---+---+----+---+---+",
3508        ];
3509        assert_batches_sorted_eq!(expected, &batches);
3510
3511        Ok(())
3512    }
3513
3514    #[apply(batch_sizes)]
3515    #[tokio::test]
3516    async fn join_right_with_filter(batch_size: usize) -> Result<()> {
3517        let task_ctx = prepare_task_ctx(batch_size);
3518        let left = build_table(
3519            ("a", &vec![0, 1, 2, 2]),
3520            ("b", &vec![4, 5, 7, 8]),
3521            ("c", &vec![7, 8, 9, 1]),
3522        );
3523        let right = build_table(
3524            ("a", &vec![10, 20, 30, 40]),
3525            ("b", &vec![2, 2, 3, 4]),
3526            ("c", &vec![7, 5, 6, 4]),
3527        );
3528        let on = vec![(
3529            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _,
3530            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _,
3531        )];
3532        let filter = prepare_join_filter();
3533
3534        let join = join_with_filter(left, right, on, filter, &JoinType::Right, false)?;
3535
3536        let columns = columns(&join.schema());
3537        assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
3538
3539        let stream = join.execute(0, task_ctx)?;
3540        let batches = common::collect(stream).await?;
3541
3542        let expected = [
3543            "+---+---+---+----+---+---+",
3544            "| a | b | c | a  | b | c |",
3545            "+---+---+---+----+---+---+",
3546            "|   |   |   | 30 | 3 | 6 |",
3547            "|   |   |   | 40 | 4 | 4 |",
3548            "| 2 | 7 | 9 | 10 | 2 | 7 |",
3549            "| 2 | 7 | 9 | 20 | 2 | 5 |",
3550            "+---+---+---+----+---+---+",
3551        ];
3552        assert_batches_sorted_eq!(expected, &batches);
3553
3554        Ok(())
3555    }
3556
3557    #[apply(batch_sizes)]
3558    #[tokio::test]
3559    async fn join_full_with_filter(batch_size: usize) -> Result<()> {
3560        let task_ctx = prepare_task_ctx(batch_size);
3561        let left = build_table(
3562            ("a", &vec![0, 1, 2, 2]),
3563            ("b", &vec![4, 5, 7, 8]),
3564            ("c", &vec![7, 8, 9, 1]),
3565        );
3566        let right = build_table(
3567            ("a", &vec![10, 20, 30, 40]),
3568            ("b", &vec![2, 2, 3, 4]),
3569            ("c", &vec![7, 5, 6, 4]),
3570        );
3571        let on = vec![(
3572            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _,
3573            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _,
3574        )];
3575        let filter = prepare_join_filter();
3576
3577        let join = join_with_filter(left, right, on, filter, &JoinType::Full, false)?;
3578
3579        let columns = columns(&join.schema());
3580        assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
3581
3582        let stream = join.execute(0, task_ctx)?;
3583        let batches = common::collect(stream).await?;
3584
3585        let expected = [
3586            "+---+---+---+----+---+---+",
3587            "| a | b | c | a  | b | c |",
3588            "+---+---+---+----+---+---+",
3589            "|   |   |   | 30 | 3 | 6 |",
3590            "|   |   |   | 40 | 4 | 4 |",
3591            "| 2 | 7 | 9 | 10 | 2 | 7 |",
3592            "| 2 | 7 | 9 | 20 | 2 | 5 |",
3593            "| 0 | 4 | 7 |    |   |   |",
3594            "| 1 | 5 | 8 |    |   |   |",
3595            "| 2 | 8 | 1 |    |   |   |",
3596            "+---+---+---+----+---+---+",
3597        ];
3598        assert_batches_sorted_eq!(expected, &batches);
3599
3600        Ok(())
3601    }
3602
3603    /// Test for parallelized HashJoinExec with PartitionMode::CollectLeft
3604    #[tokio::test]
3605    async fn test_collect_left_multiple_partitions_join() -> Result<()> {
3606        let task_ctx = Arc::new(TaskContext::default());
3607        let left = build_table(
3608            ("a1", &vec![1, 2, 3]),
3609            ("b1", &vec![4, 5, 7]),
3610            ("c1", &vec![7, 8, 9]),
3611        );
3612        let right = build_table(
3613            ("a2", &vec![10, 20, 30]),
3614            ("b2", &vec![4, 5, 6]),
3615            ("c2", &vec![70, 80, 90]),
3616        );
3617        let on = vec![(
3618            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
3619            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
3620        )];
3621
3622        let expected_inner = vec![
3623            "+----+----+----+----+----+----+",
3624            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3625            "+----+----+----+----+----+----+",
3626            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3627            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3628            "+----+----+----+----+----+----+",
3629        ];
3630        let expected_left = vec![
3631            "+----+----+----+----+----+----+",
3632            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3633            "+----+----+----+----+----+----+",
3634            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3635            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3636            "| 3  | 7  | 9  |    |    |    |",
3637            "+----+----+----+----+----+----+",
3638        ];
3639        let expected_right = vec![
3640            "+----+----+----+----+----+----+",
3641            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3642            "+----+----+----+----+----+----+",
3643            "|    |    |    | 30 | 6  | 90 |",
3644            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3645            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3646            "+----+----+----+----+----+----+",
3647        ];
3648        let expected_full = vec![
3649            "+----+----+----+----+----+----+",
3650            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3651            "+----+----+----+----+----+----+",
3652            "|    |    |    | 30 | 6  | 90 |",
3653            "| 1  | 4  | 7  | 10 | 4  | 70 |",
3654            "| 2  | 5  | 8  | 20 | 5  | 80 |",
3655            "| 3  | 7  | 9  |    |    |    |",
3656            "+----+----+----+----+----+----+",
3657        ];
3658        let expected_left_semi = vec![
3659            "+----+----+----+",
3660            "| a1 | b1 | c1 |",
3661            "+----+----+----+",
3662            "| 1  | 4  | 7  |",
3663            "| 2  | 5  | 8  |",
3664            "+----+----+----+",
3665        ];
3666        let expected_left_anti = vec![
3667            "+----+----+----+",
3668            "| a1 | b1 | c1 |",
3669            "+----+----+----+",
3670            "| 3  | 7  | 9  |",
3671            "+----+----+----+",
3672        ];
3673        let expected_right_semi = vec![
3674            "+----+----+----+",
3675            "| a2 | b2 | c2 |",
3676            "+----+----+----+",
3677            "| 10 | 4  | 70 |",
3678            "| 20 | 5  | 80 |",
3679            "+----+----+----+",
3680        ];
3681        let expected_right_anti = vec![
3682            "+----+----+----+",
3683            "| a2 | b2 | c2 |",
3684            "+----+----+----+",
3685            "| 30 | 6  | 90 |",
3686            "+----+----+----+",
3687        ];
3688        let expected_left_mark = vec![
3689            "+----+----+----+-------+",
3690            "| a1 | b1 | c1 | mark  |",
3691            "+----+----+----+-------+",
3692            "| 1  | 4  | 7  | true  |",
3693            "| 2  | 5  | 8  | true  |",
3694            "| 3  | 7  | 9  | false |",
3695            "+----+----+----+-------+",
3696        ];
3697
3698        let test_cases = vec![
3699            (JoinType::Inner, expected_inner),
3700            (JoinType::Left, expected_left),
3701            (JoinType::Right, expected_right),
3702            (JoinType::Full, expected_full),
3703            (JoinType::LeftSemi, expected_left_semi),
3704            (JoinType::LeftAnti, expected_left_anti),
3705            (JoinType::RightSemi, expected_right_semi),
3706            (JoinType::RightAnti, expected_right_anti),
3707            (JoinType::LeftMark, expected_left_mark),
3708        ];
3709
3710        for (join_type, expected) in test_cases {
3711            let (_, batches) = join_collect_with_partition_mode(
3712                Arc::clone(&left),
3713                Arc::clone(&right),
3714                on.clone(),
3715                &join_type,
3716                PartitionMode::CollectLeft,
3717                false,
3718                Arc::clone(&task_ctx),
3719            )
3720            .await?;
3721            assert_batches_sorted_eq!(expected, &batches);
3722        }
3723
3724        Ok(())
3725    }
3726
3727    #[tokio::test]
3728    async fn join_date32() -> Result<()> {
3729        let schema = Arc::new(Schema::new(vec![
3730            Field::new("date", DataType::Date32, false),
3731            Field::new("n", DataType::Int32, false),
3732        ]));
3733
3734        let dates: ArrayRef = Arc::new(Date32Array::from(vec![19107, 19108, 19109]));
3735        let n: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
3736        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![dates, n])?;
3737        let left =
3738            TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)
3739                .unwrap();
3740        let dates: ArrayRef = Arc::new(Date32Array::from(vec![19108, 19108, 19109]));
3741        let n: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6]));
3742        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![dates, n])?;
3743        let right = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
3744        let on = vec![(
3745            Arc::new(Column::new_with_schema("date", &left.schema()).unwrap()) as _,
3746            Arc::new(Column::new_with_schema("date", &right.schema()).unwrap()) as _,
3747        )];
3748
3749        let join = join(left, right, on, &JoinType::Inner, false)?;
3750
3751        let task_ctx = Arc::new(TaskContext::default());
3752        let stream = join.execute(0, task_ctx)?;
3753        let batches = common::collect(stream).await?;
3754
3755        let expected = [
3756            "+------------+---+------------+---+",
3757            "| date       | n | date       | n |",
3758            "+------------+---+------------+---+",
3759            "| 2022-04-26 | 2 | 2022-04-26 | 4 |",
3760            "| 2022-04-26 | 2 | 2022-04-26 | 5 |",
3761            "| 2022-04-27 | 3 | 2022-04-27 | 6 |",
3762            "+------------+---+------------+---+",
3763        ];
3764        assert_batches_sorted_eq!(expected, &batches);
3765
3766        Ok(())
3767    }
3768
3769    #[tokio::test]
3770    async fn join_with_error_right() {
3771        let left = build_table(
3772            ("a1", &vec![1, 2, 3]),
3773            ("b1", &vec![4, 5, 7]),
3774            ("c1", &vec![7, 8, 9]),
3775        );
3776
3777        // right input stream returns one good batch and then one error.
3778        // The error should be returned.
3779        let err = exec_err!("bad data error");
3780        let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![]));
3781
3782        let on = vec![(
3783            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
3784            Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _,
3785        )];
3786        let schema = right.schema();
3787        let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![]));
3788        let right_input = Arc::new(MockExec::new(vec![Ok(right), err], schema));
3789
3790        let join_types = vec![
3791            JoinType::Inner,
3792            JoinType::Left,
3793            JoinType::Right,
3794            JoinType::Full,
3795            JoinType::LeftSemi,
3796            JoinType::LeftAnti,
3797            JoinType::RightSemi,
3798            JoinType::RightAnti,
3799        ];
3800
3801        for join_type in join_types {
3802            let join = join(
3803                Arc::clone(&left),
3804                Arc::clone(&right_input) as Arc<dyn ExecutionPlan>,
3805                on.clone(),
3806                &join_type,
3807                false,
3808            )
3809            .unwrap();
3810            let task_ctx = Arc::new(TaskContext::default());
3811
3812            let stream = join.execute(0, task_ctx).unwrap();
3813
3814            // Expect that an error is returned
3815            let result_string = common::collect(stream).await.unwrap_err().to_string();
3816            assert!(
3817                result_string.contains("bad data error"),
3818                "actual: {result_string}"
3819            );
3820        }
3821    }
3822
3823    #[tokio::test]
3824    async fn join_splitted_batch() {
3825        let left = build_table(
3826            ("a1", &vec![1, 2, 3, 4]),
3827            ("b1", &vec![1, 1, 1, 1]),
3828            ("c1", &vec![0, 0, 0, 0]),
3829        );
3830        let right = build_table(
3831            ("a2", &vec![10, 20, 30, 40, 50]),
3832            ("b2", &vec![1, 1, 1, 1, 1]),
3833            ("c2", &vec![0, 0, 0, 0, 0]),
3834        );
3835        let on = vec![(
3836            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
3837            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
3838        )];
3839
3840        let join_types = vec![
3841            JoinType::Inner,
3842            JoinType::Left,
3843            JoinType::Right,
3844            JoinType::Full,
3845            JoinType::RightSemi,
3846            JoinType::RightAnti,
3847            JoinType::LeftSemi,
3848            JoinType::LeftAnti,
3849        ];
3850        let expected_resultset_records = 20;
3851        let common_result = [
3852            "+----+----+----+----+----+----+",
3853            "| a1 | b1 | c1 | a2 | b2 | c2 |",
3854            "+----+----+----+----+----+----+",
3855            "| 1  | 1  | 0  | 10 | 1  | 0  |",
3856            "| 2  | 1  | 0  | 10 | 1  | 0  |",
3857            "| 3  | 1  | 0  | 10 | 1  | 0  |",
3858            "| 4  | 1  | 0  | 10 | 1  | 0  |",
3859            "| 1  | 1  | 0  | 20 | 1  | 0  |",
3860            "| 2  | 1  | 0  | 20 | 1  | 0  |",
3861            "| 3  | 1  | 0  | 20 | 1  | 0  |",
3862            "| 4  | 1  | 0  | 20 | 1  | 0  |",
3863            "| 1  | 1  | 0  | 30 | 1  | 0  |",
3864            "| 2  | 1  | 0  | 30 | 1  | 0  |",
3865            "| 3  | 1  | 0  | 30 | 1  | 0  |",
3866            "| 4  | 1  | 0  | 30 | 1  | 0  |",
3867            "| 1  | 1  | 0  | 40 | 1  | 0  |",
3868            "| 2  | 1  | 0  | 40 | 1  | 0  |",
3869            "| 3  | 1  | 0  | 40 | 1  | 0  |",
3870            "| 4  | 1  | 0  | 40 | 1  | 0  |",
3871            "| 1  | 1  | 0  | 50 | 1  | 0  |",
3872            "| 2  | 1  | 0  | 50 | 1  | 0  |",
3873            "| 3  | 1  | 0  | 50 | 1  | 0  |",
3874            "| 4  | 1  | 0  | 50 | 1  | 0  |",
3875            "+----+----+----+----+----+----+",
3876        ];
3877        let left_batch = [
3878            "+----+----+----+",
3879            "| a1 | b1 | c1 |",
3880            "+----+----+----+",
3881            "| 1  | 1  | 0  |",
3882            "| 2  | 1  | 0  |",
3883            "| 3  | 1  | 0  |",
3884            "| 4  | 1  | 0  |",
3885            "+----+----+----+",
3886        ];
3887        let right_batch = [
3888            "+----+----+----+",
3889            "| a2 | b2 | c2 |",
3890            "+----+----+----+",
3891            "| 10 | 1  | 0  |",
3892            "| 20 | 1  | 0  |",
3893            "| 30 | 1  | 0  |",
3894            "| 40 | 1  | 0  |",
3895            "| 50 | 1  | 0  |",
3896            "+----+----+----+",
3897        ];
3898        let right_empty = [
3899            "+----+----+----+",
3900            "| a2 | b2 | c2 |",
3901            "+----+----+----+",
3902            "+----+----+----+",
3903        ];
3904        let left_empty = [
3905            "+----+----+----+",
3906            "| a1 | b1 | c1 |",
3907            "+----+----+----+",
3908            "+----+----+----+",
3909        ];
3910
3911        // validation of partial join results output for different batch_size setting
3912        for join_type in join_types {
3913            for batch_size in (1..21).rev() {
3914                let task_ctx = prepare_task_ctx(batch_size);
3915
3916                let join = join(
3917                    Arc::clone(&left),
3918                    Arc::clone(&right),
3919                    on.clone(),
3920                    &join_type,
3921                    false,
3922                )
3923                .unwrap();
3924
3925                let stream = join.execute(0, task_ctx).unwrap();
3926                let batches = common::collect(stream).await.unwrap();
3927
3928                // For inner/right join expected batch count equals dev_ceil result,
3929                // as there is no need to append non-joined build side data.
3930                // For other join types it'll be div_ceil + 1 -- for additional batch
3931                // containing not visited build side rows (empty in this test case).
3932                let expected_batch_count = match join_type {
3933                    JoinType::Inner
3934                    | JoinType::Right
3935                    | JoinType::RightSemi
3936                    | JoinType::RightAnti => {
3937                        div_ceil(expected_resultset_records, batch_size)
3938                    }
3939                    _ => div_ceil(expected_resultset_records, batch_size) + 1,
3940                };
3941                assert_eq!(
3942                    batches.len(),
3943                    expected_batch_count,
3944                    "expected {} output batches for {} join with batch_size = {}",
3945                    expected_batch_count,
3946                    join_type,
3947                    batch_size
3948                );
3949
3950                let expected = match join_type {
3951                    JoinType::RightSemi => right_batch.to_vec(),
3952                    JoinType::RightAnti => right_empty.to_vec(),
3953                    JoinType::LeftSemi => left_batch.to_vec(),
3954                    JoinType::LeftAnti => left_empty.to_vec(),
3955                    _ => common_result.to_vec(),
3956                };
3957                assert_batches_eq!(expected, &batches);
3958            }
3959        }
3960    }
3961
3962    #[tokio::test]
3963    async fn single_partition_join_overallocation() -> Result<()> {
3964        let left = build_table(
3965            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
3966            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
3967            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
3968        );
3969        let right = build_table(
3970            ("a2", &vec![10, 11]),
3971            ("b2", &vec![12, 13]),
3972            ("c2", &vec![14, 15]),
3973        );
3974        let on = vec![(
3975            Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _,
3976            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
3977        )];
3978
3979        let join_types = vec![
3980            JoinType::Inner,
3981            JoinType::Left,
3982            JoinType::Right,
3983            JoinType::Full,
3984            JoinType::LeftSemi,
3985            JoinType::LeftAnti,
3986            JoinType::RightSemi,
3987            JoinType::RightAnti,
3988            JoinType::LeftMark,
3989        ];
3990
3991        for join_type in join_types {
3992            let runtime = RuntimeEnvBuilder::new()
3993                .with_memory_limit(100, 1.0)
3994                .build_arc()?;
3995            let task_ctx = TaskContext::default().with_runtime(runtime);
3996            let task_ctx = Arc::new(task_ctx);
3997
3998            let join = join(
3999                Arc::clone(&left),
4000                Arc::clone(&right),
4001                on.clone(),
4002                &join_type,
4003                false,
4004            )?;
4005
4006            let stream = join.execute(0, task_ctx)?;
4007            let err = common::collect(stream).await.unwrap_err();
4008
4009            // Asserting that operator-level reservation attempting to overallocate
4010            assert_contains!(
4011                err.to_string(),
4012                "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput"
4013            );
4014
4015            assert_contains!(
4016                err.to_string(),
4017                "Failed to allocate additional 120 bytes for HashJoinInput"
4018            );
4019        }
4020
4021        Ok(())
4022    }
4023
4024    #[tokio::test]
4025    async fn partitioned_join_overallocation() -> Result<()> {
4026        // Prepare partitioned inputs for HashJoinExec
4027        // No need to adjust partitioning, as execution should fail with `Resources exhausted` error
4028        let left_batch = build_table_i32(
4029            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
4030            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
4031            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
4032        );
4033        let left = TestMemoryExec::try_new_exec(
4034            &[vec![left_batch.clone()], vec![left_batch.clone()]],
4035            left_batch.schema(),
4036            None,
4037        )
4038        .unwrap();
4039        let right_batch = build_table_i32(
4040            ("a2", &vec![10, 11]),
4041            ("b2", &vec![12, 13]),
4042            ("c2", &vec![14, 15]),
4043        );
4044        let right = TestMemoryExec::try_new_exec(
4045            &[vec![right_batch.clone()], vec![right_batch.clone()]],
4046            right_batch.schema(),
4047            None,
4048        )
4049        .unwrap();
4050        let on = vec![(
4051            Arc::new(Column::new_with_schema("b1", &left_batch.schema())?) as _,
4052            Arc::new(Column::new_with_schema("b2", &right_batch.schema())?) as _,
4053        )];
4054
4055        let join_types = vec![
4056            JoinType::Inner,
4057            JoinType::Left,
4058            JoinType::Right,
4059            JoinType::Full,
4060            JoinType::LeftSemi,
4061            JoinType::LeftAnti,
4062            JoinType::RightSemi,
4063            JoinType::RightAnti,
4064        ];
4065
4066        for join_type in join_types {
4067            let runtime = RuntimeEnvBuilder::new()
4068                .with_memory_limit(100, 1.0)
4069                .build_arc()?;
4070            let session_config = SessionConfig::default().with_batch_size(50);
4071            let task_ctx = TaskContext::default()
4072                .with_session_config(session_config)
4073                .with_runtime(runtime);
4074            let task_ctx = Arc::new(task_ctx);
4075
4076            let join = HashJoinExec::try_new(
4077                Arc::clone(&left) as Arc<dyn ExecutionPlan>,
4078                Arc::clone(&right) as Arc<dyn ExecutionPlan>,
4079                on.clone(),
4080                None,
4081                &join_type,
4082                None,
4083                PartitionMode::Partitioned,
4084                false,
4085            )?;
4086
4087            let stream = join.execute(1, task_ctx)?;
4088            let err = common::collect(stream).await.unwrap_err();
4089
4090            // Asserting that stream-level reservation attempting to overallocate
4091            assert_contains!(
4092                err.to_string(),
4093                "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput[1]"
4094
4095            );
4096
4097            assert_contains!(
4098                err.to_string(),
4099                "Failed to allocate additional 120 bytes for HashJoinInput[1]"
4100            );
4101        }
4102
4103        Ok(())
4104    }
4105
4106    fn build_table_struct(
4107        struct_name: &str,
4108        field_name_and_values: (&str, &Vec<Option<i32>>),
4109        nulls: Option<NullBuffer>,
4110    ) -> Arc<dyn ExecutionPlan> {
4111        let (field_name, values) = field_name_and_values;
4112        let inner_fields = vec![Field::new(field_name, DataType::Int32, true)];
4113        let schema = Schema::new(vec![Field::new(
4114            struct_name,
4115            DataType::Struct(inner_fields.clone().into()),
4116            nulls.is_some(),
4117        )]);
4118
4119        let batch = RecordBatch::try_new(
4120            Arc::new(schema),
4121            vec![Arc::new(StructArray::new(
4122                inner_fields.into(),
4123                vec![Arc::new(Int32Array::from(values.clone()))],
4124                nulls,
4125            ))],
4126        )
4127        .unwrap();
4128        let schema_ref = batch.schema();
4129        TestMemoryExec::try_new_exec(&[vec![batch]], schema_ref, None).unwrap()
4130    }
4131
4132    #[tokio::test]
4133    async fn join_on_struct() -> Result<()> {
4134        let task_ctx = Arc::new(TaskContext::default());
4135        let left =
4136            build_table_struct("n1", ("a", &vec![None, Some(1), Some(2), Some(3)]), None);
4137        let right =
4138            build_table_struct("n2", ("a", &vec![None, Some(1), Some(2), Some(4)]), None);
4139        let on = vec![(
4140            Arc::new(Column::new_with_schema("n1", &left.schema())?) as _,
4141            Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
4142        )];
4143
4144        let (columns, batches) =
4145            join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;
4146
4147        assert_eq!(columns, vec!["n1", "n2"]);
4148
4149        let expected = [
4150            "+--------+--------+",
4151            "| n1     | n2     |",
4152            "+--------+--------+",
4153            "| {a: }  | {a: }  |",
4154            "| {a: 1} | {a: 1} |",
4155            "| {a: 2} | {a: 2} |",
4156            "+--------+--------+",
4157        ];
4158        assert_batches_eq!(expected, &batches);
4159
4160        Ok(())
4161    }
4162
4163    #[tokio::test]
4164    async fn join_on_struct_with_nulls() -> Result<()> {
4165        let task_ctx = Arc::new(TaskContext::default());
4166        let left =
4167            build_table_struct("n1", ("a", &vec![None]), Some(NullBuffer::new_null(1)));
4168        let right =
4169            build_table_struct("n2", ("a", &vec![None]), Some(NullBuffer::new_null(1)));
4170        let on = vec![(
4171            Arc::new(Column::new_with_schema("n1", &left.schema())?) as _,
4172            Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
4173        )];
4174
4175        let (_, batches_null_eq) = join_collect(
4176            Arc::clone(&left),
4177            Arc::clone(&right),
4178            on.clone(),
4179            &JoinType::Inner,
4180            true,
4181            Arc::clone(&task_ctx),
4182        )
4183        .await?;
4184
4185        let expected_null_eq = [
4186            "+----+----+",
4187            "| n1 | n2 |",
4188            "+----+----+",
4189            "|    |    |",
4190            "+----+----+",
4191        ];
4192        assert_batches_eq!(expected_null_eq, &batches_null_eq);
4193
4194        let (_, batches_null_neq) =
4195            join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;
4196
4197        let expected_null_neq =
4198            ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"];
4199        assert_batches_eq!(expected_null_neq, &batches_null_neq);
4200
4201        Ok(())
4202    }
4203
4204    /// Returns the column names on the schema
4205    fn columns(schema: &Schema) -> Vec<String> {
4206        schema.fields().iter().map(|f| f.name().clone()).collect()
4207    }
4208}