1use std::collections::HashMap;
19use std::fmt::Debug;
20use std::sync::Arc;
21
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_buffer::{Buffer, MutableBuffer};
24use arrow_schema::{ArrowError, SchemaRef};
25
26use crate::convert::MessageBuffer;
27use crate::reader::{read_dictionary_impl, RecordBatchDecoder};
28use crate::{MessageHeader, CONTINUATION_MARKER};
29
30#[derive(Debug, Default)]
34pub struct StreamDecoder {
35 schema: Option<SchemaRef>,
37 dictionaries: HashMap<i64, ArrayRef>,
39 state: DecoderState,
41 buf: MutableBuffer,
43 require_alignment: bool,
45}
46
47#[derive(Debug)]
48enum DecoderState {
49 Header {
51 buf: [u8; 4],
53 read: u8,
55 continuation: bool,
57 },
58 Message {
60 size: u32,
62 },
63 Body {
65 message: MessageBuffer,
67 },
68 Finished,
70}
71
72impl Default for DecoderState {
73 fn default() -> Self {
74 Self::Header {
75 buf: [0; 4],
76 read: 0,
77 continuation: false,
78 }
79 }
80}
81
82impl StreamDecoder {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
101 self.require_alignment = require_alignment;
102 self
103 }
104
105 pub fn decode(&mut self, buffer: &mut Buffer) -> Result<Option<RecordBatch>, ArrowError> {
132 while !buffer.is_empty() {
133 match &mut self.state {
134 DecoderState::Header {
135 buf,
136 read,
137 continuation,
138 } => {
139 let offset_buf = &mut buf[*read as usize..];
140 let to_read = buffer.len().min(offset_buf.len());
141 offset_buf[..to_read].copy_from_slice(&buffer[..to_read]);
142 *read += to_read as u8;
143 buffer.advance(to_read);
144 if *read == 4 {
145 if !*continuation && buf == &CONTINUATION_MARKER {
146 *continuation = true;
147 *read = 0;
148 continue;
149 }
150 let size = u32::from_le_bytes(*buf);
151
152 if size == 0 {
153 self.state = DecoderState::Finished;
154 continue;
155 }
156 self.state = DecoderState::Message { size };
157 }
158 }
159 DecoderState::Message { size } => {
160 let len = *size as usize;
161 if self.buf.is_empty() && buffer.len() > len {
162 let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?;
163 self.state = DecoderState::Body { message };
164 buffer.advance(len);
165 continue;
166 }
167
168 let to_read = buffer.len().min(len - self.buf.len());
169 self.buf.extend_from_slice(&buffer[..to_read]);
170 buffer.advance(to_read);
171 if self.buf.len() == len {
172 let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?;
173 self.state = DecoderState::Body { message };
174 }
175 }
176 DecoderState::Body { message } => {
177 let message = message.as_ref();
178 let body_length = message.bodyLength() as usize;
179
180 let body = if self.buf.is_empty() && buffer.len() >= body_length {
181 let body = buffer.slice_with_length(0, body_length);
182 buffer.advance(body_length);
183 body
184 } else {
185 let to_read = buffer.len().min(body_length - self.buf.len());
186 self.buf.extend_from_slice(&buffer[..to_read]);
187 buffer.advance(to_read);
188
189 if self.buf.len() != body_length {
190 continue;
191 }
192 std::mem::take(&mut self.buf).into()
193 };
194
195 let version = message.version();
196 match message.header_type() {
197 MessageHeader::Schema => {
198 if self.schema.is_some() {
199 return Err(ArrowError::IpcError(
200 "Not expecting a schema when messages are read".to_string(),
201 ));
202 }
203
204 let ipc_schema = message.header_as_schema().unwrap();
205 let schema = crate::convert::fb_to_schema(ipc_schema);
206 self.state = DecoderState::default();
207 self.schema = Some(Arc::new(schema));
208 }
209 MessageHeader::RecordBatch => {
210 let batch = message.header_as_record_batch().unwrap();
211 let schema = self.schema.clone().ok_or_else(|| {
212 ArrowError::IpcError("Missing schema".to_string())
213 })?;
214 let batch = RecordBatchDecoder::try_new(
215 &body,
216 batch,
217 schema,
218 &self.dictionaries,
219 &version,
220 )?
221 .with_require_alignment(self.require_alignment)
222 .read_record_batch()?;
223 self.state = DecoderState::default();
224 return Ok(Some(batch));
225 }
226 MessageHeader::DictionaryBatch => {
227 let dictionary = message.header_as_dictionary_batch().unwrap();
228 let schema = self.schema.as_deref().ok_or_else(|| {
229 ArrowError::IpcError("Missing schema".to_string())
230 })?;
231 read_dictionary_impl(
232 &body,
233 dictionary,
234 schema,
235 &mut self.dictionaries,
236 &version,
237 self.require_alignment,
238 )?;
239 self.state = DecoderState::default();
240 }
241 MessageHeader::NONE => {
242 self.state = DecoderState::default();
243 }
244 t => {
245 return Err(ArrowError::IpcError(format!(
246 "Message type unsupported by StreamDecoder: {t:?}"
247 )))
248 }
249 }
250 }
251 DecoderState::Finished => {
252 return Err(ArrowError::IpcError("Unexpected EOS".to_string()))
253 }
254 }
255 }
256 Ok(None)
257 }
258
259 pub fn finish(&mut self) -> Result<(), ArrowError> {
263 match self.state {
264 DecoderState::Finished
265 | DecoderState::Header {
266 read: 0,
267 continuation: false,
268 ..
269 } => Ok(()),
270 _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())),
271 }
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::writer::{IpcWriteOptions, StreamWriter};
279 use arrow_array::{
280 types::Int32Type, DictionaryArray, Int32Array, Int64Array, RecordBatch, RunArray,
281 };
282 use arrow_schema::{DataType, Field, Schema};
283
284 #[test]
287 fn test_eos() {
288 let schema = Arc::new(Schema::new(vec![
289 Field::new("int32", DataType::Int32, false),
290 Field::new("int64", DataType::Int64, false),
291 ]));
292
293 let input = RecordBatch::try_new(
294 schema.clone(),
295 vec![
296 Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
297 Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
298 ],
299 )
300 .unwrap();
301
302 let mut buf = Vec::with_capacity(1024);
303 let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
304 s.write(&input).unwrap();
305 s.finish().unwrap();
306 drop(s);
307
308 let buffer = Buffer::from_vec(buf);
309
310 let mut b = buffer.slice_with_length(0, buffer.len() - 1);
311 let mut decoder = StreamDecoder::new();
312 let output = decoder.decode(&mut b).unwrap().unwrap();
313 assert_eq!(output, input);
314 assert_eq!(b.len(), 7); assert!(decoder.decode(&mut b).unwrap().is_none());
316
317 let err = decoder.finish().unwrap_err().to_string();
318 assert_eq!(err, "Ipc error: Unexpected End of Stream");
319 }
320
321 #[test]
322 fn test_read_ree_dict_record_batches_from_buffer() {
323 let schema = Schema::new(vec![Field::new(
324 "test1",
325 DataType::RunEndEncoded(
326 Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)),
327 #[allow(deprecated)]
328 Arc::new(Field::new_dict(
329 "values".to_string(),
330 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
331 true,
332 0,
333 false,
334 )),
335 ),
336 true,
337 )]);
338 let batch = RecordBatch::try_new(
339 schema.clone().into(),
340 vec![Arc::new(
341 RunArray::try_new(
342 &Int32Array::from(vec![1, 2, 3]),
343 &vec![Some("a"), None, Some("a")]
344 .into_iter()
345 .collect::<DictionaryArray<Int32Type>>(),
346 )
347 .expect("Failed to create RunArray"),
348 )],
349 )
350 .expect("Failed to create RecordBatch");
351
352 let mut buffer = vec![];
353 {
354 let mut writer = StreamWriter::try_new_with_options(
355 &mut buffer,
356 &schema,
357 #[allow(deprecated)]
358 IpcWriteOptions::default().with_preserve_dict_id(false),
359 )
360 .expect("Failed to create StreamWriter");
361 writer.write(&batch).expect("Failed to write RecordBatch");
362 writer.finish().expect("Failed to finish StreamWriter");
363 }
364
365 let mut decoder = StreamDecoder::new();
366 let buf = &mut Buffer::from(buffer.as_slice());
367 while let Some(batch) = decoder
368 .decode(buf)
369 .map_err(|e| {
370 ArrowError::ExternalError(format!("Failed to decode record batch: {}", e).into())
371 })
372 .expect("Failed to decode record batch")
373 {
374 assert_eq!(batch, batch);
375 }
376
377 decoder.finish().expect("Failed to finish decoder");
378 }
379}