arrow_ipc/reader/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::fmt::Debug;
20use std::sync::Arc;
21
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_buffer::{Buffer, MutableBuffer};
24use arrow_schema::{ArrowError, SchemaRef};
25
26use crate::convert::MessageBuffer;
27use crate::reader::{read_dictionary_impl, RecordBatchDecoder};
28use crate::{MessageHeader, CONTINUATION_MARKER};
29
30/// A low-level interface for reading [`RecordBatch`] data from a stream of bytes
31///
32/// See [StreamReader](crate::reader::StreamReader) for a higher-level interface
33#[derive(Debug, Default)]
34pub struct StreamDecoder {
35    /// The schema of this decoder, if read
36    schema: Option<SchemaRef>,
37    /// Lookup table for dictionaries by ID
38    dictionaries: HashMap<i64, ArrayRef>,
39    /// The decoder state
40    state: DecoderState,
41    /// A scratch buffer when a read is split across multiple `Buffer`
42    buf: MutableBuffer,
43    /// Whether or not array data in input buffers are required to be aligned
44    require_alignment: bool,
45}
46
47#[derive(Debug)]
48enum DecoderState {
49    /// Decoding the message header
50    Header {
51        /// Temporary buffer
52        buf: [u8; 4],
53        /// Number of bytes read into buf
54        read: u8,
55        /// If we have read a continuation token
56        continuation: bool,
57    },
58    /// Decoding the message flatbuffer
59    Message {
60        /// The size of the message flatbuffer
61        size: u32,
62    },
63    /// Decoding the message body
64    Body {
65        /// The message flatbuffer
66        message: MessageBuffer,
67    },
68    /// Reached the end of the stream
69    Finished,
70}
71
72impl Default for DecoderState {
73    fn default() -> Self {
74        Self::Header {
75            buf: [0; 4],
76            read: 0,
77            continuation: false,
78        }
79    }
80}
81
82impl StreamDecoder {
83    /// Create a new [`StreamDecoder`]
84    pub fn new() -> Self {
85        Self::default()
86    }
87
88    /// Specifies whether or not array data in input buffers is required to be properly aligned.
89    ///
90    /// If `require_alignment` is true, this decoder will return an error if any array data in the
91    /// input `buf` is not properly aligned.
92    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct
93    /// [`arrow_data::ArrayData`].
94    ///
95    /// If `require_alignment` is false (the default), this decoder will automatically allocate a
96    /// new aligned buffer and copy over the data if any array data in the input `buf` is not
97    /// properly aligned. (Properly aligned array data will remain zero-copy.)
98    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct
99    /// [`arrow_data::ArrayData`].
100    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
101        self.require_alignment = require_alignment;
102        self
103    }
104
105    /// Try to read the next [`RecordBatch`] from the provided [`Buffer`]
106    ///
107    /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes.
108    ///
109    /// The push-based interface facilitates integration with sources that yield arbitrarily
110    /// delimited bytes ranges, such as a chunked byte stream received from object storage
111    ///
112    /// ```
113    /// # use arrow_array::RecordBatch;
114    /// # use arrow_buffer::Buffer;
115    /// # use arrow_ipc::reader::StreamDecoder;
116    /// # use arrow_schema::ArrowError;
117    /// #
118    /// fn print_stream<I>(src: impl Iterator<Item = Buffer>) -> Result<(), ArrowError> {
119    ///     let mut decoder = StreamDecoder::new();
120    ///     for mut x in src {
121    ///         while !x.is_empty() {
122    ///             if let Some(x) = decoder.decode(&mut x)? {
123    ///                 println!("{x:?}");
124    ///             }
125    ///         }
126    ///     }
127    ///     decoder.finish().unwrap();
128    ///     Ok(())
129    /// }
130    /// ```
131    pub fn decode(&mut self, buffer: &mut Buffer) -> Result<Option<RecordBatch>, ArrowError> {
132        while !buffer.is_empty() {
133            match &mut self.state {
134                DecoderState::Header {
135                    buf,
136                    read,
137                    continuation,
138                } => {
139                    let offset_buf = &mut buf[*read as usize..];
140                    let to_read = buffer.len().min(offset_buf.len());
141                    offset_buf[..to_read].copy_from_slice(&buffer[..to_read]);
142                    *read += to_read as u8;
143                    buffer.advance(to_read);
144                    if *read == 4 {
145                        if !*continuation && buf == &CONTINUATION_MARKER {
146                            *continuation = true;
147                            *read = 0;
148                            continue;
149                        }
150                        let size = u32::from_le_bytes(*buf);
151
152                        if size == 0 {
153                            self.state = DecoderState::Finished;
154                            continue;
155                        }
156                        self.state = DecoderState::Message { size };
157                    }
158                }
159                DecoderState::Message { size } => {
160                    let len = *size as usize;
161                    if self.buf.is_empty() && buffer.len() > len {
162                        let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?;
163                        self.state = DecoderState::Body { message };
164                        buffer.advance(len);
165                        continue;
166                    }
167
168                    let to_read = buffer.len().min(len - self.buf.len());
169                    self.buf.extend_from_slice(&buffer[..to_read]);
170                    buffer.advance(to_read);
171                    if self.buf.len() == len {
172                        let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?;
173                        self.state = DecoderState::Body { message };
174                    }
175                }
176                DecoderState::Body { message } => {
177                    let message = message.as_ref();
178                    let body_length = message.bodyLength() as usize;
179
180                    let body = if self.buf.is_empty() && buffer.len() >= body_length {
181                        let body = buffer.slice_with_length(0, body_length);
182                        buffer.advance(body_length);
183                        body
184                    } else {
185                        let to_read = buffer.len().min(body_length - self.buf.len());
186                        self.buf.extend_from_slice(&buffer[..to_read]);
187                        buffer.advance(to_read);
188
189                        if self.buf.len() != body_length {
190                            continue;
191                        }
192                        std::mem::take(&mut self.buf).into()
193                    };
194
195                    let version = message.version();
196                    match message.header_type() {
197                        MessageHeader::Schema => {
198                            if self.schema.is_some() {
199                                return Err(ArrowError::IpcError(
200                                    "Not expecting a schema when messages are read".to_string(),
201                                ));
202                            }
203
204                            let ipc_schema = message.header_as_schema().unwrap();
205                            let schema = crate::convert::fb_to_schema(ipc_schema);
206                            self.state = DecoderState::default();
207                            self.schema = Some(Arc::new(schema));
208                        }
209                        MessageHeader::RecordBatch => {
210                            let batch = message.header_as_record_batch().unwrap();
211                            let schema = self.schema.clone().ok_or_else(|| {
212                                ArrowError::IpcError("Missing schema".to_string())
213                            })?;
214                            let batch = RecordBatchDecoder::try_new(
215                                &body,
216                                batch,
217                                schema,
218                                &self.dictionaries,
219                                &version,
220                            )?
221                            .with_require_alignment(self.require_alignment)
222                            .read_record_batch()?;
223                            self.state = DecoderState::default();
224                            return Ok(Some(batch));
225                        }
226                        MessageHeader::DictionaryBatch => {
227                            let dictionary = message.header_as_dictionary_batch().unwrap();
228                            let schema = self.schema.as_deref().ok_or_else(|| {
229                                ArrowError::IpcError("Missing schema".to_string())
230                            })?;
231                            read_dictionary_impl(
232                                &body,
233                                dictionary,
234                                schema,
235                                &mut self.dictionaries,
236                                &version,
237                                self.require_alignment,
238                            )?;
239                            self.state = DecoderState::default();
240                        }
241                        MessageHeader::NONE => {
242                            self.state = DecoderState::default();
243                        }
244                        t => {
245                            return Err(ArrowError::IpcError(format!(
246                                "Message type unsupported by StreamDecoder: {t:?}"
247                            )))
248                        }
249                    }
250                }
251                DecoderState::Finished => {
252                    return Err(ArrowError::IpcError("Unexpected EOS".to_string()))
253                }
254            }
255        }
256        Ok(None)
257    }
258
259    /// Signal the end of stream
260    ///
261    /// Returns an error if any partial data remains in the stream
262    pub fn finish(&mut self) -> Result<(), ArrowError> {
263        match self.state {
264            DecoderState::Finished
265            | DecoderState::Header {
266                read: 0,
267                continuation: false,
268                ..
269            } => Ok(()),
270            _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())),
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::writer::{IpcWriteOptions, StreamWriter};
279    use arrow_array::{
280        types::Int32Type, DictionaryArray, Int32Array, Int64Array, RecordBatch, RunArray,
281    };
282    use arrow_schema::{DataType, Field, Schema};
283
284    // Further tests in arrow-integration-testing/tests/ipc_reader.rs
285
286    #[test]
287    fn test_eos() {
288        let schema = Arc::new(Schema::new(vec![
289            Field::new("int32", DataType::Int32, false),
290            Field::new("int64", DataType::Int64, false),
291        ]));
292
293        let input = RecordBatch::try_new(
294            schema.clone(),
295            vec![
296                Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
297                Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
298            ],
299        )
300        .unwrap();
301
302        let mut buf = Vec::with_capacity(1024);
303        let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
304        s.write(&input).unwrap();
305        s.finish().unwrap();
306        drop(s);
307
308        let buffer = Buffer::from_vec(buf);
309
310        let mut b = buffer.slice_with_length(0, buffer.len() - 1);
311        let mut decoder = StreamDecoder::new();
312        let output = decoder.decode(&mut b).unwrap().unwrap();
313        assert_eq!(output, input);
314        assert_eq!(b.len(), 7); // 8 byte EOS truncated by 1 byte
315        assert!(decoder.decode(&mut b).unwrap().is_none());
316
317        let err = decoder.finish().unwrap_err().to_string();
318        assert_eq!(err, "Ipc error: Unexpected End of Stream");
319    }
320
321    #[test]
322    fn test_read_ree_dict_record_batches_from_buffer() {
323        let schema = Schema::new(vec![Field::new(
324            "test1",
325            DataType::RunEndEncoded(
326                Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)),
327                #[allow(deprecated)]
328                Arc::new(Field::new_dict(
329                    "values".to_string(),
330                    DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
331                    true,
332                    0,
333                    false,
334                )),
335            ),
336            true,
337        )]);
338        let batch = RecordBatch::try_new(
339            schema.clone().into(),
340            vec![Arc::new(
341                RunArray::try_new(
342                    &Int32Array::from(vec![1, 2, 3]),
343                    &vec![Some("a"), None, Some("a")]
344                        .into_iter()
345                        .collect::<DictionaryArray<Int32Type>>(),
346                )
347                .expect("Failed to create RunArray"),
348            )],
349        )
350        .expect("Failed to create RecordBatch");
351
352        let mut buffer = vec![];
353        {
354            let mut writer = StreamWriter::try_new_with_options(
355                &mut buffer,
356                &schema,
357                #[allow(deprecated)]
358                IpcWriteOptions::default().with_preserve_dict_id(false),
359            )
360            .expect("Failed to create StreamWriter");
361            writer.write(&batch).expect("Failed to write RecordBatch");
362            writer.finish().expect("Failed to finish StreamWriter");
363        }
364
365        let mut decoder = StreamDecoder::new();
366        let buf = &mut Buffer::from(buffer.as_slice());
367        while let Some(batch) = decoder
368            .decode(buf)
369            .map_err(|e| {
370                ArrowError::ExternalError(format!("Failed to decode record batch: {}", e).into())
371            })
372            .expect("Failed to decode record batch")
373        {
374            assert_eq!(batch, batch);
375        }
376
377        decoder.finish().expect("Failed to finish decoder");
378    }
379}