1use arrow_schema::{ArrowError, DataType, Field, Fields, Schema};
19use indexmap::map::IndexMap as HashMap;
20use indexmap::set::IndexSet as HashSet;
21use serde_json::Value;
22use std::borrow::Borrow;
23use std::io::{BufRead, Seek};
24use std::sync::Arc;
25
26#[derive(Debug, Clone)]
27enum InferredType {
28 Scalar(HashSet<DataType>),
29 Array(Box<InferredType>),
30 Object(HashMap<String, InferredType>),
31 Any,
32}
33
34impl InferredType {
35 fn merge(&mut self, other: InferredType) -> Result<(), ArrowError> {
36 match (self, other) {
37 (InferredType::Array(s), InferredType::Array(o)) => {
38 s.merge(*o)?;
39 }
40 (InferredType::Scalar(self_hs), InferredType::Scalar(other_hs)) => {
41 other_hs.into_iter().for_each(|v| {
42 self_hs.insert(v);
43 });
44 }
45 (InferredType::Object(self_map), InferredType::Object(other_map)) => {
46 for (k, v) in other_map {
47 self_map.entry(k).or_insert(InferredType::Any).merge(v)?;
48 }
49 }
50 (s @ InferredType::Any, v) => {
51 *s = v;
52 }
53 (_, InferredType::Any) => {}
54 (InferredType::Array(self_inner_type), other_scalar @ InferredType::Scalar(_)) => {
56 self_inner_type.merge(other_scalar)?;
57 }
58 (s @ InferredType::Scalar(_), InferredType::Array(mut other_inner_type)) => {
59 other_inner_type.merge(s.clone())?;
60 *s = InferredType::Array(other_inner_type);
61 }
62 (s, o) => {
64 return Err(ArrowError::JsonError(format!(
65 "Incompatible type found during schema inference: {s:?} v.s. {o:?}",
66 )));
67 }
68 }
69
70 Ok(())
71 }
72
73 fn is_none_or_any(ty: Option<&Self>) -> bool {
74 matches!(ty, Some(Self::Any) | None)
75 }
76}
77
78fn list_type_of(ty: DataType) -> DataType {
80 DataType::List(Arc::new(Field::new_list_field(ty, true)))
81}
82
83fn coerce_data_type(dt: Vec<&DataType>) -> DataType {
89 let mut dt_iter = dt.into_iter().cloned();
90 let dt_init = dt_iter.next().unwrap_or(DataType::Utf8);
91
92 dt_iter.fold(dt_init, |l, r| match (l, r) {
93 (DataType::Null, o) | (o, DataType::Null) => o,
94 (DataType::Boolean, DataType::Boolean) => DataType::Boolean,
95 (DataType::Int64, DataType::Int64) => DataType::Int64,
96 (DataType::Float64, DataType::Float64)
97 | (DataType::Float64, DataType::Int64)
98 | (DataType::Int64, DataType::Float64) => DataType::Float64,
99 (DataType::List(l), DataType::List(r)) => {
100 list_type_of(coerce_data_type(vec![l.data_type(), r.data_type()]))
101 }
102 (DataType::List(e), not_list) | (not_list, DataType::List(e)) => {
104 list_type_of(coerce_data_type(vec![e.data_type(), ¬_list]))
105 }
106 _ => DataType::Utf8,
107 })
108}
109
110fn generate_datatype(t: &InferredType) -> Result<DataType, ArrowError> {
111 Ok(match t {
112 InferredType::Scalar(hs) => coerce_data_type(hs.iter().collect()),
113 InferredType::Object(spec) => DataType::Struct(generate_fields(spec)?),
114 InferredType::Array(ele_type) => list_type_of(generate_datatype(ele_type)?),
115 InferredType::Any => DataType::Null,
116 })
117}
118
119fn generate_fields(spec: &HashMap<String, InferredType>) -> Result<Fields, ArrowError> {
120 spec.iter()
121 .map(|(k, types)| Ok(Field::new(k, generate_datatype(types)?, true)))
122 .collect()
123}
124
125fn generate_schema(spec: HashMap<String, InferredType>) -> Result<Schema, ArrowError> {
127 Ok(Schema::new(generate_fields(&spec)?))
128}
129
130#[derive(Debug)]
147pub struct ValueIter<R: BufRead> {
148 reader: R,
149 max_read_records: Option<usize>,
150 record_count: usize,
151 line_buf: String,
153}
154
155impl<R: BufRead> ValueIter<R> {
156 pub fn new(reader: R, max_read_records: Option<usize>) -> Self {
158 Self {
159 reader,
160 max_read_records,
161 record_count: 0,
162 line_buf: String::new(),
163 }
164 }
165}
166
167impl<R: BufRead> Iterator for ValueIter<R> {
168 type Item = Result<Value, ArrowError>;
169
170 fn next(&mut self) -> Option<Self::Item> {
171 if let Some(max) = self.max_read_records {
172 if self.record_count >= max {
173 return None;
174 }
175 }
176
177 loop {
178 self.line_buf.truncate(0);
179 match self.reader.read_line(&mut self.line_buf) {
180 Ok(0) => {
181 return None;
183 }
184 Err(e) => {
185 return Some(Err(ArrowError::JsonError(format!(
186 "Failed to read JSON record: {e}"
187 ))));
188 }
189 _ => {
190 let trimmed_s = self.line_buf.trim();
191 if trimmed_s.is_empty() {
192 continue;
194 }
195
196 self.record_count += 1;
197 return Some(
198 serde_json::from_str(trimmed_s)
199 .map_err(|e| ArrowError::JsonError(format!("Not valid JSON: {e}"))),
200 );
201 }
202 }
203 }
204 }
205}
206
207pub fn infer_json_schema_from_seekable<R: BufRead + Seek>(
232 mut reader: R,
233 max_read_records: Option<usize>,
234) -> Result<(Schema, usize), ArrowError> {
235 let schema = infer_json_schema(&mut reader, max_read_records);
236 reader.rewind()?;
238
239 schema
240}
241
242pub fn infer_json_schema<R: BufRead>(
271 reader: R,
272 max_read_records: Option<usize>,
273) -> Result<(Schema, usize), ArrowError> {
274 let mut values = ValueIter::new(reader, max_read_records);
275 let schema = infer_json_schema_from_iterator(&mut values)?;
276 Ok((schema, values.record_count))
277}
278
279fn set_object_scalar_field_type(
280 field_types: &mut HashMap<String, InferredType>,
281 key: &str,
282 ftype: DataType,
283) -> Result<(), ArrowError> {
284 if InferredType::is_none_or_any(field_types.get(key)) {
285 field_types.insert(key.to_string(), InferredType::Scalar(HashSet::new()));
286 }
287
288 match field_types.get_mut(key).unwrap() {
289 InferredType::Scalar(hs) => {
290 hs.insert(ftype);
291 Ok(())
292 }
293 scalar_array @ InferredType::Array(_) => {
296 let mut hs = HashSet::new();
297 hs.insert(ftype);
298 scalar_array.merge(InferredType::Scalar(hs))?;
299 Ok(())
300 }
301 t => Err(ArrowError::JsonError(format!(
302 "Expected scalar or scalar array JSON type, found: {t:?}",
303 ))),
304 }
305}
306
307fn infer_scalar_array_type(array: &[Value]) -> Result<InferredType, ArrowError> {
308 let mut hs = HashSet::new();
309
310 for v in array {
311 match v {
312 Value::Null => {}
313 Value::Number(n) => {
314 if n.is_i64() {
315 hs.insert(DataType::Int64);
316 } else {
317 hs.insert(DataType::Float64);
318 }
319 }
320 Value::Bool(_) => {
321 hs.insert(DataType::Boolean);
322 }
323 Value::String(_) => {
324 hs.insert(DataType::Utf8);
325 }
326 Value::Array(_) | Value::Object(_) => {
327 return Err(ArrowError::JsonError(format!(
328 "Expected scalar value for scalar array, got: {v:?}"
329 )));
330 }
331 }
332 }
333
334 Ok(InferredType::Scalar(hs))
335}
336
337fn infer_nested_array_type(array: &[Value]) -> Result<InferredType, ArrowError> {
338 let mut inner_ele_type = InferredType::Any;
339
340 for v in array {
341 match v {
342 Value::Array(inner_array) => {
343 inner_ele_type.merge(infer_array_element_type(inner_array)?)?;
344 }
345 x => {
346 return Err(ArrowError::JsonError(format!(
347 "Got non array element in nested array: {x:?}"
348 )));
349 }
350 }
351 }
352
353 Ok(InferredType::Array(Box::new(inner_ele_type)))
354}
355
356fn infer_struct_array_type(array: &[Value]) -> Result<InferredType, ArrowError> {
357 let mut field_types = HashMap::new();
358
359 for v in array {
360 match v {
361 Value::Object(map) => {
362 collect_field_types_from_object(&mut field_types, map)?;
363 }
364 _ => {
365 return Err(ArrowError::JsonError(format!(
366 "Expected struct value for struct array, got: {v:?}"
367 )));
368 }
369 }
370 }
371
372 Ok(InferredType::Object(field_types))
373}
374
375fn infer_array_element_type(array: &[Value]) -> Result<InferredType, ArrowError> {
376 match array.iter().take(1).next() {
377 None => Ok(InferredType::Any), Some(a) => match a {
379 Value::Array(_) => infer_nested_array_type(array),
380 Value::Object(_) => infer_struct_array_type(array),
381 _ => infer_scalar_array_type(array),
382 },
383 }
384}
385
386fn collect_field_types_from_object(
387 field_types: &mut HashMap<String, InferredType>,
388 map: &serde_json::map::Map<String, Value>,
389) -> Result<(), ArrowError> {
390 for (k, v) in map {
391 match v {
392 Value::Array(array) => {
393 let ele_type = infer_array_element_type(array)?;
394
395 if InferredType::is_none_or_any(field_types.get(k)) {
396 match ele_type {
397 InferredType::Scalar(_) => {
398 field_types.insert(
399 k.to_string(),
400 InferredType::Array(Box::new(InferredType::Scalar(HashSet::new()))),
401 );
402 }
403 InferredType::Object(_) => {
404 field_types.insert(
405 k.to_string(),
406 InferredType::Array(Box::new(InferredType::Object(HashMap::new()))),
407 );
408 }
409 InferredType::Any | InferredType::Array(_) => {
410 field_types.insert(
413 k.to_string(),
414 InferredType::Array(Box::new(InferredType::Any)),
415 );
416 }
417 }
418 }
419
420 match field_types.get_mut(k).unwrap() {
421 InferredType::Array(inner_type) => {
422 inner_type.merge(ele_type)?;
423 }
424 field_type @ InferredType::Scalar(_) => {
427 field_type.merge(ele_type)?;
428 *field_type = InferredType::Array(Box::new(field_type.clone()));
429 }
430 t => {
431 return Err(ArrowError::JsonError(format!(
432 "Expected array json type, found: {t:?}",
433 )));
434 }
435 }
436 }
437 Value::Bool(_) => {
438 set_object_scalar_field_type(field_types, k, DataType::Boolean)?;
439 }
440 Value::Null => {
441 if !field_types.contains_key(k) {
444 field_types.insert(k.to_string(), InferredType::Any);
445 }
446 }
447 Value::Number(n) => {
448 if n.is_i64() {
449 set_object_scalar_field_type(field_types, k, DataType::Int64)?;
450 } else {
451 set_object_scalar_field_type(field_types, k, DataType::Float64)?;
452 }
453 }
454 Value::String(_) => {
455 set_object_scalar_field_type(field_types, k, DataType::Utf8)?;
456 }
457 Value::Object(inner_map) => {
458 if let InferredType::Any = field_types.get(k).unwrap_or(&InferredType::Any) {
459 field_types.insert(k.to_string(), InferredType::Object(HashMap::new()));
460 }
461 match field_types.get_mut(k).unwrap() {
462 InferredType::Object(inner_field_types) => {
463 collect_field_types_from_object(inner_field_types, inner_map)?;
464 }
465 t => {
466 return Err(ArrowError::JsonError(format!(
467 "Expected object json type, found: {t:?}",
468 )));
469 }
470 }
471 }
472 }
473 }
474
475 Ok(())
476}
477
478pub fn infer_json_schema_from_iterator<I, V>(value_iter: I) -> Result<Schema, ArrowError>
492where
493 I: Iterator<Item = Result<V, ArrowError>>,
494 V: Borrow<Value>,
495{
496 let mut field_types: HashMap<String, InferredType> = HashMap::new();
497
498 for record in value_iter {
499 match record?.borrow() {
500 Value::Object(map) => {
501 collect_field_types_from_object(&mut field_types, map)?;
502 }
503 value => {
504 return Err(ArrowError::JsonError(format!(
505 "Expected JSON record to be an object, found {value:?}"
506 )));
507 }
508 };
509 }
510
511 generate_schema(field_types)
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use flate2::read::GzDecoder;
518 use std::fs::File;
519 use std::io::{BufReader, Cursor};
520
521 #[test]
522 fn test_json_infer_schema() {
523 let schema = Schema::new(vec![
524 Field::new("a", DataType::Int64, true),
525 Field::new("b", list_type_of(DataType::Float64), true),
526 Field::new("c", list_type_of(DataType::Boolean), true),
527 Field::new("d", list_type_of(DataType::Utf8), true),
528 ]);
529
530 let mut reader = BufReader::new(File::open("test/data/mixed_arrays.json").unwrap());
531 let (inferred_schema, n_rows) = infer_json_schema_from_seekable(&mut reader, None).unwrap();
532
533 assert_eq!(inferred_schema, schema);
534 assert_eq!(n_rows, 4);
535
536 let file = File::open("test/data/mixed_arrays.json.gz").unwrap();
537 let mut reader = BufReader::new(GzDecoder::new(&file));
538 let (inferred_schema, n_rows) = infer_json_schema(&mut reader, None).unwrap();
539
540 assert_eq!(inferred_schema, schema);
541 assert_eq!(n_rows, 4);
542 }
543
544 #[test]
545 fn test_row_limit() {
546 let mut reader = BufReader::new(File::open("test/data/basic.json").unwrap());
547
548 let (_, n_rows) = infer_json_schema_from_seekable(&mut reader, None).unwrap();
549 assert_eq!(n_rows, 12);
550
551 let (_, n_rows) = infer_json_schema_from_seekable(&mut reader, Some(5)).unwrap();
552 assert_eq!(n_rows, 5);
553 }
554
555 #[test]
556 fn test_json_infer_schema_nested_structs() {
557 let schema = Schema::new(vec![
558 Field::new(
559 "c1",
560 DataType::Struct(Fields::from(vec![
561 Field::new("a", DataType::Boolean, true),
562 Field::new(
563 "b",
564 DataType::Struct(vec![Field::new("c", DataType::Utf8, true)].into()),
565 true,
566 ),
567 ])),
568 true,
569 ),
570 Field::new("c2", DataType::Int64, true),
571 Field::new("c3", DataType::Utf8, true),
572 ]);
573
574 let inferred_schema = infer_json_schema_from_iterator(
575 vec![
576 Ok(serde_json::json!({"c1": {"a": true, "b": {"c": "text"}}, "c2": 1})),
577 Ok(serde_json::json!({"c1": {"a": false, "b": null}, "c2": 0})),
578 Ok(serde_json::json!({"c1": {"a": true, "b": {"c": "text"}}, "c3": "ok"})),
579 ]
580 .into_iter(),
581 )
582 .unwrap();
583
584 assert_eq!(inferred_schema, schema);
585 }
586
587 #[test]
588 fn test_json_infer_schema_struct_in_list() {
589 let schema = Schema::new(vec![
590 Field::new(
591 "c1",
592 list_type_of(DataType::Struct(Fields::from(vec![
593 Field::new("a", DataType::Utf8, true),
594 Field::new("b", DataType::Int64, true),
595 Field::new("c", DataType::Boolean, true),
596 ]))),
597 true,
598 ),
599 Field::new("c2", DataType::Float64, true),
600 Field::new(
601 "c3",
602 list_type_of(DataType::Null),
604 true,
605 ),
606 ]);
607
608 let inferred_schema = infer_json_schema_from_iterator(
609 vec![
610 Ok(serde_json::json!({
611 "c1": [{"a": "foo", "b": 100}], "c2": 1, "c3": [],
612 })),
613 Ok(serde_json::json!({
614 "c1": [{"a": "bar", "b": 2}, {"a": "foo", "c": true}], "c2": 0, "c3": [],
615 })),
616 Ok(serde_json::json!({"c1": [], "c2": 0.5, "c3": []})),
617 ]
618 .into_iter(),
619 )
620 .unwrap();
621
622 assert_eq!(inferred_schema, schema);
623 }
624
625 #[test]
626 fn test_json_infer_schema_nested_list() {
627 let schema = Schema::new(vec![
628 Field::new("c1", list_type_of(list_type_of(DataType::Utf8)), true),
629 Field::new("c2", DataType::Float64, true),
630 ]);
631
632 let inferred_schema = infer_json_schema_from_iterator(
633 vec![
634 Ok(serde_json::json!({
635 "c1": [],
636 "c2": 12,
637 })),
638 Ok(serde_json::json!({
639 "c1": [["a", "b"], ["c"]],
640 })),
641 Ok(serde_json::json!({
642 "c1": [["foo"]],
643 "c2": 0.11,
644 })),
645 ]
646 .into_iter(),
647 )
648 .unwrap();
649
650 assert_eq!(inferred_schema, schema);
651 }
652
653 #[test]
654 fn test_infer_json_schema_bigger_than_i64_max() {
655 let bigger_than_i64_max = (i64::MAX as i128) + 1;
656 let smaller_than_i64_min = (i64::MIN as i128) - 1;
657 let json = format!(
658 "{{ \"bigger_than_i64_max\": {}, \"smaller_than_i64_min\": {} }}",
659 bigger_than_i64_max, smaller_than_i64_min
660 );
661 let mut buf_reader = BufReader::new(json.as_bytes());
662 let (inferred_schema, _) = infer_json_schema(&mut buf_reader, Some(1)).unwrap();
663 let fields = inferred_schema.fields();
664
665 let (_, big_field) = fields.find("bigger_than_i64_max").unwrap();
666 assert_eq!(big_field.data_type(), &DataType::Float64);
667 let (_, small_field) = fields.find("smaller_than_i64_min").unwrap();
668 assert_eq!(small_field.data_type(), &DataType::Float64);
669 }
670
671 #[test]
672 fn test_coercion_scalar_and_list() {
673 assert_eq!(
674 list_type_of(DataType::Float64),
675 coerce_data_type(vec![&DataType::Float64, &list_type_of(DataType::Float64)])
676 );
677 assert_eq!(
678 list_type_of(DataType::Float64),
679 coerce_data_type(vec![&DataType::Float64, &list_type_of(DataType::Int64)])
680 );
681 assert_eq!(
682 list_type_of(DataType::Int64),
683 coerce_data_type(vec![&DataType::Int64, &list_type_of(DataType::Int64)])
684 );
685 assert_eq!(
687 list_type_of(DataType::Utf8),
688 coerce_data_type(vec![&DataType::Boolean, &list_type_of(DataType::Float64)])
689 );
690 }
691
692 #[test]
693 fn test_invalid_json_infer_schema() {
694 let re = infer_json_schema_from_seekable(Cursor::new(b"}"), None);
695 assert_eq!(
696 re.err().unwrap().to_string(),
697 "Json error: Not valid JSON: expected value at line 1 column 1",
698 );
699 }
700
701 #[test]
702 fn test_null_field_inferred_as_null() {
703 let data = r#"
704 {"in":1, "ni":null, "ns":null, "sn":"4", "n":null, "an":[], "na": null, "nas":null}
705 {"in":null, "ni":2, "ns":"3", "sn":null, "n":null, "an":null, "na": [], "nas":["8"]}
706 {"in":1, "ni":null, "ns":null, "sn":"4", "n":null, "an":[], "na": null, "nas":[]}
707 "#;
708 let (inferred_schema, _) =
709 infer_json_schema_from_seekable(Cursor::new(data), None).expect("infer");
710 let schema = Schema::new(vec![
711 Field::new("an", list_type_of(DataType::Null), true),
712 Field::new("in", DataType::Int64, true),
713 Field::new("n", DataType::Null, true),
714 Field::new("na", list_type_of(DataType::Null), true),
715 Field::new("nas", list_type_of(DataType::Utf8), true),
716 Field::new("ni", DataType::Int64, true),
717 Field::new("ns", DataType::Utf8, true),
718 Field::new("sn", DataType::Utf8, true),
719 ]);
720 assert_eq!(inferred_schema, schema);
721 }
722
723 #[test]
724 fn test_infer_from_null_then_object() {
725 let data = r#"
726 {"obj":null}
727 {"obj":{"foo":1}}
728 "#;
729 let (inferred_schema, _) =
730 infer_json_schema_from_seekable(Cursor::new(data), None).expect("infer");
731 let schema = Schema::new(vec![Field::new(
732 "obj",
733 DataType::Struct(
734 [Field::new("foo", DataType::Int64, true)]
735 .into_iter()
736 .collect(),
737 ),
738 true,
739 )]);
740 assert_eq!(inferred_schema, schema);
741 }
742}