lance_encoding/encodings/physical/
block_compress.rs1use arrow_schema::DataType;
5use snafu::location;
6use std::{
7 io::{Cursor, Write},
8 str::FromStr,
9};
10
11use lance_core::{Error, Result};
12
13use crate::{
14 data::{BlockInfo, DataBlock, OpaqueBlock},
15 encoder::{ArrayEncoder, EncodedArray},
16 format::ProtobufUtils,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq)]
20pub struct CompressionConfig {
21 pub(crate) scheme: CompressionScheme,
22 pub(crate) level: Option<i32>,
23}
24
25impl CompressionConfig {
26 pub(crate) fn new(scheme: CompressionScheme, level: Option<i32>) -> Self {
27 Self { scheme, level }
28 }
29}
30
31impl Default for CompressionConfig {
32 fn default() -> Self {
33 Self {
34 scheme: CompressionScheme::Zstd,
35 level: Some(0),
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq)]
41pub enum CompressionScheme {
42 None,
43 Fsst,
44 Zstd,
45}
46
47impl std::fmt::Display for CompressionScheme {
48 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
49 let scheme_str = match self {
50 Self::Fsst => "fsst",
51 Self::Zstd => "zstd",
52 Self::None => "none",
53 };
54 write!(f, "{}", scheme_str)
55 }
56}
57
58impl FromStr for CompressionScheme {
59 type Err = Error;
60
61 fn from_str(s: &str) -> Result<Self> {
62 match s {
63 "none" => Ok(Self::None),
64 "zstd" => Ok(Self::Zstd),
65 _ => Err(Error::invalid_input(
66 format!("Unknown compression scheme: {}", s),
67 location!(),
68 )),
69 }
70 }
71}
72
73pub trait BufferCompressor: std::fmt::Debug + Send + Sync {
74 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
75 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
76}
77
78#[derive(Debug, Default)]
79pub struct ZstdBufferCompressor {
80 compression_level: i32,
81}
82
83impl ZstdBufferCompressor {
84 pub fn new(compression_level: i32) -> Self {
85 Self { compression_level }
86 }
87}
88
89impl BufferCompressor for ZstdBufferCompressor {
90 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
91 let mut encoder = zstd::Encoder::new(output_buf, self.compression_level)?;
92 encoder.write_all(input_buf)?;
93 match encoder.finish() {
94 Ok(_) => Ok(()),
95 Err(e) => Err(e.into()),
96 }
97 }
98
99 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
100 let source = Cursor::new(input_buf);
101 zstd::stream::copy_decode(source, output_buf)?;
102 Ok(())
103 }
104}
105
106#[derive(Debug, Default)]
107pub struct NoopBufferCompressor {}
108
109impl BufferCompressor for NoopBufferCompressor {
110 fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
111 output_buf.extend_from_slice(input_buf);
112 Ok(())
113 }
114
115 fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
116 output_buf.extend_from_slice(input_buf);
117 Ok(())
118 }
119}
120
121pub struct GeneralBufferCompressor {}
122
123impl GeneralBufferCompressor {
124 pub fn get_compressor(compression_config: CompressionConfig) -> Box<dyn BufferCompressor> {
125 match compression_config.scheme {
126 CompressionScheme::Fsst => unimplemented!(),
128 CompressionScheme::Zstd => Box::new(ZstdBufferCompressor::new(
129 compression_config.level.unwrap_or(0),
130 )),
131 CompressionScheme::None => Box::new(NoopBufferCompressor {}),
132 }
133 }
134}
135
136#[derive(Debug)]
138pub struct CompressedBufferEncoder {
139 compressor: Box<dyn BufferCompressor>,
140}
141
142impl Default for CompressedBufferEncoder {
143 fn default() -> Self {
144 Self {
145 compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
146 scheme: CompressionScheme::Zstd,
147 level: Some(0),
148 }),
149 }
150 }
151}
152
153impl CompressedBufferEncoder {
154 pub fn new(compression_config: CompressionConfig) -> Self {
155 let compressor = GeneralBufferCompressor::get_compressor(compression_config);
156 Self { compressor }
157 }
158}
159
160impl ArrayEncoder for CompressedBufferEncoder {
161 fn encode(
162 &self,
163 data: DataBlock,
164 _data_type: &DataType,
165 buffer_index: &mut u32,
166 ) -> Result<EncodedArray> {
167 let uncompressed_data = data.as_fixed_width().unwrap();
168
169 let mut compressed_buf = Vec::with_capacity(uncompressed_data.data.len());
170 self.compressor
171 .compress(&uncompressed_data.data, &mut compressed_buf)?;
172
173 let compressed_data = DataBlock::Opaque(OpaqueBlock {
174 buffers: vec![compressed_buf.into()],
175 num_values: uncompressed_data.num_values,
176 block_info: BlockInfo::new(),
177 });
178
179 let comp_buf_index = *buffer_index;
180 *buffer_index += 1;
181
182 let encoding = ProtobufUtils::flat_encoding(
183 uncompressed_data.bits_per_value,
184 comp_buf_index,
185 Some(CompressionConfig::new(CompressionScheme::Zstd, None)),
186 );
187
188 Ok(EncodedArray {
189 data: compressed_data,
190 encoding,
191 })
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::buffer::LanceBuffer;
199 use crate::data::FixedWidthDataBlock;
200 use arrow_schema::DataType;
201 use std::str::FromStr;
202
203 #[test]
204 fn test_compression_scheme_from_str() {
205 assert_eq!(
206 CompressionScheme::from_str("none").unwrap(),
207 CompressionScheme::None
208 );
209 assert_eq!(
210 CompressionScheme::from_str("zstd").unwrap(),
211 CompressionScheme::Zstd
212 );
213 }
214
215 #[test]
216 fn test_compression_scheme_from_str_invalid() {
217 assert!(CompressionScheme::from_str("invalid").is_err());
218 }
219
220 #[test]
221 fn test_compressed_buffer_encoder() {
222 let encoder = CompressedBufferEncoder::default();
223 let data = DataBlock::FixedWidth(FixedWidthDataBlock {
224 bits_per_value: 64,
225 data: LanceBuffer::reinterpret_vec(vec![0, 1, 2, 3, 4, 5, 6, 7]),
226 num_values: 8,
227 block_info: BlockInfo::new(),
228 });
229
230 let mut buffer_index = 0;
231 let encoded_array_result = encoder.encode(data, &DataType::Int64, &mut buffer_index);
232 assert!(encoded_array_result.is_ok(), "{:?}", encoded_array_result);
233 let encoded_array = encoded_array_result.unwrap();
234 assert_eq!(encoded_array.data.num_values(), 8);
235 let buffers = encoded_array.data.into_buffers();
236 assert_eq!(buffers.len(), 1);
237 assert!(buffers[0].len() < 64 * 8);
238 }
239}