1mod records;
127
128use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
129use arrow_array::types::*;
130use arrow_array::*;
131use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser};
132use arrow_schema::*;
133use chrono::{TimeZone, Utc};
134use csv::StringRecord;
135use lazy_static::lazy_static;
136use regex::{Regex, RegexSet};
137use std::fmt::{self, Debug};
138use std::fs::File;
139use std::io::{BufRead, BufReader as StdBufReader, Read};
140use std::sync::Arc;
141
142use crate::map_csv_error;
143use crate::reader::records::{RecordDecoder, StringRecords};
144use arrow_array::timezone::Tz;
145
146lazy_static! {
147 static ref REGEX_SET: RegexSet = RegexSet::new([
149 r"(?i)^(true)$|^(false)$(?-i)", r"^-?(\d+)$", r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", r"^\d{4}-\d\d-\d\d$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", ]).unwrap();
158}
159
160#[derive(Debug, Clone, Default)]
162struct NullRegex(Option<Regex>);
163
164impl NullRegex {
165 #[inline]
168 fn is_null(&self, s: &str) -> bool {
169 match &self.0 {
170 Some(r) => r.is_match(s),
171 None => s.is_empty(),
172 }
173 }
174}
175
176#[derive(Default, Copy, Clone)]
177struct InferredDataType {
178 packed: u16,
190}
191
192impl InferredDataType {
193 fn get(&self) -> DataType {
195 match self.packed {
196 0 => DataType::Null,
197 1 => DataType::Boolean,
198 2 => DataType::Int64,
199 4 | 6 => DataType::Float64, b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
201 8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
203 9 => DataType::Timestamp(TimeUnit::Microsecond, None),
204 10 => DataType::Timestamp(TimeUnit::Millisecond, None),
205 11 => DataType::Timestamp(TimeUnit::Second, None),
206 12 => DataType::Date32,
207 _ => unreachable!(),
208 },
209 _ => DataType::Utf8,
210 }
211 }
212
213 fn update(&mut self, string: &str) {
215 self.packed |= if string.starts_with('"') {
216 1 << 8 } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
218 if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
219 1 << 8
221 } else {
222 1 << m
223 }
224 } else {
225 1 << 8 }
227 }
228}
229
230#[derive(Debug, Clone, Default)]
232pub struct Format {
233 header: bool,
234 delimiter: Option<u8>,
235 escape: Option<u8>,
236 quote: Option<u8>,
237 terminator: Option<u8>,
238 comment: Option<u8>,
239 null_regex: NullRegex,
240 truncated_rows: bool,
241}
242
243impl Format {
244 pub fn with_header(mut self, has_header: bool) -> Self {
248 self.header = has_header;
249 self
250 }
251
252 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
254 self.delimiter = Some(delimiter);
255 self
256 }
257
258 pub fn with_escape(mut self, escape: u8) -> Self {
260 self.escape = Some(escape);
261 self
262 }
263
264 pub fn with_quote(mut self, quote: u8) -> Self {
266 self.quote = Some(quote);
267 self
268 }
269
270 pub fn with_terminator(mut self, terminator: u8) -> Self {
272 self.terminator = Some(terminator);
273 self
274 }
275
276 pub fn with_comment(mut self, comment: u8) -> Self {
280 self.comment = Some(comment);
281 self
282 }
283
284 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
286 self.null_regex = NullRegex(Some(null_regex));
287 self
288 }
289
290 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
297 self.truncated_rows = allow;
298 self
299 }
300
301 pub fn infer_schema<R: Read>(
308 &self,
309 reader: R,
310 max_records: Option<usize>,
311 ) -> Result<(Schema, usize), ArrowError> {
312 let mut csv_reader = self.build_reader(reader);
313
314 let headers: Vec<String> = if self.header {
317 let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
318 headers.iter().map(|s| s.to_string()).collect()
319 } else {
320 let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
321 (0..*first_record_count)
322 .map(|i| format!("column_{}", i + 1))
323 .collect()
324 };
325
326 let header_length = headers.len();
327 let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
329
330 let mut records_count = 0;
331
332 let mut record = StringRecord::new();
333 let max_records = max_records.unwrap_or(usize::MAX);
334 while records_count < max_records {
335 if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
336 break;
337 }
338 records_count += 1;
339
340 for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
343 if let Some(string) = record.get(i) {
344 if !self.null_regex.is_null(string) {
345 column_type.update(string)
346 }
347 }
348 }
349 }
350
351 let fields: Fields = column_types
353 .iter()
354 .zip(&headers)
355 .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
356 .collect();
357
358 Ok((Schema::new(fields), records_count))
359 }
360
361 fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
363 let mut builder = csv::ReaderBuilder::new();
364 builder.has_headers(self.header);
365 builder.flexible(self.truncated_rows);
366
367 if let Some(c) = self.delimiter {
368 builder.delimiter(c);
369 }
370 builder.escape(self.escape);
371 if let Some(c) = self.quote {
372 builder.quote(c);
373 }
374 if let Some(t) = self.terminator {
375 builder.terminator(csv::Terminator::Any(t));
376 }
377 if let Some(comment) = self.comment {
378 builder.comment(Some(comment));
379 }
380 builder.from_reader(reader)
381 }
382
383 fn build_parser(&self) -> csv_core::Reader {
385 let mut builder = csv_core::ReaderBuilder::new();
386 builder.escape(self.escape);
387 builder.comment(self.comment);
388
389 if let Some(c) = self.delimiter {
390 builder.delimiter(c);
391 }
392 if let Some(c) = self.quote {
393 builder.quote(c);
394 }
395 if let Some(t) = self.terminator {
396 builder.terminator(csv_core::Terminator::Any(t));
397 }
398 builder.build()
399 }
400}
401
402pub fn infer_schema_from_files(
409 files: &[String],
410 delimiter: u8,
411 max_read_records: Option<usize>,
412 has_header: bool,
413) -> Result<Schema, ArrowError> {
414 let mut schemas = vec![];
415 let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
416 let format = Format {
417 delimiter: Some(delimiter),
418 header: has_header,
419 ..Default::default()
420 };
421
422 for fname in files.iter() {
423 let f = File::open(fname)?;
424 let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
425 if records_read == 0 {
426 continue;
427 }
428 schemas.push(schema.clone());
429 records_to_read -= records_read;
430 if records_to_read == 0 {
431 break;
432 }
433 }
434
435 Schema::try_merge(schemas)
436}
437
438type Bounds = Option<(usize, usize)>;
440
441pub type Reader<R> = BufReader<StdBufReader<R>>;
443
444pub struct BufReader<R> {
446 reader: R,
448
449 decoder: Decoder,
451}
452
453impl<R> fmt::Debug for BufReader<R>
454where
455 R: BufRead,
456{
457 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458 f.debug_struct("Reader")
459 .field("decoder", &self.decoder)
460 .finish()
461 }
462}
463
464impl<R: Read> Reader<R> {
465 pub fn schema(&self) -> SchemaRef {
468 match &self.decoder.projection {
469 Some(projection) => {
470 let fields = self.decoder.schema.fields();
471 let projected = projection.iter().map(|i| fields[*i].clone());
472 Arc::new(Schema::new(projected.collect::<Fields>()))
473 }
474 None => self.decoder.schema.clone(),
475 }
476 }
477}
478
479impl<R: BufRead> BufReader<R> {
480 fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
481 loop {
482 let buf = self.reader.fill_buf()?;
483 let decoded = self.decoder.decode(buf)?;
484 self.reader.consume(decoded);
485 if decoded == 0 || self.decoder.capacity() == 0 {
491 break;
492 }
493 }
494
495 self.decoder.flush()
496 }
497}
498
499impl<R: BufRead> Iterator for BufReader<R> {
500 type Item = Result<RecordBatch, ArrowError>;
501
502 fn next(&mut self) -> Option<Self::Item> {
503 self.read().transpose()
504 }
505}
506
507impl<R: BufRead> RecordBatchReader for BufReader<R> {
508 fn schema(&self) -> SchemaRef {
509 self.decoder.schema.clone()
510 }
511}
512
513#[derive(Debug)]
553pub struct Decoder {
554 schema: SchemaRef,
556
557 projection: Option<Vec<usize>>,
559
560 batch_size: usize,
562
563 to_skip: usize,
565
566 line_number: usize,
568
569 end: usize,
571
572 record_decoder: RecordDecoder,
574
575 null_regex: NullRegex,
577}
578
579impl Decoder {
580 pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
590 if self.to_skip != 0 {
591 let to_skip = self.to_skip.min(self.batch_size);
593 let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
594 self.to_skip -= skipped;
595 self.record_decoder.clear();
596 return Ok(bytes);
597 }
598
599 let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
600 let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
601 Ok(bytes)
602 }
603
604 pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
611 if self.record_decoder.is_empty() {
612 return Ok(None);
613 }
614
615 let rows = self.record_decoder.flush()?;
616 let batch = parse(
617 &rows,
618 self.schema.fields(),
619 Some(self.schema.metadata.clone()),
620 self.projection.as_ref(),
621 self.line_number,
622 &self.null_regex,
623 )?;
624 self.line_number += rows.len();
625 Ok(Some(batch))
626 }
627
628 pub fn capacity(&self) -> usize {
630 self.batch_size - self.record_decoder.len()
631 }
632}
633
634fn parse(
636 rows: &StringRecords<'_>,
637 fields: &Fields,
638 metadata: Option<std::collections::HashMap<String, String>>,
639 projection: Option<&Vec<usize>>,
640 line_number: usize,
641 null_regex: &NullRegex,
642) -> Result<RecordBatch, ArrowError> {
643 let projection: Vec<usize> = match projection {
644 Some(v) => v.clone(),
645 None => fields.iter().enumerate().map(|(i, _)| i).collect(),
646 };
647
648 let arrays: Result<Vec<ArrayRef>, _> = projection
649 .iter()
650 .map(|i| {
651 let i = *i;
652 let field = &fields[i];
653 match field.data_type() {
654 DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
655 DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
656 line_number,
657 rows,
658 i,
659 *precision,
660 *scale,
661 null_regex,
662 ),
663 DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
664 line_number,
665 rows,
666 i,
667 *precision,
668 *scale,
669 null_regex,
670 ),
671 DataType::Int8 => {
672 build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
673 }
674 DataType::Int16 => {
675 build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
676 }
677 DataType::Int32 => {
678 build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
679 }
680 DataType::Int64 => {
681 build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
682 }
683 DataType::UInt8 => {
684 build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
685 }
686 DataType::UInt16 => {
687 build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
688 }
689 DataType::UInt32 => {
690 build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
691 }
692 DataType::UInt64 => {
693 build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
694 }
695 DataType::Float32 => {
696 build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
697 }
698 DataType::Float64 => {
699 build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
700 }
701 DataType::Date32 => {
702 build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
703 }
704 DataType::Date64 => {
705 build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
706 }
707 DataType::Time32(TimeUnit::Second) => {
708 build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
709 }
710 DataType::Time32(TimeUnit::Millisecond) => {
711 build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
712 }
713 DataType::Time64(TimeUnit::Microsecond) => {
714 build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
715 }
716 DataType::Time64(TimeUnit::Nanosecond) => {
717 build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
718 }
719 DataType::Timestamp(TimeUnit::Second, tz) => {
720 build_timestamp_array::<TimestampSecondType>(
721 line_number,
722 rows,
723 i,
724 tz.as_deref(),
725 null_regex,
726 )
727 }
728 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
729 build_timestamp_array::<TimestampMillisecondType>(
730 line_number,
731 rows,
732 i,
733 tz.as_deref(),
734 null_regex,
735 )
736 }
737 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
738 build_timestamp_array::<TimestampMicrosecondType>(
739 line_number,
740 rows,
741 i,
742 tz.as_deref(),
743 null_regex,
744 )
745 }
746 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
747 build_timestamp_array::<TimestampNanosecondType>(
748 line_number,
749 rows,
750 i,
751 tz.as_deref(),
752 null_regex,
753 )
754 }
755 DataType::Null => Ok(Arc::new({
756 let mut builder = NullBuilder::new();
757 builder.append_nulls(rows.len());
758 builder.finish()
759 }) as ArrayRef),
760 DataType::Utf8 => Ok(Arc::new(
761 rows.iter()
762 .map(|row| {
763 let s = row.get(i);
764 (!null_regex.is_null(s)).then_some(s)
765 })
766 .collect::<StringArray>(),
767 ) as ArrayRef),
768 DataType::Utf8View => Ok(Arc::new(
769 rows.iter()
770 .map(|row| {
771 let s = row.get(i);
772 (!null_regex.is_null(s)).then_some(s)
773 })
774 .collect::<StringViewArray>(),
775 ) as ArrayRef),
776 DataType::Dictionary(key_type, value_type)
777 if value_type.as_ref() == &DataType::Utf8 =>
778 {
779 match key_type.as_ref() {
780 DataType::Int8 => Ok(Arc::new(
781 rows.iter()
782 .map(|row| {
783 let s = row.get(i);
784 (!null_regex.is_null(s)).then_some(s)
785 })
786 .collect::<DictionaryArray<Int8Type>>(),
787 ) as ArrayRef),
788 DataType::Int16 => Ok(Arc::new(
789 rows.iter()
790 .map(|row| {
791 let s = row.get(i);
792 (!null_regex.is_null(s)).then_some(s)
793 })
794 .collect::<DictionaryArray<Int16Type>>(),
795 ) as ArrayRef),
796 DataType::Int32 => Ok(Arc::new(
797 rows.iter()
798 .map(|row| {
799 let s = row.get(i);
800 (!null_regex.is_null(s)).then_some(s)
801 })
802 .collect::<DictionaryArray<Int32Type>>(),
803 ) as ArrayRef),
804 DataType::Int64 => Ok(Arc::new(
805 rows.iter()
806 .map(|row| {
807 let s = row.get(i);
808 (!null_regex.is_null(s)).then_some(s)
809 })
810 .collect::<DictionaryArray<Int64Type>>(),
811 ) as ArrayRef),
812 DataType::UInt8 => Ok(Arc::new(
813 rows.iter()
814 .map(|row| {
815 let s = row.get(i);
816 (!null_regex.is_null(s)).then_some(s)
817 })
818 .collect::<DictionaryArray<UInt8Type>>(),
819 ) as ArrayRef),
820 DataType::UInt16 => Ok(Arc::new(
821 rows.iter()
822 .map(|row| {
823 let s = row.get(i);
824 (!null_regex.is_null(s)).then_some(s)
825 })
826 .collect::<DictionaryArray<UInt16Type>>(),
827 ) as ArrayRef),
828 DataType::UInt32 => Ok(Arc::new(
829 rows.iter()
830 .map(|row| {
831 let s = row.get(i);
832 (!null_regex.is_null(s)).then_some(s)
833 })
834 .collect::<DictionaryArray<UInt32Type>>(),
835 ) as ArrayRef),
836 DataType::UInt64 => Ok(Arc::new(
837 rows.iter()
838 .map(|row| {
839 let s = row.get(i);
840 (!null_regex.is_null(s)).then_some(s)
841 })
842 .collect::<DictionaryArray<UInt64Type>>(),
843 ) as ArrayRef),
844 _ => Err(ArrowError::ParseError(format!(
845 "Unsupported dictionary key type {key_type:?}"
846 ))),
847 }
848 }
849 other => Err(ArrowError::ParseError(format!(
850 "Unsupported data type {other:?}"
851 ))),
852 }
853 })
854 .collect();
855
856 let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
857
858 let projected_schema = Arc::new(match metadata {
859 None => Schema::new(projected_fields),
860 Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
861 });
862
863 arrays.and_then(|arr| {
864 RecordBatch::try_new_with_options(
865 projected_schema,
866 arr,
867 &RecordBatchOptions::new()
868 .with_match_field_names(true)
869 .with_row_count(Some(rows.len())),
870 )
871 })
872}
873
874fn parse_bool(string: &str) -> Option<bool> {
875 if string.eq_ignore_ascii_case("false") {
876 Some(false)
877 } else if string.eq_ignore_ascii_case("true") {
878 Some(true)
879 } else {
880 None
881 }
882}
883
884fn build_decimal_array<T: DecimalType>(
886 _line_number: usize,
887 rows: &StringRecords<'_>,
888 col_idx: usize,
889 precision: u8,
890 scale: i8,
891 null_regex: &NullRegex,
892) -> Result<ArrayRef, ArrowError> {
893 let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
894 for row in rows.iter() {
895 let s = row.get(col_idx);
896 if null_regex.is_null(s) {
897 decimal_builder.append_null();
899 } else {
900 let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
901 match decimal_value {
902 Ok(v) => {
903 decimal_builder.append_value(v);
904 }
905 Err(e) => {
906 return Err(e);
907 }
908 }
909 }
910 }
911 Ok(Arc::new(
912 decimal_builder
913 .finish()
914 .with_precision_and_scale(precision, scale)?,
915 ))
916}
917
918fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
920 line_number: usize,
921 rows: &StringRecords<'_>,
922 col_idx: usize,
923 null_regex: &NullRegex,
924) -> Result<ArrayRef, ArrowError> {
925 rows.iter()
926 .enumerate()
927 .map(|(row_index, row)| {
928 let s = row.get(col_idx);
929 if null_regex.is_null(s) {
930 return Ok(None);
931 }
932
933 match T::parse(s) {
934 Some(e) => Ok(Some(e)),
935 None => Err(ArrowError::ParseError(format!(
936 "Error while parsing value {} for column {} at line {}",
938 s,
939 col_idx,
940 line_number + row_index
941 ))),
942 }
943 })
944 .collect::<Result<PrimitiveArray<T>, ArrowError>>()
945 .map(|e| Arc::new(e) as ArrayRef)
946}
947
948fn build_timestamp_array<T: ArrowTimestampType>(
949 line_number: usize,
950 rows: &StringRecords<'_>,
951 col_idx: usize,
952 timezone: Option<&str>,
953 null_regex: &NullRegex,
954) -> Result<ArrayRef, ArrowError> {
955 Ok(Arc::new(match timezone {
956 Some(timezone) => {
957 let tz: Tz = timezone.parse()?;
958 build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
959 .with_timezone(timezone)
960 }
961 None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
962 }))
963}
964
965fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
966 line_number: usize,
967 rows: &StringRecords<'_>,
968 col_idx: usize,
969 timezone: &Tz,
970 null_regex: &NullRegex,
971) -> Result<PrimitiveArray<T>, ArrowError> {
972 rows.iter()
973 .enumerate()
974 .map(|(row_index, row)| {
975 let s = row.get(col_idx);
976 if null_regex.is_null(s) {
977 return Ok(None);
978 }
979
980 let date = string_to_datetime(timezone, s)
981 .and_then(|date| match T::UNIT {
982 TimeUnit::Second => Ok(date.timestamp()),
983 TimeUnit::Millisecond => Ok(date.timestamp_millis()),
984 TimeUnit::Microsecond => Ok(date.timestamp_micros()),
985 TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
986 ArrowError::ParseError(format!(
987 "{} would overflow 64-bit signed nanoseconds",
988 date.to_rfc3339(),
989 ))
990 }),
991 })
992 .map_err(|e| {
993 ArrowError::ParseError(format!(
994 "Error parsing column {col_idx} at line {}: {}",
995 line_number + row_index,
996 e
997 ))
998 })?;
999 Ok(Some(date))
1000 })
1001 .collect()
1002}
1003
1004fn build_boolean_array(
1006 line_number: usize,
1007 rows: &StringRecords<'_>,
1008 col_idx: usize,
1009 null_regex: &NullRegex,
1010) -> Result<ArrayRef, ArrowError> {
1011 rows.iter()
1012 .enumerate()
1013 .map(|(row_index, row)| {
1014 let s = row.get(col_idx);
1015 if null_regex.is_null(s) {
1016 return Ok(None);
1017 }
1018 let parsed = parse_bool(s);
1019 match parsed {
1020 Some(e) => Ok(Some(e)),
1021 None => Err(ArrowError::ParseError(format!(
1022 "Error while parsing value {} for column {} at line {}",
1024 s,
1025 col_idx,
1026 line_number + row_index
1027 ))),
1028 }
1029 })
1030 .collect::<Result<BooleanArray, _>>()
1031 .map(|e| Arc::new(e) as ArrayRef)
1032}
1033
1034#[derive(Debug)]
1036pub struct ReaderBuilder {
1037 schema: SchemaRef,
1039 format: Format,
1041 batch_size: usize,
1045 bounds: Bounds,
1047 projection: Option<Vec<usize>>,
1049}
1050
1051impl ReaderBuilder {
1052 pub fn new(schema: SchemaRef) -> ReaderBuilder {
1074 Self {
1075 schema,
1076 format: Format::default(),
1077 batch_size: 1024,
1078 bounds: None,
1079 projection: None,
1080 }
1081 }
1082
1083 pub fn with_header(mut self, has_header: bool) -> Self {
1085 self.format.header = has_header;
1086 self
1087 }
1088
1089 pub fn with_format(mut self, format: Format) -> Self {
1091 self.format = format;
1092 self
1093 }
1094
1095 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1097 self.format.delimiter = Some(delimiter);
1098 self
1099 }
1100
1101 pub fn with_escape(mut self, escape: u8) -> Self {
1103 self.format.escape = Some(escape);
1104 self
1105 }
1106
1107 pub fn with_quote(mut self, quote: u8) -> Self {
1109 self.format.quote = Some(quote);
1110 self
1111 }
1112
1113 pub fn with_terminator(mut self, terminator: u8) -> Self {
1115 self.format.terminator = Some(terminator);
1116 self
1117 }
1118
1119 pub fn with_comment(mut self, comment: u8) -> Self {
1121 self.format.comment = Some(comment);
1122 self
1123 }
1124
1125 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1127 self.format.null_regex = NullRegex(Some(null_regex));
1128 self
1129 }
1130
1131 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1133 self.batch_size = batch_size;
1134 self
1135 }
1136
1137 pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1140 self.bounds = Some((start, end));
1141 self
1142 }
1143
1144 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1146 self.projection = Some(projection);
1147 self
1148 }
1149
1150 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1157 self.format.truncated_rows = allow;
1158 self
1159 }
1160
1161 pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1166 self.build_buffered(StdBufReader::new(reader))
1167 }
1168
1169 pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1171 Ok(BufReader {
1172 reader,
1173 decoder: self.build_decoder(),
1174 })
1175 }
1176
1177 pub fn build_decoder(self) -> Decoder {
1179 let delimiter = self.format.build_parser();
1180 let record_decoder = RecordDecoder::new(
1181 delimiter,
1182 self.schema.fields().len(),
1183 self.format.truncated_rows,
1184 );
1185
1186 let header = self.format.header as usize;
1187
1188 let (start, end) = match self.bounds {
1189 Some((start, end)) => (start + header, end + header),
1190 None => (header, usize::MAX),
1191 };
1192
1193 Decoder {
1194 schema: self.schema,
1195 to_skip: start,
1196 record_decoder,
1197 line_number: start,
1198 end,
1199 projection: self.projection,
1200 batch_size: self.batch_size,
1201 null_regex: self.format.null_regex,
1202 }
1203 }
1204}
1205
1206#[cfg(test)]
1207mod tests {
1208 use super::*;
1209
1210 use std::io::{Cursor, Seek, SeekFrom, Write};
1211 use tempfile::NamedTempFile;
1212
1213 use arrow_array::cast::AsArray;
1214
1215 #[test]
1216 fn test_csv() {
1217 let schema = Arc::new(Schema::new(vec![
1218 Field::new("city", DataType::Utf8, false),
1219 Field::new("lat", DataType::Float64, false),
1220 Field::new("lng", DataType::Float64, false),
1221 ]));
1222
1223 let file = File::open("test/data/uk_cities.csv").unwrap();
1224 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1225 assert_eq!(schema, csv.schema());
1226 let batch = csv.next().unwrap().unwrap();
1227 assert_eq!(37, batch.num_rows());
1228 assert_eq!(3, batch.num_columns());
1229
1230 let lat = batch.column(1).as_primitive::<Float64Type>();
1232 assert_eq!(57.653484, lat.value(0));
1233
1234 let city = batch.column(0).as_string::<i32>();
1236
1237 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1238 }
1239
1240 #[test]
1241 fn test_csv_schema_metadata() {
1242 let mut metadata = std::collections::HashMap::new();
1243 metadata.insert("foo".to_owned(), "bar".to_owned());
1244 let schema = Arc::new(Schema::new_with_metadata(
1245 vec![
1246 Field::new("city", DataType::Utf8, false),
1247 Field::new("lat", DataType::Float64, false),
1248 Field::new("lng", DataType::Float64, false),
1249 ],
1250 metadata.clone(),
1251 ));
1252
1253 let file = File::open("test/data/uk_cities.csv").unwrap();
1254
1255 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1256 assert_eq!(schema, csv.schema());
1257 let batch = csv.next().unwrap().unwrap();
1258 assert_eq!(37, batch.num_rows());
1259 assert_eq!(3, batch.num_columns());
1260
1261 assert_eq!(&metadata, batch.schema().metadata());
1262 }
1263
1264 #[test]
1265 fn test_csv_reader_with_decimal() {
1266 let schema = Arc::new(Schema::new(vec![
1267 Field::new("city", DataType::Utf8, false),
1268 Field::new("lat", DataType::Decimal128(38, 6), false),
1269 Field::new("lng", DataType::Decimal256(76, 6), false),
1270 ]));
1271
1272 let file = File::open("test/data/decimal_test.csv").unwrap();
1273
1274 let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1275 let batch = csv.next().unwrap().unwrap();
1276 let lat = batch
1278 .column(1)
1279 .as_any()
1280 .downcast_ref::<Decimal128Array>()
1281 .unwrap();
1282
1283 assert_eq!("57.653484", lat.value_as_string(0));
1284 assert_eq!("53.002666", lat.value_as_string(1));
1285 assert_eq!("52.412811", lat.value_as_string(2));
1286 assert_eq!("51.481583", lat.value_as_string(3));
1287 assert_eq!("12.123456", lat.value_as_string(4));
1288 assert_eq!("50.760000", lat.value_as_string(5));
1289 assert_eq!("0.123000", lat.value_as_string(6));
1290 assert_eq!("123.000000", lat.value_as_string(7));
1291 assert_eq!("123.000000", lat.value_as_string(8));
1292 assert_eq!("-50.760000", lat.value_as_string(9));
1293
1294 let lng = batch
1295 .column(2)
1296 .as_any()
1297 .downcast_ref::<Decimal256Array>()
1298 .unwrap();
1299
1300 assert_eq!("-3.335724", lng.value_as_string(0));
1301 assert_eq!("-2.179404", lng.value_as_string(1));
1302 assert_eq!("-1.778197", lng.value_as_string(2));
1303 assert_eq!("-3.179090", lng.value_as_string(3));
1304 assert_eq!("-3.179090", lng.value_as_string(4));
1305 assert_eq!("0.290472", lng.value_as_string(5));
1306 assert_eq!("0.290472", lng.value_as_string(6));
1307 assert_eq!("0.290472", lng.value_as_string(7));
1308 assert_eq!("0.290472", lng.value_as_string(8));
1309 assert_eq!("0.290472", lng.value_as_string(9));
1310 }
1311
1312 #[test]
1313 fn test_csv_from_buf_reader() {
1314 let schema = Schema::new(vec![
1315 Field::new("city", DataType::Utf8, false),
1316 Field::new("lat", DataType::Float64, false),
1317 Field::new("lng", DataType::Float64, false),
1318 ]);
1319
1320 let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1321 let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1322 let both_files = file_with_headers
1323 .chain(Cursor::new("\n".to_string()))
1324 .chain(file_without_headers);
1325 let mut csv = ReaderBuilder::new(Arc::new(schema))
1326 .with_header(true)
1327 .build(both_files)
1328 .unwrap();
1329 let batch = csv.next().unwrap().unwrap();
1330 assert_eq!(74, batch.num_rows());
1331 assert_eq!(3, batch.num_columns());
1332 }
1333
1334 #[test]
1335 fn test_csv_with_schema_inference() {
1336 let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1337
1338 let (schema, _) = Format::default()
1339 .with_header(true)
1340 .infer_schema(&mut file, None)
1341 .unwrap();
1342
1343 file.rewind().unwrap();
1344 let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1345
1346 let mut csv = builder.build(file).unwrap();
1347 let expected_schema = Schema::new(vec![
1348 Field::new("city", DataType::Utf8, true),
1349 Field::new("lat", DataType::Float64, true),
1350 Field::new("lng", DataType::Float64, true),
1351 ]);
1352 assert_eq!(Arc::new(expected_schema), csv.schema());
1353 let batch = csv.next().unwrap().unwrap();
1354 assert_eq!(37, batch.num_rows());
1355 assert_eq!(3, batch.num_columns());
1356
1357 let lat = batch
1359 .column(1)
1360 .as_any()
1361 .downcast_ref::<Float64Array>()
1362 .unwrap();
1363 assert_eq!(57.653484, lat.value(0));
1364
1365 let city = batch
1367 .column(0)
1368 .as_any()
1369 .downcast_ref::<StringArray>()
1370 .unwrap();
1371
1372 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1373 }
1374
1375 #[test]
1376 fn test_csv_with_schema_inference_no_headers() {
1377 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1378
1379 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1380 file.rewind().unwrap();
1381
1382 let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1383
1384 let schema = csv.schema();
1386 assert_eq!("column_1", schema.field(0).name());
1387 assert_eq!("column_2", schema.field(1).name());
1388 assert_eq!("column_3", schema.field(2).name());
1389 let batch = csv.next().unwrap().unwrap();
1390 let batch_schema = batch.schema();
1391
1392 assert_eq!(schema, batch_schema);
1393 assert_eq!(37, batch.num_rows());
1394 assert_eq!(3, batch.num_columns());
1395
1396 let lat = batch
1398 .column(1)
1399 .as_any()
1400 .downcast_ref::<Float64Array>()
1401 .unwrap();
1402 assert_eq!(57.653484, lat.value(0));
1403
1404 let city = batch
1406 .column(0)
1407 .as_any()
1408 .downcast_ref::<StringArray>()
1409 .unwrap();
1410
1411 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1412 }
1413
1414 #[test]
1415 fn test_csv_builder_with_bounds() {
1416 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1417
1418 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1420 file.rewind().unwrap();
1421 let mut csv = ReaderBuilder::new(Arc::new(schema))
1422 .with_bounds(0, 2)
1423 .build(file)
1424 .unwrap();
1425 let batch = csv.next().unwrap().unwrap();
1426
1427 let city = batch
1429 .column(0)
1430 .as_any()
1431 .downcast_ref::<StringArray>()
1432 .unwrap();
1433
1434 assert_eq!("Elgin, Scotland, the UK", city.value(0));
1436
1437 let result = std::panic::catch_unwind(|| city.value(13));
1440 assert!(result.is_err());
1441 }
1442
1443 #[test]
1444 fn test_csv_with_projection() {
1445 let schema = Arc::new(Schema::new(vec![
1446 Field::new("city", DataType::Utf8, false),
1447 Field::new("lat", DataType::Float64, false),
1448 Field::new("lng", DataType::Float64, false),
1449 ]));
1450
1451 let file = File::open("test/data/uk_cities.csv").unwrap();
1452
1453 let mut csv = ReaderBuilder::new(schema)
1454 .with_projection(vec![0, 1])
1455 .build(file)
1456 .unwrap();
1457
1458 let projected_schema = Arc::new(Schema::new(vec![
1459 Field::new("city", DataType::Utf8, false),
1460 Field::new("lat", DataType::Float64, false),
1461 ]));
1462 assert_eq!(projected_schema, csv.schema());
1463 let batch = csv.next().unwrap().unwrap();
1464 assert_eq!(projected_schema, batch.schema());
1465 assert_eq!(37, batch.num_rows());
1466 assert_eq!(2, batch.num_columns());
1467 }
1468
1469 #[test]
1470 fn test_csv_with_dictionary() {
1471 let schema = Arc::new(Schema::new(vec![
1472 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1473 Field::new("lat", DataType::Float64, false),
1474 Field::new("lng", DataType::Float64, false),
1475 ]));
1476
1477 let file = File::open("test/data/uk_cities.csv").unwrap();
1478
1479 let mut csv = ReaderBuilder::new(schema)
1480 .with_projection(vec![0, 1])
1481 .build(file)
1482 .unwrap();
1483
1484 let projected_schema = Arc::new(Schema::new(vec![
1485 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1486 Field::new("lat", DataType::Float64, false),
1487 ]));
1488 assert_eq!(projected_schema, csv.schema());
1489 let batch = csv.next().unwrap().unwrap();
1490 assert_eq!(projected_schema, batch.schema());
1491 assert_eq!(37, batch.num_rows());
1492 assert_eq!(2, batch.num_columns());
1493
1494 let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1495 let strings = strings.as_string::<i32>();
1496
1497 assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1498 assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1499 assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1500 }
1501
1502 #[test]
1503 fn test_csv_with_nullable_dictionary() {
1504 let offset_type = vec![
1505 DataType::Int8,
1506 DataType::Int16,
1507 DataType::Int32,
1508 DataType::Int64,
1509 DataType::UInt8,
1510 DataType::UInt16,
1511 DataType::UInt32,
1512 DataType::UInt64,
1513 ];
1514 for data_type in offset_type {
1515 let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1516 let dictionary_type =
1517 DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1518 let schema = Arc::new(Schema::new(vec![
1519 Field::new("id", DataType::Utf8, false),
1520 Field::new("name", dictionary_type.clone(), true),
1521 ]));
1522
1523 let mut csv = ReaderBuilder::new(schema)
1524 .build(file.try_clone().unwrap())
1525 .unwrap();
1526
1527 let batch = csv.next().unwrap().unwrap();
1528 assert_eq!(3, batch.num_rows());
1529 assert_eq!(2, batch.num_columns());
1530
1531 let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1532 assert!(!names.is_null(2));
1533 assert!(names.is_null(1));
1534 }
1535 }
1536 #[test]
1537 fn test_nulls() {
1538 let schema = Arc::new(Schema::new(vec![
1539 Field::new("c_int", DataType::UInt64, false),
1540 Field::new("c_float", DataType::Float32, true),
1541 Field::new("c_string", DataType::Utf8, true),
1542 Field::new("c_bool", DataType::Boolean, false),
1543 ]));
1544
1545 let file = File::open("test/data/null_test.csv").unwrap();
1546
1547 let mut csv = ReaderBuilder::new(schema)
1548 .with_header(true)
1549 .build(file)
1550 .unwrap();
1551
1552 let batch = csv.next().unwrap().unwrap();
1553
1554 assert!(!batch.column(1).is_null(0));
1555 assert!(!batch.column(1).is_null(1));
1556 assert!(batch.column(1).is_null(2));
1557 assert!(!batch.column(1).is_null(3));
1558 assert!(!batch.column(1).is_null(4));
1559 }
1560
1561 #[test]
1562 fn test_init_nulls() {
1563 let schema = Arc::new(Schema::new(vec![
1564 Field::new("c_int", DataType::UInt64, true),
1565 Field::new("c_float", DataType::Float32, true),
1566 Field::new("c_string", DataType::Utf8, true),
1567 Field::new("c_bool", DataType::Boolean, true),
1568 Field::new("c_null", DataType::Null, true),
1569 ]));
1570 let file = File::open("test/data/init_null_test.csv").unwrap();
1571
1572 let mut csv = ReaderBuilder::new(schema)
1573 .with_header(true)
1574 .build(file)
1575 .unwrap();
1576
1577 let batch = csv.next().unwrap().unwrap();
1578
1579 assert!(batch.column(1).is_null(0));
1580 assert!(!batch.column(1).is_null(1));
1581 assert!(batch.column(1).is_null(2));
1582 assert!(!batch.column(1).is_null(3));
1583 assert!(!batch.column(1).is_null(4));
1584 }
1585
1586 #[test]
1587 fn test_init_nulls_with_inference() {
1588 let format = Format::default().with_header(true).with_delimiter(b',');
1589
1590 let mut file = File::open("test/data/init_null_test.csv").unwrap();
1591 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1592 file.rewind().unwrap();
1593
1594 let expected_schema = Schema::new(vec![
1595 Field::new("c_int", DataType::Int64, true),
1596 Field::new("c_float", DataType::Float64, true),
1597 Field::new("c_string", DataType::Utf8, true),
1598 Field::new("c_bool", DataType::Boolean, true),
1599 Field::new("c_null", DataType::Null, true),
1600 ]);
1601 assert_eq!(schema, expected_schema);
1602
1603 let mut csv = ReaderBuilder::new(Arc::new(schema))
1604 .with_format(format)
1605 .build(file)
1606 .unwrap();
1607
1608 let batch = csv.next().unwrap().unwrap();
1609
1610 assert!(batch.column(1).is_null(0));
1611 assert!(!batch.column(1).is_null(1));
1612 assert!(batch.column(1).is_null(2));
1613 assert!(!batch.column(1).is_null(3));
1614 assert!(!batch.column(1).is_null(4));
1615 }
1616
1617 #[test]
1618 fn test_custom_nulls() {
1619 let schema = Arc::new(Schema::new(vec![
1620 Field::new("c_int", DataType::UInt64, true),
1621 Field::new("c_float", DataType::Float32, true),
1622 Field::new("c_string", DataType::Utf8, true),
1623 Field::new("c_bool", DataType::Boolean, true),
1624 ]));
1625
1626 let file = File::open("test/data/custom_null_test.csv").unwrap();
1627
1628 let null_regex = Regex::new("^nil$").unwrap();
1629
1630 let mut csv = ReaderBuilder::new(schema)
1631 .with_header(true)
1632 .with_null_regex(null_regex)
1633 .build(file)
1634 .unwrap();
1635
1636 let batch = csv.next().unwrap().unwrap();
1637
1638 assert!(batch.column(0).is_null(1));
1640 assert!(batch.column(1).is_null(2));
1641 assert!(batch.column(3).is_null(4));
1642 assert!(batch.column(2).is_null(3));
1643 assert!(!batch.column(2).is_null(4));
1644 }
1645
1646 #[test]
1647 fn test_nulls_with_inference() {
1648 let mut file = File::open("test/data/various_types.csv").unwrap();
1649 let format = Format::default().with_header(true).with_delimiter(b'|');
1650
1651 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1652 file.rewind().unwrap();
1653
1654 let builder = ReaderBuilder::new(Arc::new(schema))
1655 .with_format(format)
1656 .with_batch_size(512)
1657 .with_projection(vec![0, 1, 2, 3, 4, 5]);
1658
1659 let mut csv = builder.build(file).unwrap();
1660 let batch = csv.next().unwrap().unwrap();
1661
1662 assert_eq!(7, batch.num_rows());
1663 assert_eq!(6, batch.num_columns());
1664
1665 let schema = batch.schema();
1666
1667 assert_eq!(&DataType::Int64, schema.field(0).data_type());
1668 assert_eq!(&DataType::Float64, schema.field(1).data_type());
1669 assert_eq!(&DataType::Float64, schema.field(2).data_type());
1670 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1671 assert_eq!(&DataType::Date32, schema.field(4).data_type());
1672 assert_eq!(
1673 &DataType::Timestamp(TimeUnit::Second, None),
1674 schema.field(5).data_type()
1675 );
1676
1677 let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1678 assert_eq!(
1679 names,
1680 vec![
1681 "c_int",
1682 "c_float",
1683 "c_string",
1684 "c_bool",
1685 "c_date",
1686 "c_datetime"
1687 ]
1688 );
1689
1690 assert!(schema.field(0).is_nullable());
1691 assert!(schema.field(1).is_nullable());
1692 assert!(schema.field(2).is_nullable());
1693 assert!(schema.field(3).is_nullable());
1694 assert!(schema.field(4).is_nullable());
1695 assert!(schema.field(5).is_nullable());
1696
1697 assert!(!batch.column(1).is_null(0));
1698 assert!(!batch.column(1).is_null(1));
1699 assert!(batch.column(1).is_null(2));
1700 assert!(!batch.column(1).is_null(3));
1701 assert!(!batch.column(1).is_null(4));
1702 }
1703
1704 #[test]
1705 fn test_custom_nulls_with_inference() {
1706 let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1707
1708 let null_regex = Regex::new("^nil$").unwrap();
1709
1710 let format = Format::default()
1711 .with_header(true)
1712 .with_null_regex(null_regex);
1713
1714 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1715 file.rewind().unwrap();
1716
1717 let expected_schema = Schema::new(vec![
1718 Field::new("c_int", DataType::Int64, true),
1719 Field::new("c_float", DataType::Float64, true),
1720 Field::new("c_string", DataType::Utf8, true),
1721 Field::new("c_bool", DataType::Boolean, true),
1722 ]);
1723
1724 assert_eq!(schema, expected_schema);
1725
1726 let builder = ReaderBuilder::new(Arc::new(schema))
1727 .with_format(format)
1728 .with_batch_size(512)
1729 .with_projection(vec![0, 1, 2, 3]);
1730
1731 let mut csv = builder.build(file).unwrap();
1732 let batch = csv.next().unwrap().unwrap();
1733
1734 assert_eq!(5, batch.num_rows());
1735 assert_eq!(4, batch.num_columns());
1736
1737 assert_eq!(batch.schema().as_ref(), &expected_schema);
1738 }
1739
1740 #[test]
1741 fn test_scientific_notation_with_inference() {
1742 let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1743 let format = Format::default().with_header(false).with_delimiter(b',');
1744
1745 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1746 file.rewind().unwrap();
1747
1748 let builder = ReaderBuilder::new(Arc::new(schema))
1749 .with_format(format)
1750 .with_batch_size(512)
1751 .with_projection(vec![0, 1]);
1752
1753 let mut csv = builder.build(file).unwrap();
1754 let batch = csv.next().unwrap().unwrap();
1755
1756 let schema = batch.schema();
1757
1758 assert_eq!(&DataType::Float64, schema.field(0).data_type());
1759 }
1760
1761 #[test]
1762 fn test_parse_invalid_csv() {
1763 let file = File::open("test/data/various_types_invalid.csv").unwrap();
1764
1765 let schema = Schema::new(vec![
1766 Field::new("c_int", DataType::UInt64, false),
1767 Field::new("c_float", DataType::Float32, false),
1768 Field::new("c_string", DataType::Utf8, false),
1769 Field::new("c_bool", DataType::Boolean, false),
1770 ]);
1771
1772 let builder = ReaderBuilder::new(Arc::new(schema))
1773 .with_header(true)
1774 .with_delimiter(b'|')
1775 .with_batch_size(512)
1776 .with_projection(vec![0, 1, 2, 3]);
1777
1778 let mut csv = builder.build(file).unwrap();
1779 match csv.next() {
1780 Some(e) => match e {
1781 Err(e) => assert_eq!(
1782 "ParseError(\"Error while parsing value 4.x4 for column 1 at line 4\")",
1783 format!("{e:?}")
1784 ),
1785 Ok(_) => panic!("should have failed"),
1786 },
1787 None => panic!("should have failed"),
1788 }
1789 }
1790
1791 fn infer_field_schema(string: &str) -> DataType {
1793 let mut v = InferredDataType::default();
1794 v.update(string);
1795 v.get()
1796 }
1797
1798 #[test]
1799 fn test_infer_field_schema() {
1800 assert_eq!(infer_field_schema("A"), DataType::Utf8);
1801 assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1802 assert_eq!(infer_field_schema("10"), DataType::Int64);
1803 assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1804 assert_eq!(infer_field_schema(".2"), DataType::Float64);
1805 assert_eq!(infer_field_schema("2."), DataType::Float64);
1806 assert_eq!(infer_field_schema("true"), DataType::Boolean);
1807 assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1808 assert_eq!(infer_field_schema("false"), DataType::Boolean);
1809 assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1810 assert_eq!(
1811 infer_field_schema("2020-11-08T14:20:01"),
1812 DataType::Timestamp(TimeUnit::Second, None)
1813 );
1814 assert_eq!(
1815 infer_field_schema("2020-11-08 14:20:01"),
1816 DataType::Timestamp(TimeUnit::Second, None)
1817 );
1818 assert_eq!(
1819 infer_field_schema("2020-11-08 14:20:01"),
1820 DataType::Timestamp(TimeUnit::Second, None)
1821 );
1822 assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1823 assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1824 assert_eq!(
1825 infer_field_schema("2021-12-19 13:12:30.921"),
1826 DataType::Timestamp(TimeUnit::Millisecond, None)
1827 );
1828 assert_eq!(
1829 infer_field_schema("2021-12-19T13:12:30.123456789"),
1830 DataType::Timestamp(TimeUnit::Nanosecond, None)
1831 );
1832 assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1833 assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1834 }
1835
1836 #[test]
1837 fn parse_date32() {
1838 assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1839 assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1840 assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1841 }
1842
1843 #[test]
1844 fn parse_time() {
1845 assert_eq!(
1846 Time64NanosecondType::parse("12:10:01.123456789 AM"),
1847 Some(601_123_456_789)
1848 );
1849 assert_eq!(
1850 Time64MicrosecondType::parse("12:10:01.123456 am"),
1851 Some(601_123_456)
1852 );
1853 assert_eq!(
1854 Time32MillisecondType::parse("2:10:01.12 PM"),
1855 Some(51_001_120)
1856 );
1857 assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
1858 }
1859
1860 #[test]
1861 fn parse_date64() {
1862 assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
1863 assert_eq!(
1864 Date64Type::parse("2018-11-13T17:11:10").unwrap(),
1865 1542129070000
1866 );
1867 assert_eq!(
1868 Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
1869 1542129070011
1870 );
1871 assert_eq!(
1872 Date64Type::parse("1900-02-28T12:34:56").unwrap(),
1873 -2203932304000
1874 );
1875 assert_eq!(
1876 Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
1877 -2203932304000
1878 );
1879 assert_eq!(
1880 Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
1881 -2203932304000 - (30 * 60 * 1000)
1882 );
1883 }
1884
1885 fn test_parse_timestamp_impl<T: ArrowTimestampType>(
1886 timezone: Option<Arc<str>>,
1887 expected: &[i64],
1888 ) {
1889 let csv = [
1890 "1970-01-01T00:00:00",
1891 "1970-01-01T00:00:00Z",
1892 "1970-01-01T00:00:00+02:00",
1893 ]
1894 .join("\n");
1895 let schema = Arc::new(Schema::new(vec![Field::new(
1896 "field",
1897 DataType::Timestamp(T::UNIT, timezone.clone()),
1898 true,
1899 )]));
1900
1901 let mut decoder = ReaderBuilder::new(schema).build_decoder();
1902
1903 let decoded = decoder.decode(csv.as_bytes()).unwrap();
1904 assert_eq!(decoded, csv.len());
1905 decoder.decode(&[]).unwrap();
1906
1907 let batch = decoder.flush().unwrap().unwrap();
1908 assert_eq!(batch.num_columns(), 1);
1909 assert_eq!(batch.num_rows(), 3);
1910 let col = batch.column(0).as_primitive::<T>();
1911 assert_eq!(col.values(), expected);
1912 assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
1913 }
1914
1915 #[test]
1916 fn test_parse_timestamp() {
1917 test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
1918 test_parse_timestamp_impl::<TimestampNanosecondType>(
1919 Some("+00:00".into()),
1920 &[0, 0, -7_200_000_000_000],
1921 );
1922 test_parse_timestamp_impl::<TimestampNanosecondType>(
1923 Some("-05:00".into()),
1924 &[18_000_000_000_000, 0, -7_200_000_000_000],
1925 );
1926 test_parse_timestamp_impl::<TimestampMicrosecondType>(
1927 Some("-03".into()),
1928 &[10_800_000_000, 0, -7_200_000_000],
1929 );
1930 test_parse_timestamp_impl::<TimestampMillisecondType>(
1931 Some("-03".into()),
1932 &[10_800_000, 0, -7_200_000],
1933 );
1934 test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
1935 }
1936
1937 #[test]
1938 fn test_infer_schema_from_multiple_files() {
1939 let mut csv1 = NamedTempFile::new().unwrap();
1940 let mut csv2 = NamedTempFile::new().unwrap();
1941 let csv3 = NamedTempFile::new().unwrap(); let mut csv4 = NamedTempFile::new().unwrap();
1943 writeln!(csv1, "c1,c2,c3").unwrap();
1944 writeln!(csv1, "1,\"foo\",0.5").unwrap();
1945 writeln!(csv1, "3,\"bar\",1").unwrap();
1946 writeln!(csv1, "3,\"bar\",2e-06").unwrap();
1947 writeln!(csv2, "c1,c2,c3,c4").unwrap();
1949 writeln!(csv2, "10,,3.14,true").unwrap();
1950 writeln!(csv4, "c1,c2,c3").unwrap();
1952 writeln!(csv4, "10,\"foo\",").unwrap();
1953
1954 let schema = infer_schema_from_files(
1955 &[
1956 csv3.path().to_str().unwrap().to_string(),
1957 csv1.path().to_str().unwrap().to_string(),
1958 csv2.path().to_str().unwrap().to_string(),
1959 csv4.path().to_str().unwrap().to_string(),
1960 ],
1961 b',',
1962 Some(4), true,
1964 )
1965 .unwrap();
1966
1967 assert_eq!(schema.fields().len(), 4);
1968 assert!(schema.field(0).is_nullable());
1969 assert!(schema.field(1).is_nullable());
1970 assert!(schema.field(2).is_nullable());
1971 assert!(schema.field(3).is_nullable());
1972
1973 assert_eq!(&DataType::Int64, schema.field(0).data_type());
1974 assert_eq!(&DataType::Utf8, schema.field(1).data_type());
1975 assert_eq!(&DataType::Float64, schema.field(2).data_type());
1976 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1977 }
1978
1979 #[test]
1980 fn test_bounded() {
1981 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
1982 let data = [
1983 vec!["0"],
1984 vec!["1"],
1985 vec!["2"],
1986 vec!["3"],
1987 vec!["4"],
1988 vec!["5"],
1989 vec!["6"],
1990 ];
1991
1992 let data = data
1993 .iter()
1994 .map(|x| x.join(","))
1995 .collect::<Vec<_>>()
1996 .join("\n");
1997 let data = data.as_bytes();
1998
1999 let reader = std::io::Cursor::new(data);
2000
2001 let mut csv = ReaderBuilder::new(Arc::new(schema))
2002 .with_batch_size(2)
2003 .with_projection(vec![0])
2004 .with_bounds(2, 6)
2005 .build_buffered(reader)
2006 .unwrap();
2007
2008 let batch = csv.next().unwrap().unwrap();
2009 let a = batch.column(0);
2010 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2011 assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2012
2013 let batch = csv.next().unwrap().unwrap();
2014 let a = batch.column(0);
2015 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2016 assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2017
2018 assert!(csv.next().is_none());
2019 }
2020
2021 #[test]
2022 fn test_empty_projection() {
2023 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2024 let data = [vec!["0"], vec!["1"]];
2025
2026 let data = data
2027 .iter()
2028 .map(|x| x.join(","))
2029 .collect::<Vec<_>>()
2030 .join("\n");
2031
2032 let mut csv = ReaderBuilder::new(Arc::new(schema))
2033 .with_batch_size(2)
2034 .with_projection(vec![])
2035 .build_buffered(Cursor::new(data.as_bytes()))
2036 .unwrap();
2037
2038 let batch = csv.next().unwrap().unwrap();
2039 assert_eq!(batch.columns().len(), 0);
2040 assert_eq!(batch.num_rows(), 2);
2041
2042 assert!(csv.next().is_none());
2043 }
2044
2045 #[test]
2046 fn test_parsing_bool() {
2047 assert_eq!(Some(true), parse_bool("true"));
2049 assert_eq!(Some(true), parse_bool("tRUe"));
2050 assert_eq!(Some(true), parse_bool("True"));
2051 assert_eq!(Some(true), parse_bool("TRUE"));
2052 assert_eq!(None, parse_bool("t"));
2053 assert_eq!(None, parse_bool("T"));
2054 assert_eq!(None, parse_bool(""));
2055
2056 assert_eq!(Some(false), parse_bool("false"));
2057 assert_eq!(Some(false), parse_bool("fALse"));
2058 assert_eq!(Some(false), parse_bool("False"));
2059 assert_eq!(Some(false), parse_bool("FALSE"));
2060 assert_eq!(None, parse_bool("f"));
2061 assert_eq!(None, parse_bool("F"));
2062 assert_eq!(None, parse_bool(""));
2063 }
2064
2065 #[test]
2066 fn test_parsing_float() {
2067 assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2068 assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2069 assert_eq!(Some(12.0), Float64Type::parse("12"));
2070 assert_eq!(Some(0.0), Float64Type::parse("0"));
2071 assert_eq!(Some(2.0), Float64Type::parse("2."));
2072 assert_eq!(Some(0.2), Float64Type::parse(".2"));
2073 assert!(Float64Type::parse("nan").unwrap().is_nan());
2074 assert!(Float64Type::parse("NaN").unwrap().is_nan());
2075 assert!(Float64Type::parse("inf").unwrap().is_infinite());
2076 assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2077 assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2078 assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2079 assert_eq!(None, Float64Type::parse(""));
2080 assert_eq!(None, Float64Type::parse("dd"));
2081 assert_eq!(None, Float64Type::parse("12.34.56"));
2082 }
2083
2084 #[test]
2085 fn test_non_std_quote() {
2086 let schema = Schema::new(vec![
2087 Field::new("text1", DataType::Utf8, false),
2088 Field::new("text2", DataType::Utf8, false),
2089 ]);
2090 let builder = ReaderBuilder::new(Arc::new(schema))
2091 .with_header(false)
2092 .with_quote(b'~'); let mut csv_text = Vec::new();
2095 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2096 for index in 0..10 {
2097 let text1 = format!("id{index:}");
2098 let text2 = format!("value{index:}");
2099 csv_writer
2100 .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2101 .unwrap();
2102 }
2103 let mut csv_reader = std::io::Cursor::new(&csv_text);
2104 let mut reader = builder.build(&mut csv_reader).unwrap();
2105 let batch = reader.next().unwrap().unwrap();
2106 let col0 = batch.column(0);
2107 assert_eq!(col0.len(), 10);
2108 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2109 assert_eq!(col0_arr.value(0), "id0");
2110 let col1 = batch.column(1);
2111 assert_eq!(col1.len(), 10);
2112 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2113 assert_eq!(col1_arr.value(5), "value5");
2114 }
2115
2116 #[test]
2117 fn test_non_std_escape() {
2118 let schema = Schema::new(vec![
2119 Field::new("text1", DataType::Utf8, false),
2120 Field::new("text2", DataType::Utf8, false),
2121 ]);
2122 let builder = ReaderBuilder::new(Arc::new(schema))
2123 .with_header(false)
2124 .with_escape(b'\\'); let mut csv_text = Vec::new();
2127 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2128 for index in 0..10 {
2129 let text1 = format!("id{index:}");
2130 let text2 = format!("value\\\"{index:}");
2131 csv_writer
2132 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2133 .unwrap();
2134 }
2135 let mut csv_reader = std::io::Cursor::new(&csv_text);
2136 let mut reader = builder.build(&mut csv_reader).unwrap();
2137 let batch = reader.next().unwrap().unwrap();
2138 let col0 = batch.column(0);
2139 assert_eq!(col0.len(), 10);
2140 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2141 assert_eq!(col0_arr.value(0), "id0");
2142 let col1 = batch.column(1);
2143 assert_eq!(col1.len(), 10);
2144 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2145 assert_eq!(col1_arr.value(5), "value\"5");
2146 }
2147
2148 #[test]
2149 fn test_non_std_terminator() {
2150 let schema = Schema::new(vec![
2151 Field::new("text1", DataType::Utf8, false),
2152 Field::new("text2", DataType::Utf8, false),
2153 ]);
2154 let builder = ReaderBuilder::new(Arc::new(schema))
2155 .with_header(false)
2156 .with_terminator(b'\n'); let mut csv_text = Vec::new();
2159 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2160 for index in 0..10 {
2161 let text1 = format!("id{index:}");
2162 let text2 = format!("value{index:}");
2163 csv_writer
2164 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2165 .unwrap();
2166 }
2167 let mut csv_reader = std::io::Cursor::new(&csv_text);
2168 let mut reader = builder.build(&mut csv_reader).unwrap();
2169 let batch = reader.next().unwrap().unwrap();
2170 let col0 = batch.column(0);
2171 assert_eq!(col0.len(), 10);
2172 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2173 assert_eq!(col0_arr.value(0), "id0");
2174 let col1 = batch.column(1);
2175 assert_eq!(col1.len(), 10);
2176 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2177 assert_eq!(col1_arr.value(5), "value5");
2178 }
2179
2180 #[test]
2181 fn test_header_bounds() {
2182 let csv = "a,b\na,b\na,b\na,b\na,b\n";
2183 let tests = [
2184 (None, false, 5),
2185 (None, true, 4),
2186 (Some((0, 4)), false, 4),
2187 (Some((1, 4)), false, 3),
2188 (Some((0, 4)), true, 4),
2189 (Some((1, 4)), true, 3),
2190 ];
2191 let schema = Arc::new(Schema::new(vec![
2192 Field::new("a", DataType::Utf8, false),
2193 Field::new("a", DataType::Utf8, false),
2194 ]));
2195
2196 for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2197 let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2198 if let Some((start, end)) = bounds {
2199 reader = reader.with_bounds(start, end);
2200 }
2201 let b = reader
2202 .build_buffered(Cursor::new(csv.as_bytes()))
2203 .unwrap()
2204 .next()
2205 .unwrap()
2206 .unwrap();
2207 assert_eq!(b.num_rows(), expected, "{idx}");
2208 }
2209 }
2210
2211 #[test]
2212 fn test_null_boolean() {
2213 let csv = "true,false\nFalse,True\n,True\nFalse,";
2214 let schema = Arc::new(Schema::new(vec![
2215 Field::new("a", DataType::Boolean, true),
2216 Field::new("a", DataType::Boolean, true),
2217 ]));
2218
2219 let b = ReaderBuilder::new(schema)
2220 .build_buffered(Cursor::new(csv.as_bytes()))
2221 .unwrap()
2222 .next()
2223 .unwrap()
2224 .unwrap();
2225
2226 assert_eq!(b.num_rows(), 4);
2227 assert_eq!(b.num_columns(), 2);
2228
2229 let c = b.column(0).as_boolean();
2230 assert_eq!(c.null_count(), 1);
2231 assert!(c.value(0));
2232 assert!(!c.value(1));
2233 assert!(c.is_null(2));
2234 assert!(!c.value(3));
2235
2236 let c = b.column(1).as_boolean();
2237 assert_eq!(c.null_count(), 1);
2238 assert!(!c.value(0));
2239 assert!(c.value(1));
2240 assert!(c.value(2));
2241 assert!(c.is_null(3));
2242 }
2243
2244 #[test]
2245 fn test_truncated_rows() {
2246 let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2247 let schema = Arc::new(Schema::new(vec![
2248 Field::new("a", DataType::Int32, true),
2249 Field::new("b", DataType::Int32, true),
2250 Field::new("c", DataType::Int32, true),
2251 ]));
2252
2253 let reader = ReaderBuilder::new(schema.clone())
2254 .with_header(true)
2255 .with_truncated_rows(true)
2256 .build(Cursor::new(data))
2257 .unwrap();
2258
2259 let batches = reader.collect::<Result<Vec<_>, _>>();
2260 assert!(batches.is_ok());
2261 let batch = batches.unwrap().into_iter().next().unwrap();
2262 assert_eq!(batch.num_rows(), 3);
2264
2265 let reader = ReaderBuilder::new(schema.clone())
2266 .with_header(true)
2267 .with_truncated_rows(false)
2268 .build(Cursor::new(data))
2269 .unwrap();
2270
2271 let batches = reader.collect::<Result<Vec<_>, _>>();
2272 assert!(match batches {
2273 Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2274 _ => false,
2275 });
2276 }
2277
2278 #[test]
2279 fn test_truncated_rows_csv() {
2280 let file = File::open("test/data/truncated_rows.csv").unwrap();
2281 let schema = Arc::new(Schema::new(vec![
2282 Field::new("Name", DataType::Utf8, true),
2283 Field::new("Age", DataType::UInt32, true),
2284 Field::new("Occupation", DataType::Utf8, true),
2285 Field::new("DOB", DataType::Date32, true),
2286 ]));
2287 let reader = ReaderBuilder::new(schema.clone())
2288 .with_header(true)
2289 .with_batch_size(24)
2290 .with_truncated_rows(true);
2291 let csv = reader.build(file).unwrap();
2292 let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2293
2294 assert_eq!(batches.len(), 1);
2295 let batch = &batches[0];
2296 assert_eq!(batch.num_rows(), 6);
2297 assert_eq!(batch.num_columns(), 4);
2298 let name = batch
2299 .column(0)
2300 .as_any()
2301 .downcast_ref::<StringArray>()
2302 .unwrap();
2303 let age = batch
2304 .column(1)
2305 .as_any()
2306 .downcast_ref::<UInt32Array>()
2307 .unwrap();
2308 let occupation = batch
2309 .column(2)
2310 .as_any()
2311 .downcast_ref::<StringArray>()
2312 .unwrap();
2313 let dob = batch
2314 .column(3)
2315 .as_any()
2316 .downcast_ref::<Date32Array>()
2317 .unwrap();
2318
2319 assert_eq!(name.value(0), "A1");
2320 assert_eq!(name.value(1), "B2");
2321 assert!(name.is_null(2));
2322 assert_eq!(name.value(3), "C3");
2323 assert_eq!(name.value(4), "D4");
2324 assert_eq!(name.value(5), "E5");
2325
2326 assert_eq!(age.value(0), 34);
2327 assert_eq!(age.value(1), 29);
2328 assert!(age.is_null(2));
2329 assert_eq!(age.value(3), 45);
2330 assert!(age.is_null(4));
2331 assert_eq!(age.value(5), 31);
2332
2333 assert_eq!(occupation.value(0), "Engineer");
2334 assert_eq!(occupation.value(1), "Doctor");
2335 assert!(occupation.is_null(2));
2336 assert_eq!(occupation.value(3), "Artist");
2337 assert!(occupation.is_null(4));
2338 assert!(occupation.is_null(5));
2339
2340 assert_eq!(dob.value(0), 5675);
2341 assert!(dob.is_null(1));
2342 assert!(dob.is_null(2));
2343 assert_eq!(dob.value(3), -1858);
2344 assert!(dob.is_null(4));
2345 assert!(dob.is_null(5));
2346 }
2347
2348 #[test]
2349 fn test_truncated_rows_not_nullable_error() {
2350 let data = "a,b,c\n1,2,3\n4,5";
2351 let schema = Arc::new(Schema::new(vec![
2352 Field::new("a", DataType::Int32, false),
2353 Field::new("b", DataType::Int32, false),
2354 Field::new("c", DataType::Int32, false),
2355 ]));
2356
2357 let reader = ReaderBuilder::new(schema.clone())
2358 .with_header(true)
2359 .with_truncated_rows(true)
2360 .build(Cursor::new(data))
2361 .unwrap();
2362
2363 let batches = reader.collect::<Result<Vec<_>, _>>();
2364 assert!(match batches {
2365 Err(ArrowError::InvalidArgumentError(e)) =>
2366 e.to_string().contains("contains null values"),
2367 _ => false,
2368 });
2369 }
2370
2371 #[test]
2372 fn test_buffered() {
2373 let tests = [
2374 ("test/data/uk_cities.csv", false, 37),
2375 ("test/data/various_types.csv", true, 7),
2376 ("test/data/decimal_test.csv", false, 10),
2377 ];
2378
2379 for (path, has_header, expected_rows) in tests {
2380 let (schema, _) = Format::default()
2381 .infer_schema(File::open(path).unwrap(), None)
2382 .unwrap();
2383 let schema = Arc::new(schema);
2384
2385 for batch_size in [1, 4] {
2386 for capacity in [1, 3, 7, 100] {
2387 let reader = ReaderBuilder::new(schema.clone())
2388 .with_batch_size(batch_size)
2389 .with_header(has_header)
2390 .build(File::open(path).unwrap())
2391 .unwrap();
2392
2393 let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2394
2395 assert_eq!(
2396 expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2397 expected_rows
2398 );
2399
2400 let buffered =
2401 std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2402
2403 let reader = ReaderBuilder::new(schema.clone())
2404 .with_batch_size(batch_size)
2405 .with_header(has_header)
2406 .build_buffered(buffered)
2407 .unwrap();
2408
2409 let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2410 assert_eq!(expected, actual)
2411 }
2412 }
2413 }
2414 }
2415
2416 fn err_test(csv: &[u8], expected: &str) {
2417 fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2418 let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2419 let b = ReaderBuilder::new(schema)
2420 .with_batch_size(2)
2421 .build_buffered(buffer)
2422 .unwrap();
2423 let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2424 assert_eq!(err, expected)
2425 }
2426
2427 let schema_utf8 = Arc::new(Schema::new(vec![
2428 Field::new("text1", DataType::Utf8, true),
2429 Field::new("text2", DataType::Utf8, true),
2430 ]));
2431 err_test_with_schema(csv, expected, schema_utf8);
2432
2433 let schema_utf8view = Arc::new(Schema::new(vec![
2434 Field::new("text1", DataType::Utf8View, true),
2435 Field::new("text2", DataType::Utf8View, true),
2436 ]));
2437 err_test_with_schema(csv, expected, schema_utf8view);
2438 }
2439
2440 #[test]
2441 fn test_invalid_utf8() {
2442 err_test(
2443 b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2444 "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2445 );
2446
2447 err_test(
2448 b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2449 "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2450 );
2451
2452 err_test(
2453 b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2454 "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2455 );
2456
2457 err_test(
2458 b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2459 "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2460 );
2461 }
2462
2463 struct InstrumentedRead<R> {
2464 r: R,
2465 fill_count: usize,
2466 fill_sizes: Vec<usize>,
2467 }
2468
2469 impl<R> InstrumentedRead<R> {
2470 fn new(r: R) -> Self {
2471 Self {
2472 r,
2473 fill_count: 0,
2474 fill_sizes: vec![],
2475 }
2476 }
2477 }
2478
2479 impl<R: Seek> Seek for InstrumentedRead<R> {
2480 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2481 self.r.seek(pos)
2482 }
2483 }
2484
2485 impl<R: BufRead> Read for InstrumentedRead<R> {
2486 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2487 self.r.read(buf)
2488 }
2489 }
2490
2491 impl<R: BufRead> BufRead for InstrumentedRead<R> {
2492 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2493 self.fill_count += 1;
2494 let buf = self.r.fill_buf()?;
2495 self.fill_sizes.push(buf.len());
2496 Ok(buf)
2497 }
2498
2499 fn consume(&mut self, amt: usize) {
2500 self.r.consume(amt)
2501 }
2502 }
2503
2504 #[test]
2505 fn test_io() {
2506 let schema = Arc::new(Schema::new(vec![
2507 Field::new("a", DataType::Utf8, false),
2508 Field::new("b", DataType::Utf8, false),
2509 ]));
2510 let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2511 let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2512 let reader = ReaderBuilder::new(schema)
2513 .with_batch_size(3)
2514 .build_buffered(&mut read)
2515 .unwrap();
2516
2517 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2518 assert_eq!(batches.len(), 2);
2519 assert_eq!(batches[0].num_rows(), 3);
2520 assert_eq!(batches[1].num_rows(), 1);
2521
2522 assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2528 assert_eq!(read.fill_count, 4);
2529 }
2530
2531 #[test]
2532 fn test_inference() {
2533 let cases: &[(&[&str], DataType)] = &[
2534 (&[], DataType::Null),
2535 (&["false", "12"], DataType::Utf8),
2536 (&["12", "cupcakes"], DataType::Utf8),
2537 (&["12", "12.4"], DataType::Float64),
2538 (&["14050", "24332"], DataType::Int64),
2539 (&["14050.0", "true"], DataType::Utf8),
2540 (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2541 (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2542 (
2543 &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2544 DataType::Timestamp(TimeUnit::Second, None),
2545 ),
2546 (&["2020-03-19", "2020-03-20"], DataType::Date32),
2547 (
2548 &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2549 DataType::Timestamp(TimeUnit::Second, None),
2550 ),
2551 (
2552 &[
2553 "2020-03-19",
2554 "2020-03-19 02:00:00",
2555 "2020-03-19 00:00:00.000",
2556 ],
2557 DataType::Timestamp(TimeUnit::Millisecond, None),
2558 ),
2559 (
2560 &[
2561 "2020-03-19",
2562 "2020-03-19 02:00:00",
2563 "2020-03-19 00:00:00.000000",
2564 ],
2565 DataType::Timestamp(TimeUnit::Microsecond, None),
2566 ),
2567 (
2568 &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2569 DataType::Timestamp(TimeUnit::Second, None),
2570 ),
2571 (
2572 &[
2573 "2020-03-19",
2574 "2020-03-19 02:00:00+02:00",
2575 "2020-03-19 02:00:00Z",
2576 "2020-03-19 02:00:00.12Z",
2577 ],
2578 DataType::Timestamp(TimeUnit::Millisecond, None),
2579 ),
2580 (
2581 &[
2582 "2020-03-19",
2583 "2020-03-19 02:00:00.000000000",
2584 "2020-03-19 00:00:00.000000",
2585 ],
2586 DataType::Timestamp(TimeUnit::Nanosecond, None),
2587 ),
2588 ];
2589
2590 for (values, expected) in cases {
2591 let mut t = InferredDataType::default();
2592 for v in *values {
2593 t.update(v)
2594 }
2595 assert_eq!(&t.get(), expected, "{values:?}")
2596 }
2597 }
2598
2599 #[test]
2600 fn test_record_length_mismatch() {
2601 let csv = "\
2602 a,b,c\n\
2603 1,2,3\n\
2604 4,5\n\
2605 6,7,8";
2606 let mut read = Cursor::new(csv.as_bytes());
2607 let result = Format::default()
2608 .with_header(true)
2609 .infer_schema(&mut read, None);
2610 assert!(result.is_err());
2611 assert_eq!(result.err().unwrap().to_string(), "Csv error: Encountered unequal lengths between records on CSV file. Expected 2 records, found 3 records at line 3");
2613 }
2614
2615 #[test]
2616 fn test_comment() {
2617 let schema = Schema::new(vec![
2618 Field::new("a", DataType::Int8, false),
2619 Field::new("b", DataType::Int8, false),
2620 ]);
2621
2622 let csv = "# comment1 \n1,2\n#comment2\n11,22";
2623 let mut read = Cursor::new(csv.as_bytes());
2624 let reader = ReaderBuilder::new(Arc::new(schema))
2625 .with_comment(b'#')
2626 .build(&mut read)
2627 .unwrap();
2628
2629 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2630 assert_eq!(batches.len(), 1);
2631 let b = batches.first().unwrap();
2632 assert_eq!(b.num_columns(), 2);
2633 assert_eq!(
2634 b.column(0)
2635 .as_any()
2636 .downcast_ref::<Int8Array>()
2637 .unwrap()
2638 .values(),
2639 &vec![1, 11]
2640 );
2641 assert_eq!(
2642 b.column(1)
2643 .as_any()
2644 .downcast_ref::<Int8Array>()
2645 .unwrap()
2646 .values(),
2647 &vec![2, 22]
2648 );
2649 }
2650
2651 #[test]
2652 fn test_parse_string_view_single_column() {
2653 let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2654 let schema = Arc::new(Schema::new(vec![Field::new(
2655 "c1",
2656 DataType::Utf8View,
2657 true,
2658 )]));
2659
2660 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2661
2662 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2663 assert_eq!(decoded, csv.len());
2664 decoder.decode(&[]).unwrap();
2665
2666 let batch = decoder.flush().unwrap().unwrap();
2667 assert_eq!(batch.num_columns(), 1);
2668 assert_eq!(batch.num_rows(), 3);
2669 let col = batch.column(0).as_string_view();
2670 assert_eq!(col.data_type(), &DataType::Utf8View);
2671 assert_eq!(col.value(0), "foo");
2672 assert_eq!(col.value(1), "something_cannot_be_inlined");
2673 assert_eq!(col.value(2), "foobar");
2674 }
2675
2676 #[test]
2677 fn test_parse_string_view_multi_column() {
2678 let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2679 let schema = Arc::new(Schema::new(vec![
2680 Field::new("c1", DataType::Utf8View, true),
2681 Field::new("c2", DataType::Utf8View, true),
2682 ]));
2683
2684 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2685
2686 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2687 assert_eq!(decoded, csv.len());
2688 decoder.decode(&[]).unwrap();
2689
2690 let batch = decoder.flush().unwrap().unwrap();
2691 assert_eq!(batch.num_columns(), 2);
2692 assert_eq!(batch.num_rows(), 3);
2693 let c1 = batch.column(0).as_string_view();
2694 let c2 = batch.column(1).as_string_view();
2695 assert_eq!(c1.data_type(), &DataType::Utf8View);
2696 assert_eq!(c2.data_type(), &DataType::Utf8View);
2697
2698 assert!(!c1.is_null(0));
2699 assert!(c1.is_null(1));
2700 assert!(!c1.is_null(2));
2701 assert_eq!(c1.value(0), "foo");
2702 assert_eq!(c1.value(2), "foobarfoobar");
2703
2704 assert!(c2.is_null(0));
2705 assert!(!c2.is_null(1));
2706 assert!(!c2.is_null(2));
2707 assert_eq!(c2.value(1), "something_cannot_be_inlined");
2708 assert_eq!(c2.value(2), "bar");
2709 }
2710}