archive_to_parquet/
sink.rs

1use crate::batch::arrow_schema;
2use crate::hasher::HASH_SIZE;
3use crate::ConvertionOptions;
4use arrow::array::{Array, AsArray, BooleanArray};
5use arrow::compute::filter_record_batch;
6use arrow::record_batch::RecordBatch;
7use parquet::arrow::ArrowWriter;
8use parquet::basic::Compression;
9use parquet::file::properties::{EnabledStatistics, WriterProperties, WriterVersion};
10use std::collections::HashSet;
11use std::io::Write;
12
13#[derive(Debug, Clone, Copy, Eq, PartialEq, clap::ValueEnum, strum::EnumString, strum::Display)]
14#[strum(serialize_all = "lowercase", ascii_case_insensitive)]
15pub enum IncludeType {
16    All,
17    Text,
18    Binary,
19}
20
21impl Default for IncludeType {
22    fn default() -> Self {
23        Self::All
24    }
25}
26
27pub fn new_parquet_writer<T: Write + Send>(
28    writer: T,
29    compression: Compression,
30) -> parquet::errors::Result<ArrowWriter<T>> {
31    let schema = arrow_schema();
32    let mut props = WriterProperties::builder()
33        .set_compression(compression)
34        .set_writer_version(WriterVersion::PARQUET_2_0)
35        .set_dictionary_enabled(false)
36        .set_bloom_filter_enabled(false)
37        .set_statistics_enabled(EnabledStatistics::None)
38        .set_column_encoding("hash".into(), parquet::basic::Encoding::PLAIN)
39        .set_write_batch_size(1024)
40        .set_data_page_size_limit(1024 * 1024)
41        .set_data_page_row_count_limit(20_00)
42        .set_max_row_group_size(1024 * 1024);
43
44    const BLOOM_FILTER_FIELDS: &[&str] = &["source", "path", "hash"];
45    const STATISTICS_FIELDS: &[&str] = &["source", "path", "size", "hash"];
46    const DICTIONARY_FIELDS: &[&str] = &["source", "path"];
47
48    for field in BLOOM_FILTER_FIELDS {
49        props = props.set_column_bloom_filter_enabled((*field).into(), true);
50    }
51    for field in STATISTICS_FIELDS {
52        props = props.set_column_statistics_enabled((*field).into(), EnabledStatistics::Page);
53    }
54    for field in DICTIONARY_FIELDS {
55        props = props.set_column_dictionary_enabled((*field).into(), true);
56    }
57
58    ArrowWriter::try_new(writer, schema, Some(props.build()))
59}
60
61pub struct ParquetSink<'a, T: Write + Send> {
62    writer: &'a mut ArrowWriter<T>,
63    seen_hashes: Option<HashSet<[u8; HASH_SIZE]>>,
64}
65
66impl<'a, T: Write + Send> ParquetSink<'a, T> {
67    pub fn new(writer: &'a mut ArrowWriter<T>, options: ConvertionOptions) -> Self {
68        let seen_hashes = if options.unique {
69            Some(HashSet::new())
70        } else {
71            None
72        };
73        Self {
74            writer,
75            seen_hashes,
76        }
77    }
78
79    fn deduplicate_batch(
80        record_batch: RecordBatch,
81        seen_hashes: &mut HashSet<[u8; HASH_SIZE]>,
82    ) -> parquet::errors::Result<RecordBatch> {
83        let hashes = record_batch
84            .column_by_name("hash")
85            .expect("hash column not found")
86            .as_fixed_size_binary();
87        let mut unique_indexes = Vec::new();
88        assert_eq!(
89            hashes.value_length(),
90            HASH_SIZE as i32,
91            "Hash column size != {HASH_SIZE}"
92        );
93        assert!(!hashes.is_nullable(), "Hash column is nullable");
94
95        for (idx, hash) in hashes.iter().enumerate() {
96            let hash: [u8; HASH_SIZE] = hash.unwrap().try_into().unwrap();
97            if seen_hashes.insert(hash) {
98                unique_indexes.push(idx);
99            }
100        }
101
102        let select_mask = BooleanArray::from_iter(
103            (0..record_batch.num_rows()).map(|idx| Some(unique_indexes.contains(&idx))),
104        );
105
106        Ok(filter_record_batch(&record_batch, &select_mask)?)
107    }
108
109    pub fn write_batch(&mut self, batch: RecordBatch) -> parquet::errors::Result<WriteBatchOutput> {
110        let batch = match &mut self.seen_hashes {
111            None => batch,
112            Some(seen_hashes) => Self::deduplicate_batch(batch, seen_hashes)?,
113        };
114
115        let output = WriteBatchOutput {
116            num_rows: batch.num_rows() as u64,
117            bytes: batch.get_array_memory_size() as u64,
118        };
119        self.writer.write(&batch)?;
120        Ok(output)
121    }
122
123    pub fn flush(&mut self) -> parquet::errors::Result<()> {
124        self.writer.flush()
125    }
126}
127
128#[derive(Debug)]
129pub struct WriteBatchOutput {
130    pub num_rows: u64,
131    pub bytes: u64,
132}
133
134#[cfg(test)]
135mod tests {
136    use crate::IncludeType;
137    use std::str::FromStr;
138
139    #[test]
140    fn test_include_type() {
141        let include_type = IncludeType::from_str("all").unwrap();
142        assert_eq!(include_type, IncludeType::All);
143        let include_type = IncludeType::from_str("text").unwrap();
144        assert_eq!(include_type, IncludeType::Text);
145        let include_type = IncludeType::from_str("binary").unwrap();
146        assert_eq!(include_type, IncludeType::Binary);
147    }
148}