lance_io/
utils.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::cmp::min;
5
6use arrow_array::{
7    types::{BinaryType, LargeBinaryType, LargeUtf8Type, Utf8Type},
8    ArrayRef,
9};
10use arrow_schema::DataType;
11use byteorder::{ByteOrder, LittleEndian};
12use bytes::Bytes;
13use lance_arrow::*;
14use prost::Message;
15use snafu::location;
16
17use crate::{
18    encodings::{binary::BinaryDecoder, plain::PlainDecoder, AsyncIndex, Decoder},
19    traits::ProtoStruct,
20};
21use crate::{traits::Reader, ReadBatchParams};
22use lance_core::{Error, Result};
23
24/// Read a binary array from a [Reader].
25///
26pub async fn read_binary_array(
27    reader: &dyn Reader,
28    data_type: &DataType,
29    nullable: bool,
30    position: usize,
31    length: usize,
32    params: impl Into<ReadBatchParams>,
33) -> Result<ArrayRef> {
34    use arrow_schema::DataType::*;
35    let decoder: Box<dyn Decoder<Output = Result<ArrayRef>> + Send> = match data_type {
36        Utf8 => Box::new(BinaryDecoder::<Utf8Type>::new(
37            reader, position, length, nullable,
38        )),
39        Binary => Box::new(BinaryDecoder::<BinaryType>::new(
40            reader, position, length, nullable,
41        )),
42        LargeUtf8 => Box::new(BinaryDecoder::<LargeUtf8Type>::new(
43            reader, position, length, nullable,
44        )),
45        LargeBinary => Box::new(BinaryDecoder::<LargeBinaryType>::new(
46            reader, position, length, nullable,
47        )),
48        _ => {
49            return Err(Error::io(
50                format!("Unsupported binary type: {}", data_type),
51                location!(),
52            ));
53        }
54    };
55    let fut = decoder.as_ref().get(params.into());
56    fut.await
57}
58
59/// Read a fixed stride array from disk.
60///
61pub async fn read_fixed_stride_array(
62    reader: &dyn Reader,
63    data_type: &DataType,
64    position: usize,
65    length: usize,
66    params: impl Into<ReadBatchParams>,
67) -> Result<ArrayRef> {
68    if !data_type.is_fixed_stride() {
69        return Err(Error::Schema {
70            message: format!("{data_type} is not a fixed stride type"),
71            location: location!(),
72        });
73    }
74    // TODO: support more than plain encoding here.
75    let decoder = PlainDecoder::new(reader, data_type, position, length)?;
76    decoder.get(params.into()).await
77}
78
79/// Read a protobuf message at file position 'pos'.
80///
81/// We write protobuf by first writing the length of the message as a u32,
82/// followed by the message itself.
83pub async fn read_message<M: Message + Default>(reader: &dyn Reader, pos: usize) -> Result<M> {
84    let file_size = reader.size().await?;
85    if pos > file_size {
86        return Err(Error::io("file size is too small".to_string(), location!()));
87    }
88
89    let range = pos..min(pos + reader.block_size(), file_size);
90    let buf = reader.get_range(range.clone()).await?;
91    let msg_len = LittleEndian::read_u32(&buf) as usize;
92
93    if msg_len + 4 > buf.len() {
94        let remaining_range = range.end..min(4 + pos + msg_len, file_size);
95        let remaining_bytes = reader.get_range(remaining_range).await?;
96        let buf = [buf, remaining_bytes].concat();
97        assert!(buf.len() >= msg_len + 4);
98        Ok(M::decode(&buf[4..4 + msg_len])?)
99    } else {
100        Ok(M::decode(&buf[4..4 + msg_len])?)
101    }
102}
103
104/// Read a Protobuf-backed struct at file position: `pos`.
105// TODO: pub(crate)
106pub async fn read_struct<
107    M: Message + Default + 'static,
108    T: ProtoStruct<Proto = M> + TryFrom<M, Error = Error>,
109>(
110    reader: &dyn Reader,
111    pos: usize,
112) -> Result<T> {
113    let msg = read_message::<M>(reader, pos).await?;
114    T::try_from(msg)
115}
116
117pub async fn read_last_block(reader: &dyn Reader) -> object_store::Result<Bytes> {
118    let file_size = reader.size().await?;
119    let block_size = reader.block_size();
120    let begin = file_size.saturating_sub(block_size);
121    reader.get_range(begin..file_size).await
122}
123
124pub fn read_metadata_offset(bytes: &Bytes) -> Result<usize> {
125    let len = bytes.len();
126    if len < 16 {
127        return Err(Error::io(
128            format!(
129                "does not have sufficient data, len: {}, bytes: {:?}",
130                len, bytes
131            ),
132            location!(),
133        ));
134    }
135    let offset_bytes = bytes.slice(len - 16..len - 8);
136    Ok(LittleEndian::read_u64(offset_bytes.as_ref()) as usize)
137}
138
139/// Read the version from the footer bytes
140pub fn read_version(bytes: &Bytes) -> Result<(u16, u16)> {
141    let len = bytes.len();
142    if len < 8 {
143        return Err(Error::io(
144            format!(
145                "does not have sufficient data, len: {}, bytes: {:?}",
146                len, bytes
147            ),
148            location!(),
149        ));
150    }
151
152    let major_version = LittleEndian::read_u16(bytes.slice(len - 8..len - 6).as_ref());
153    let minor_version = LittleEndian::read_u16(bytes.slice(len - 6..len - 4).as_ref());
154    Ok((major_version, minor_version))
155}
156
157/// Read protobuf from a buffer.
158pub fn read_message_from_buf<M: Message + Default>(buf: &Bytes) -> Result<M> {
159    let msg_len = LittleEndian::read_u32(buf) as usize;
160    Ok(M::decode(&buf[4..4 + msg_len])?)
161}
162
163/// Read a Protobuf-backed struct from a buffer.
164pub fn read_struct_from_buf<
165    M: Message + Default,
166    T: ProtoStruct<Proto = M> + TryFrom<M, Error = Error>,
167>(
168    buf: &Bytes,
169) -> Result<T> {
170    let msg: M = read_message_from_buf(buf)?;
171    T::try_from(msg)
172}
173
174#[cfg(test)]
175mod tests {
176    use bytes::Bytes;
177    use object_store::path::Path;
178
179    use crate::{
180        object_reader::CloudObjectReader,
181        object_store::{ObjectStore, DEFAULT_DOWNLOAD_RETRY_COUNT},
182        object_writer::ObjectWriter,
183        traits::{ProtoStruct, WriteExt, Writer},
184        utils::read_struct,
185        Error, Result,
186    };
187
188    // Bytes is a prost::Message, since we don't have any .proto files in this crate we
189    // can use it to simulate a real message object.
190    #[derive(Debug, PartialEq)]
191    struct BytesWrapper(Bytes);
192
193    impl ProtoStruct for BytesWrapper {
194        type Proto = Bytes;
195    }
196
197    impl From<&BytesWrapper> for Bytes {
198        fn from(value: &BytesWrapper) -> Self {
199            value.0.clone()
200        }
201    }
202
203    impl TryFrom<Bytes> for BytesWrapper {
204        type Error = Error;
205        fn try_from(value: Bytes) -> Result<Self> {
206            Ok(Self(value))
207        }
208    }
209
210    #[tokio::test]
211    async fn test_write_proto_structs() {
212        let store = ObjectStore::memory();
213        let path = Path::from("/foo");
214
215        let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap();
216        assert_eq!(object_writer.tell().await.unwrap(), 0);
217
218        let some_message = BytesWrapper(Bytes::from(vec![10, 20, 30]));
219
220        let pos = object_writer.write_struct(&some_message).await.unwrap();
221        assert_eq!(pos, 0);
222        object_writer.shutdown().await.unwrap();
223
224        let object_reader =
225            CloudObjectReader::new(store.inner, path, 1024, None, DEFAULT_DOWNLOAD_RETRY_COUNT)
226                .unwrap();
227        let actual: BytesWrapper = read_struct(&object_reader, pos).await.unwrap();
228        assert_eq!(some_message, actual);
229    }
230}