arrow_ipc/
convert.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities for converting between IPC types and native Arrow types
19
20use arrow_buffer::Buffer;
21use arrow_schema::*;
22use flatbuffers::{
23    FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, Verifiable, Verifier,
24    VerifierOptions, WIPOffset,
25};
26use std::collections::HashMap;
27use std::fmt::{Debug, Formatter};
28use std::sync::Arc;
29
30use crate::writer::DictionaryTracker;
31use crate::{size_prefixed_root_as_message, KeyValue, Message, CONTINUATION_MARKER};
32use DataType::*;
33
34/// Low level Arrow [Schema] to IPC bytes converter
35///
36/// See also [`fb_to_schema`] for the reverse operation
37///
38/// # Example
39/// ```
40/// # use arrow_ipc::convert::{fb_to_schema, IpcSchemaEncoder};
41/// # use arrow_ipc::root_as_schema;
42/// # use arrow_ipc::writer::DictionaryTracker;
43/// # use arrow_schema::{DataType, Field, Schema};
44/// // given an arrow schema to serialize
45/// let schema = Schema::new(vec![
46///    Field::new("a", DataType::Int32, false),
47/// ]);
48///
49/// // Use a dictionary tracker to track dictionary id if needed
50///  let mut dictionary_tracker = DictionaryTracker::new(true);
51/// // create a FlatBuffersBuilder that contains the encoded bytes
52///  let fb = IpcSchemaEncoder::new()
53///    .with_dictionary_tracker(&mut dictionary_tracker)
54///    .schema_to_fb(&schema);
55///
56/// // the bytes are in `fb.finished_data()`
57/// let ipc_bytes = fb.finished_data();
58///
59///  // convert the IPC bytes back to an Arrow schema
60///  let ipc_schema = root_as_schema(ipc_bytes).unwrap();
61///  let schema2 = fb_to_schema(ipc_schema);
62/// assert_eq!(schema, schema2);
63/// ```
64#[derive(Debug)]
65pub struct IpcSchemaEncoder<'a> {
66    dictionary_tracker: Option<&'a mut DictionaryTracker>,
67}
68
69impl Default for IpcSchemaEncoder<'_> {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl<'a> IpcSchemaEncoder<'a> {
76    /// Create a new schema encoder
77    pub fn new() -> IpcSchemaEncoder<'a> {
78        IpcSchemaEncoder {
79            dictionary_tracker: None,
80        }
81    }
82
83    /// Specify a dictionary tracker to use
84    pub fn with_dictionary_tracker(
85        mut self,
86        dictionary_tracker: &'a mut DictionaryTracker,
87    ) -> Self {
88        self.dictionary_tracker = Some(dictionary_tracker);
89        self
90    }
91
92    /// Serialize a schema in IPC format, returning a completed [`FlatBufferBuilder`]
93    ///
94    /// Note: Call [`FlatBufferBuilder::finished_data`] to get the serialized bytes
95    pub fn schema_to_fb<'b>(&mut self, schema: &Schema) -> FlatBufferBuilder<'b> {
96        let mut fbb = FlatBufferBuilder::new();
97
98        let root = self.schema_to_fb_offset(&mut fbb, schema);
99
100        fbb.finish(root, None);
101
102        fbb
103    }
104
105    /// Serialize a schema to an in progress [`FlatBufferBuilder`], returning the in progress offset.
106    pub fn schema_to_fb_offset<'b>(
107        &mut self,
108        fbb: &mut FlatBufferBuilder<'b>,
109        schema: &Schema,
110    ) -> WIPOffset<crate::Schema<'b>> {
111        let fields = schema
112            .fields()
113            .iter()
114            .map(|field| build_field(fbb, &mut self.dictionary_tracker, field))
115            .collect::<Vec<_>>();
116        let fb_field_list = fbb.create_vector(&fields);
117
118        let fb_metadata_list =
119            (!schema.metadata().is_empty()).then(|| metadata_to_fb(fbb, schema.metadata()));
120
121        let mut builder = crate::SchemaBuilder::new(fbb);
122        builder.add_fields(fb_field_list);
123        if let Some(fb_metadata_list) = fb_metadata_list {
124            builder.add_custom_metadata(fb_metadata_list);
125        }
126        builder.finish()
127    }
128}
129
130/// Serialize a schema in IPC format
131#[deprecated(since = "54.0.0", note = "Use `IpcSchemaConverter`.")]
132pub fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder<'_> {
133    IpcSchemaEncoder::new().schema_to_fb(schema)
134}
135
136/// Push a key-value metadata into a FlatBufferBuilder and return [WIPOffset]
137pub fn metadata_to_fb<'a>(
138    fbb: &mut FlatBufferBuilder<'a>,
139    metadata: &HashMap<String, String>,
140) -> WIPOffset<Vector<'a, ForwardsUOffset<KeyValue<'a>>>> {
141    let custom_metadata = metadata
142        .iter()
143        .map(|(k, v)| {
144            let fb_key_name = fbb.create_string(k);
145            let fb_val_name = fbb.create_string(v);
146
147            let mut kv_builder = crate::KeyValueBuilder::new(fbb);
148            kv_builder.add_key(fb_key_name);
149            kv_builder.add_value(fb_val_name);
150            kv_builder.finish()
151        })
152        .collect::<Vec<_>>();
153    fbb.create_vector(&custom_metadata)
154}
155
156/// Adds a [Schema] to a flatbuffer and returns the offset
157pub fn schema_to_fb_offset<'a>(
158    fbb: &mut FlatBufferBuilder<'a>,
159    schema: &Schema,
160) -> WIPOffset<crate::Schema<'a>> {
161    IpcSchemaEncoder::new().schema_to_fb_offset(fbb, schema)
162}
163
164/// Convert an IPC Field to Arrow Field
165impl From<crate::Field<'_>> for Field {
166    fn from(field: crate::Field) -> Field {
167        let arrow_field = if let Some(dictionary) = field.dictionary() {
168            Field::new_dict(
169                field.name().unwrap(),
170                get_data_type(field, true),
171                field.nullable(),
172                dictionary.id(),
173                dictionary.isOrdered(),
174            )
175        } else {
176            Field::new(
177                field.name().unwrap(),
178                get_data_type(field, true),
179                field.nullable(),
180            )
181        };
182
183        let mut metadata_map = HashMap::default();
184        if let Some(list) = field.custom_metadata() {
185            for kv in list {
186                if let (Some(k), Some(v)) = (kv.key(), kv.value()) {
187                    metadata_map.insert(k.to_string(), v.to_string());
188                }
189            }
190        }
191
192        arrow_field.with_metadata(metadata_map)
193    }
194}
195
196/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema].
197pub fn fb_to_schema(fb: crate::Schema) -> Schema {
198    let mut fields: Vec<Field> = vec![];
199    let c_fields = fb.fields().unwrap();
200    let len = c_fields.len();
201    for i in 0..len {
202        let c_field: crate::Field = c_fields.get(i);
203        match c_field.type_type() {
204            crate::Type::Decimal if fb.endianness() == crate::Endianness::Big => {
205                unimplemented!("Big Endian is not supported for Decimal!")
206            }
207            _ => (),
208        };
209        fields.push(c_field.into());
210    }
211
212    let mut metadata: HashMap<String, String> = HashMap::default();
213    if let Some(md_fields) = fb.custom_metadata() {
214        let len = md_fields.len();
215        for i in 0..len {
216            let kv = md_fields.get(i);
217            let k_str = kv.key();
218            let v_str = kv.value();
219            if let Some(k) = k_str {
220                if let Some(v) = v_str {
221                    metadata.insert(k.to_string(), v.to_string());
222                }
223            }
224        }
225    }
226    Schema::new_with_metadata(fields, metadata)
227}
228
229/// Try deserialize flat buffer format bytes into a schema
230pub fn try_schema_from_flatbuffer_bytes(bytes: &[u8]) -> Result<Schema, ArrowError> {
231    if let Ok(ipc) = crate::root_as_message(bytes) {
232        if let Some(schema) = ipc.header_as_schema().map(fb_to_schema) {
233            Ok(schema)
234        } else {
235            Err(ArrowError::ParseError(
236                "Unable to get head as schema".to_string(),
237            ))
238        }
239    } else {
240        Err(ArrowError::ParseError(
241            "Unable to get root as message".to_string(),
242        ))
243    }
244}
245
246/// Try deserialize the IPC format bytes into a schema
247pub fn try_schema_from_ipc_buffer(buffer: &[u8]) -> Result<Schema, ArrowError> {
248    // There are two protocol types: https://issues.apache.org/jira/browse/ARROW-6313
249    // The original protocol is:
250    //   4 bytes - the byte length of the payload
251    //   a flatbuffer Message whose header is the Schema
252    // The latest version of protocol is:
253    // The schema of the dataset in its IPC form:
254    //   4 bytes - an optional IPC_CONTINUATION_TOKEN prefix
255    //   4 bytes - the byte length of the payload
256    //   a flatbuffer Message whose header is the Schema
257    if buffer.len() >= 4 {
258        // check continuation marker
259        let continuation_marker = &buffer[0..4];
260        let begin_offset: usize = if continuation_marker.eq(&CONTINUATION_MARKER) {
261            // 4 bytes: CONTINUATION_MARKER
262            // 4 bytes: length
263            // buffer
264            4
265        } else {
266            // backward compatibility for buffer without the continuation marker
267            // 4 bytes: length
268            // buffer
269            0
270        };
271        let msg = size_prefixed_root_as_message(&buffer[begin_offset..]).map_err(|err| {
272            ArrowError::ParseError(format!("Unable to convert flight info to a message: {err}"))
273        })?;
274        let ipc_schema = msg.header_as_schema().ok_or_else(|| {
275            ArrowError::ParseError("Unable to convert flight info to a schema".to_string())
276        })?;
277        Ok(fb_to_schema(ipc_schema))
278    } else {
279        Err(ArrowError::ParseError(
280            "The buffer length is less than 4 and missing the continuation marker or length of buffer".to_string()
281        ))
282    }
283}
284
285/// Get the Arrow data type from the flatbuffer Field table
286pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> DataType {
287    if let Some(dictionary) = field.dictionary() {
288        if may_be_dictionary {
289            let int = dictionary.indexType().unwrap();
290            let index_type = match (int.bitWidth(), int.is_signed()) {
291                (8, true) => DataType::Int8,
292                (8, false) => DataType::UInt8,
293                (16, true) => DataType::Int16,
294                (16, false) => DataType::UInt16,
295                (32, true) => DataType::Int32,
296                (32, false) => DataType::UInt32,
297                (64, true) => DataType::Int64,
298                (64, false) => DataType::UInt64,
299                _ => panic!("Unexpected bitwidth and signed"),
300            };
301            return DataType::Dictionary(
302                Box::new(index_type),
303                Box::new(get_data_type(field, false)),
304            );
305        }
306    }
307
308    match field.type_type() {
309        crate::Type::Null => DataType::Null,
310        crate::Type::Bool => DataType::Boolean,
311        crate::Type::Int => {
312            let int = field.type_as_int().unwrap();
313            match (int.bitWidth(), int.is_signed()) {
314                (8, true) => DataType::Int8,
315                (8, false) => DataType::UInt8,
316                (16, true) => DataType::Int16,
317                (16, false) => DataType::UInt16,
318                (32, true) => DataType::Int32,
319                (32, false) => DataType::UInt32,
320                (64, true) => DataType::Int64,
321                (64, false) => DataType::UInt64,
322                z => panic!(
323                    "Int type with bit width of {} and signed of {} not supported",
324                    z.0, z.1
325                ),
326            }
327        }
328        crate::Type::Binary => DataType::Binary,
329        crate::Type::BinaryView => DataType::BinaryView,
330        crate::Type::LargeBinary => DataType::LargeBinary,
331        crate::Type::Utf8 => DataType::Utf8,
332        crate::Type::Utf8View => DataType::Utf8View,
333        crate::Type::LargeUtf8 => DataType::LargeUtf8,
334        crate::Type::FixedSizeBinary => {
335            let fsb = field.type_as_fixed_size_binary().unwrap();
336            DataType::FixedSizeBinary(fsb.byteWidth())
337        }
338        crate::Type::FloatingPoint => {
339            let float = field.type_as_floating_point().unwrap();
340            match float.precision() {
341                crate::Precision::HALF => DataType::Float16,
342                crate::Precision::SINGLE => DataType::Float32,
343                crate::Precision::DOUBLE => DataType::Float64,
344                z => panic!("FloatingPoint type with precision of {z:?} not supported"),
345            }
346        }
347        crate::Type::Date => {
348            let date = field.type_as_date().unwrap();
349            match date.unit() {
350                crate::DateUnit::DAY => DataType::Date32,
351                crate::DateUnit::MILLISECOND => DataType::Date64,
352                z => panic!("Date type with unit of {z:?} not supported"),
353            }
354        }
355        crate::Type::Time => {
356            let time = field.type_as_time().unwrap();
357            match (time.bitWidth(), time.unit()) {
358                (32, crate::TimeUnit::SECOND) => DataType::Time32(TimeUnit::Second),
359                (32, crate::TimeUnit::MILLISECOND) => DataType::Time32(TimeUnit::Millisecond),
360                (64, crate::TimeUnit::MICROSECOND) => DataType::Time64(TimeUnit::Microsecond),
361                (64, crate::TimeUnit::NANOSECOND) => DataType::Time64(TimeUnit::Nanosecond),
362                z => panic!(
363                    "Time type with bit width of {} and unit of {:?} not supported",
364                    z.0, z.1
365                ),
366            }
367        }
368        crate::Type::Timestamp => {
369            let timestamp = field.type_as_timestamp().unwrap();
370            let timezone: Option<_> = timestamp.timezone().map(|tz| tz.into());
371            match timestamp.unit() {
372                crate::TimeUnit::SECOND => DataType::Timestamp(TimeUnit::Second, timezone),
373                crate::TimeUnit::MILLISECOND => {
374                    DataType::Timestamp(TimeUnit::Millisecond, timezone)
375                }
376                crate::TimeUnit::MICROSECOND => {
377                    DataType::Timestamp(TimeUnit::Microsecond, timezone)
378                }
379                crate::TimeUnit::NANOSECOND => DataType::Timestamp(TimeUnit::Nanosecond, timezone),
380                z => panic!("Timestamp type with unit of {z:?} not supported"),
381            }
382        }
383        crate::Type::Interval => {
384            let interval = field.type_as_interval().unwrap();
385            match interval.unit() {
386                crate::IntervalUnit::YEAR_MONTH => DataType::Interval(IntervalUnit::YearMonth),
387                crate::IntervalUnit::DAY_TIME => DataType::Interval(IntervalUnit::DayTime),
388                crate::IntervalUnit::MONTH_DAY_NANO => {
389                    DataType::Interval(IntervalUnit::MonthDayNano)
390                }
391                z => panic!("Interval type with unit of {z:?} unsupported"),
392            }
393        }
394        crate::Type::Duration => {
395            let duration = field.type_as_duration().unwrap();
396            match duration.unit() {
397                crate::TimeUnit::SECOND => DataType::Duration(TimeUnit::Second),
398                crate::TimeUnit::MILLISECOND => DataType::Duration(TimeUnit::Millisecond),
399                crate::TimeUnit::MICROSECOND => DataType::Duration(TimeUnit::Microsecond),
400                crate::TimeUnit::NANOSECOND => DataType::Duration(TimeUnit::Nanosecond),
401                z => panic!("Duration type with unit of {z:?} unsupported"),
402            }
403        }
404        crate::Type::List => {
405            let children = field.children().unwrap();
406            if children.len() != 1 {
407                panic!("expect a list to have one child")
408            }
409            DataType::List(Arc::new(children.get(0).into()))
410        }
411        crate::Type::LargeList => {
412            let children = field.children().unwrap();
413            if children.len() != 1 {
414                panic!("expect a large list to have one child")
415            }
416            DataType::LargeList(Arc::new(children.get(0).into()))
417        }
418        crate::Type::FixedSizeList => {
419            let children = field.children().unwrap();
420            if children.len() != 1 {
421                panic!("expect a list to have one child")
422            }
423            let fsl = field.type_as_fixed_size_list().unwrap();
424            DataType::FixedSizeList(Arc::new(children.get(0).into()), fsl.listSize())
425        }
426        crate::Type::Struct_ => {
427            let fields = match field.children() {
428                Some(children) => children.iter().map(Field::from).collect(),
429                None => Fields::empty(),
430            };
431            DataType::Struct(fields)
432        }
433        crate::Type::RunEndEncoded => {
434            let children = field.children().unwrap();
435            if children.len() != 2 {
436                panic!(
437                    "RunEndEncoded type should have exactly two children. Found {}",
438                    children.len()
439                )
440            }
441            let run_ends_field = children.get(0).into();
442            let values_field = children.get(1).into();
443            DataType::RunEndEncoded(Arc::new(run_ends_field), Arc::new(values_field))
444        }
445        crate::Type::Map => {
446            let map = field.type_as_map().unwrap();
447            let children = field.children().unwrap();
448            if children.len() != 1 {
449                panic!("expect a map to have one child")
450            }
451            DataType::Map(Arc::new(children.get(0).into()), map.keysSorted())
452        }
453        crate::Type::Decimal => {
454            let fsb = field.type_as_decimal().unwrap();
455            let bit_width = fsb.bitWidth();
456            if bit_width == 128 {
457                DataType::Decimal128(
458                    fsb.precision().try_into().unwrap(),
459                    fsb.scale().try_into().unwrap(),
460                )
461            } else if bit_width == 256 {
462                DataType::Decimal256(
463                    fsb.precision().try_into().unwrap(),
464                    fsb.scale().try_into().unwrap(),
465                )
466            } else {
467                panic!("Unexpected decimal bit width {bit_width}")
468            }
469        }
470        crate::Type::Union => {
471            let union = field.type_as_union().unwrap();
472
473            let union_mode = match union.mode() {
474                crate::UnionMode::Dense => UnionMode::Dense,
475                crate::UnionMode::Sparse => UnionMode::Sparse,
476                mode => panic!("Unexpected union mode: {mode:?}"),
477            };
478
479            let mut fields = vec![];
480            if let Some(children) = field.children() {
481                for i in 0..children.len() {
482                    fields.push(Field::from(children.get(i)));
483                }
484            };
485
486            let fields = match union.typeIds() {
487                None => UnionFields::new(0_i8..fields.len() as i8, fields),
488                Some(ids) => UnionFields::new(ids.iter().map(|i| i as i8), fields),
489            };
490
491            DataType::Union(fields, union_mode)
492        }
493        t => unimplemented!("Type {:?} not supported", t),
494    }
495}
496
497pub(crate) struct FBFieldType<'b> {
498    pub(crate) type_type: crate::Type,
499    pub(crate) type_: WIPOffset<UnionWIPOffset>,
500    pub(crate) children: Option<WIPOffset<Vector<'b, ForwardsUOffset<crate::Field<'b>>>>>,
501}
502
503/// Create an IPC Field from an Arrow Field
504pub(crate) fn build_field<'a>(
505    fbb: &mut FlatBufferBuilder<'a>,
506    dictionary_tracker: &mut Option<&mut DictionaryTracker>,
507    field: &Field,
508) -> WIPOffset<crate::Field<'a>> {
509    // Optional custom metadata.
510    let mut fb_metadata = None;
511    if !field.metadata().is_empty() {
512        fb_metadata = Some(metadata_to_fb(fbb, field.metadata()));
513    };
514
515    let fb_field_name = fbb.create_string(field.name().as_str());
516    let field_type = get_fb_field_type(field.data_type(), dictionary_tracker, fbb);
517
518    let fb_dictionary = if let Dictionary(index_type, _) = field.data_type() {
519        match dictionary_tracker {
520            Some(tracker) => Some(get_fb_dictionary(
521                index_type,
522                tracker.set_dict_id(field),
523                field
524                    .dict_is_ordered()
525                    .expect("All Dictionary types have `dict_is_ordered`"),
526                fbb,
527            )),
528            None => Some(get_fb_dictionary(
529                index_type,
530                field
531                    .dict_id()
532                    .expect("Dictionary type must have a dictionary id"),
533                field
534                    .dict_is_ordered()
535                    .expect("All Dictionary types have `dict_is_ordered`"),
536                fbb,
537            )),
538        }
539    } else {
540        None
541    };
542
543    let mut field_builder = crate::FieldBuilder::new(fbb);
544    field_builder.add_name(fb_field_name);
545    if let Some(dictionary) = fb_dictionary {
546        field_builder.add_dictionary(dictionary)
547    }
548    field_builder.add_type_type(field_type.type_type);
549    field_builder.add_nullable(field.is_nullable());
550    match field_type.children {
551        None => {}
552        Some(children) => field_builder.add_children(children),
553    };
554    field_builder.add_type_(field_type.type_);
555
556    if let Some(fb_metadata) = fb_metadata {
557        field_builder.add_custom_metadata(fb_metadata);
558    }
559
560    field_builder.finish()
561}
562
563/// Get the IPC type of a data type
564pub(crate) fn get_fb_field_type<'a>(
565    data_type: &DataType,
566    dictionary_tracker: &mut Option<&mut DictionaryTracker>,
567    fbb: &mut FlatBufferBuilder<'a>,
568) -> FBFieldType<'a> {
569    // some IPC implementations expect an empty list for child data, instead of a null value.
570    // An empty field list is thus returned for primitive types
571    let empty_fields: Vec<WIPOffset<crate::Field>> = vec![];
572    match data_type {
573        Null => FBFieldType {
574            type_type: crate::Type::Null,
575            type_: crate::NullBuilder::new(fbb).finish().as_union_value(),
576            children: Some(fbb.create_vector(&empty_fields[..])),
577        },
578        Boolean => FBFieldType {
579            type_type: crate::Type::Bool,
580            type_: crate::BoolBuilder::new(fbb).finish().as_union_value(),
581            children: Some(fbb.create_vector(&empty_fields[..])),
582        },
583        UInt8 | UInt16 | UInt32 | UInt64 => {
584            let children = fbb.create_vector(&empty_fields[..]);
585            let mut builder = crate::IntBuilder::new(fbb);
586            builder.add_is_signed(false);
587            match data_type {
588                UInt8 => builder.add_bitWidth(8),
589                UInt16 => builder.add_bitWidth(16),
590                UInt32 => builder.add_bitWidth(32),
591                UInt64 => builder.add_bitWidth(64),
592                _ => {}
593            };
594            FBFieldType {
595                type_type: crate::Type::Int,
596                type_: builder.finish().as_union_value(),
597                children: Some(children),
598            }
599        }
600        Int8 | Int16 | Int32 | Int64 => {
601            let children = fbb.create_vector(&empty_fields[..]);
602            let mut builder = crate::IntBuilder::new(fbb);
603            builder.add_is_signed(true);
604            match data_type {
605                Int8 => builder.add_bitWidth(8),
606                Int16 => builder.add_bitWidth(16),
607                Int32 => builder.add_bitWidth(32),
608                Int64 => builder.add_bitWidth(64),
609                _ => {}
610            };
611            FBFieldType {
612                type_type: crate::Type::Int,
613                type_: builder.finish().as_union_value(),
614                children: Some(children),
615            }
616        }
617        Float16 | Float32 | Float64 => {
618            let children = fbb.create_vector(&empty_fields[..]);
619            let mut builder = crate::FloatingPointBuilder::new(fbb);
620            match data_type {
621                Float16 => builder.add_precision(crate::Precision::HALF),
622                Float32 => builder.add_precision(crate::Precision::SINGLE),
623                Float64 => builder.add_precision(crate::Precision::DOUBLE),
624                _ => {}
625            };
626            FBFieldType {
627                type_type: crate::Type::FloatingPoint,
628                type_: builder.finish().as_union_value(),
629                children: Some(children),
630            }
631        }
632        Binary => FBFieldType {
633            type_type: crate::Type::Binary,
634            type_: crate::BinaryBuilder::new(fbb).finish().as_union_value(),
635            children: Some(fbb.create_vector(&empty_fields[..])),
636        },
637        LargeBinary => FBFieldType {
638            type_type: crate::Type::LargeBinary,
639            type_: crate::LargeBinaryBuilder::new(fbb)
640                .finish()
641                .as_union_value(),
642            children: Some(fbb.create_vector(&empty_fields[..])),
643        },
644        BinaryView => FBFieldType {
645            type_type: crate::Type::BinaryView,
646            type_: crate::BinaryViewBuilder::new(fbb).finish().as_union_value(),
647            children: Some(fbb.create_vector(&empty_fields[..])),
648        },
649        Utf8View => FBFieldType {
650            type_type: crate::Type::Utf8View,
651            type_: crate::Utf8ViewBuilder::new(fbb).finish().as_union_value(),
652            children: Some(fbb.create_vector(&empty_fields[..])),
653        },
654        Utf8 => FBFieldType {
655            type_type: crate::Type::Utf8,
656            type_: crate::Utf8Builder::new(fbb).finish().as_union_value(),
657            children: Some(fbb.create_vector(&empty_fields[..])),
658        },
659        LargeUtf8 => FBFieldType {
660            type_type: crate::Type::LargeUtf8,
661            type_: crate::LargeUtf8Builder::new(fbb).finish().as_union_value(),
662            children: Some(fbb.create_vector(&empty_fields[..])),
663        },
664        FixedSizeBinary(len) => {
665            let mut builder = crate::FixedSizeBinaryBuilder::new(fbb);
666            builder.add_byteWidth(*len);
667            FBFieldType {
668                type_type: crate::Type::FixedSizeBinary,
669                type_: builder.finish().as_union_value(),
670                children: Some(fbb.create_vector(&empty_fields[..])),
671            }
672        }
673        Date32 => {
674            let mut builder = crate::DateBuilder::new(fbb);
675            builder.add_unit(crate::DateUnit::DAY);
676            FBFieldType {
677                type_type: crate::Type::Date,
678                type_: builder.finish().as_union_value(),
679                children: Some(fbb.create_vector(&empty_fields[..])),
680            }
681        }
682        Date64 => {
683            let mut builder = crate::DateBuilder::new(fbb);
684            builder.add_unit(crate::DateUnit::MILLISECOND);
685            FBFieldType {
686                type_type: crate::Type::Date,
687                type_: builder.finish().as_union_value(),
688                children: Some(fbb.create_vector(&empty_fields[..])),
689            }
690        }
691        Time32(unit) | Time64(unit) => {
692            let mut builder = crate::TimeBuilder::new(fbb);
693            match unit {
694                TimeUnit::Second => {
695                    builder.add_bitWidth(32);
696                    builder.add_unit(crate::TimeUnit::SECOND);
697                }
698                TimeUnit::Millisecond => {
699                    builder.add_bitWidth(32);
700                    builder.add_unit(crate::TimeUnit::MILLISECOND);
701                }
702                TimeUnit::Microsecond => {
703                    builder.add_bitWidth(64);
704                    builder.add_unit(crate::TimeUnit::MICROSECOND);
705                }
706                TimeUnit::Nanosecond => {
707                    builder.add_bitWidth(64);
708                    builder.add_unit(crate::TimeUnit::NANOSECOND);
709                }
710            }
711            FBFieldType {
712                type_type: crate::Type::Time,
713                type_: builder.finish().as_union_value(),
714                children: Some(fbb.create_vector(&empty_fields[..])),
715            }
716        }
717        Timestamp(unit, tz) => {
718            let tz = tz.as_deref().unwrap_or_default();
719            let tz_str = fbb.create_string(tz);
720            let mut builder = crate::TimestampBuilder::new(fbb);
721            let time_unit = match unit {
722                TimeUnit::Second => crate::TimeUnit::SECOND,
723                TimeUnit::Millisecond => crate::TimeUnit::MILLISECOND,
724                TimeUnit::Microsecond => crate::TimeUnit::MICROSECOND,
725                TimeUnit::Nanosecond => crate::TimeUnit::NANOSECOND,
726            };
727            builder.add_unit(time_unit);
728            if !tz.is_empty() {
729                builder.add_timezone(tz_str);
730            }
731            FBFieldType {
732                type_type: crate::Type::Timestamp,
733                type_: builder.finish().as_union_value(),
734                children: Some(fbb.create_vector(&empty_fields[..])),
735            }
736        }
737        Interval(unit) => {
738            let mut builder = crate::IntervalBuilder::new(fbb);
739            let interval_unit = match unit {
740                IntervalUnit::YearMonth => crate::IntervalUnit::YEAR_MONTH,
741                IntervalUnit::DayTime => crate::IntervalUnit::DAY_TIME,
742                IntervalUnit::MonthDayNano => crate::IntervalUnit::MONTH_DAY_NANO,
743            };
744            builder.add_unit(interval_unit);
745            FBFieldType {
746                type_type: crate::Type::Interval,
747                type_: builder.finish().as_union_value(),
748                children: Some(fbb.create_vector(&empty_fields[..])),
749            }
750        }
751        Duration(unit) => {
752            let mut builder = crate::DurationBuilder::new(fbb);
753            let time_unit = match unit {
754                TimeUnit::Second => crate::TimeUnit::SECOND,
755                TimeUnit::Millisecond => crate::TimeUnit::MILLISECOND,
756                TimeUnit::Microsecond => crate::TimeUnit::MICROSECOND,
757                TimeUnit::Nanosecond => crate::TimeUnit::NANOSECOND,
758            };
759            builder.add_unit(time_unit);
760            FBFieldType {
761                type_type: crate::Type::Duration,
762                type_: builder.finish().as_union_value(),
763                children: Some(fbb.create_vector(&empty_fields[..])),
764            }
765        }
766        List(ref list_type) => {
767            let child = build_field(fbb, dictionary_tracker, list_type);
768            FBFieldType {
769                type_type: crate::Type::List,
770                type_: crate::ListBuilder::new(fbb).finish().as_union_value(),
771                children: Some(fbb.create_vector(&[child])),
772            }
773        }
774        ListView(_) | LargeListView(_) => unimplemented!("ListView/LargeListView not implemented"),
775        LargeList(ref list_type) => {
776            let child = build_field(fbb, dictionary_tracker, list_type);
777            FBFieldType {
778                type_type: crate::Type::LargeList,
779                type_: crate::LargeListBuilder::new(fbb).finish().as_union_value(),
780                children: Some(fbb.create_vector(&[child])),
781            }
782        }
783        FixedSizeList(ref list_type, len) => {
784            let child = build_field(fbb, dictionary_tracker, list_type);
785            let mut builder = crate::FixedSizeListBuilder::new(fbb);
786            builder.add_listSize(*len);
787            FBFieldType {
788                type_type: crate::Type::FixedSizeList,
789                type_: builder.finish().as_union_value(),
790                children: Some(fbb.create_vector(&[child])),
791            }
792        }
793        Struct(fields) => {
794            // struct's fields are children
795            let mut children = vec![];
796            for field in fields {
797                children.push(build_field(fbb, dictionary_tracker, field));
798            }
799            FBFieldType {
800                type_type: crate::Type::Struct_,
801                type_: crate::Struct_Builder::new(fbb).finish().as_union_value(),
802                children: Some(fbb.create_vector(&children[..])),
803            }
804        }
805        RunEndEncoded(run_ends, values) => {
806            let run_ends_field = build_field(fbb, dictionary_tracker, run_ends);
807            let values_field = build_field(fbb, dictionary_tracker, values);
808            let children = [run_ends_field, values_field];
809            FBFieldType {
810                type_type: crate::Type::RunEndEncoded,
811                type_: crate::RunEndEncodedBuilder::new(fbb)
812                    .finish()
813                    .as_union_value(),
814                children: Some(fbb.create_vector(&children[..])),
815            }
816        }
817        Map(map_field, keys_sorted) => {
818            let child = build_field(fbb, dictionary_tracker, map_field);
819            let mut field_type = crate::MapBuilder::new(fbb);
820            field_type.add_keysSorted(*keys_sorted);
821            FBFieldType {
822                type_type: crate::Type::Map,
823                type_: field_type.finish().as_union_value(),
824                children: Some(fbb.create_vector(&[child])),
825            }
826        }
827        Dictionary(_, value_type) => {
828            // In this library, the dictionary "type" is a logical construct. Here we
829            // pass through to the value type, as we've already captured the index
830            // type in the DictionaryEncoding metadata in the parent field
831            get_fb_field_type(value_type, dictionary_tracker, fbb)
832        }
833        Decimal128(precision, scale) => {
834            let mut builder = crate::DecimalBuilder::new(fbb);
835            builder.add_precision(*precision as i32);
836            builder.add_scale(*scale as i32);
837            builder.add_bitWidth(128);
838            FBFieldType {
839                type_type: crate::Type::Decimal,
840                type_: builder.finish().as_union_value(),
841                children: Some(fbb.create_vector(&empty_fields[..])),
842            }
843        }
844        Decimal256(precision, scale) => {
845            let mut builder = crate::DecimalBuilder::new(fbb);
846            builder.add_precision(*precision as i32);
847            builder.add_scale(*scale as i32);
848            builder.add_bitWidth(256);
849            FBFieldType {
850                type_type: crate::Type::Decimal,
851                type_: builder.finish().as_union_value(),
852                children: Some(fbb.create_vector(&empty_fields[..])),
853            }
854        }
855        Union(fields, mode) => {
856            let mut children = vec![];
857            for (_, field) in fields.iter() {
858                children.push(build_field(fbb, dictionary_tracker, field));
859            }
860
861            let union_mode = match mode {
862                UnionMode::Sparse => crate::UnionMode::Sparse,
863                UnionMode::Dense => crate::UnionMode::Dense,
864            };
865
866            let fbb_type_ids =
867                fbb.create_vector(&fields.iter().map(|(t, _)| t as i32).collect::<Vec<_>>());
868            let mut builder = crate::UnionBuilder::new(fbb);
869            builder.add_mode(union_mode);
870            builder.add_typeIds(fbb_type_ids);
871
872            FBFieldType {
873                type_type: crate::Type::Union,
874                type_: builder.finish().as_union_value(),
875                children: Some(fbb.create_vector(&children[..])),
876            }
877        }
878    }
879}
880
881/// Create an IPC dictionary encoding
882pub(crate) fn get_fb_dictionary<'a>(
883    index_type: &DataType,
884    dict_id: i64,
885    dict_is_ordered: bool,
886    fbb: &mut FlatBufferBuilder<'a>,
887) -> WIPOffset<crate::DictionaryEncoding<'a>> {
888    // We assume that the dictionary index type (as an integer) has already been
889    // validated elsewhere, and can safely assume we are dealing with integers
890    let mut index_builder = crate::IntBuilder::new(fbb);
891
892    match *index_type {
893        Int8 | Int16 | Int32 | Int64 => index_builder.add_is_signed(true),
894        UInt8 | UInt16 | UInt32 | UInt64 => index_builder.add_is_signed(false),
895        _ => {}
896    }
897
898    match *index_type {
899        Int8 | UInt8 => index_builder.add_bitWidth(8),
900        Int16 | UInt16 => index_builder.add_bitWidth(16),
901        Int32 | UInt32 => index_builder.add_bitWidth(32),
902        Int64 | UInt64 => index_builder.add_bitWidth(64),
903        _ => {}
904    }
905
906    let index_builder = index_builder.finish();
907
908    let mut builder = crate::DictionaryEncodingBuilder::new(fbb);
909    builder.add_id(dict_id);
910    builder.add_indexType(index_builder);
911    builder.add_isOrdered(dict_is_ordered);
912
913    builder.finish()
914}
915
916/// An owned container for a validated [`Message`]
917///
918/// Safely decoding a flatbuffer requires validating the various embedded offsets,
919/// see [`Verifier`]. This is a potentially expensive operation, and it is therefore desirable
920/// to only do this once. [`crate::root_as_message`] performs this validation on construction,
921/// however, it returns a [`Message`] borrowing the provided byte slice. This prevents
922/// storing this [`Message`] in the same data structure that owns the buffer, as this
923/// would require self-referential borrows.
924///
925/// [`MessageBuffer`] solves this problem by providing a safe API for a [`Message`]
926/// without a lifetime bound.
927#[derive(Clone)]
928pub struct MessageBuffer(Buffer);
929
930impl Debug for MessageBuffer {
931    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
932        self.as_ref().fmt(f)
933    }
934}
935
936impl MessageBuffer {
937    /// Try to create a [`MessageBuffer`] from the provided [`Buffer`]
938    pub fn try_new(buf: Buffer) -> Result<Self, ArrowError> {
939        let opts = VerifierOptions::default();
940        let mut v = Verifier::new(&opts, &buf);
941        <ForwardsUOffset<Message>>::run_verifier(&mut v, 0).map_err(|err| {
942            ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
943        })?;
944        Ok(Self(buf))
945    }
946
947    /// Return the [`Message`]
948    #[inline]
949    pub fn as_ref(&self) -> Message<'_> {
950        // SAFETY: Run verifier on construction
951        unsafe { crate::root_as_message_unchecked(&self.0) }
952    }
953}
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958
959    #[test]
960    fn convert_schema_round_trip() {
961        let md: HashMap<String, String> = [("Key".to_string(), "value".to_string())]
962            .iter()
963            .cloned()
964            .collect();
965        let field_md: HashMap<String, String> = [("k".to_string(), "v".to_string())]
966            .iter()
967            .cloned()
968            .collect();
969        let schema = Schema::new_with_metadata(
970            vec![
971                Field::new("uint8", DataType::UInt8, false).with_metadata(field_md),
972                Field::new("uint16", DataType::UInt16, true),
973                Field::new("uint32", DataType::UInt32, false),
974                Field::new("uint64", DataType::UInt64, true),
975                Field::new("int8", DataType::Int8, true),
976                Field::new("int16", DataType::Int16, false),
977                Field::new("int32", DataType::Int32, true),
978                Field::new("int64", DataType::Int64, false),
979                Field::new("float16", DataType::Float16, true),
980                Field::new("float32", DataType::Float32, false),
981                Field::new("float64", DataType::Float64, true),
982                Field::new("null", DataType::Null, false),
983                Field::new("bool", DataType::Boolean, false),
984                Field::new("date32", DataType::Date32, false),
985                Field::new("date64", DataType::Date64, true),
986                Field::new("time32[s]", DataType::Time32(TimeUnit::Second), true),
987                Field::new("time32[ms]", DataType::Time32(TimeUnit::Millisecond), false),
988                Field::new("time64[us]", DataType::Time64(TimeUnit::Microsecond), false),
989                Field::new("time64[ns]", DataType::Time64(TimeUnit::Nanosecond), true),
990                Field::new(
991                    "timestamp[s]",
992                    DataType::Timestamp(TimeUnit::Second, None),
993                    false,
994                ),
995                Field::new(
996                    "timestamp[ms]",
997                    DataType::Timestamp(TimeUnit::Millisecond, None),
998                    true,
999                ),
1000                Field::new(
1001                    "timestamp[us]",
1002                    DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())),
1003                    false,
1004                ),
1005                Field::new(
1006                    "timestamp[ns]",
1007                    DataType::Timestamp(TimeUnit::Nanosecond, None),
1008                    true,
1009                ),
1010                Field::new(
1011                    "interval[ym]",
1012                    DataType::Interval(IntervalUnit::YearMonth),
1013                    true,
1014                ),
1015                Field::new(
1016                    "interval[dt]",
1017                    DataType::Interval(IntervalUnit::DayTime),
1018                    true,
1019                ),
1020                Field::new(
1021                    "interval[mdn]",
1022                    DataType::Interval(IntervalUnit::MonthDayNano),
1023                    true,
1024                ),
1025                Field::new("utf8", DataType::Utf8, false),
1026                Field::new("utf8_view", DataType::Utf8View, false),
1027                Field::new("binary", DataType::Binary, false),
1028                Field::new("binary_view", DataType::BinaryView, false),
1029                Field::new_list("list[u8]", Field::new("item", DataType::UInt8, false), true),
1030                Field::new_fixed_size_list(
1031                    "fixed_size_list[u8]",
1032                    Field::new("item", DataType::UInt8, false),
1033                    2,
1034                    true,
1035                ),
1036                Field::new_list(
1037                    "list[struct<float32, int32, bool>]",
1038                    Field::new_struct(
1039                        "struct",
1040                        vec![
1041                            Field::new("float32", UInt8, false),
1042                            Field::new("int32", Int32, true),
1043                            Field::new("bool", Boolean, true),
1044                        ],
1045                        true,
1046                    ),
1047                    false,
1048                ),
1049                Field::new_struct(
1050                    "struct<dictionary<int32, utf8>>",
1051                    vec![Field::new(
1052                        "dictionary<int32, utf8>",
1053                        Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
1054                        false,
1055                    )],
1056                    false,
1057                ),
1058                Field::new_struct(
1059                    "struct<int64, list[struct<date32, list[struct<>]>]>",
1060                    vec![
1061                        Field::new("int64", DataType::Int64, true),
1062                        Field::new_list(
1063                            "list[struct<date32, list[struct<>]>]",
1064                            Field::new_struct(
1065                                "struct",
1066                                vec![
1067                                    Field::new("date32", DataType::Date32, true),
1068                                    Field::new_list(
1069                                        "list[struct<>]",
1070                                        Field::new(
1071                                            "struct",
1072                                            DataType::Struct(Fields::empty()),
1073                                            false,
1074                                        ),
1075                                        false,
1076                                    ),
1077                                ],
1078                                false,
1079                            ),
1080                            false,
1081                        ),
1082                    ],
1083                    false,
1084                ),
1085                Field::new_union(
1086                    "union<int64, list[union<date32, list[union<>]>]>",
1087                    vec![0, 1],
1088                    vec![
1089                        Field::new("int64", DataType::Int64, true),
1090                        Field::new_list(
1091                            "list[union<date32, list[union<>]>]",
1092                            Field::new_union(
1093                                "union<date32, list[union<>]>",
1094                                vec![0, 1],
1095                                vec![
1096                                    Field::new("date32", DataType::Date32, true),
1097                                    Field::new_list(
1098                                        "list[union<>]",
1099                                        Field::new(
1100                                            "union",
1101                                            DataType::Union(
1102                                                UnionFields::empty(),
1103                                                UnionMode::Sparse,
1104                                            ),
1105                                            false,
1106                                        ),
1107                                        false,
1108                                    ),
1109                                ],
1110                                UnionMode::Dense,
1111                            ),
1112                            false,
1113                        ),
1114                    ],
1115                    UnionMode::Sparse,
1116                ),
1117                Field::new("struct<>", DataType::Struct(Fields::empty()), true),
1118                Field::new(
1119                    "union<>",
1120                    DataType::Union(UnionFields::empty(), UnionMode::Dense),
1121                    true,
1122                ),
1123                Field::new(
1124                    "union<>",
1125                    DataType::Union(UnionFields::empty(), UnionMode::Sparse),
1126                    true,
1127                ),
1128                Field::new(
1129                    "union<int32, utf8>",
1130                    DataType::Union(
1131                        UnionFields::new(
1132                            vec![2, 3], // non-default type ids
1133                            vec![
1134                                Field::new("int32", DataType::Int32, true),
1135                                Field::new("utf8", DataType::Utf8, true),
1136                            ],
1137                        ),
1138                        UnionMode::Dense,
1139                    ),
1140                    true,
1141                ),
1142                Field::new_dict(
1143                    "dictionary<int32, utf8>",
1144                    DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
1145                    true,
1146                    123,
1147                    true,
1148                ),
1149                Field::new_dict(
1150                    "dictionary<uint8, uint32>",
1151                    DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)),
1152                    true,
1153                    123,
1154                    true,
1155                ),
1156                Field::new("decimal<usize, usize>", DataType::Decimal128(10, 6), false),
1157            ],
1158            md,
1159        );
1160
1161        let mut dictionary_tracker = DictionaryTracker::new(true);
1162        let fb = IpcSchemaEncoder::new()
1163            .with_dictionary_tracker(&mut dictionary_tracker)
1164            .schema_to_fb(&schema);
1165
1166        // read back fields
1167        let ipc = crate::root_as_schema(fb.finished_data()).unwrap();
1168        let schema2 = fb_to_schema(ipc);
1169        assert_eq!(schema, schema2);
1170    }
1171
1172    #[test]
1173    fn schema_from_bytes() {
1174        // Bytes of a schema generated via following python code, using pyarrow 10.0.1:
1175        //
1176        // import pyarrow as pa
1177        // schema = pa.schema([pa.field('field1', pa.uint32(), nullable=False)])
1178        // sink = pa.BufferOutputStream()
1179        // with pa.ipc.new_stream(sink, schema) as writer:
1180        //     pass
1181        // # stripping continuation & length prefix & suffix bytes to get only schema bytes
1182        // [x for x in sink.getvalue().to_pybytes()][8:-8]
1183        let bytes: Vec<u8> = vec![
1184            16, 0, 0, 0, 0, 0, 10, 0, 12, 0, 6, 0, 5, 0, 8, 0, 10, 0, 0, 0, 0, 1, 4, 0, 12, 0, 0,
1185            0, 8, 0, 8, 0, 0, 0, 4, 0, 8, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 20, 0, 0, 0, 16, 0, 20,
1186            0, 8, 0, 0, 0, 7, 0, 12, 0, 0, 0, 16, 0, 16, 0, 0, 0, 0, 0, 0, 2, 16, 0, 0, 0, 32, 0,
1187            0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 102, 105, 101, 108, 100, 49, 0, 0, 0, 0, 6,
1188            0, 8, 0, 4, 0, 6, 0, 0, 0, 32, 0, 0, 0,
1189        ];
1190        let ipc = crate::root_as_message(&bytes).unwrap();
1191        let schema = ipc.header_as_schema().unwrap();
1192
1193        // generate same message with Rust
1194        let data_gen = crate::writer::IpcDataGenerator::default();
1195        let mut dictionary_tracker = DictionaryTracker::new(true);
1196        let arrow_schema = Schema::new(vec![Field::new("field1", DataType::UInt32, false)]);
1197        let bytes = data_gen
1198            .schema_to_bytes_with_dictionary_tracker(
1199                &arrow_schema,
1200                &mut dictionary_tracker,
1201                &crate::writer::IpcWriteOptions::default(),
1202            )
1203            .ipc_message;
1204
1205        let ipc2 = crate::root_as_message(&bytes).unwrap();
1206        let schema2 = ipc2.header_as_schema().unwrap();
1207
1208        // can't compare schema directly as it compares the underlying bytes, which can differ
1209        assert!(schema.custom_metadata().is_none());
1210        assert!(schema2.custom_metadata().is_none());
1211        assert_eq!(schema.endianness(), schema2.endianness());
1212        assert!(schema.features().is_none());
1213        assert!(schema2.features().is_none());
1214        assert_eq!(fb_to_schema(schema), fb_to_schema(schema2));
1215
1216        assert_eq!(ipc.version(), ipc2.version());
1217        assert_eq!(ipc.header_type(), ipc2.header_type());
1218        assert_eq!(ipc.bodyLength(), ipc2.bodyLength());
1219        assert!(ipc.custom_metadata().is_none());
1220        assert!(ipc2.custom_metadata().is_none());
1221    }
1222}