use std::any::Any;
use std::sync::Arc;
use crate::common::spawn_buffered;
use crate::expressions::PhysicalSortExpr;
use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use crate::sorts::streaming_merge;
use crate::{
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
};
use datafusion_common::{internal_err, Result};
use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::PhysicalSortRequirement;
use log::{debug, trace};
#[derive(Debug)]
pub struct SortPreservingMergeExec {
input: Arc<dyn ExecutionPlan>,
expr: Vec<PhysicalSortExpr>,
metrics: ExecutionPlanMetricsSet,
fetch: Option<usize>,
cache: PlanProperties,
}
impl SortPreservingMergeExec {
pub fn new(expr: Vec<PhysicalSortExpr>, input: Arc<dyn ExecutionPlan>) -> Self {
let cache = Self::compute_properties(&input, expr.clone());
Self {
input,
expr,
metrics: ExecutionPlanMetricsSet::new(),
fetch: None,
cache,
}
}
pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
self.fetch = fetch;
self
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn expr(&self) -> &[PhysicalSortExpr] {
&self.expr
}
pub fn fetch(&self) -> Option<usize> {
self.fetch
}
fn compute_properties(
input: &Arc<dyn ExecutionPlan>,
ordering: Vec<PhysicalSortExpr>,
) -> PlanProperties {
let mut eq_properties = input.equivalence_properties().clone();
eq_properties.clear_per_partition_constants();
eq_properties.add_new_orderings(vec![ordering]);
PlanProperties::new(
eq_properties, Partitioning::UnknownPartitioning(1), input.execution_mode(), )
}
}
impl DisplayAs for SortPreservingMergeExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"SortPreservingMergeExec: [{}]",
PhysicalSortExpr::format_list(&self.expr)
)?;
if let Some(fetch) = self.fetch {
write!(f, ", fetch={fetch}")?;
};
Ok(())
}
}
}
}
impl ExecutionPlan for SortPreservingMergeExec {
fn name(&self) -> &'static str {
"SortPreservingMergeExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.cache
}
fn required_input_distribution(&self) -> Vec<Distribution> {
vec![Distribution::UnspecifiedDistribution]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn required_input_ordering(&self) -> Vec<Option<Vec<PhysicalSortRequirement>>> {
vec![Some(PhysicalSortRequirement::from_sort_exprs(&self.expr))]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![true]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(
SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0]))
.with_fetch(self.fetch),
))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
trace!(
"Start SortPreservingMergeExec::execute for partition: {}",
partition
);
if 0 != partition {
return internal_err!(
"SortPreservingMergeExec invalid partition {partition}"
);
}
let input_partitions = self.input.output_partitioning().partition_count();
trace!(
"Number of input partitions of SortPreservingMergeExec::execute: {}",
input_partitions
);
let schema = self.schema();
let reservation =
MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
.register(&context.runtime_env().memory_pool);
match input_partitions {
0 => internal_err!(
"SortPreservingMergeExec requires at least one input partition"
),
1 => {
let result = self.input.execute(0, context);
debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input");
result
}
_ => {
let receivers = (0..input_partitions)
.map(|partition| {
let stream =
self.input.execute(partition, Arc::clone(&context))?;
Ok(spawn_buffered(stream, 1))
})
.collect::<Result<_>>()?;
debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute");
let result = streaming_merge(
receivers,
schema,
&self.expr,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
self.fetch,
reservation,
)?;
debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
Ok(result)
}
}
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Result<Statistics> {
self.input.statistics()
}
fn supports_limit_pushdown(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::coalesce_partitions::CoalescePartitionsExec;
use crate::expressions::col;
use crate::memory::MemoryExec;
use crate::metrics::{MetricValue, Timestamp};
use crate::sorts::sort::SortExec;
use crate::stream::RecordBatchReceiverStream;
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
use crate::test::{self, assert_is_pending, make_partition};
use crate::{collect, common};
use arrow::array::{ArrayRef, Int32Array, StringArray, TimestampNanosecondArray};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::{assert_batches_eq, assert_contains};
use datafusion_execution::config::SessionConfig;
use futures::{FutureExt, StreamExt};
#[tokio::test]
async fn test_merge_interleave() {
let task_ctx = Arc::new(TaskContext::default());
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("a"),
Some("c"),
Some("e"),
Some("g"),
Some("j"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("b"),
Some("d"),
Some("f"),
Some("h"),
Some("j"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
_test_merge(
&[vec![b1], vec![b2]],
&[
"+----+---+-------------------------------+",
"| a | b | c |",
"+----+---+-------------------------------+",
"| 1 | a | 1970-01-01T00:00:00.000000008 |",
"| 10 | b | 1970-01-01T00:00:00.000000004 |",
"| 2 | c | 1970-01-01T00:00:00.000000007 |",
"| 20 | d | 1970-01-01T00:00:00.000000006 |",
"| 7 | e | 1970-01-01T00:00:00.000000006 |",
"| 70 | f | 1970-01-01T00:00:00.000000002 |",
"| 9 | g | 1970-01-01T00:00:00.000000005 |",
"| 90 | h | 1970-01-01T00:00:00.000000002 |",
"| 30 | j | 1970-01-01T00:00:00.000000006 |", "| 3 | j | 1970-01-01T00:00:00.000000008 |",
"+----+---+-------------------------------+",
],
task_ctx,
)
.await;
}
#[tokio::test]
async fn test_merge_no_exprs() {
let task_ctx = Arc::new(TaskContext::default());
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap();
let schema = batch.schema();
let sort = vec![]; let exec = MemoryExec::try_new(&[vec![batch.clone()], vec![batch]], schema, None)
.unwrap();
let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
let res = collect(merge, task_ctx).await.unwrap_err();
assert_contains!(
res.to_string(),
"Internal error: Sort expressions cannot be empty for streaming merge"
);
}
#[tokio::test]
async fn test_merge_some_overlap() {
let task_ctx = Arc::new(TaskContext::default());
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("a"),
Some("b"),
Some("c"),
Some("d"),
Some("e"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("c"),
Some("d"),
Some("e"),
Some("f"),
Some("g"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
_test_merge(
&[vec![b1], vec![b2]],
&[
"+-----+---+-------------------------------+",
"| a | b | c |",
"+-----+---+-------------------------------+",
"| 1 | a | 1970-01-01T00:00:00.000000008 |",
"| 2 | b | 1970-01-01T00:00:00.000000007 |",
"| 70 | c | 1970-01-01T00:00:00.000000004 |",
"| 7 | c | 1970-01-01T00:00:00.000000006 |",
"| 9 | d | 1970-01-01T00:00:00.000000005 |",
"| 90 | d | 1970-01-01T00:00:00.000000006 |",
"| 30 | e | 1970-01-01T00:00:00.000000002 |",
"| 3 | e | 1970-01-01T00:00:00.000000008 |",
"| 100 | f | 1970-01-01T00:00:00.000000002 |",
"| 110 | g | 1970-01-01T00:00:00.000000006 |",
"+-----+---+-------------------------------+",
],
task_ctx,
)
.await;
}
#[tokio::test]
async fn test_merge_no_overlap() {
let task_ctx = Arc::new(TaskContext::default());
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("a"),
Some("b"),
Some("c"),
Some("d"),
Some("e"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("f"),
Some("g"),
Some("h"),
Some("i"),
Some("j"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
_test_merge(
&[vec![b1], vec![b2]],
&[
"+----+---+-------------------------------+",
"| a | b | c |",
"+----+---+-------------------------------+",
"| 1 | a | 1970-01-01T00:00:00.000000008 |",
"| 2 | b | 1970-01-01T00:00:00.000000007 |",
"| 7 | c | 1970-01-01T00:00:00.000000006 |",
"| 9 | d | 1970-01-01T00:00:00.000000005 |",
"| 3 | e | 1970-01-01T00:00:00.000000008 |",
"| 10 | f | 1970-01-01T00:00:00.000000004 |",
"| 20 | g | 1970-01-01T00:00:00.000000006 |",
"| 70 | h | 1970-01-01T00:00:00.000000002 |",
"| 90 | i | 1970-01-01T00:00:00.000000002 |",
"| 30 | j | 1970-01-01T00:00:00.000000006 |",
"+----+---+-------------------------------+",
],
task_ctx,
)
.await;
}
#[tokio::test]
async fn test_merge_three_partitions() {
let task_ctx = Arc::new(TaskContext::default());
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("a"),
Some("b"),
Some("c"),
Some("d"),
Some("f"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("e"),
Some("g"),
Some("h"),
Some("i"),
Some("j"),
]));
let c: ArrayRef =
Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("f"),
Some("g"),
Some("h"),
Some("i"),
Some("j"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
_test_merge(
&[vec![b1], vec![b2], vec![b3]],
&[
"+-----+---+-------------------------------+",
"| a | b | c |",
"+-----+---+-------------------------------+",
"| 1 | a | 1970-01-01T00:00:00.000000008 |",
"| 2 | b | 1970-01-01T00:00:00.000000007 |",
"| 7 | c | 1970-01-01T00:00:00.000000006 |",
"| 9 | d | 1970-01-01T00:00:00.000000005 |",
"| 10 | e | 1970-01-01T00:00:00.000000040 |",
"| 100 | f | 1970-01-01T00:00:00.000000004 |",
"| 3 | f | 1970-01-01T00:00:00.000000008 |",
"| 200 | g | 1970-01-01T00:00:00.000000006 |",
"| 20 | g | 1970-01-01T00:00:00.000000060 |",
"| 700 | h | 1970-01-01T00:00:00.000000002 |",
"| 70 | h | 1970-01-01T00:00:00.000000020 |",
"| 900 | i | 1970-01-01T00:00:00.000000002 |",
"| 90 | i | 1970-01-01T00:00:00.000000020 |",
"| 300 | j | 1970-01-01T00:00:00.000000006 |",
"| 30 | j | 1970-01-01T00:00:00.000000060 |",
"+-----+---+-------------------------------+",
],
task_ctx,
)
.await;
}
async fn _test_merge(
partitions: &[Vec<RecordBatch>],
exp: &[&str],
context: Arc<TaskContext>,
) {
let schema = partitions[0][0].schema();
let sort = vec![
PhysicalSortExpr {
expr: col("b", &schema).unwrap(),
options: Default::default(),
},
PhysicalSortExpr {
expr: col("c", &schema).unwrap(),
options: Default::default(),
},
];
let exec = MemoryExec::try_new(partitions, schema, None).unwrap();
let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
let collected = collect(merge, context).await.unwrap();
assert_batches_eq!(exp, collected.as_slice());
}
async fn sorted_merge(
input: Arc<dyn ExecutionPlan>,
sort: Vec<PhysicalSortExpr>,
context: Arc<TaskContext>,
) -> RecordBatch {
let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
let mut result = collect(merge, context).await.unwrap();
assert_eq!(result.len(), 1);
result.remove(0)
}
async fn partition_sort(
input: Arc<dyn ExecutionPlan>,
sort: Vec<PhysicalSortExpr>,
context: Arc<TaskContext>,
) -> RecordBatch {
let sort_exec =
Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
sorted_merge(sort_exec, sort, context).await
}
async fn basic_sort(
src: Arc<dyn ExecutionPlan>,
sort: Vec<PhysicalSortExpr>,
context: Arc<TaskContext>,
) -> RecordBatch {
let merge = Arc::new(CoalescePartitionsExec::new(src));
let sort_exec = Arc::new(SortExec::new(sort, merge));
let mut result = collect(sort_exec, context).await.unwrap();
assert_eq!(result.len(), 1);
result.remove(0)
}
#[tokio::test]
async fn test_partition_sort() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let partitions = 4;
let csv = test::scan_partitioned(partitions);
let schema = csv.schema();
let sort = vec![PhysicalSortExpr {
expr: col("i", &schema).unwrap(),
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let basic =
basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
let basic = arrow::util::pretty::pretty_format_batches(&[basic])
.unwrap()
.to_string();
let partition = arrow::util::pretty::pretty_format_batches(&[partition])
.unwrap()
.to_string();
assert_eq!(
basic, partition,
"basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
);
Ok(())
}
fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
let batches = (sorted.num_rows() + batch_size - 1) / batch_size;
(0..batches)
.map(|batch_idx| {
let columns = (0..sorted.num_columns())
.map(|column_idx| {
let length =
batch_size.min(sorted.num_rows() - batch_idx * batch_size);
sorted
.column(column_idx)
.slice(batch_idx * batch_size, length)
})
.collect();
RecordBatch::try_new(sorted.schema(), columns).unwrap()
})
.collect()
}
async fn sorted_partitioned_input(
sort: Vec<PhysicalSortExpr>,
sizes: &[usize],
context: Arc<TaskContext>,
) -> Result<Arc<dyn ExecutionPlan>> {
let partitions = 4;
let csv = test::scan_partitioned(partitions);
let sorted = basic_sort(csv, sort, context).await;
let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
Ok(Arc::new(
MemoryExec::try_new(&split, sorted.schema(), None).unwrap(),
))
}
#[tokio::test]
async fn test_partition_sort_streaming_input() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let schema = make_partition(11).schema();
let sort = vec![PhysicalSortExpr {
expr: col("i", &schema).unwrap(),
options: Default::default(),
}];
let input =
sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
.await?;
let basic =
basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
assert_eq!(basic.num_rows(), 1200);
assert_eq!(partition.num_rows(), 1200);
let basic = arrow::util::pretty::pretty_format_batches(&[basic])
.unwrap()
.to_string();
let partition = arrow::util::pretty::pretty_format_batches(&[partition])
.unwrap()
.to_string();
assert_eq!(basic, partition);
Ok(())
}
#[tokio::test]
async fn test_partition_sort_streaming_input_output() -> Result<()> {
let schema = make_partition(11).schema();
let sort = vec![PhysicalSortExpr {
expr: col("i", &schema).unwrap(),
options: Default::default(),
}];
let task_ctx = Arc::new(TaskContext::default());
let input =
sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
.await?;
let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
let task_ctx = TaskContext::default()
.with_session_config(SessionConfig::new().with_batch_size(23));
let task_ctx = Arc::new(task_ctx);
let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
let merged = collect(merge, task_ctx).await.unwrap();
assert_eq!(merged.len(), 53);
assert_eq!(basic.num_rows(), 1200);
assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
let basic = arrow::util::pretty::pretty_format_batches(&[basic])
.unwrap()
.to_string();
let partition = arrow::util::pretty::pretty_format_batches(merged.as_slice())
.unwrap()
.to_string();
assert_eq!(basic, partition);
Ok(())
}
#[tokio::test]
async fn test_nulls() {
let task_ctx = Arc::new(TaskContext::default());
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
None,
Some("a"),
Some("b"),
Some("d"),
Some("e"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
Some(8),
None,
Some(6),
None,
Some(4),
]));
let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
None,
Some("b"),
Some("g"),
Some("h"),
Some("i"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
Some(8),
None,
Some(5),
None,
Some(4),
]));
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
let schema = b1.schema();
let sort = vec![
PhysicalSortExpr {
expr: col("b", &schema).unwrap(),
options: SortOptions {
descending: false,
nulls_first: true,
},
},
PhysicalSortExpr {
expr: col("c", &schema).unwrap(),
options: SortOptions {
descending: false,
nulls_first: false,
},
},
];
let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
let collected = collect(merge, task_ctx).await.unwrap();
assert_eq!(collected.len(), 1);
assert_batches_eq!(
&[
"+---+---+-------------------------------+",
"| a | b | c |",
"+---+---+-------------------------------+",
"| 1 | | 1970-01-01T00:00:00.000000008 |",
"| 1 | | 1970-01-01T00:00:00.000000008 |",
"| 2 | a | |",
"| 7 | b | 1970-01-01T00:00:00.000000006 |",
"| 2 | b | |",
"| 9 | d | |",
"| 3 | e | 1970-01-01T00:00:00.000000004 |",
"| 3 | g | 1970-01-01T00:00:00.000000005 |",
"| 4 | h | |",
"| 5 | i | 1970-01-01T00:00:00.000000004 |",
"+---+---+-------------------------------+",
],
collected.as_slice()
);
}
#[tokio::test]
async fn test_async() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let schema = make_partition(11).schema();
let sort = vec![PhysicalSortExpr {
expr: col("i", &schema).unwrap(),
options: SortOptions::default(),
}];
let batches =
sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
.await?;
let partition_count = batches.output_partitioning().partition_count();
let mut streams = Vec::with_capacity(partition_count);
for partition in 0..partition_count {
let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
let sender = builder.tx();
let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
builder.spawn(async move {
while let Some(batch) = stream.next().await {
sender.send(batch).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
Ok(())
});
streams.push(builder.build());
}
let metrics = ExecutionPlanMetricsSet::new();
let reservation =
MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
let fetch = None;
let merge_stream = streaming_merge(
streams,
batches.schema(),
sort.as_slice(),
BaselineMetrics::new(&metrics, 0),
task_ctx.session_config().batch_size(),
fetch,
reservation,
)
.unwrap();
let mut merged = common::collect(merge_stream).await.unwrap();
assert_eq!(merged.len(), 1);
let merged = merged.remove(0);
let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
let basic = arrow::util::pretty::pretty_format_batches(&[basic])
.unwrap()
.to_string();
let partition = arrow::util::pretty::pretty_format_batches(&[merged])
.unwrap()
.to_string();
assert_eq!(
basic, partition,
"basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
);
Ok(())
}
#[tokio::test]
async fn test_merge_metrics() {
let task_ctx = Arc::new(TaskContext::default());
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
let schema = b1.schema();
let sort = vec![PhysicalSortExpr {
expr: col("b", &schema).unwrap(),
options: Default::default(),
}];
let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
.await
.unwrap();
let expected = [
"+----+---+",
"| a | b |",
"+----+---+",
"| 1 | a |",
"| 10 | b |",
"| 2 | c |",
"| 20 | d |",
"+----+---+",
];
assert_batches_eq!(expected, collected.as_slice());
let metrics = merge.metrics().unwrap();
assert_eq!(metrics.output_rows().unwrap(), 4);
assert!(metrics.elapsed_compute().unwrap() > 0);
let mut saw_start = false;
let mut saw_end = false;
metrics.iter().for_each(|m| match m.value() {
MetricValue::StartTimestamp(ts) => {
saw_start = true;
assert!(nanos_from_timestamp(ts) > 0);
}
MetricValue::EndTimestamp(ts) => {
saw_end = true;
assert!(nanos_from_timestamp(ts) > 0);
}
_ => {}
});
assert!(saw_start);
assert!(saw_end);
}
fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
ts.value().unwrap().timestamp_nanos_opt().unwrap()
}
#[tokio::test]
async fn test_drop_cancel() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
let refs = blocking_exec.refs();
let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
vec![PhysicalSortExpr {
expr: col("a", &schema)?,
options: SortOptions::default(),
}],
blocking_exec,
));
let fut = collect(sort_preserving_merge_exec, task_ctx);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
#[tokio::test]
async fn test_stable_sort() {
let task_ctx = Arc::new(TaskContext::default());
let partitions: Vec<Vec<RecordBatch>> = (0..10)
.map(|batch_number| {
let batch_number: Int32Array =
vec![Some(batch_number), Some(batch_number)]
.into_iter()
.collect();
let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("batch_number", Arc::new(batch_number) as ArrayRef),
("value", Arc::new(value) as ArrayRef),
])
.unwrap();
vec![batch]
})
.collect();
let schema = partitions[0][0].schema();
let sort = vec![PhysicalSortExpr {
expr: col("value", &schema).unwrap(),
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let exec = MemoryExec::try_new(&partitions, schema, None).unwrap();
let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
let collected = collect(merge, task_ctx).await.unwrap();
assert_eq!(collected.len(), 1);
assert_batches_eq!(
&[
"+--------------+-------+",
"| batch_number | value |",
"+--------------+-------+",
"| 0 | A |",
"| 1 | A |",
"| 2 | A |",
"| 3 | A |",
"| 4 | A |",
"| 5 | A |",
"| 6 | A |",
"| 7 | A |",
"| 8 | A |",
"| 9 | A |",
"| 0 | B |",
"| 1 | B |",
"| 2 | B |",
"| 3 | B |",
"| 4 | B |",
"| 5 | B |",
"| 6 | B |",
"| 7 | B |",
"| 8 | B |",
"| 9 | B |",
"+--------------+-------+",
],
collected.as_slice()
);
}
}