datafusion_physical_plan/
common.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines common code used in execution plans
19
20use std::fs;
21use std::fs::{metadata, File};
22use std::path::{Path, PathBuf};
23use std::sync::Arc;
24
25use super::SendableRecordBatchStream;
26use crate::stream::RecordBatchReceiverStream;
27use crate::{ColumnStatistics, Statistics};
28
29use arrow::array::Array;
30use arrow::datatypes::Schema;
31use arrow::ipc::writer::{FileWriter, IpcWriteOptions};
32use arrow::record_batch::RecordBatch;
33use datafusion_common::stats::Precision;
34use datafusion_common::{plan_err, DataFusionError, Result};
35use datafusion_execution::memory_pool::MemoryReservation;
36
37use futures::{StreamExt, TryStreamExt};
38use parking_lot::Mutex;
39
40/// [`MemoryReservation`] used across query execution streams
41pub(crate) type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
42
43/// Create a vector of record batches from a stream
44pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatch>> {
45    stream.try_collect::<Vec<_>>().await
46}
47
48/// Recursively builds a list of files in a directory with a given extension
49pub fn build_checked_file_list(dir: &str, ext: &str) -> Result<Vec<String>> {
50    let mut filenames: Vec<String> = Vec::new();
51    build_file_list_recurse(dir, &mut filenames, ext)?;
52    if filenames.is_empty() {
53        return plan_err!("No files found at {dir} with file extension {ext}");
54    }
55    Ok(filenames)
56}
57
58/// Recursively builds a list of files in a directory with a given extension
59pub fn build_file_list(dir: &str, ext: &str) -> Result<Vec<String>> {
60    let mut filenames: Vec<String> = Vec::new();
61    build_file_list_recurse(dir, &mut filenames, ext)?;
62    Ok(filenames)
63}
64
65/// Recursively build a list of files in a directory with a given extension with an accumulator list
66fn build_file_list_recurse(
67    dir: &str,
68    filenames: &mut Vec<String>,
69    ext: &str,
70) -> Result<()> {
71    let metadata = metadata(dir)?;
72    if metadata.is_file() {
73        if dir.ends_with(ext) {
74            filenames.push(dir.to_string());
75        }
76    } else {
77        for entry in fs::read_dir(dir)? {
78            let entry = entry?;
79            let path = entry.path();
80            if let Some(path_name) = path.to_str() {
81                if path.is_dir() {
82                    build_file_list_recurse(path_name, filenames, ext)?;
83                } else if path_name.ends_with(ext) {
84                    filenames.push(path_name.to_string());
85                }
86            } else {
87                return plan_err!("Invalid path");
88            }
89        }
90    }
91    Ok(())
92}
93
94/// If running in a tokio context spawns the execution of `stream` to a separate task
95/// allowing it to execute in parallel with an intermediate buffer of size `buffer`
96pub(crate) fn spawn_buffered(
97    mut input: SendableRecordBatchStream,
98    buffer: usize,
99) -> SendableRecordBatchStream {
100    // Use tokio only if running from a multi-thread tokio context
101    match tokio::runtime::Handle::try_current() {
102        Ok(handle)
103            if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread =>
104        {
105            let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer);
106
107            let sender = builder.tx();
108
109            builder.spawn(async move {
110                while let Some(item) = input.next().await {
111                    if sender.send(item).await.is_err() {
112                        // Receiver dropped when query is shutdown early (e.g., limit) or error,
113                        // no need to return propagate the send error.
114                        return Ok(());
115                    }
116                }
117
118                Ok(())
119            });
120
121            builder.build()
122        }
123        _ => input,
124    }
125}
126
127/// Computes the statistics for an in-memory RecordBatch
128///
129/// Only computes statistics that are in arrows metadata (num rows, byte size and nulls)
130/// and does not apply any kernel on the actual data.
131pub fn compute_record_batch_statistics(
132    batches: &[Vec<RecordBatch>],
133    schema: &Schema,
134    projection: Option<Vec<usize>>,
135) -> Statistics {
136    let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum();
137
138    let projection = match projection {
139        Some(p) => p,
140        None => (0..schema.fields().len()).collect(),
141    };
142
143    let total_byte_size = batches
144        .iter()
145        .flatten()
146        .map(|b| {
147            projection
148                .iter()
149                .map(|index| b.column(*index).get_array_memory_size())
150                .sum::<usize>()
151        })
152        .sum();
153
154    let mut null_counts = vec![0; projection.len()];
155
156    for partition in batches.iter() {
157        for batch in partition {
158            for (stat_index, col_index) in projection.iter().enumerate() {
159                null_counts[stat_index] += batch
160                    .column(*col_index)
161                    .logical_nulls()
162                    .map(|nulls| nulls.null_count())
163                    .unwrap_or_default();
164            }
165        }
166    }
167    let column_statistics = null_counts
168        .into_iter()
169        .map(|null_count| {
170            let mut s = ColumnStatistics::new_unknown();
171            s.null_count = Precision::Exact(null_count);
172            s
173        })
174        .collect();
175
176    Statistics {
177        num_rows: Precision::Exact(nb_rows),
178        total_byte_size: Precision::Exact(total_byte_size),
179        column_statistics,
180    }
181}
182
183/// Write in Arrow IPC File format.
184pub struct IPCWriter {
185    /// Path
186    pub path: PathBuf,
187    /// Inner writer
188    pub writer: FileWriter<File>,
189    /// Batches written
190    pub num_batches: usize,
191    /// Rows written
192    pub num_rows: usize,
193    /// Bytes written
194    pub num_bytes: usize,
195}
196
197impl IPCWriter {
198    /// Create new writer
199    pub fn new(path: &Path, schema: &Schema) -> Result<Self> {
200        let file = File::create(path).map_err(|e| {
201            DataFusionError::Execution(format!(
202                "Failed to create partition file at {path:?}: {e:?}"
203            ))
204        })?;
205        Ok(Self {
206            num_batches: 0,
207            num_rows: 0,
208            num_bytes: 0,
209            path: path.into(),
210            writer: FileWriter::try_new(file, schema)?,
211        })
212    }
213
214    /// Create new writer with IPC write options
215    pub fn new_with_options(
216        path: &Path,
217        schema: &Schema,
218        write_options: IpcWriteOptions,
219    ) -> Result<Self> {
220        let file = File::create(path).map_err(|e| {
221            DataFusionError::Execution(format!(
222                "Failed to create partition file at {path:?}: {e:?}"
223            ))
224        })?;
225        Ok(Self {
226            num_batches: 0,
227            num_rows: 0,
228            num_bytes: 0,
229            path: path.into(),
230            writer: FileWriter::try_new_with_options(file, schema, write_options)?,
231        })
232    }
233    /// Write one single batch
234    pub fn write(&mut self, batch: &RecordBatch) -> Result<()> {
235        self.writer.write(batch)?;
236        self.num_batches += 1;
237        self.num_rows += batch.num_rows();
238        let num_bytes: usize = batch.get_array_memory_size();
239        self.num_bytes += num_bytes;
240        Ok(())
241    }
242
243    /// Finish the writer
244    pub fn finish(&mut self) -> Result<()> {
245        self.writer.finish().map_err(Into::into)
246    }
247
248    /// Path write to
249    pub fn path(&self) -> &Path {
250        &self.path
251    }
252}
253
254/// Checks if the given projection is valid for the given schema.
255pub fn can_project(
256    schema: &arrow::datatypes::SchemaRef,
257    projection: Option<&Vec<usize>>,
258) -> Result<()> {
259    match projection {
260        Some(columns) => {
261            if columns
262                .iter()
263                .max()
264                .is_some_and(|&i| i >= schema.fields().len())
265            {
266                Err(arrow::error::ArrowError::SchemaError(format!(
267                    "project index {} out of bounds, max field {}",
268                    columns.iter().max().unwrap(),
269                    schema.fields().len()
270                ))
271                .into())
272            } else {
273                Ok(())
274            }
275        }
276        None => Ok(()),
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    use arrow::{
285        array::{Float32Array, Float64Array, UInt64Array},
286        datatypes::{DataType, Field},
287    };
288
289    #[test]
290    fn test_compute_record_batch_statistics_empty() -> Result<()> {
291        let schema = Arc::new(Schema::new(vec![
292            Field::new("f32", DataType::Float32, false),
293            Field::new("f64", DataType::Float64, false),
294        ]));
295        let stats = compute_record_batch_statistics(&[], &schema, Some(vec![0, 1]));
296
297        assert_eq!(stats.num_rows, Precision::Exact(0));
298        assert_eq!(stats.total_byte_size, Precision::Exact(0));
299        Ok(())
300    }
301
302    #[test]
303    fn test_compute_record_batch_statistics() -> Result<()> {
304        let schema = Arc::new(Schema::new(vec![
305            Field::new("f32", DataType::Float32, false),
306            Field::new("f64", DataType::Float64, false),
307            Field::new("u64", DataType::UInt64, false),
308        ]));
309        let batch = RecordBatch::try_new(
310            Arc::clone(&schema),
311            vec![
312                Arc::new(Float32Array::from(vec![1., 2., 3.])),
313                Arc::new(Float64Array::from(vec![9., 8., 7.])),
314                Arc::new(UInt64Array::from(vec![4, 5, 6])),
315            ],
316        )?;
317
318        // Just select f32,f64
319        let select_projection = Some(vec![0, 1]);
320        let byte_size = batch
321            .project(&select_projection.clone().unwrap())
322            .unwrap()
323            .get_array_memory_size();
324
325        let actual =
326            compute_record_batch_statistics(&[vec![batch]], &schema, select_projection);
327
328        let expected = Statistics {
329            num_rows: Precision::Exact(3),
330            total_byte_size: Precision::Exact(byte_size),
331            column_statistics: vec![
332                ColumnStatistics {
333                    distinct_count: Precision::Absent,
334                    max_value: Precision::Absent,
335                    min_value: Precision::Absent,
336                    sum_value: Precision::Absent,
337                    null_count: Precision::Exact(0),
338                },
339                ColumnStatistics {
340                    distinct_count: Precision::Absent,
341                    max_value: Precision::Absent,
342                    min_value: Precision::Absent,
343                    sum_value: Precision::Absent,
344                    null_count: Precision::Exact(0),
345                },
346            ],
347        };
348
349        assert_eq!(actual, expected);
350        Ok(())
351    }
352
353    #[test]
354    fn test_compute_record_batch_statistics_null() -> Result<()> {
355        let schema =
356            Arc::new(Schema::new(vec![Field::new("u64", DataType::UInt64, true)]));
357        let batch1 = RecordBatch::try_new(
358            Arc::clone(&schema),
359            vec![Arc::new(UInt64Array::from(vec![Some(1), None, None]))],
360        )?;
361        let batch2 = RecordBatch::try_new(
362            Arc::clone(&schema),
363            vec![Arc::new(UInt64Array::from(vec![Some(1), Some(2), None]))],
364        )?;
365        let byte_size = batch1.get_array_memory_size() + batch2.get_array_memory_size();
366        let actual =
367            compute_record_batch_statistics(&[vec![batch1], vec![batch2]], &schema, None);
368
369        let expected = Statistics {
370            num_rows: Precision::Exact(6),
371            total_byte_size: Precision::Exact(byte_size),
372            column_statistics: vec![ColumnStatistics {
373                distinct_count: Precision::Absent,
374                max_value: Precision::Absent,
375                min_value: Precision::Absent,
376                sum_value: Precision::Absent,
377                null_count: Precision::Exact(3),
378            }],
379        };
380
381        assert_eq!(actual, expected);
382        Ok(())
383    }
384}