lance_encoding/encodings/physical/
block_compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_schema::DataType;
5use snafu::location;
6use std::{
7    io::{Cursor, Write},
8    str::FromStr,
9};
10
11use lance_core::{Error, Result};
12
13use crate::{
14    data::{BlockInfo, DataBlock, OpaqueBlock},
15    encoder::{ArrayEncoder, EncodedArray},
16    format::ProtobufUtils,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq)]
20pub struct CompressionConfig {
21    pub(crate) scheme: CompressionScheme,
22    pub(crate) level: Option<i32>,
23}
24
25impl CompressionConfig {
26    pub(crate) fn new(scheme: CompressionScheme, level: Option<i32>) -> Self {
27        Self { scheme, level }
28    }
29}
30
31impl Default for CompressionConfig {
32    fn default() -> Self {
33        Self {
34            scheme: CompressionScheme::Zstd,
35            level: Some(0),
36        }
37    }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq)]
41pub enum CompressionScheme {
42    None,
43    Fsst,
44    Zstd,
45}
46
47impl std::fmt::Display for CompressionScheme {
48    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
49        let scheme_str = match self {
50            Self::Fsst => "fsst",
51            Self::Zstd => "zstd",
52            Self::None => "none",
53        };
54        write!(f, "{}", scheme_str)
55    }
56}
57
58impl FromStr for CompressionScheme {
59    type Err = Error;
60
61    fn from_str(s: &str) -> Result<Self> {
62        match s {
63            "none" => Ok(Self::None),
64            "zstd" => Ok(Self::Zstd),
65            _ => Err(Error::invalid_input(
66                format!("Unknown compression scheme: {}", s),
67                location!(),
68            )),
69        }
70    }
71}
72
73pub trait BufferCompressor: std::fmt::Debug + Send + Sync {
74    fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
75    fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()>;
76}
77
78#[derive(Debug, Default)]
79pub struct ZstdBufferCompressor {
80    compression_level: i32,
81}
82
83impl ZstdBufferCompressor {
84    pub fn new(compression_level: i32) -> Self {
85        Self { compression_level }
86    }
87}
88
89impl BufferCompressor for ZstdBufferCompressor {
90    fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
91        let mut encoder = zstd::Encoder::new(output_buf, self.compression_level)?;
92        encoder.write_all(input_buf)?;
93        match encoder.finish() {
94            Ok(_) => Ok(()),
95            Err(e) => Err(e.into()),
96        }
97    }
98
99    fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
100        let source = Cursor::new(input_buf);
101        zstd::stream::copy_decode(source, output_buf)?;
102        Ok(())
103    }
104}
105
106#[derive(Debug, Default)]
107pub struct NoopBufferCompressor {}
108
109impl BufferCompressor for NoopBufferCompressor {
110    fn compress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
111        output_buf.extend_from_slice(input_buf);
112        Ok(())
113    }
114
115    fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
116        output_buf.extend_from_slice(input_buf);
117        Ok(())
118    }
119}
120
121pub struct GeneralBufferCompressor {}
122
123impl GeneralBufferCompressor {
124    pub fn get_compressor(compression_config: CompressionConfig) -> Box<dyn BufferCompressor> {
125        match compression_config.scheme {
126            // FSST has its own compression path and isn't implemented as a generic buffer compressor
127            CompressionScheme::Fsst => unimplemented!(),
128            CompressionScheme::Zstd => Box::new(ZstdBufferCompressor::new(
129                compression_config.level.unwrap_or(0),
130            )),
131            CompressionScheme::None => Box::new(NoopBufferCompressor {}),
132        }
133    }
134}
135
136// An encoder which uses generic compression, such as zstd/lz4 to encode buffers
137#[derive(Debug)]
138pub struct CompressedBufferEncoder {
139    compressor: Box<dyn BufferCompressor>,
140}
141
142impl Default for CompressedBufferEncoder {
143    fn default() -> Self {
144        Self {
145            compressor: GeneralBufferCompressor::get_compressor(CompressionConfig {
146                scheme: CompressionScheme::Zstd,
147                level: Some(0),
148            }),
149        }
150    }
151}
152
153impl CompressedBufferEncoder {
154    pub fn new(compression_config: CompressionConfig) -> Self {
155        let compressor = GeneralBufferCompressor::get_compressor(compression_config);
156        Self { compressor }
157    }
158}
159
160impl ArrayEncoder for CompressedBufferEncoder {
161    fn encode(
162        &self,
163        data: DataBlock,
164        _data_type: &DataType,
165        buffer_index: &mut u32,
166    ) -> Result<EncodedArray> {
167        let uncompressed_data = data.as_fixed_width().unwrap();
168
169        let mut compressed_buf = Vec::with_capacity(uncompressed_data.data.len());
170        self.compressor
171            .compress(&uncompressed_data.data, &mut compressed_buf)?;
172
173        let compressed_data = DataBlock::Opaque(OpaqueBlock {
174            buffers: vec![compressed_buf.into()],
175            num_values: uncompressed_data.num_values,
176            block_info: BlockInfo::new(),
177        });
178
179        let comp_buf_index = *buffer_index;
180        *buffer_index += 1;
181
182        let encoding = ProtobufUtils::flat_encoding(
183            uncompressed_data.bits_per_value,
184            comp_buf_index,
185            Some(CompressionConfig::new(CompressionScheme::Zstd, None)),
186        );
187
188        Ok(EncodedArray {
189            data: compressed_data,
190            encoding,
191        })
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::buffer::LanceBuffer;
199    use crate::data::FixedWidthDataBlock;
200    use arrow_schema::DataType;
201    use std::str::FromStr;
202
203    #[test]
204    fn test_compression_scheme_from_str() {
205        assert_eq!(
206            CompressionScheme::from_str("none").unwrap(),
207            CompressionScheme::None
208        );
209        assert_eq!(
210            CompressionScheme::from_str("zstd").unwrap(),
211            CompressionScheme::Zstd
212        );
213    }
214
215    #[test]
216    fn test_compression_scheme_from_str_invalid() {
217        assert!(CompressionScheme::from_str("invalid").is_err());
218    }
219
220    #[test]
221    fn test_compressed_buffer_encoder() {
222        let encoder = CompressedBufferEncoder::default();
223        let data = DataBlock::FixedWidth(FixedWidthDataBlock {
224            bits_per_value: 64,
225            data: LanceBuffer::reinterpret_vec(vec![0, 1, 2, 3, 4, 5, 6, 7]),
226            num_values: 8,
227            block_info: BlockInfo::new(),
228        });
229
230        let mut buffer_index = 0;
231        let encoded_array_result = encoder.encode(data, &DataType::Int64, &mut buffer_index);
232        assert!(encoded_array_result.is_ok(), "{:?}", encoded_array_result);
233        let encoded_array = encoded_array_result.unwrap();
234        assert_eq!(encoded_array.data.num_values(), 8);
235        let buffers = encoded_array.data.into_buffers();
236        assert_eq!(buffers.len(), 1);
237        assert!(buffers[0].len() < 64 * 8);
238    }
239}