1use std::collections::HashMap;
19use std::fmt::Debug;
20use std::sync::Arc;
21
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_buffer::{Buffer, MutableBuffer};
24use arrow_data::UnsafeFlag;
25use arrow_schema::{ArrowError, SchemaRef};
26
27use crate::convert::MessageBuffer;
28use crate::reader::{read_dictionary_impl, RecordBatchDecoder};
29use crate::{MessageHeader, CONTINUATION_MARKER};
30
31#[derive(Debug, Default)]
35pub struct StreamDecoder {
36 schema: Option<SchemaRef>,
38 dictionaries: HashMap<i64, ArrayRef>,
40 state: DecoderState,
42 buf: MutableBuffer,
44 require_alignment: bool,
46 skip_validation: UnsafeFlag,
52}
53
54#[derive(Debug)]
55enum DecoderState {
56 Header {
58 buf: [u8; 4],
60 read: u8,
62 continuation: bool,
64 },
65 Message {
67 size: u32,
69 },
70 Body {
72 message: MessageBuffer,
74 },
75 Finished,
77}
78
79impl Default for DecoderState {
80 fn default() -> Self {
81 Self::Header {
82 buf: [0; 4],
83 read: 0,
84 continuation: false,
85 }
86 }
87}
88
89impl StreamDecoder {
90 pub fn new() -> Self {
92 Self::default()
93 }
94
95 pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
108 self.require_alignment = require_alignment;
109 self
110 }
111
112 pub fn decode(&mut self, buffer: &mut Buffer) -> Result<Option<RecordBatch>, ArrowError> {
139 while !buffer.is_empty() {
140 match &mut self.state {
141 DecoderState::Header {
142 buf,
143 read,
144 continuation,
145 } => {
146 let offset_buf = &mut buf[*read as usize..];
147 let to_read = buffer.len().min(offset_buf.len());
148 offset_buf[..to_read].copy_from_slice(&buffer[..to_read]);
149 *read += to_read as u8;
150 buffer.advance(to_read);
151 if *read == 4 {
152 if !*continuation && buf == &CONTINUATION_MARKER {
153 *continuation = true;
154 *read = 0;
155 continue;
156 }
157 let size = u32::from_le_bytes(*buf);
158
159 if size == 0 {
160 self.state = DecoderState::Finished;
161 continue;
162 }
163 self.state = DecoderState::Message { size };
164 }
165 }
166 DecoderState::Message { size } => {
167 let len = *size as usize;
168 if self.buf.is_empty() && buffer.len() > len {
169 let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?;
170 self.state = DecoderState::Body { message };
171 buffer.advance(len);
172 continue;
173 }
174
175 let to_read = buffer.len().min(len - self.buf.len());
176 self.buf.extend_from_slice(&buffer[..to_read]);
177 buffer.advance(to_read);
178 if self.buf.len() == len {
179 let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?;
180 self.state = DecoderState::Body { message };
181 }
182 }
183 DecoderState::Body { message } => {
184 let message = message.as_ref();
185 let body_length = message.bodyLength() as usize;
186
187 let body = if self.buf.is_empty() && buffer.len() >= body_length {
188 let body = buffer.slice_with_length(0, body_length);
189 buffer.advance(body_length);
190 body
191 } else {
192 let to_read = buffer.len().min(body_length - self.buf.len());
193 self.buf.extend_from_slice(&buffer[..to_read]);
194 buffer.advance(to_read);
195
196 if self.buf.len() != body_length {
197 continue;
198 }
199 std::mem::take(&mut self.buf).into()
200 };
201
202 let version = message.version();
203 match message.header_type() {
204 MessageHeader::Schema => {
205 if self.schema.is_some() {
206 return Err(ArrowError::IpcError(
207 "Not expecting a schema when messages are read".to_string(),
208 ));
209 }
210
211 let ipc_schema = message.header_as_schema().unwrap();
212 let schema = crate::convert::fb_to_schema(ipc_schema);
213 self.state = DecoderState::default();
214 self.schema = Some(Arc::new(schema));
215 }
216 MessageHeader::RecordBatch => {
217 let batch = message.header_as_record_batch().unwrap();
218 let schema = self.schema.clone().ok_or_else(|| {
219 ArrowError::IpcError("Missing schema".to_string())
220 })?;
221 let batch = RecordBatchDecoder::try_new(
222 &body,
223 batch,
224 schema,
225 &self.dictionaries,
226 &version,
227 )?
228 .with_require_alignment(self.require_alignment)
229 .read_record_batch()?;
230 self.state = DecoderState::default();
231 return Ok(Some(batch));
232 }
233 MessageHeader::DictionaryBatch => {
234 let dictionary = message.header_as_dictionary_batch().unwrap();
235 let schema = self.schema.as_deref().ok_or_else(|| {
236 ArrowError::IpcError("Missing schema".to_string())
237 })?;
238 read_dictionary_impl(
239 &body,
240 dictionary,
241 schema,
242 &mut self.dictionaries,
243 &version,
244 self.require_alignment,
245 self.skip_validation.clone(),
246 )?;
247 self.state = DecoderState::default();
248 }
249 MessageHeader::NONE => {
250 self.state = DecoderState::default();
251 }
252 t => {
253 return Err(ArrowError::IpcError(format!(
254 "Message type unsupported by StreamDecoder: {t:?}"
255 )))
256 }
257 }
258 }
259 DecoderState::Finished => {
260 return Err(ArrowError::IpcError("Unexpected EOS".to_string()))
261 }
262 }
263 }
264 Ok(None)
265 }
266
267 pub fn finish(&mut self) -> Result<(), ArrowError> {
271 match self.state {
272 DecoderState::Finished
273 | DecoderState::Header {
274 read: 0,
275 continuation: false,
276 ..
277 } => Ok(()),
278 _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())),
279 }
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use crate::writer::{IpcWriteOptions, StreamWriter};
287 use arrow_array::{
288 types::Int32Type, DictionaryArray, Int32Array, Int64Array, RecordBatch, RunArray,
289 };
290 use arrow_schema::{DataType, Field, Schema};
291
292 #[test]
295 fn test_eos() {
296 let schema = Arc::new(Schema::new(vec![
297 Field::new("int32", DataType::Int32, false),
298 Field::new("int64", DataType::Int64, false),
299 ]));
300
301 let input = RecordBatch::try_new(
302 schema.clone(),
303 vec![
304 Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
305 Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
306 ],
307 )
308 .unwrap();
309
310 let mut buf = Vec::with_capacity(1024);
311 let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
312 s.write(&input).unwrap();
313 s.finish().unwrap();
314 drop(s);
315
316 let buffer = Buffer::from_vec(buf);
317
318 let mut b = buffer.slice_with_length(0, buffer.len() - 1);
319 let mut decoder = StreamDecoder::new();
320 let output = decoder.decode(&mut b).unwrap().unwrap();
321 assert_eq!(output, input);
322 assert_eq!(b.len(), 7); assert!(decoder.decode(&mut b).unwrap().is_none());
324
325 let err = decoder.finish().unwrap_err().to_string();
326 assert_eq!(err, "Ipc error: Unexpected End of Stream");
327 }
328
329 #[test]
330 fn test_read_ree_dict_record_batches_from_buffer() {
331 let schema = Schema::new(vec![Field::new(
332 "test1",
333 DataType::RunEndEncoded(
334 Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)),
335 #[allow(deprecated)]
336 Arc::new(Field::new_dict(
337 "values".to_string(),
338 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
339 true,
340 0,
341 false,
342 )),
343 ),
344 true,
345 )]);
346 let batch = RecordBatch::try_new(
347 schema.clone().into(),
348 vec![Arc::new(
349 RunArray::try_new(
350 &Int32Array::from(vec![1, 2, 3]),
351 &vec![Some("a"), None, Some("a")]
352 .into_iter()
353 .collect::<DictionaryArray<Int32Type>>(),
354 )
355 .expect("Failed to create RunArray"),
356 )],
357 )
358 .expect("Failed to create RecordBatch");
359
360 let mut buffer = vec![];
361 {
362 let mut writer = StreamWriter::try_new_with_options(
363 &mut buffer,
364 &schema,
365 #[allow(deprecated)]
366 IpcWriteOptions::default().with_preserve_dict_id(false),
367 )
368 .expect("Failed to create StreamWriter");
369 writer.write(&batch).expect("Failed to write RecordBatch");
370 writer.finish().expect("Failed to finish StreamWriter");
371 }
372
373 let mut decoder = StreamDecoder::new();
374 let buf = &mut Buffer::from(buffer.as_slice());
375 while let Some(batch) = decoder
376 .decode(buf)
377 .map_err(|e| {
378 ArrowError::ExternalError(format!("Failed to decode record batch: {}", e).into())
379 })
380 .expect("Failed to decode record batch")
381 {
382 assert_eq!(batch, batch);
383 }
384
385 decoder.finish().expect("Failed to finish decoder");
386 }
387}