archive_to_parquet/
sink.rs1use crate::batch::arrow_schema;
2use crate::hasher::HASH_SIZE;
3use crate::ConvertionOptions;
4use arrow::array::{Array, AsArray, BooleanArray};
5use arrow::compute::filter_record_batch;
6use arrow::record_batch::RecordBatch;
7use parquet::arrow::ArrowWriter;
8use parquet::basic::Compression;
9use parquet::file::properties::{EnabledStatistics, WriterProperties, WriterVersion};
10use std::collections::HashSet;
11use std::io::Write;
12
13#[derive(Debug, Clone, Copy, Eq, PartialEq, clap::ValueEnum, strum::EnumString, strum::Display)]
14#[strum(serialize_all = "lowercase", ascii_case_insensitive)]
15pub enum IncludeType {
16 All,
17 Text,
18 Binary,
19}
20
21impl Default for IncludeType {
22 fn default() -> Self {
23 Self::All
24 }
25}
26
27pub fn new_parquet_writer<T: Write + Send>(
28 writer: T,
29 compression: Compression,
30) -> parquet::errors::Result<ArrowWriter<T>> {
31 let schema = arrow_schema();
32 let mut props = WriterProperties::builder()
33 .set_compression(compression)
34 .set_writer_version(WriterVersion::PARQUET_2_0)
35 .set_dictionary_enabled(false)
36 .set_bloom_filter_enabled(false)
37 .set_statistics_enabled(EnabledStatistics::None)
38 .set_column_encoding("hash".into(), parquet::basic::Encoding::PLAIN)
39 .set_write_batch_size(1024)
40 .set_data_page_size_limit(1024 * 1024)
41 .set_data_page_row_count_limit(20_00)
42 .set_max_row_group_size(1024 * 1024);
43
44 const BLOOM_FILTER_FIELDS: &[&str] = &["source", "path", "hash"];
45 const STATISTICS_FIELDS: &[&str] = &["source", "path", "size", "hash"];
46 const DICTIONARY_FIELDS: &[&str] = &["source", "path"];
47
48 for field in BLOOM_FILTER_FIELDS {
49 props = props.set_column_bloom_filter_enabled((*field).into(), true);
50 }
51 for field in STATISTICS_FIELDS {
52 props = props.set_column_statistics_enabled((*field).into(), EnabledStatistics::Page);
53 }
54 for field in DICTIONARY_FIELDS {
55 props = props.set_column_dictionary_enabled((*field).into(), true);
56 }
57
58 ArrowWriter::try_new(writer, schema, Some(props.build()))
59}
60
61pub struct ParquetSink<'a, T: Write + Send> {
62 writer: &'a mut ArrowWriter<T>,
63 seen_hashes: Option<HashSet<[u8; HASH_SIZE]>>,
64}
65
66impl<'a, T: Write + Send> ParquetSink<'a, T> {
67 pub fn new(writer: &'a mut ArrowWriter<T>, options: ConvertionOptions) -> Self {
68 let seen_hashes = if options.unique {
69 Some(HashSet::new())
70 } else {
71 None
72 };
73 Self {
74 writer,
75 seen_hashes,
76 }
77 }
78
79 fn deduplicate_batch(
80 record_batch: RecordBatch,
81 seen_hashes: &mut HashSet<[u8; HASH_SIZE]>,
82 ) -> parquet::errors::Result<RecordBatch> {
83 let hashes = record_batch
84 .column_by_name("hash")
85 .expect("hash column not found")
86 .as_fixed_size_binary();
87 let mut unique_indexes = Vec::new();
88 assert_eq!(
89 hashes.value_length(),
90 HASH_SIZE as i32,
91 "Hash column size != {HASH_SIZE}"
92 );
93 assert!(!hashes.is_nullable(), "Hash column is nullable");
94
95 for (idx, hash) in hashes.iter().enumerate() {
96 let hash: [u8; HASH_SIZE] = hash.unwrap().try_into().unwrap();
97 if seen_hashes.insert(hash) {
98 unique_indexes.push(idx);
99 }
100 }
101
102 let select_mask = BooleanArray::from_iter(
103 (0..record_batch.num_rows()).map(|idx| Some(unique_indexes.contains(&idx))),
104 );
105
106 Ok(filter_record_batch(&record_batch, &select_mask)?)
107 }
108
109 pub fn write_batch(&mut self, batch: RecordBatch) -> parquet::errors::Result<WriteBatchOutput> {
110 let batch = match &mut self.seen_hashes {
111 None => batch,
112 Some(seen_hashes) => Self::deduplicate_batch(batch, seen_hashes)?,
113 };
114
115 let output = WriteBatchOutput {
116 num_rows: batch.num_rows() as u64,
117 bytes: batch.get_array_memory_size() as u64,
118 };
119 self.writer.write(&batch)?;
120 Ok(output)
121 }
122
123 pub fn flush(&mut self) -> parquet::errors::Result<()> {
124 self.writer.flush()
125 }
126}
127
128#[derive(Debug)]
129pub struct WriteBatchOutput {
130 pub num_rows: u64,
131 pub bytes: u64,
132}
133
134#[cfg(test)]
135mod tests {
136 use crate::IncludeType;
137 use std::str::FromStr;
138
139 #[test]
140 fn test_include_type() {
141 let include_type = IncludeType::from_str("all").unwrap();
142 assert_eq!(include_type, IncludeType::All);
143 let include_type = IncludeType::from_str("text").unwrap();
144 assert_eq!(include_type, IncludeType::Text);
145 let include_type = IncludeType::from_str("binary").unwrap();
146 assert_eq!(include_type, IncludeType::Binary);
147 }
148}