arrow_ipc/reader/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::fmt::Debug;
20use std::sync::Arc;
21
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_buffer::{Buffer, MutableBuffer};
24use arrow_data::UnsafeFlag;
25use arrow_schema::{ArrowError, SchemaRef};
26
27use crate::convert::MessageBuffer;
28use crate::reader::{read_dictionary_impl, RecordBatchDecoder};
29use crate::{MessageHeader, CONTINUATION_MARKER};
30
31/// A low-level interface for reading [`RecordBatch`] data from a stream of bytes
32///
33/// See [StreamReader](crate::reader::StreamReader) for a higher-level interface
34#[derive(Debug, Default)]
35pub struct StreamDecoder {
36    /// The schema of this decoder, if read
37    schema: Option<SchemaRef>,
38    /// Lookup table for dictionaries by ID
39    dictionaries: HashMap<i64, ArrayRef>,
40    /// The decoder state
41    state: DecoderState,
42    /// A scratch buffer when a read is split across multiple `Buffer`
43    buf: MutableBuffer,
44    /// Whether or not array data in input buffers are required to be aligned
45    require_alignment: bool,
46    /// Should validation be skipped when reading data? Defaults to false.
47    ///
48    /// See [`FileDecoder::with_skip_validation`] for details.
49    ///
50    /// [`FileDecoder::with_skip_validation`]: crate::reader::FileDecoder::with_skip_validation
51    skip_validation: UnsafeFlag,
52}
53
54#[derive(Debug)]
55enum DecoderState {
56    /// Decoding the message header
57    Header {
58        /// Temporary buffer
59        buf: [u8; 4],
60        /// Number of bytes read into buf
61        read: u8,
62        /// If we have read a continuation token
63        continuation: bool,
64    },
65    /// Decoding the message flatbuffer
66    Message {
67        /// The size of the message flatbuffer
68        size: u32,
69    },
70    /// Decoding the message body
71    Body {
72        /// The message flatbuffer
73        message: MessageBuffer,
74    },
75    /// Reached the end of the stream
76    Finished,
77}
78
79impl Default for DecoderState {
80    fn default() -> Self {
81        Self::Header {
82            buf: [0; 4],
83            read: 0,
84            continuation: false,
85        }
86    }
87}
88
89impl StreamDecoder {
90    /// Create a new [`StreamDecoder`]
91    pub fn new() -> Self {
92        Self::default()
93    }
94
95    /// Specifies whether or not array data in input buffers is required to be properly aligned.
96    ///
97    /// If `require_alignment` is true, this decoder will return an error if any array data in the
98    /// input `buf` is not properly aligned.
99    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct
100    /// [`arrow_data::ArrayData`].
101    ///
102    /// If `require_alignment` is false (the default), this decoder will automatically allocate a
103    /// new aligned buffer and copy over the data if any array data in the input `buf` is not
104    /// properly aligned. (Properly aligned array data will remain zero-copy.)
105    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct
106    /// [`arrow_data::ArrayData`].
107    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
108        self.require_alignment = require_alignment;
109        self
110    }
111
112    /// Try to read the next [`RecordBatch`] from the provided [`Buffer`]
113    ///
114    /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes.
115    ///
116    /// The push-based interface facilitates integration with sources that yield arbitrarily
117    /// delimited bytes ranges, such as a chunked byte stream received from object storage
118    ///
119    /// ```
120    /// # use arrow_array::RecordBatch;
121    /// # use arrow_buffer::Buffer;
122    /// # use arrow_ipc::reader::StreamDecoder;
123    /// # use arrow_schema::ArrowError;
124    /// #
125    /// fn print_stream<I>(src: impl Iterator<Item = Buffer>) -> Result<(), ArrowError> {
126    ///     let mut decoder = StreamDecoder::new();
127    ///     for mut x in src {
128    ///         while !x.is_empty() {
129    ///             if let Some(x) = decoder.decode(&mut x)? {
130    ///                 println!("{x:?}");
131    ///             }
132    ///         }
133    ///     }
134    ///     decoder.finish().unwrap();
135    ///     Ok(())
136    /// }
137    /// ```
138    pub fn decode(&mut self, buffer: &mut Buffer) -> Result<Option<RecordBatch>, ArrowError> {
139        while !buffer.is_empty() {
140            match &mut self.state {
141                DecoderState::Header {
142                    buf,
143                    read,
144                    continuation,
145                } => {
146                    let offset_buf = &mut buf[*read as usize..];
147                    let to_read = buffer.len().min(offset_buf.len());
148                    offset_buf[..to_read].copy_from_slice(&buffer[..to_read]);
149                    *read += to_read as u8;
150                    buffer.advance(to_read);
151                    if *read == 4 {
152                        if !*continuation && buf == &CONTINUATION_MARKER {
153                            *continuation = true;
154                            *read = 0;
155                            continue;
156                        }
157                        let size = u32::from_le_bytes(*buf);
158
159                        if size == 0 {
160                            self.state = DecoderState::Finished;
161                            continue;
162                        }
163                        self.state = DecoderState::Message { size };
164                    }
165                }
166                DecoderState::Message { size } => {
167                    let len = *size as usize;
168                    if self.buf.is_empty() && buffer.len() > len {
169                        let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?;
170                        self.state = DecoderState::Body { message };
171                        buffer.advance(len);
172                        continue;
173                    }
174
175                    let to_read = buffer.len().min(len - self.buf.len());
176                    self.buf.extend_from_slice(&buffer[..to_read]);
177                    buffer.advance(to_read);
178                    if self.buf.len() == len {
179                        let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?;
180                        self.state = DecoderState::Body { message };
181                    }
182                }
183                DecoderState::Body { message } => {
184                    let message = message.as_ref();
185                    let body_length = message.bodyLength() as usize;
186
187                    let body = if self.buf.is_empty() && buffer.len() >= body_length {
188                        let body = buffer.slice_with_length(0, body_length);
189                        buffer.advance(body_length);
190                        body
191                    } else {
192                        let to_read = buffer.len().min(body_length - self.buf.len());
193                        self.buf.extend_from_slice(&buffer[..to_read]);
194                        buffer.advance(to_read);
195
196                        if self.buf.len() != body_length {
197                            continue;
198                        }
199                        std::mem::take(&mut self.buf).into()
200                    };
201
202                    let version = message.version();
203                    match message.header_type() {
204                        MessageHeader::Schema => {
205                            if self.schema.is_some() {
206                                return Err(ArrowError::IpcError(
207                                    "Not expecting a schema when messages are read".to_string(),
208                                ));
209                            }
210
211                            let ipc_schema = message.header_as_schema().unwrap();
212                            let schema = crate::convert::fb_to_schema(ipc_schema);
213                            self.state = DecoderState::default();
214                            self.schema = Some(Arc::new(schema));
215                        }
216                        MessageHeader::RecordBatch => {
217                            let batch = message.header_as_record_batch().unwrap();
218                            let schema = self.schema.clone().ok_or_else(|| {
219                                ArrowError::IpcError("Missing schema".to_string())
220                            })?;
221                            let batch = RecordBatchDecoder::try_new(
222                                &body,
223                                batch,
224                                schema,
225                                &self.dictionaries,
226                                &version,
227                            )?
228                            .with_require_alignment(self.require_alignment)
229                            .read_record_batch()?;
230                            self.state = DecoderState::default();
231                            return Ok(Some(batch));
232                        }
233                        MessageHeader::DictionaryBatch => {
234                            let dictionary = message.header_as_dictionary_batch().unwrap();
235                            let schema = self.schema.as_deref().ok_or_else(|| {
236                                ArrowError::IpcError("Missing schema".to_string())
237                            })?;
238                            read_dictionary_impl(
239                                &body,
240                                dictionary,
241                                schema,
242                                &mut self.dictionaries,
243                                &version,
244                                self.require_alignment,
245                                self.skip_validation.clone(),
246                            )?;
247                            self.state = DecoderState::default();
248                        }
249                        MessageHeader::NONE => {
250                            self.state = DecoderState::default();
251                        }
252                        t => {
253                            return Err(ArrowError::IpcError(format!(
254                                "Message type unsupported by StreamDecoder: {t:?}"
255                            )))
256                        }
257                    }
258                }
259                DecoderState::Finished => {
260                    return Err(ArrowError::IpcError("Unexpected EOS".to_string()))
261                }
262            }
263        }
264        Ok(None)
265    }
266
267    /// Signal the end of stream
268    ///
269    /// Returns an error if any partial data remains in the stream
270    pub fn finish(&mut self) -> Result<(), ArrowError> {
271        match self.state {
272            DecoderState::Finished
273            | DecoderState::Header {
274                read: 0,
275                continuation: false,
276                ..
277            } => Ok(()),
278            _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())),
279        }
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use crate::writer::{IpcWriteOptions, StreamWriter};
287    use arrow_array::{
288        types::Int32Type, DictionaryArray, Int32Array, Int64Array, RecordBatch, RunArray,
289    };
290    use arrow_schema::{DataType, Field, Schema};
291
292    // Further tests in arrow-integration-testing/tests/ipc_reader.rs
293
294    #[test]
295    fn test_eos() {
296        let schema = Arc::new(Schema::new(vec![
297            Field::new("int32", DataType::Int32, false),
298            Field::new("int64", DataType::Int64, false),
299        ]));
300
301        let input = RecordBatch::try_new(
302            schema.clone(),
303            vec![
304                Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
305                Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
306            ],
307        )
308        .unwrap();
309
310        let mut buf = Vec::with_capacity(1024);
311        let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
312        s.write(&input).unwrap();
313        s.finish().unwrap();
314        drop(s);
315
316        let buffer = Buffer::from_vec(buf);
317
318        let mut b = buffer.slice_with_length(0, buffer.len() - 1);
319        let mut decoder = StreamDecoder::new();
320        let output = decoder.decode(&mut b).unwrap().unwrap();
321        assert_eq!(output, input);
322        assert_eq!(b.len(), 7); // 8 byte EOS truncated by 1 byte
323        assert!(decoder.decode(&mut b).unwrap().is_none());
324
325        let err = decoder.finish().unwrap_err().to_string();
326        assert_eq!(err, "Ipc error: Unexpected End of Stream");
327    }
328
329    #[test]
330    fn test_read_ree_dict_record_batches_from_buffer() {
331        let schema = Schema::new(vec![Field::new(
332            "test1",
333            DataType::RunEndEncoded(
334                Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)),
335                #[allow(deprecated)]
336                Arc::new(Field::new_dict(
337                    "values".to_string(),
338                    DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
339                    true,
340                    0,
341                    false,
342                )),
343            ),
344            true,
345        )]);
346        let batch = RecordBatch::try_new(
347            schema.clone().into(),
348            vec![Arc::new(
349                RunArray::try_new(
350                    &Int32Array::from(vec![1, 2, 3]),
351                    &vec![Some("a"), None, Some("a")]
352                        .into_iter()
353                        .collect::<DictionaryArray<Int32Type>>(),
354                )
355                .expect("Failed to create RunArray"),
356            )],
357        )
358        .expect("Failed to create RecordBatch");
359
360        let mut buffer = vec![];
361        {
362            let mut writer = StreamWriter::try_new_with_options(
363                &mut buffer,
364                &schema,
365                #[allow(deprecated)]
366                IpcWriteOptions::default().with_preserve_dict_id(false),
367            )
368            .expect("Failed to create StreamWriter");
369            writer.write(&batch).expect("Failed to write RecordBatch");
370            writer.finish().expect("Failed to finish StreamWriter");
371        }
372
373        let mut decoder = StreamDecoder::new();
374        let buf = &mut Buffer::from(buffer.as_slice());
375        while let Some(batch) = decoder
376            .decode(buf)
377            .map_err(|e| {
378                ArrowError::ExternalError(format!("Failed to decode record batch: {}", e).into())
379            })
380            .expect("Failed to decode record batch")
381        {
382            assert_eq!(batch, batch);
383        }
384
385        decoder.finish().expect("Failed to finish decoder");
386    }
387}