datafusion_physical_plan/joins/
symmetric_hash_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the symmetric hash join algorithm with range-based
19//! data pruning to join two (potentially infinite) streams.
20//!
21//! A [`SymmetricHashJoinExec`] plan takes two children plan (with appropriate
22//! output ordering) and produces the join output according to the given join
23//! type and other options.
24//!
25//! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations
26//! for both its children.
27
28use std::any::Any;
29use std::fmt::{self, Debug};
30use std::mem::{size_of, size_of_val};
31use std::sync::Arc;
32use std::task::{Context, Poll};
33use std::vec;
34
35use crate::common::SharedMemoryReservation;
36use crate::execution_plan::{boundedness_from_children, emission_type_from_children};
37use crate::joins::hash_join::{equal_rows_arr, update_hash};
38use crate::joins::stream_join_utils::{
39    calculate_filter_expr_intervals, combine_two_batches,
40    convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
41    get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
42    PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics,
43};
44use crate::joins::utils::{
45    apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
46    check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter,
47    BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef,
48    NoopBatchTransformer, StatefulStreamResult,
49};
50use crate::projection::{
51    join_allows_pushdown, join_table_borders, new_join_children,
52    physical_to_column_exprs, update_join_filter, update_join_on, ProjectionExec,
53};
54use crate::{
55    joins::StreamJoinPartitionMode,
56    metrics::{ExecutionPlanMetricsSet, MetricsSet},
57    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
58    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
59};
60
61use arrow::array::{
62    ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array,
63    UInt64Array,
64};
65use arrow::compute::concat_batches;
66use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
67use arrow::record_batch::RecordBatch;
68use datafusion_common::hash_utils::create_hashes;
69use datafusion_common::utils::bisect;
70use datafusion_common::{internal_err, plan_err, HashSet, JoinSide, JoinType, Result};
71use datafusion_execution::memory_pool::MemoryConsumer;
72use datafusion_execution::TaskContext;
73use datafusion_expr::interval_arithmetic::Interval;
74use datafusion_physical_expr::equivalence::join_equivalence_properties;
75use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
76use datafusion_physical_expr::PhysicalExprRef;
77use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
78
79use ahash::RandomState;
80use futures::{ready, Stream, StreamExt};
81use parking_lot::Mutex;
82
83const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
84
85/// A symmetric hash join with range conditions is when both streams are hashed on the
86/// join key and the resulting hash tables are used to join the streams.
87/// The join is considered symmetric because the hash table is built on the join keys from both
88/// streams, and the matching of rows is based on the values of the join keys in both streams.
89/// This type of join is efficient in streaming context as it allows for fast lookups in the hash
90/// table, rather than having to scan through one or both of the streams to find matching rows, also it
91/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions),
92/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming
93/// data without any memory issues.
94///
95/// For each input stream, create a hash table.
96///   - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets.
97///   - Test if input is equal to a predefined set of other inputs.
98///   - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch].
99///   - Try to prune other side (probe) with new [RecordBatch].
100///   - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.),
101///     output the [RecordBatch] when a pruning happens or at the end of the data.
102///
103///
104/// ``` text
105///                        +-------------------------+
106///                        |                         |
107///   left stream ---------|  Left OneSideHashJoiner |---+
108///                        |                         |   |
109///                        +-------------------------+   |
110///                                                      |
111///                                                      |--------- Joined output
112///                                                      |
113///                        +-------------------------+   |
114///                        |                         |   |
115///  right stream ---------| Right OneSideHashJoiner |---+
116///                        |                         |
117///                        +-------------------------+
118///
119/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic
120/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range.
121///
122///
123///               PROBE SIDE          BUILD SIDE
124///                 BUFFER              BUFFER
125///             +-------------+     +------------+
126///             |             |     |            |    Unjoinable
127///             |             |     |            |    Range
128///             |             |     |            |
129///             |             |  |---------------------------------
130///             |             |  |  |            |
131///             |             |  |  |            |
132///             |             | /   |            |
133///             |             | |   |            |
134///             |             | |   |            |
135///             |             | |   |            |
136///             |             | |   |            |
137///             |             | |   |            |    Joinable
138///             |             |/    |            |    Range
139///             |             ||    |            |
140///             |+-----------+||    |            |
141///             || Record    ||     |            |
142///             || Batch     ||     |            |
143///             |+-----------+||    |            |
144///             +-------------+\    +------------+
145///                             |
146///                             \
147///                              |---------------------------------
148///
149///  This happens when range conditions are provided on sorted columns. E.g.
150///
151///        SELECT * FROM left_table, right_table
152///        ON
153///          left_key = right_key AND
154///          left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR
155///
156/// or
157///       SELECT * FROM left_table, right_table
158///        ON
159///          left_key = right_key AND
160///          left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10
161///
162/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to
163/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the
164/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios)
165/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning
166/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" ,
167/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending)
168/// than that can be dropped from the inner buffer.
169/// ```
170#[derive(Debug, Clone)]
171pub struct SymmetricHashJoinExec {
172    /// Left side stream
173    pub(crate) left: Arc<dyn ExecutionPlan>,
174    /// Right side stream
175    pub(crate) right: Arc<dyn ExecutionPlan>,
176    /// Set of common columns used to join on
177    pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
178    /// Filters applied when finding matching rows
179    pub(crate) filter: Option<JoinFilter>,
180    /// How the join is performed
181    pub(crate) join_type: JoinType,
182    /// Shares the `RandomState` for the hashing algorithm
183    random_state: RandomState,
184    /// Execution metrics
185    metrics: ExecutionPlanMetricsSet,
186    /// Information of index and left / right placement of columns
187    column_indices: Vec<ColumnIndex>,
188    /// If null_equals_null is true, null == null else null != null
189    pub(crate) null_equals_null: bool,
190    /// Left side sort expression(s)
191    pub(crate) left_sort_exprs: Option<LexOrdering>,
192    /// Right side sort expression(s)
193    pub(crate) right_sort_exprs: Option<LexOrdering>,
194    /// Partition Mode
195    mode: StreamJoinPartitionMode,
196    /// Cache holding plan properties like equivalences, output partitioning etc.
197    cache: PlanProperties,
198}
199
200impl SymmetricHashJoinExec {
201    /// Tries to create a new [SymmetricHashJoinExec].
202    /// # Error
203    /// This function errors when:
204    /// - It is not possible to join the left and right sides on keys `on`, or
205    /// - It fails to construct `SortedFilterExpr`s, or
206    /// - It fails to create the [ExprIntervalGraph].
207    #[allow(clippy::too_many_arguments)]
208    pub fn try_new(
209        left: Arc<dyn ExecutionPlan>,
210        right: Arc<dyn ExecutionPlan>,
211        on: JoinOn,
212        filter: Option<JoinFilter>,
213        join_type: &JoinType,
214        null_equals_null: bool,
215        left_sort_exprs: Option<LexOrdering>,
216        right_sort_exprs: Option<LexOrdering>,
217        mode: StreamJoinPartitionMode,
218    ) -> Result<Self> {
219        let left_schema = left.schema();
220        let right_schema = right.schema();
221
222        // Error out if no "on" constraints are given:
223        if on.is_empty() {
224            return plan_err!(
225                "On constraints in SymmetricHashJoinExec should be non-empty"
226            );
227        }
228
229        // Check if the join is valid with the given on constraints:
230        check_join_is_valid(&left_schema, &right_schema, &on)?;
231
232        // Build the join schema from the left and right schemas:
233        let (schema, column_indices) =
234            build_join_schema(&left_schema, &right_schema, join_type);
235
236        // Initialize the random state for the join operation:
237        let random_state = RandomState::with_seeds(0, 0, 0, 0);
238        let schema = Arc::new(schema);
239        let cache =
240            Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type, &on);
241        Ok(SymmetricHashJoinExec {
242            left,
243            right,
244            on,
245            filter,
246            join_type: *join_type,
247            random_state,
248            metrics: ExecutionPlanMetricsSet::new(),
249            column_indices,
250            null_equals_null,
251            left_sort_exprs,
252            right_sort_exprs,
253            mode,
254            cache,
255        })
256    }
257
258    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
259    fn compute_properties(
260        left: &Arc<dyn ExecutionPlan>,
261        right: &Arc<dyn ExecutionPlan>,
262        schema: SchemaRef,
263        join_type: JoinType,
264        join_on: JoinOnRef,
265    ) -> PlanProperties {
266        // Calculate equivalence properties:
267        let eq_properties = join_equivalence_properties(
268            left.equivalence_properties().clone(),
269            right.equivalence_properties().clone(),
270            &join_type,
271            schema,
272            &[false, false],
273            // Has alternating probe side
274            None,
275            join_on,
276        );
277
278        let output_partitioning =
279            symmetric_join_output_partitioning(left, right, &join_type);
280
281        PlanProperties::new(
282            eq_properties,
283            output_partitioning,
284            emission_type_from_children([left, right]),
285            boundedness_from_children([left, right]),
286        )
287    }
288
289    /// left stream
290    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
291        &self.left
292    }
293
294    /// right stream
295    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
296        &self.right
297    }
298
299    /// Set of common columns used to join on
300    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
301        &self.on
302    }
303
304    /// Filters applied before join output
305    pub fn filter(&self) -> Option<&JoinFilter> {
306        self.filter.as_ref()
307    }
308
309    /// How the join is performed
310    pub fn join_type(&self) -> &JoinType {
311        &self.join_type
312    }
313
314    /// Get null_equals_null
315    pub fn null_equals_null(&self) -> bool {
316        self.null_equals_null
317    }
318
319    /// Get partition mode
320    pub fn partition_mode(&self) -> StreamJoinPartitionMode {
321        self.mode
322    }
323
324    /// Get left_sort_exprs
325    pub fn left_sort_exprs(&self) -> Option<&LexOrdering> {
326        self.left_sort_exprs.as_ref()
327    }
328
329    /// Get right_sort_exprs
330    pub fn right_sort_exprs(&self) -> Option<&LexOrdering> {
331        self.right_sort_exprs.as_ref()
332    }
333
334    /// Check if order information covers every column in the filter expression.
335    pub fn check_if_order_information_available(&self) -> Result<bool> {
336        if let Some(filter) = self.filter() {
337            let left = self.left();
338            if let Some(left_ordering) = left.output_ordering() {
339                let right = self.right();
340                if let Some(right_ordering) = right.output_ordering() {
341                    let left_convertible = convert_sort_expr_with_filter_schema(
342                        &JoinSide::Left,
343                        filter,
344                        &left.schema(),
345                        &left_ordering[0],
346                    )?
347                    .is_some();
348                    let right_convertible = convert_sort_expr_with_filter_schema(
349                        &JoinSide::Right,
350                        filter,
351                        &right.schema(),
352                        &right_ordering[0],
353                    )?
354                    .is_some();
355                    return Ok(left_convertible && right_convertible);
356                }
357            }
358        }
359        Ok(false)
360    }
361}
362
363impl DisplayAs for SymmetricHashJoinExec {
364    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
365        match t {
366            DisplayFormatType::Default | DisplayFormatType::Verbose => {
367                let display_filter = self.filter.as_ref().map_or_else(
368                    || "".to_string(),
369                    |f| format!(", filter={}", f.expression()),
370                );
371                let on = self
372                    .on
373                    .iter()
374                    .map(|(c1, c2)| format!("({}, {})", c1, c2))
375                    .collect::<Vec<String>>()
376                    .join(", ");
377                write!(
378                    f,
379                    "SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}",
380                    self.mode, self.join_type, on, display_filter
381                )
382            }
383        }
384    }
385}
386
387impl ExecutionPlan for SymmetricHashJoinExec {
388    fn name(&self) -> &'static str {
389        "SymmetricHashJoinExec"
390    }
391
392    fn as_any(&self) -> &dyn Any {
393        self
394    }
395
396    fn properties(&self) -> &PlanProperties {
397        &self.cache
398    }
399
400    fn required_input_distribution(&self) -> Vec<Distribution> {
401        match self.mode {
402            StreamJoinPartitionMode::Partitioned => {
403                let (left_expr, right_expr) = self
404                    .on
405                    .iter()
406                    .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
407                    .unzip();
408                vec![
409                    Distribution::HashPartitioned(left_expr),
410                    Distribution::HashPartitioned(right_expr),
411                ]
412            }
413            StreamJoinPartitionMode::SinglePartition => {
414                vec![Distribution::SinglePartition, Distribution::SinglePartition]
415            }
416        }
417    }
418
419    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
420        vec![
421            self.left_sort_exprs
422                .as_ref()
423                .cloned()
424                .map(LexRequirement::from),
425            self.right_sort_exprs
426                .as_ref()
427                .cloned()
428                .map(LexRequirement::from),
429        ]
430    }
431
432    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
433        vec![&self.left, &self.right]
434    }
435
436    fn with_new_children(
437        self: Arc<Self>,
438        children: Vec<Arc<dyn ExecutionPlan>>,
439    ) -> Result<Arc<dyn ExecutionPlan>> {
440        Ok(Arc::new(SymmetricHashJoinExec::try_new(
441            Arc::clone(&children[0]),
442            Arc::clone(&children[1]),
443            self.on.clone(),
444            self.filter.clone(),
445            &self.join_type,
446            self.null_equals_null,
447            self.left_sort_exprs.clone(),
448            self.right_sort_exprs.clone(),
449            self.mode,
450        )?))
451    }
452
453    fn metrics(&self) -> Option<MetricsSet> {
454        Some(self.metrics.clone_inner())
455    }
456
457    fn statistics(&self) -> Result<Statistics> {
458        // TODO stats: it is not possible in general to know the output size of joins
459        Ok(Statistics::new_unknown(&self.schema()))
460    }
461
462    fn execute(
463        &self,
464        partition: usize,
465        context: Arc<TaskContext>,
466    ) -> Result<SendableRecordBatchStream> {
467        let left_partitions = self.left.output_partitioning().partition_count();
468        let right_partitions = self.right.output_partitioning().partition_count();
469        if left_partitions != right_partitions {
470            return internal_err!(
471                "Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
472                 consider using RepartitionExec"
473            );
474        }
475        // If `filter_state` and `filter` are both present, then calculate sorted
476        // filter expressions for both sides, and build an expression graph.
477        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match (
478            self.left_sort_exprs(),
479            self.right_sort_exprs(),
480            &self.filter,
481        ) {
482            (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
483                let (left, right, graph) = prepare_sorted_exprs(
484                    filter,
485                    &self.left,
486                    &self.right,
487                    left_sort_exprs,
488                    right_sort_exprs,
489                )?;
490                (Some(left), Some(right), Some(graph))
491            }
492            // If `filter_state` or `filter` is not present, then return None
493            // for all three values:
494            _ => (None, None, None),
495        };
496
497        let (on_left, on_right) = self.on.iter().cloned().unzip();
498
499        let left_side_joiner =
500            OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
501        let right_side_joiner =
502            OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());
503
504        let left_stream = self.left.execute(partition, Arc::clone(&context))?;
505
506        let right_stream = self.right.execute(partition, Arc::clone(&context))?;
507
508        let batch_size = context.session_config().batch_size();
509        let enforce_batch_size_in_joins =
510            context.session_config().enforce_batch_size_in_joins();
511
512        let reservation = Arc::new(Mutex::new(
513            MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
514                .register(context.memory_pool()),
515        ));
516        if let Some(g) = graph.as_ref() {
517            reservation.lock().try_grow(g.size())?;
518        }
519
520        if enforce_batch_size_in_joins {
521            Ok(Box::pin(SymmetricHashJoinStream {
522                left_stream,
523                right_stream,
524                schema: self.schema(),
525                filter: self.filter.clone(),
526                join_type: self.join_type,
527                random_state: self.random_state.clone(),
528                left: left_side_joiner,
529                right: right_side_joiner,
530                column_indices: self.column_indices.clone(),
531                metrics: StreamJoinMetrics::new(partition, &self.metrics),
532                graph,
533                left_sorted_filter_expr,
534                right_sorted_filter_expr,
535                null_equals_null: self.null_equals_null,
536                state: SHJStreamState::PullRight,
537                reservation,
538                batch_transformer: BatchSplitter::new(batch_size),
539            }))
540        } else {
541            Ok(Box::pin(SymmetricHashJoinStream {
542                left_stream,
543                right_stream,
544                schema: self.schema(),
545                filter: self.filter.clone(),
546                join_type: self.join_type,
547                random_state: self.random_state.clone(),
548                left: left_side_joiner,
549                right: right_side_joiner,
550                column_indices: self.column_indices.clone(),
551                metrics: StreamJoinMetrics::new(partition, &self.metrics),
552                graph,
553                left_sorted_filter_expr,
554                right_sorted_filter_expr,
555                null_equals_null: self.null_equals_null,
556                state: SHJStreamState::PullRight,
557                reservation,
558                batch_transformer: NoopBatchTransformer::new(),
559            }))
560        }
561    }
562
563    /// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done,
564    /// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan.
565    /// Otherwise, it returns None.
566    fn try_swapping_with_projection(
567        &self,
568        projection: &ProjectionExec,
569    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
570        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
571        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
572        else {
573            return Ok(None);
574        };
575
576        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
577            self.left().schema().fields().len(),
578            &projection_as_columns,
579        );
580
581        if !join_allows_pushdown(
582            &projection_as_columns,
583            &self.schema(),
584            far_right_left_col_ind,
585            far_left_right_col_ind,
586        ) {
587            return Ok(None);
588        }
589
590        let Some(new_on) = update_join_on(
591            &projection_as_columns[0..=far_right_left_col_ind as _],
592            &projection_as_columns[far_left_right_col_ind as _..],
593            self.on(),
594            self.left().schema().fields().len(),
595        ) else {
596            return Ok(None);
597        };
598
599        let new_filter = if let Some(filter) = self.filter() {
600            match update_join_filter(
601                &projection_as_columns[0..=far_right_left_col_ind as _],
602                &projection_as_columns[far_left_right_col_ind as _..],
603                filter,
604                self.left().schema().fields().len(),
605            ) {
606                Some(updated_filter) => Some(updated_filter),
607                None => return Ok(None),
608            }
609        } else {
610            None
611        };
612
613        let (new_left, new_right) = new_join_children(
614            &projection_as_columns,
615            far_right_left_col_ind,
616            far_left_right_col_ind,
617            self.left(),
618            self.right(),
619        )?;
620
621        Ok(Some(Arc::new(SymmetricHashJoinExec::try_new(
622            Arc::new(new_left),
623            Arc::new(new_right),
624            new_on,
625            new_filter,
626            self.join_type(),
627            self.null_equals_null(),
628            self.right()
629                .output_ordering()
630                .map(|p| LexOrdering::new(p.to_vec())),
631            self.left()
632                .output_ordering()
633                .map(|p| LexOrdering::new(p.to_vec())),
634            self.partition_mode(),
635        )?)))
636    }
637}
638
639/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
640struct SymmetricHashJoinStream<T> {
641    /// Input streams
642    left_stream: SendableRecordBatchStream,
643    right_stream: SendableRecordBatchStream,
644    /// Input schema
645    schema: Arc<Schema>,
646    /// join filter
647    filter: Option<JoinFilter>,
648    /// type of the join
649    join_type: JoinType,
650    // left hash joiner
651    left: OneSideHashJoiner,
652    /// right hash joiner
653    right: OneSideHashJoiner,
654    /// Information of index and left / right placement of columns
655    column_indices: Vec<ColumnIndex>,
656    // Expression graph for range pruning.
657    graph: Option<ExprIntervalGraph>,
658    // Left globally sorted filter expr
659    left_sorted_filter_expr: Option<SortedFilterExpr>,
660    // Right globally sorted filter expr
661    right_sorted_filter_expr: Option<SortedFilterExpr>,
662    /// Random state used for hashing initialization
663    random_state: RandomState,
664    /// If null_equals_null is true, null == null else null != null
665    null_equals_null: bool,
666    /// Metrics
667    metrics: StreamJoinMetrics,
668    /// Memory reservation
669    reservation: SharedMemoryReservation,
670    /// State machine for input execution
671    state: SHJStreamState,
672    /// Transforms the output batch before returning.
673    batch_transformer: T,
674}
675
676impl<T: BatchTransformer + Unpin + Send> RecordBatchStream
677    for SymmetricHashJoinStream<T>
678{
679    fn schema(&self) -> SchemaRef {
680        Arc::clone(&self.schema)
681    }
682}
683
684impl<T: BatchTransformer + Unpin + Send> Stream for SymmetricHashJoinStream<T> {
685    type Item = Result<RecordBatch>;
686
687    fn poll_next(
688        mut self: std::pin::Pin<&mut Self>,
689        cx: &mut Context<'_>,
690    ) -> Poll<Option<Self::Item>> {
691        self.poll_next_impl(cx)
692    }
693}
694
695/// Determine the pruning length for `buffer`.
696///
697/// This function evaluates the build side filter expression, converts the
698/// result into an array and determines the pruning length by performing a
699/// binary search on the array.
700///
701/// # Arguments
702///
703/// * `buffer`: The record batch to be pruned.
704/// * `build_side_filter_expr`: The filter expression on the build side used
705///   to determine the pruning length.
706///
707/// # Returns
708///
709/// A [Result] object that contains the pruning length. The function will return
710/// an error if
711/// - there is an issue evaluating the build side filter expression;
712/// - there is an issue converting the build side filter expression into an array
713fn determine_prune_length(
714    buffer: &RecordBatch,
715    build_side_filter_expr: &SortedFilterExpr,
716) -> Result<usize> {
717    let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
718    let interval = build_side_filter_expr.interval();
719    // Evaluate the build side filter expression and convert it into an array
720    let batch_arr = origin_sorted_expr
721        .expr
722        .evaluate(buffer)?
723        .into_array(buffer.num_rows())?;
724
725    // Get the lower or upper interval based on the sort direction
726    let target = if origin_sorted_expr.options.descending {
727        interval.upper().clone()
728    } else {
729        interval.lower().clone()
730    };
731
732    // Perform binary search on the array to determine the length of the record batch to be pruned
733    bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
734}
735
736/// This method determines if the result of the join should be produced in the final step or not.
737///
738/// # Arguments
739///
740/// * `build_side` - Enum indicating the side of the join used as the build side.
741/// * `join_type` - Enum indicating the type of join to be performed.
742///
743/// # Returns
744///
745/// A boolean indicating whether the result of the join should be produced in the final step or not.
746/// The result will be true if the build side is JoinSide::Left and the join type is one of
747/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi.
748/// If the build side is JoinSide::Right, the result will be true if the join type
749/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi.
750fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
751    if build_side == JoinSide::Left {
752        matches!(
753            join_type,
754            JoinType::Left
755                | JoinType::LeftAnti
756                | JoinType::Full
757                | JoinType::LeftSemi
758                | JoinType::LeftMark
759        )
760    } else {
761        matches!(
762            join_type,
763            JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi
764        )
765    }
766}
767
768/// Calculate indices by join type.
769///
770/// This method returns a tuple of two arrays: build and probe indices.
771/// The length of both arrays will be the same.
772///
773/// # Arguments
774///
775/// * `build_side`: Join side which defines the build side.
776/// * `prune_length`: Length of the prune data.
777/// * `visited_rows`: Hash set of visited rows of the build side.
778/// * `deleted_offset`: Deleted offset of the build side.
779/// * `join_type`: The type of join to be performed.
780///
781/// # Returns
782///
783/// A tuple of two arrays of primitive types representing the build and probe indices.
784///
785fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
786    build_side: JoinSide,
787    prune_length: usize,
788    visited_rows: &HashSet<usize>,
789    deleted_offset: usize,
790    join_type: JoinType,
791) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
792where
793    NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
794{
795    // Store the result in a tuple
796    let result = match (build_side, join_type) {
797        (JoinSide::Left, JoinType::LeftMark) => {
798            let build_indices = (0..prune_length)
799                .map(L::Native::from_usize)
800                .collect::<PrimitiveArray<L>>();
801            let probe_indices = (0..prune_length)
802                .map(|idx| {
803                    // For mark join we output a dummy index 0 to indicate the row had a match
804                    visited_rows
805                        .contains(&(idx + deleted_offset))
806                        .then_some(R::Native::from_usize(0).unwrap())
807                })
808                .collect();
809            (build_indices, probe_indices)
810        }
811        // In the case of `Left` or `Right` join, or `Full` join, get the anti indices
812        (JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
813        | (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
814        | (_, JoinType::Full) => {
815            let build_unmatched_indices =
816                get_pruning_anti_indices(prune_length, deleted_offset, visited_rows);
817            let mut builder =
818                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
819            builder.append_nulls(build_unmatched_indices.len());
820            let probe_indices = builder.finish();
821            (build_unmatched_indices, probe_indices)
822        }
823        // In the case of `LeftSemi` or `RightSemi` join, get the semi indices
824        (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
825            let build_unmatched_indices =
826                get_pruning_semi_indices(prune_length, deleted_offset, visited_rows);
827            let mut builder =
828                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
829            builder.append_nulls(build_unmatched_indices.len());
830            let probe_indices = builder.finish();
831            (build_unmatched_indices, probe_indices)
832        }
833        // The case of other join types is not considered
834        _ => unreachable!(),
835    };
836    Ok(result)
837}
838
839/// This function produces unmatched record results based on the build side,
840/// join type and other parameters.
841///
842/// The method uses first `prune_length` rows from the build side input buffer
843/// to produce results.
844///
845/// # Arguments
846///
847/// * `output_schema` - The schema of the final output record batch.
848/// * `prune_length` - The length of the determined prune length.
849/// * `probe_schema` - The schema of the probe [RecordBatch].
850/// * `join_type` - The type of join to be performed.
851/// * `column_indices` - Indices of columns that are being joined.
852///
853/// # Returns
854///
855/// * `Option<RecordBatch>` - The final output record batch if required, otherwise [None].
856pub(crate) fn build_side_determined_results(
857    build_hash_joiner: &OneSideHashJoiner,
858    output_schema: &SchemaRef,
859    prune_length: usize,
860    probe_schema: SchemaRef,
861    join_type: JoinType,
862    column_indices: &[ColumnIndex],
863) -> Result<Option<RecordBatch>> {
864    // Check if we need to produce a result in the final output:
865    if prune_length > 0
866        && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type)
867    {
868        // Calculate the indices for build and probe sides based on join type and build side:
869        let (build_indices, probe_indices) = calculate_indices_by_join_type(
870            build_hash_joiner.build_side,
871            prune_length,
872            &build_hash_joiner.visited_rows,
873            build_hash_joiner.deleted_offset,
874            join_type,
875        )?;
876
877        // Create an empty probe record batch:
878        let empty_probe_batch = RecordBatch::new_empty(probe_schema);
879        // Build the final result from the indices of build and probe sides:
880        build_batch_from_indices(
881            output_schema.as_ref(),
882            &build_hash_joiner.input_buffer,
883            &empty_probe_batch,
884            &build_indices,
885            &probe_indices,
886            column_indices,
887            build_hash_joiner.build_side,
888        )
889        .map(|batch| (batch.num_rows() > 0).then_some(batch))
890    } else {
891        // If we don't need to produce a result, return None
892        Ok(None)
893    }
894}
895
896/// This method performs a join between the build side input buffer and the probe side batch.
897///
898/// # Arguments
899///
900/// * `build_hash_joiner` - Build side hash joiner
901/// * `probe_hash_joiner` - Probe side hash joiner
902/// * `schema` - A reference to the schema of the output record batch.
903/// * `join_type` - The type of join to be performed.
904/// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
905/// * `filter` - An optional filter on the join condition.
906/// * `probe_batch` - The second record batch to be joined.
907/// * `column_indices` - An array of columns to be selected for the result of the join.
908/// * `random_state` - The random state for the join.
909/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining.
910///
911/// # Returns
912///
913/// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`.
914/// If the join type is one of the above four, the function will return [None].
915#[allow(clippy::too_many_arguments)]
916pub(crate) fn join_with_probe_batch(
917    build_hash_joiner: &mut OneSideHashJoiner,
918    probe_hash_joiner: &mut OneSideHashJoiner,
919    schema: &SchemaRef,
920    join_type: JoinType,
921    filter: Option<&JoinFilter>,
922    probe_batch: &RecordBatch,
923    column_indices: &[ColumnIndex],
924    random_state: &RandomState,
925    null_equals_null: bool,
926) -> Result<Option<RecordBatch>> {
927    if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
928        return Ok(None);
929    }
930    let (build_indices, probe_indices) = lookup_join_hashmap(
931        &build_hash_joiner.hashmap,
932        &build_hash_joiner.input_buffer,
933        probe_batch,
934        &build_hash_joiner.on,
935        &probe_hash_joiner.on,
936        random_state,
937        null_equals_null,
938        &mut build_hash_joiner.hashes_buffer,
939        Some(build_hash_joiner.deleted_offset),
940    )?;
941
942    let (build_indices, probe_indices) = if let Some(filter) = filter {
943        apply_join_filter_to_indices(
944            &build_hash_joiner.input_buffer,
945            probe_batch,
946            build_indices,
947            probe_indices,
948            filter,
949            build_hash_joiner.build_side,
950        )?
951    } else {
952        (build_indices, probe_indices)
953    };
954
955    if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
956        record_visited_indices(
957            &mut build_hash_joiner.visited_rows,
958            build_hash_joiner.deleted_offset,
959            &build_indices,
960        );
961    }
962    if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) {
963        record_visited_indices(
964            &mut probe_hash_joiner.visited_rows,
965            probe_hash_joiner.offset,
966            &probe_indices,
967        );
968    }
969    if matches!(
970        join_type,
971        JoinType::LeftAnti
972            | JoinType::RightAnti
973            | JoinType::LeftSemi
974            | JoinType::LeftMark
975            | JoinType::RightSemi
976    ) {
977        Ok(None)
978    } else {
979        build_batch_from_indices(
980            schema,
981            &build_hash_joiner.input_buffer,
982            probe_batch,
983            &build_indices,
984            &probe_indices,
985            column_indices,
986            build_hash_joiner.build_side,
987        )
988        .map(|batch| (batch.num_rows() > 0).then_some(batch))
989    }
990}
991
992/// This method performs lookups against JoinHashMap by hash values of join-key columns, and handles potential
993/// hash collisions.
994///
995/// # Arguments
996///
997/// * `build_hashmap` - hashmap collected from build side data.
998/// * `build_batch` - Build side record batch.
999/// * `probe_batch` - Probe side record batch.
1000/// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join.
1001/// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
1002/// * `random_state` - The random state for the join.
1003/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining.
1004/// * `hashes_buffer` - Buffer used for probe side keys hash calculation.
1005/// * `deleted_offset` - deleted offset for build side data.
1006///
1007/// # Returns
1008///
1009/// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side,
1010/// matched by join key columns.
1011#[allow(clippy::too_many_arguments)]
1012fn lookup_join_hashmap(
1013    build_hashmap: &PruningJoinHashMap,
1014    build_batch: &RecordBatch,
1015    probe_batch: &RecordBatch,
1016    build_on: &[PhysicalExprRef],
1017    probe_on: &[PhysicalExprRef],
1018    random_state: &RandomState,
1019    null_equals_null: bool,
1020    hashes_buffer: &mut Vec<u64>,
1021    deleted_offset: Option<usize>,
1022) -> Result<(UInt64Array, UInt32Array)> {
1023    let keys_values = probe_on
1024        .iter()
1025        .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))
1026        .collect::<Result<Vec<_>>>()?;
1027    let build_join_values = build_on
1028        .iter()
1029        .map(|c| c.evaluate(build_batch)?.into_array(build_batch.num_rows()))
1030        .collect::<Result<Vec<_>>>()?;
1031
1032    hashes_buffer.clear();
1033    hashes_buffer.resize(probe_batch.num_rows(), 0);
1034    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1035
1036    // As SymmetricHashJoin uses LIFO JoinHashMap, the chained list algorithm
1037    // will return build indices for each probe row in a reverse order as such:
1038    // Build Indices: [5, 4, 3]
1039    // Probe Indices: [1, 1, 1]
1040    //
1041    // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side.
1042    // Let's consider probe rows [0,1] as an example:
1043    //
1044    // When the probe iteration sequence is reversed, the following pairings can be derived:
1045    //
1046    // For probe row 1:
1047    //     (5, 1)
1048    //     (4, 1)
1049    //     (3, 1)
1050    //
1051    // For probe row 0:
1052    //     (5, 0)
1053    //     (4, 0)
1054    //     (3, 0)
1055    //
1056    // After reversing both sets of indices, we obtain reversed indices:
1057    //
1058    //     (3,0)
1059    //     (4,0)
1060    //     (5,0)
1061    //     (3,1)
1062    //     (4,1)
1063    //     (5,1)
1064    //
1065    // With this approach, the lexicographic order on both the probe side and the build side is preserved.
1066    let (mut matched_probe, mut matched_build) = build_hashmap
1067        .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset);
1068
1069    matched_probe.reverse();
1070    matched_build.reverse();
1071
1072    let build_indices: UInt64Array = matched_build.into();
1073    let probe_indices: UInt32Array = matched_probe.into();
1074
1075    let (build_indices, probe_indices) = equal_rows_arr(
1076        &build_indices,
1077        &probe_indices,
1078        &build_join_values,
1079        &keys_values,
1080        null_equals_null,
1081    )?;
1082
1083    Ok((build_indices, probe_indices))
1084}
1085
1086pub struct OneSideHashJoiner {
1087    /// Build side
1088    build_side: JoinSide,
1089    /// Input record batch buffer
1090    pub input_buffer: RecordBatch,
1091    /// Columns from the side
1092    pub(crate) on: Vec<PhysicalExprRef>,
1093    /// Hashmap
1094    pub(crate) hashmap: PruningJoinHashMap,
1095    /// Reuse the hashes buffer
1096    pub(crate) hashes_buffer: Vec<u64>,
1097    /// Matched rows
1098    pub(crate) visited_rows: HashSet<usize>,
1099    /// Offset
1100    pub(crate) offset: usize,
1101    /// Deleted offset
1102    pub(crate) deleted_offset: usize,
1103}
1104
1105impl OneSideHashJoiner {
1106    pub fn size(&self) -> usize {
1107        let mut size = 0;
1108        size += size_of_val(self);
1109        size += size_of_val(&self.build_side);
1110        size += self.input_buffer.get_array_memory_size();
1111        size += size_of_val(&self.on);
1112        size += self.hashmap.size();
1113        size += self.hashes_buffer.capacity() * size_of::<u64>();
1114        size += self.visited_rows.capacity() * size_of::<usize>();
1115        size += size_of_val(&self.offset);
1116        size += size_of_val(&self.deleted_offset);
1117        size
1118    }
1119    pub fn new(
1120        build_side: JoinSide,
1121        on: Vec<PhysicalExprRef>,
1122        schema: SchemaRef,
1123    ) -> Self {
1124        Self {
1125            build_side,
1126            input_buffer: RecordBatch::new_empty(schema),
1127            on,
1128            hashmap: PruningJoinHashMap::with_capacity(0),
1129            hashes_buffer: vec![],
1130            visited_rows: HashSet::new(),
1131            offset: 0,
1132            deleted_offset: 0,
1133        }
1134    }
1135
1136    /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch.
1137    ///
1138    /// # Arguments
1139    ///
1140    /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer
1141    /// * `random_state` - The random state used to hash values
1142    ///
1143    /// # Returns
1144    ///
1145    /// Returns a [Result] encapsulating any intermediate errors.
1146    pub(crate) fn update_internal_state(
1147        &mut self,
1148        batch: &RecordBatch,
1149        random_state: &RandomState,
1150    ) -> Result<()> {
1151        // Merge the incoming batch with the existing input buffer:
1152        self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?;
1153        // Resize the hashes buffer to the number of rows in the incoming batch:
1154        self.hashes_buffer.resize(batch.num_rows(), 0);
1155        // Get allocation_info before adding the item
1156        // Update the hashmap with the join key values and hashes of the incoming batch:
1157        update_hash(
1158            &self.on,
1159            batch,
1160            &mut self.hashmap,
1161            self.offset,
1162            random_state,
1163            &mut self.hashes_buffer,
1164            self.deleted_offset,
1165            false,
1166        )?;
1167        Ok(())
1168    }
1169
1170    /// Calculate prune length.
1171    ///
1172    /// # Arguments
1173    ///
1174    /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression..
1175    /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression.
1176    /// * `graph` - A mutable reference to the physical expression graph.
1177    ///
1178    /// # Returns
1179    ///
1180    /// A Result object that contains the pruning length.
1181    pub(crate) fn calculate_prune_length_with_probe_batch(
1182        &mut self,
1183        build_side_sorted_filter_expr: &mut SortedFilterExpr,
1184        probe_side_sorted_filter_expr: &mut SortedFilterExpr,
1185        graph: &mut ExprIntervalGraph,
1186    ) -> Result<usize> {
1187        // Return early if the input buffer is empty:
1188        if self.input_buffer.num_rows() == 0 {
1189            return Ok(0);
1190        }
1191        // Process the build and probe side sorted filter expressions if both are present:
1192        // Collect the sorted filter expressions into a vector of (node_index, interval) tuples:
1193        let mut filter_intervals = vec![];
1194        for expr in [
1195            &build_side_sorted_filter_expr,
1196            &probe_side_sorted_filter_expr,
1197        ] {
1198            filter_intervals.push((expr.node_index(), expr.interval().clone()))
1199        }
1200        // Update the physical expression graph using the join filter intervals:
1201        graph.update_ranges(&mut filter_intervals, Interval::CERTAINLY_TRUE)?;
1202        // Extract the new join filter interval for the build side:
1203        let calculated_build_side_interval = filter_intervals.remove(0).1;
1204        // If the intervals have not changed, return early without pruning:
1205        if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) {
1206            return Ok(0);
1207        }
1208        // Update the build side interval and determine the pruning length:
1209        build_side_sorted_filter_expr.set_interval(calculated_build_side_interval);
1210
1211        determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr)
1212    }
1213
1214    pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> {
1215        // Prune the hash values:
1216        self.hashmap.prune_hash_values(
1217            prune_length,
1218            self.deleted_offset as u64,
1219            HASHMAP_SHRINK_SCALE_FACTOR,
1220        );
1221        // Remove pruned rows from the visited rows set:
1222        for row in self.deleted_offset..(self.deleted_offset + prune_length) {
1223            self.visited_rows.remove(&row);
1224        }
1225        // Update the input buffer after pruning:
1226        self.input_buffer = self
1227            .input_buffer
1228            .slice(prune_length, self.input_buffer.num_rows() - prune_length);
1229        // Increment the deleted offset:
1230        self.deleted_offset += prune_length;
1231        Ok(())
1232    }
1233}
1234
1235/// `SymmetricHashJoinStream` manages incremental join operations between two
1236/// streams. Unlike traditional join approaches that need to scan one side of
1237/// the join fully before proceeding, `SymmetricHashJoinStream` facilitates
1238/// more dynamic join operations by working with streams as they emit data. This
1239/// approach allows for more efficient processing, particularly in scenarios
1240/// where waiting for complete data materialization is not feasible or optimal.
1241/// The trait provides a framework for handling various states of such a join
1242/// process, ensuring that join logic is efficiently executed as data becomes
1243/// available from either stream.
1244///
1245/// This implementation performs eager joins of data from two different asynchronous
1246/// streams, typically referred to as left and right streams. The implementation
1247/// provides a comprehensive set of methods to control and execute the join
1248/// process, leveraging the states defined in `SHJStreamState`. Methods are
1249/// primarily focused on asynchronously fetching data batches from each stream,
1250/// processing them, and managing transitions between various states of the join.
1251///
1252/// This implementations use a state machine approach to navigate different
1253/// stages of the join operation, handling data from both streams and determining
1254/// when the join completes.
1255///
1256/// State Transitions:
1257/// - From `PullLeft` to `PullRight` or `LeftExhausted`:
1258///   - In `fetch_next_from_left_stream`, when fetching a batch from the left stream:
1259///     - On success (`Some(Ok(batch))`), state transitions to `PullRight` for
1260///       processing the batch.
1261///     - On error (`Some(Err(e))`), the error is returned, and the state remains
1262///       unchanged.
1263///     - On no data (`None`), state changes to `LeftExhausted`, returning `Continue`
1264///       to proceed with the join process.
1265/// - From `PullRight` to `PullLeft` or `RightExhausted`:
1266///   - In `fetch_next_from_right_stream`, when fetching from the right stream:
1267///     - If a batch is available, state changes to `PullLeft` for processing.
1268///     - On error, the error is returned without changing the state.
1269///     - If right stream is exhausted (`None`), state transitions to `RightExhausted`,
1270///       with a `Continue` result.
1271/// - Handling `RightExhausted` and `LeftExhausted`:
1272///   - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios
1273///     when streams are exhausted:
1274///     - They attempt to continue processing with the other stream.
1275///     - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`.
1276/// - Transition to `BothExhausted { final_result: true }`:
1277///   - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are
1278///     exhausted, indicating completion of processing and availability of final results.
1279impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
1280    /// Implements the main polling logic for the join stream.
1281    ///
1282    /// This method continuously checks the state of the join stream and
1283    /// acts accordingly by delegating the handling to appropriate sub-methods
1284    /// depending on the current state.
1285    ///
1286    /// # Arguments
1287    ///
1288    /// * `cx` - A context that facilitates cooperative non-blocking execution within a task.
1289    ///
1290    /// # Returns
1291    ///
1292    /// * `Poll<Option<Result<RecordBatch>>>` - A polled result, either a `RecordBatch` or None.
1293    fn poll_next_impl(
1294        &mut self,
1295        cx: &mut Context<'_>,
1296    ) -> Poll<Option<Result<RecordBatch>>> {
1297        loop {
1298            match self.batch_transformer.next() {
1299                None => {
1300                    let result = match self.state() {
1301                        SHJStreamState::PullRight => {
1302                            ready!(self.fetch_next_from_right_stream(cx))
1303                        }
1304                        SHJStreamState::PullLeft => {
1305                            ready!(self.fetch_next_from_left_stream(cx))
1306                        }
1307                        SHJStreamState::RightExhausted => {
1308                            ready!(self.handle_right_stream_end(cx))
1309                        }
1310                        SHJStreamState::LeftExhausted => {
1311                            ready!(self.handle_left_stream_end(cx))
1312                        }
1313                        SHJStreamState::BothExhausted {
1314                            final_result: false,
1315                        } => self.prepare_for_final_results_after_exhaustion(),
1316                        SHJStreamState::BothExhausted { final_result: true } => {
1317                            return Poll::Ready(None);
1318                        }
1319                    };
1320
1321                    match result? {
1322                        StatefulStreamResult::Ready(None) => {
1323                            return Poll::Ready(None);
1324                        }
1325                        StatefulStreamResult::Ready(Some(batch)) => {
1326                            self.batch_transformer.set_batch(batch);
1327                        }
1328                        _ => {}
1329                    }
1330                }
1331                Some((batch, _)) => {
1332                    self.metrics.output_batches.add(1);
1333                    self.metrics.output_rows.add(batch.num_rows());
1334                    return Poll::Ready(Some(Ok(batch)));
1335                }
1336            }
1337        }
1338    }
1339    /// Asynchronously pulls the next batch from the right stream.
1340    ///
1341    /// This default implementation checks for the next value in the right stream.
1342    /// If a batch is found, the state is switched to `PullLeft`, and the batch handling
1343    /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`.
1344    ///
1345    /// # Returns
1346    ///
1347    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1348    fn fetch_next_from_right_stream(
1349        &mut self,
1350        cx: &mut Context<'_>,
1351    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1352        match ready!(self.right_stream().poll_next_unpin(cx)) {
1353            Some(Ok(batch)) => {
1354                if batch.num_rows() == 0 {
1355                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1356                }
1357                self.set_state(SHJStreamState::PullLeft);
1358                Poll::Ready(self.process_batch_from_right(batch))
1359            }
1360            Some(Err(e)) => Poll::Ready(Err(e)),
1361            None => {
1362                self.set_state(SHJStreamState::RightExhausted);
1363                Poll::Ready(Ok(StatefulStreamResult::Continue))
1364            }
1365        }
1366    }
1367
1368    /// Asynchronously pulls the next batch from the left stream.
1369    ///
1370    /// This default implementation checks for the next value in the left stream.
1371    /// If a batch is found, the state is switched to `PullRight`, and the batch handling
1372    /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`.
1373    ///
1374    /// # Returns
1375    ///
1376    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1377    fn fetch_next_from_left_stream(
1378        &mut self,
1379        cx: &mut Context<'_>,
1380    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1381        match ready!(self.left_stream().poll_next_unpin(cx)) {
1382            Some(Ok(batch)) => {
1383                if batch.num_rows() == 0 {
1384                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1385                }
1386                self.set_state(SHJStreamState::PullRight);
1387                Poll::Ready(self.process_batch_from_left(batch))
1388            }
1389            Some(Err(e)) => Poll::Ready(Err(e)),
1390            None => {
1391                self.set_state(SHJStreamState::LeftExhausted);
1392                Poll::Ready(Ok(StatefulStreamResult::Continue))
1393            }
1394        }
1395    }
1396
1397    /// Asynchronously handles the scenario when the right stream is exhausted.
1398    ///
1399    /// In this default implementation, when the right stream is exhausted, it attempts
1400    /// to pull from the left stream. If a batch is found in the left stream, it delegates
1401    /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set
1402    /// to indicate both streams are exhausted without final results yet.
1403    ///
1404    /// # Returns
1405    ///
1406    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1407    fn handle_right_stream_end(
1408        &mut self,
1409        cx: &mut Context<'_>,
1410    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1411        match ready!(self.left_stream().poll_next_unpin(cx)) {
1412            Some(Ok(batch)) => {
1413                if batch.num_rows() == 0 {
1414                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1415                }
1416                Poll::Ready(self.process_batch_after_right_end(batch))
1417            }
1418            Some(Err(e)) => Poll::Ready(Err(e)),
1419            None => {
1420                self.set_state(SHJStreamState::BothExhausted {
1421                    final_result: false,
1422                });
1423                Poll::Ready(Ok(StatefulStreamResult::Continue))
1424            }
1425        }
1426    }
1427
1428    /// Asynchronously handles the scenario when the left stream is exhausted.
1429    ///
1430    /// When the left stream is exhausted, this default
1431    /// implementation tries to pull from the right stream and delegates the batch
1432    /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state
1433    /// is updated to indicate so.
1434    ///
1435    /// # Returns
1436    ///
1437    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1438    fn handle_left_stream_end(
1439        &mut self,
1440        cx: &mut Context<'_>,
1441    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1442        match ready!(self.right_stream().poll_next_unpin(cx)) {
1443            Some(Ok(batch)) => {
1444                if batch.num_rows() == 0 {
1445                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1446                }
1447                Poll::Ready(self.process_batch_after_left_end(batch))
1448            }
1449            Some(Err(e)) => Poll::Ready(Err(e)),
1450            None => {
1451                self.set_state(SHJStreamState::BothExhausted {
1452                    final_result: false,
1453                });
1454                Poll::Ready(Ok(StatefulStreamResult::Continue))
1455            }
1456        }
1457    }
1458
1459    /// Handles the state when both streams are exhausted and final results are yet to be produced.
1460    ///
1461    /// This default implementation switches the state to indicate both streams are
1462    /// exhausted with final results and then invokes the handling for this specific
1463    /// scenario via `process_batches_before_finalization`.
1464    ///
1465    /// # Returns
1466    ///
1467    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after both streams are exhausted.
1468    fn prepare_for_final_results_after_exhaustion(
1469        &mut self,
1470    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1471        self.set_state(SHJStreamState::BothExhausted { final_result: true });
1472        self.process_batches_before_finalization()
1473    }
1474
1475    fn process_batch_from_right(
1476        &mut self,
1477        batch: RecordBatch,
1478    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1479        self.perform_join_for_given_side(batch, JoinSide::Right)
1480            .map(|maybe_batch| {
1481                if maybe_batch.is_some() {
1482                    StatefulStreamResult::Ready(maybe_batch)
1483                } else {
1484                    StatefulStreamResult::Continue
1485                }
1486            })
1487    }
1488
1489    fn process_batch_from_left(
1490        &mut self,
1491        batch: RecordBatch,
1492    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1493        self.perform_join_for_given_side(batch, JoinSide::Left)
1494            .map(|maybe_batch| {
1495                if maybe_batch.is_some() {
1496                    StatefulStreamResult::Ready(maybe_batch)
1497                } else {
1498                    StatefulStreamResult::Continue
1499                }
1500            })
1501    }
1502
1503    fn process_batch_after_left_end(
1504        &mut self,
1505        right_batch: RecordBatch,
1506    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1507        self.process_batch_from_right(right_batch)
1508    }
1509
1510    fn process_batch_after_right_end(
1511        &mut self,
1512        left_batch: RecordBatch,
1513    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1514        self.process_batch_from_left(left_batch)
1515    }
1516
1517    fn process_batches_before_finalization(
1518        &mut self,
1519    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1520        // Get the left side results:
1521        let left_result = build_side_determined_results(
1522            &self.left,
1523            &self.schema,
1524            self.left.input_buffer.num_rows(),
1525            self.right.input_buffer.schema(),
1526            self.join_type,
1527            &self.column_indices,
1528        )?;
1529        // Get the right side results:
1530        let right_result = build_side_determined_results(
1531            &self.right,
1532            &self.schema,
1533            self.right.input_buffer.num_rows(),
1534            self.left.input_buffer.schema(),
1535            self.join_type,
1536            &self.column_indices,
1537        )?;
1538
1539        // Combine the left and right results:
1540        let result = combine_two_batches(&self.schema, left_result, right_result)?;
1541
1542        // Return the result:
1543        if result.is_some() {
1544            return Ok(StatefulStreamResult::Ready(result));
1545        }
1546        Ok(StatefulStreamResult::Continue)
1547    }
1548
1549    fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
1550        &mut self.right_stream
1551    }
1552
1553    fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
1554        &mut self.left_stream
1555    }
1556
1557    fn set_state(&mut self, state: SHJStreamState) {
1558        self.state = state;
1559    }
1560
1561    fn state(&mut self) -> SHJStreamState {
1562        self.state.clone()
1563    }
1564
1565    fn size(&self) -> usize {
1566        let mut size = 0;
1567        size += size_of_val(&self.schema);
1568        size += size_of_val(&self.filter);
1569        size += size_of_val(&self.join_type);
1570        size += self.left.size();
1571        size += self.right.size();
1572        size += size_of_val(&self.column_indices);
1573        size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
1574        size += size_of_val(&self.left_sorted_filter_expr);
1575        size += size_of_val(&self.right_sorted_filter_expr);
1576        size += size_of_val(&self.random_state);
1577        size += size_of_val(&self.null_equals_null);
1578        size += size_of_val(&self.metrics);
1579        size
1580    }
1581
1582    /// Performs a join operation for the specified `probe_side` (either left or right).
1583    /// This function:
1584    /// 1. Determines which side is the probe and which is the build side.
1585    /// 2. Updates metrics based on the batch that was polled.
1586    /// 3. Executes the join with the given `probe_batch`.
1587    /// 4. Optionally computes anti-join results if all conditions are met.
1588    /// 5. Combines the results and returns a combined batch or `None` if no batch was produced.
1589    fn perform_join_for_given_side(
1590        &mut self,
1591        probe_batch: RecordBatch,
1592        probe_side: JoinSide,
1593    ) -> Result<Option<RecordBatch>> {
1594        let (
1595            probe_hash_joiner,
1596            build_hash_joiner,
1597            probe_side_sorted_filter_expr,
1598            build_side_sorted_filter_expr,
1599            probe_side_metrics,
1600        ) = if probe_side.eq(&JoinSide::Left) {
1601            (
1602                &mut self.left,
1603                &mut self.right,
1604                &mut self.left_sorted_filter_expr,
1605                &mut self.right_sorted_filter_expr,
1606                &mut self.metrics.left,
1607            )
1608        } else {
1609            (
1610                &mut self.right,
1611                &mut self.left,
1612                &mut self.right_sorted_filter_expr,
1613                &mut self.left_sorted_filter_expr,
1614                &mut self.metrics.right,
1615            )
1616        };
1617        // Update the metrics for the stream that was polled:
1618        probe_side_metrics.input_batches.add(1);
1619        probe_side_metrics.input_rows.add(probe_batch.num_rows());
1620        // Update the internal state of the hash joiner for the build side:
1621        probe_hash_joiner.update_internal_state(&probe_batch, &self.random_state)?;
1622        // Join the two sides:
1623        let equal_result = join_with_probe_batch(
1624            build_hash_joiner,
1625            probe_hash_joiner,
1626            &self.schema,
1627            self.join_type,
1628            self.filter.as_ref(),
1629            &probe_batch,
1630            &self.column_indices,
1631            &self.random_state,
1632            self.null_equals_null,
1633        )?;
1634        // Increment the offset for the probe hash joiner:
1635        probe_hash_joiner.offset += probe_batch.num_rows();
1636
1637        let anti_result = if let (
1638            Some(build_side_sorted_filter_expr),
1639            Some(probe_side_sorted_filter_expr),
1640            Some(graph),
1641        ) = (
1642            build_side_sorted_filter_expr.as_mut(),
1643            probe_side_sorted_filter_expr.as_mut(),
1644            self.graph.as_mut(),
1645        ) {
1646            // Calculate filter intervals:
1647            calculate_filter_expr_intervals(
1648                &build_hash_joiner.input_buffer,
1649                build_side_sorted_filter_expr,
1650                &probe_batch,
1651                probe_side_sorted_filter_expr,
1652            )?;
1653            let prune_length = build_hash_joiner
1654                .calculate_prune_length_with_probe_batch(
1655                    build_side_sorted_filter_expr,
1656                    probe_side_sorted_filter_expr,
1657                    graph,
1658                )?;
1659            let result = build_side_determined_results(
1660                build_hash_joiner,
1661                &self.schema,
1662                prune_length,
1663                probe_batch.schema(),
1664                self.join_type,
1665                &self.column_indices,
1666            )?;
1667            build_hash_joiner.prune_internal_state(prune_length)?;
1668            result
1669        } else {
1670            None
1671        };
1672
1673        // Combine results:
1674        let result = combine_two_batches(&self.schema, equal_result, anti_result)?;
1675        let capacity = self.size();
1676        self.metrics.stream_memory_usage.set(capacity);
1677        self.reservation.lock().try_resize(capacity)?;
1678        Ok(result)
1679    }
1680}
1681
1682/// Represents the various states of an symmetric hash join stream operation.
1683///
1684/// This enum is used to track the current state of streaming during a join
1685/// operation. It provides indicators as to which side of the join needs to be
1686/// pulled next or if one (or both) sides have been exhausted. This allows
1687/// for efficient management of resources and optimal performance during the
1688/// join process.
1689#[derive(Clone, Debug)]
1690pub enum SHJStreamState {
1691    /// Indicates that the next step should pull from the right side of the join.
1692    PullRight,
1693
1694    /// Indicates that the next step should pull from the left side of the join.
1695    PullLeft,
1696
1697    /// State representing that the right side of the join has been fully processed.
1698    RightExhausted,
1699
1700    /// State representing that the left side of the join has been fully processed.
1701    LeftExhausted,
1702
1703    /// Represents a state where both sides of the join are exhausted.
1704    ///
1705    /// The `final_result` field indicates whether the join operation has
1706    /// produced a final result or not.
1707    BothExhausted { final_result: bool },
1708}
1709
1710#[cfg(test)]
1711mod tests {
1712    use std::collections::HashMap;
1713    use std::sync::{LazyLock, Mutex};
1714
1715    use super::*;
1716    use crate::joins::test_utils::{
1717        build_sides_record_batches, compare_batches, complicated_filter,
1718        create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
1719        join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
1720        partitioned_sym_join_with_filter, split_record_batches,
1721    };
1722
1723    use arrow::compute::SortOptions;
1724    use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
1725    use datafusion_common::ScalarValue;
1726    use datafusion_execution::config::SessionConfig;
1727    use datafusion_expr::Operator;
1728    use datafusion_physical_expr::expressions::{binary, col, lit, Column};
1729    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1730
1731    use rstest::*;
1732
1733    const TABLE_SIZE: i32 = 30;
1734
1735    type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size)
1736    type TableValue = (Vec<RecordBatch>, Vec<RecordBatch>); // (left, right)
1737
1738    // Cache for storing tables
1739    static TABLE_CACHE: LazyLock<Mutex<HashMap<TableKey, TableValue>>> =
1740        LazyLock::new(|| Mutex::new(HashMap::new()));
1741
1742    fn get_or_create_table(
1743        cardinality: (i32, i32),
1744        batch_size: usize,
1745    ) -> Result<TableValue> {
1746        {
1747            let cache = TABLE_CACHE.lock().unwrap();
1748            if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) {
1749                return Ok(table.clone());
1750            }
1751        }
1752
1753        // If not, create the table
1754        let (left_batch, right_batch) =
1755            build_sides_record_batches(TABLE_SIZE, cardinality)?;
1756
1757        let (left_partition, right_partition) = (
1758            split_record_batches(&left_batch, batch_size)?,
1759            split_record_batches(&right_batch, batch_size)?,
1760        );
1761
1762        // Lock the cache again and store the table
1763        let mut cache = TABLE_CACHE.lock().unwrap();
1764
1765        // Store the table in the cache
1766        cache.insert(
1767            (cardinality.0, cardinality.1, batch_size),
1768            (left_partition.clone(), right_partition.clone()),
1769        );
1770
1771        Ok((left_partition, right_partition))
1772    }
1773
1774    pub async fn experiment(
1775        left: Arc<dyn ExecutionPlan>,
1776        right: Arc<dyn ExecutionPlan>,
1777        filter: Option<JoinFilter>,
1778        join_type: JoinType,
1779        on: JoinOn,
1780        task_ctx: Arc<TaskContext>,
1781    ) -> Result<()> {
1782        let first_batches = partitioned_sym_join_with_filter(
1783            Arc::clone(&left),
1784            Arc::clone(&right),
1785            on.clone(),
1786            filter.clone(),
1787            &join_type,
1788            false,
1789            Arc::clone(&task_ctx),
1790        )
1791        .await?;
1792        let second_batches = partitioned_hash_join_with_filter(
1793            left, right, on, filter, &join_type, false, task_ctx,
1794        )
1795        .await?;
1796        compare_batches(&first_batches, &second_batches);
1797        Ok(())
1798    }
1799
1800    #[rstest]
1801    #[tokio::test(flavor = "multi_thread")]
1802    async fn complex_join_all_one_ascending_numeric(
1803        #[values(
1804            JoinType::Inner,
1805            JoinType::Left,
1806            JoinType::Right,
1807            JoinType::RightSemi,
1808            JoinType::LeftSemi,
1809            JoinType::LeftAnti,
1810            JoinType::LeftMark,
1811            JoinType::RightAnti,
1812            JoinType::Full
1813        )]
1814        join_type: JoinType,
1815        #[values(
1816        (4, 5),
1817        (12, 17),
1818        )]
1819        cardinality: (i32, i32),
1820    ) -> Result<()> {
1821        // a + b > c + 10 AND a + b < c + 100
1822        let task_ctx = Arc::new(TaskContext::default());
1823
1824        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
1825
1826        let left_schema = &left_partition[0].schema();
1827        let right_schema = &right_partition[0].schema();
1828
1829        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
1830            expr: binary(
1831                col("la1", left_schema)?,
1832                Operator::Plus,
1833                col("la2", left_schema)?,
1834                left_schema,
1835            )?,
1836            options: SortOptions::default(),
1837        }]);
1838        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
1839            expr: col("ra1", right_schema)?,
1840            options: SortOptions::default(),
1841        }]);
1842        let (left, right) = create_memory_table(
1843            left_partition,
1844            right_partition,
1845            vec![left_sorted],
1846            vec![right_sorted],
1847        )?;
1848
1849        let on = vec![(
1850            binary(
1851                col("lc1", left_schema)?,
1852                Operator::Plus,
1853                lit(ScalarValue::Int32(Some(1))),
1854                left_schema,
1855            )?,
1856            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1857        )];
1858
1859        let intermediate_schema = Schema::new(vec![
1860            Field::new("0", DataType::Int32, true),
1861            Field::new("1", DataType::Int32, true),
1862            Field::new("2", DataType::Int32, true),
1863        ]);
1864        let filter_expr = complicated_filter(&intermediate_schema)?;
1865        let column_indices = vec![
1866            ColumnIndex {
1867                index: left_schema.index_of("la1")?,
1868                side: JoinSide::Left,
1869            },
1870            ColumnIndex {
1871                index: left_schema.index_of("la2")?,
1872                side: JoinSide::Left,
1873            },
1874            ColumnIndex {
1875                index: right_schema.index_of("ra1")?,
1876                side: JoinSide::Right,
1877            },
1878        ];
1879        let filter =
1880            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1881
1882        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1883        Ok(())
1884    }
1885
1886    #[rstest]
1887    #[tokio::test(flavor = "multi_thread")]
1888    async fn join_all_one_ascending_numeric(
1889        #[values(
1890            JoinType::Inner,
1891            JoinType::Left,
1892            JoinType::Right,
1893            JoinType::RightSemi,
1894            JoinType::LeftSemi,
1895            JoinType::LeftAnti,
1896            JoinType::LeftMark,
1897            JoinType::RightAnti,
1898            JoinType::Full
1899        )]
1900        join_type: JoinType,
1901        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1902    ) -> Result<()> {
1903        let task_ctx = Arc::new(TaskContext::default());
1904        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
1905
1906        let left_schema = &left_partition[0].schema();
1907        let right_schema = &right_partition[0].schema();
1908
1909        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
1910            expr: col("la1", left_schema)?,
1911            options: SortOptions::default(),
1912        }]);
1913        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
1914            expr: col("ra1", right_schema)?,
1915            options: SortOptions::default(),
1916        }]);
1917        let (left, right) = create_memory_table(
1918            left_partition,
1919            right_partition,
1920            vec![left_sorted],
1921            vec![right_sorted],
1922        )?;
1923
1924        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
1925
1926        let intermediate_schema = Schema::new(vec![
1927            Field::new("left", DataType::Int32, true),
1928            Field::new("right", DataType::Int32, true),
1929        ]);
1930        let filter_expr = join_expr_tests_fixture_i32(
1931            case_expr,
1932            col("left", &intermediate_schema)?,
1933            col("right", &intermediate_schema)?,
1934        );
1935        let column_indices = vec![
1936            ColumnIndex {
1937                index: 0,
1938                side: JoinSide::Left,
1939            },
1940            ColumnIndex {
1941                index: 0,
1942                side: JoinSide::Right,
1943            },
1944        ];
1945        let filter =
1946            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1947
1948        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1949        Ok(())
1950    }
1951
1952    #[rstest]
1953    #[tokio::test(flavor = "multi_thread")]
1954    async fn join_without_sort_information(
1955        #[values(
1956            JoinType::Inner,
1957            JoinType::Left,
1958            JoinType::Right,
1959            JoinType::RightSemi,
1960            JoinType::LeftSemi,
1961            JoinType::LeftAnti,
1962            JoinType::LeftMark,
1963            JoinType::RightAnti,
1964            JoinType::Full
1965        )]
1966        join_type: JoinType,
1967        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1968    ) -> Result<()> {
1969        let task_ctx = Arc::new(TaskContext::default());
1970        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
1971
1972        let left_schema = &left_partition[0].schema();
1973        let right_schema = &right_partition[0].schema();
1974        let (left, right) =
1975            create_memory_table(left_partition, right_partition, vec![], vec![])?;
1976
1977        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
1978
1979        let intermediate_schema = Schema::new(vec![
1980            Field::new("left", DataType::Int32, true),
1981            Field::new("right", DataType::Int32, true),
1982        ]);
1983        let filter_expr = join_expr_tests_fixture_i32(
1984            case_expr,
1985            col("left", &intermediate_schema)?,
1986            col("right", &intermediate_schema)?,
1987        );
1988        let column_indices = vec![
1989            ColumnIndex {
1990                index: 5,
1991                side: JoinSide::Left,
1992            },
1993            ColumnIndex {
1994                index: 5,
1995                side: JoinSide::Right,
1996            },
1997        ];
1998        let filter =
1999            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2000
2001        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2002        Ok(())
2003    }
2004
2005    #[rstest]
2006    #[tokio::test(flavor = "multi_thread")]
2007    async fn join_without_filter(
2008        #[values(
2009            JoinType::Inner,
2010            JoinType::Left,
2011            JoinType::Right,
2012            JoinType::RightSemi,
2013            JoinType::LeftSemi,
2014            JoinType::LeftAnti,
2015            JoinType::LeftMark,
2016            JoinType::RightAnti,
2017            JoinType::Full
2018        )]
2019        join_type: JoinType,
2020    ) -> Result<()> {
2021        let task_ctx = Arc::new(TaskContext::default());
2022        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2023        let left_schema = &left_partition[0].schema();
2024        let right_schema = &right_partition[0].schema();
2025        let (left, right) =
2026            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2027
2028        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2029        experiment(left, right, None, join_type, on, task_ctx).await?;
2030        Ok(())
2031    }
2032
2033    #[rstest]
2034    #[tokio::test(flavor = "multi_thread")]
2035    async fn join_all_one_descending_numeric_particular(
2036        #[values(
2037            JoinType::Inner,
2038            JoinType::Left,
2039            JoinType::Right,
2040            JoinType::RightSemi,
2041            JoinType::LeftSemi,
2042            JoinType::LeftAnti,
2043            JoinType::LeftMark,
2044            JoinType::RightAnti,
2045            JoinType::Full
2046        )]
2047        join_type: JoinType,
2048        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2049    ) -> Result<()> {
2050        let task_ctx = Arc::new(TaskContext::default());
2051        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2052        let left_schema = &left_partition[0].schema();
2053        let right_schema = &right_partition[0].schema();
2054        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2055            expr: col("la1_des", left_schema)?,
2056            options: SortOptions {
2057                descending: true,
2058                nulls_first: true,
2059            },
2060        }]);
2061        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2062            expr: col("ra1_des", right_schema)?,
2063            options: SortOptions {
2064                descending: true,
2065                nulls_first: true,
2066            },
2067        }]);
2068        let (left, right) = create_memory_table(
2069            left_partition,
2070            right_partition,
2071            vec![left_sorted],
2072            vec![right_sorted],
2073        )?;
2074
2075        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2076
2077        let intermediate_schema = Schema::new(vec![
2078            Field::new("left", DataType::Int32, true),
2079            Field::new("right", DataType::Int32, true),
2080        ]);
2081        let filter_expr = join_expr_tests_fixture_i32(
2082            case_expr,
2083            col("left", &intermediate_schema)?,
2084            col("right", &intermediate_schema)?,
2085        );
2086        let column_indices = vec![
2087            ColumnIndex {
2088                index: 5,
2089                side: JoinSide::Left,
2090            },
2091            ColumnIndex {
2092                index: 5,
2093                side: JoinSide::Right,
2094            },
2095        ];
2096        let filter =
2097            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2098
2099        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2100        Ok(())
2101    }
2102
2103    #[tokio::test(flavor = "multi_thread")]
2104    async fn build_null_columns_first() -> Result<()> {
2105        let join_type = JoinType::Full;
2106        let case_expr = 1;
2107        let session_config = SessionConfig::new().with_repartition_joins(false);
2108        let task_ctx = TaskContext::default().with_session_config(session_config);
2109        let task_ctx = Arc::new(task_ctx);
2110        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2111        let left_schema = &left_partition[0].schema();
2112        let right_schema = &right_partition[0].schema();
2113        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2114            expr: col("l_asc_null_first", left_schema)?,
2115            options: SortOptions {
2116                descending: false,
2117                nulls_first: true,
2118            },
2119        }]);
2120        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2121            expr: col("r_asc_null_first", right_schema)?,
2122            options: SortOptions {
2123                descending: false,
2124                nulls_first: true,
2125            },
2126        }]);
2127        let (left, right) = create_memory_table(
2128            left_partition,
2129            right_partition,
2130            vec![left_sorted],
2131            vec![right_sorted],
2132        )?;
2133
2134        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2135
2136        let intermediate_schema = Schema::new(vec![
2137            Field::new("left", DataType::Int32, true),
2138            Field::new("right", DataType::Int32, true),
2139        ]);
2140        let filter_expr = join_expr_tests_fixture_i32(
2141            case_expr,
2142            col("left", &intermediate_schema)?,
2143            col("right", &intermediate_schema)?,
2144        );
2145        let column_indices = vec![
2146            ColumnIndex {
2147                index: 6,
2148                side: JoinSide::Left,
2149            },
2150            ColumnIndex {
2151                index: 6,
2152                side: JoinSide::Right,
2153            },
2154        ];
2155        let filter =
2156            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2157        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2158        Ok(())
2159    }
2160
2161    #[tokio::test(flavor = "multi_thread")]
2162    async fn build_null_columns_last() -> Result<()> {
2163        let join_type = JoinType::Full;
2164        let case_expr = 1;
2165        let session_config = SessionConfig::new().with_repartition_joins(false);
2166        let task_ctx = TaskContext::default().with_session_config(session_config);
2167        let task_ctx = Arc::new(task_ctx);
2168        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2169
2170        let left_schema = &left_partition[0].schema();
2171        let right_schema = &right_partition[0].schema();
2172        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2173            expr: col("l_asc_null_last", left_schema)?,
2174            options: SortOptions {
2175                descending: false,
2176                nulls_first: false,
2177            },
2178        }]);
2179        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2180            expr: col("r_asc_null_last", right_schema)?,
2181            options: SortOptions {
2182                descending: false,
2183                nulls_first: false,
2184            },
2185        }]);
2186        let (left, right) = create_memory_table(
2187            left_partition,
2188            right_partition,
2189            vec![left_sorted],
2190            vec![right_sorted],
2191        )?;
2192
2193        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2194
2195        let intermediate_schema = Schema::new(vec![
2196            Field::new("left", DataType::Int32, true),
2197            Field::new("right", DataType::Int32, true),
2198        ]);
2199        let filter_expr = join_expr_tests_fixture_i32(
2200            case_expr,
2201            col("left", &intermediate_schema)?,
2202            col("right", &intermediate_schema)?,
2203        );
2204        let column_indices = vec![
2205            ColumnIndex {
2206                index: 7,
2207                side: JoinSide::Left,
2208            },
2209            ColumnIndex {
2210                index: 7,
2211                side: JoinSide::Right,
2212            },
2213        ];
2214        let filter =
2215            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2216
2217        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2218        Ok(())
2219    }
2220
2221    #[tokio::test(flavor = "multi_thread")]
2222    async fn build_null_columns_first_descending() -> Result<()> {
2223        let join_type = JoinType::Full;
2224        let cardinality = (10, 11);
2225        let case_expr = 1;
2226        let session_config = SessionConfig::new().with_repartition_joins(false);
2227        let task_ctx = TaskContext::default().with_session_config(session_config);
2228        let task_ctx = Arc::new(task_ctx);
2229        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2230
2231        let left_schema = &left_partition[0].schema();
2232        let right_schema = &right_partition[0].schema();
2233        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2234            expr: col("l_desc_null_first", left_schema)?,
2235            options: SortOptions {
2236                descending: true,
2237                nulls_first: true,
2238            },
2239        }]);
2240        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2241            expr: col("r_desc_null_first", right_schema)?,
2242            options: SortOptions {
2243                descending: true,
2244                nulls_first: true,
2245            },
2246        }]);
2247        let (left, right) = create_memory_table(
2248            left_partition,
2249            right_partition,
2250            vec![left_sorted],
2251            vec![right_sorted],
2252        )?;
2253
2254        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2255
2256        let intermediate_schema = Schema::new(vec![
2257            Field::new("left", DataType::Int32, true),
2258            Field::new("right", DataType::Int32, true),
2259        ]);
2260        let filter_expr = join_expr_tests_fixture_i32(
2261            case_expr,
2262            col("left", &intermediate_schema)?,
2263            col("right", &intermediate_schema)?,
2264        );
2265        let column_indices = vec![
2266            ColumnIndex {
2267                index: 8,
2268                side: JoinSide::Left,
2269            },
2270            ColumnIndex {
2271                index: 8,
2272                side: JoinSide::Right,
2273            },
2274        ];
2275        let filter =
2276            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2277
2278        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2279        Ok(())
2280    }
2281
2282    #[tokio::test(flavor = "multi_thread")]
2283    async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
2284        let cardinality = (3, 4);
2285        let join_type = JoinType::Full;
2286
2287        // a + b > c + 10 AND a + b < c + 100
2288        let session_config = SessionConfig::new().with_repartition_joins(false);
2289        let task_ctx = TaskContext::default().with_session_config(session_config);
2290        let task_ctx = Arc::new(task_ctx);
2291        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2292
2293        let left_schema = &left_partition[0].schema();
2294        let right_schema = &right_partition[0].schema();
2295        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2296            expr: col("la1", left_schema)?,
2297            options: SortOptions::default(),
2298        }]);
2299
2300        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2301            expr: col("ra1", right_schema)?,
2302            options: SortOptions::default(),
2303        }]);
2304        let (left, right) = create_memory_table(
2305            left_partition,
2306            right_partition,
2307            vec![left_sorted],
2308            vec![right_sorted],
2309        )?;
2310
2311        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2312
2313        let intermediate_schema = Schema::new(vec![
2314            Field::new("0", DataType::Int32, true),
2315            Field::new("1", DataType::Int32, true),
2316            Field::new("2", DataType::Int32, true),
2317        ]);
2318        let filter_expr = complicated_filter(&intermediate_schema)?;
2319        let column_indices = vec![
2320            ColumnIndex {
2321                index: 0,
2322                side: JoinSide::Left,
2323            },
2324            ColumnIndex {
2325                index: 4,
2326                side: JoinSide::Left,
2327            },
2328            ColumnIndex {
2329                index: 0,
2330                side: JoinSide::Right,
2331            },
2332        ];
2333        let filter =
2334            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2335
2336        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2337        Ok(())
2338    }
2339
2340    #[tokio::test(flavor = "multi_thread")]
2341    async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
2342        let cardinality = (3, 4);
2343        let join_type = JoinType::Full;
2344
2345        // a + b > c + 10 AND a + b < c + 100
2346        let config = SessionConfig::new().with_repartition_joins(false);
2347        // let session_ctx = SessionContext::with_config(config);
2348        // let task_ctx = session_ctx.task_ctx();
2349        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
2350        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2351        let left_schema = &left_partition[0].schema();
2352        let right_schema = &right_partition[0].schema();
2353        let left_sorted = vec![
2354            LexOrdering::new(vec![PhysicalSortExpr {
2355                expr: col("la1", left_schema)?,
2356                options: SortOptions::default(),
2357            }]),
2358            LexOrdering::new(vec![PhysicalSortExpr {
2359                expr: col("la2", left_schema)?,
2360                options: SortOptions::default(),
2361            }]),
2362        ];
2363
2364        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2365            expr: col("ra1", right_schema)?,
2366            options: SortOptions::default(),
2367        }]);
2368
2369        let (left, right) = create_memory_table(
2370            left_partition,
2371            right_partition,
2372            left_sorted,
2373            vec![right_sorted],
2374        )?;
2375
2376        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2377
2378        let intermediate_schema = Schema::new(vec![
2379            Field::new("0", DataType::Int32, true),
2380            Field::new("1", DataType::Int32, true),
2381            Field::new("2", DataType::Int32, true),
2382        ]);
2383        let filter_expr = complicated_filter(&intermediate_schema)?;
2384        let column_indices = vec![
2385            ColumnIndex {
2386                index: 0,
2387                side: JoinSide::Left,
2388            },
2389            ColumnIndex {
2390                index: 4,
2391                side: JoinSide::Left,
2392            },
2393            ColumnIndex {
2394                index: 0,
2395                side: JoinSide::Right,
2396            },
2397        ];
2398        let filter =
2399            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2400
2401        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2402        Ok(())
2403    }
2404
2405    #[rstest]
2406    #[tokio::test(flavor = "multi_thread")]
2407    async fn testing_with_temporal_columns(
2408        #[values(
2409            JoinType::Inner,
2410            JoinType::Left,
2411            JoinType::Right,
2412            JoinType::RightSemi,
2413            JoinType::LeftSemi,
2414            JoinType::LeftAnti,
2415            JoinType::LeftMark,
2416            JoinType::RightAnti,
2417            JoinType::Full
2418        )]
2419        join_type: JoinType,
2420        #[values(
2421            (4, 5),
2422            (12, 17),
2423        )]
2424        cardinality: (i32, i32),
2425        #[values(0, 1, 2)] case_expr: usize,
2426    ) -> Result<()> {
2427        let session_config = SessionConfig::new().with_repartition_joins(false);
2428        let task_ctx = TaskContext::default().with_session_config(session_config);
2429        let task_ctx = Arc::new(task_ctx);
2430        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2431
2432        let left_schema = &left_partition[0].schema();
2433        let right_schema = &right_partition[0].schema();
2434        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2435        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2436            expr: col("lt1", left_schema)?,
2437            options: SortOptions {
2438                descending: false,
2439                nulls_first: true,
2440            },
2441        }]);
2442        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2443            expr: col("rt1", right_schema)?,
2444            options: SortOptions {
2445                descending: false,
2446                nulls_first: true,
2447            },
2448        }]);
2449        let (left, right) = create_memory_table(
2450            left_partition,
2451            right_partition,
2452            vec![left_sorted],
2453            vec![right_sorted],
2454        )?;
2455        let intermediate_schema = Schema::new(vec![
2456            Field::new(
2457                "left",
2458                DataType::Timestamp(TimeUnit::Millisecond, None),
2459                false,
2460            ),
2461            Field::new(
2462                "right",
2463                DataType::Timestamp(TimeUnit::Millisecond, None),
2464                false,
2465            ),
2466        ]);
2467        let filter_expr = join_expr_tests_fixture_temporal(
2468            case_expr,
2469            col("left", &intermediate_schema)?,
2470            col("right", &intermediate_schema)?,
2471            &intermediate_schema,
2472        )?;
2473        let column_indices = vec![
2474            ColumnIndex {
2475                index: 3,
2476                side: JoinSide::Left,
2477            },
2478            ColumnIndex {
2479                index: 3,
2480                side: JoinSide::Right,
2481            },
2482        ];
2483        let filter =
2484            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2485        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2486        Ok(())
2487    }
2488
2489    #[rstest]
2490    #[tokio::test(flavor = "multi_thread")]
2491    async fn test_with_interval_columns(
2492        #[values(
2493            JoinType::Inner,
2494            JoinType::Left,
2495            JoinType::Right,
2496            JoinType::RightSemi,
2497            JoinType::LeftSemi,
2498            JoinType::LeftAnti,
2499            JoinType::LeftMark,
2500            JoinType::RightAnti,
2501            JoinType::Full
2502        )]
2503        join_type: JoinType,
2504        #[values(
2505            (4, 5),
2506            (12, 17),
2507        )]
2508        cardinality: (i32, i32),
2509    ) -> Result<()> {
2510        let session_config = SessionConfig::new().with_repartition_joins(false);
2511        let task_ctx = TaskContext::default().with_session_config(session_config);
2512        let task_ctx = Arc::new(task_ctx);
2513        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2514
2515        let left_schema = &left_partition[0].schema();
2516        let right_schema = &right_partition[0].schema();
2517        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2518        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2519            expr: col("li1", left_schema)?,
2520            options: SortOptions {
2521                descending: false,
2522                nulls_first: true,
2523            },
2524        }]);
2525        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2526            expr: col("ri1", right_schema)?,
2527            options: SortOptions {
2528                descending: false,
2529                nulls_first: true,
2530            },
2531        }]);
2532        let (left, right) = create_memory_table(
2533            left_partition,
2534            right_partition,
2535            vec![left_sorted],
2536            vec![right_sorted],
2537        )?;
2538        let intermediate_schema = Schema::new(vec![
2539            Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
2540            Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
2541        ]);
2542        let filter_expr = join_expr_tests_fixture_temporal(
2543            0,
2544            col("left", &intermediate_schema)?,
2545            col("right", &intermediate_schema)?,
2546            &intermediate_schema,
2547        )?;
2548        let column_indices = vec![
2549            ColumnIndex {
2550                index: 9,
2551                side: JoinSide::Left,
2552            },
2553            ColumnIndex {
2554                index: 9,
2555                side: JoinSide::Right,
2556            },
2557        ];
2558        let filter =
2559            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2560        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2561
2562        Ok(())
2563    }
2564
2565    #[rstest]
2566    #[tokio::test(flavor = "multi_thread")]
2567    async fn testing_ascending_float_pruning(
2568        #[values(
2569            JoinType::Inner,
2570            JoinType::Left,
2571            JoinType::Right,
2572            JoinType::RightSemi,
2573            JoinType::LeftSemi,
2574            JoinType::LeftAnti,
2575            JoinType::LeftMark,
2576            JoinType::RightAnti,
2577            JoinType::Full
2578        )]
2579        join_type: JoinType,
2580        #[values(
2581            (4, 5),
2582            (12, 17),
2583        )]
2584        cardinality: (i32, i32),
2585        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2586    ) -> Result<()> {
2587        let session_config = SessionConfig::new().with_repartition_joins(false);
2588        let task_ctx = TaskContext::default().with_session_config(session_config);
2589        let task_ctx = Arc::new(task_ctx);
2590        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2591
2592        let left_schema = &left_partition[0].schema();
2593        let right_schema = &right_partition[0].schema();
2594        let left_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2595            expr: col("l_float", left_schema)?,
2596            options: SortOptions::default(),
2597        }]);
2598        let right_sorted = LexOrdering::new(vec![PhysicalSortExpr {
2599            expr: col("r_float", right_schema)?,
2600            options: SortOptions::default(),
2601        }]);
2602        let (left, right) = create_memory_table(
2603            left_partition,
2604            right_partition,
2605            vec![left_sorted],
2606            vec![right_sorted],
2607        )?;
2608
2609        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2610
2611        let intermediate_schema = Schema::new(vec![
2612            Field::new("left", DataType::Float64, true),
2613            Field::new("right", DataType::Float64, true),
2614        ]);
2615        let filter_expr = join_expr_tests_fixture_f64(
2616            case_expr,
2617            col("left", &intermediate_schema)?,
2618            col("right", &intermediate_schema)?,
2619        );
2620        let column_indices = vec![
2621            ColumnIndex {
2622                index: 10, // l_float
2623                side: JoinSide::Left,
2624            },
2625            ColumnIndex {
2626                index: 10, // r_float
2627                side: JoinSide::Right,
2628            },
2629        ];
2630        let filter =
2631            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2632
2633        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2634        Ok(())
2635    }
2636}