arrow_csv/reader/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! CSV Reader
19//!
20//! # Basic Usage
21//!
22//! This CSV reader allows CSV files to be read into the Arrow memory model. Records are
23//! loaded in batches and are then converted from row-based data to columnar data.
24//!
25//! Example:
26//!
27//! ```
28//! # use arrow_schema::*;
29//! # use arrow_csv::{Reader, ReaderBuilder};
30//! # use std::fs::File;
31//! # use std::sync::Arc;
32//!
33//! let schema = Schema::new(vec![
34//!     Field::new("city", DataType::Utf8, false),
35//!     Field::new("lat", DataType::Float64, false),
36//!     Field::new("lng", DataType::Float64, false),
37//! ]);
38//!
39//! let file = File::open("test/data/uk_cities.csv").unwrap();
40//!
41//! let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
42//! let batch = csv.next().unwrap().unwrap();
43//! ```
44//!
45//! # Async Usage
46//!
47//! The lower-level [`Decoder`] can be integrated with various forms of async data streams,
48//! and is designed to be agnostic to the various different kinds of async IO primitives found
49//! within the Rust ecosystem.
50//!
51//! For example, see below for how it can be used with an arbitrary `Stream` of `Bytes`
52//!
53//! ```
54//! # use std::task::{Poll, ready};
55//! # use bytes::{Buf, Bytes};
56//! # use arrow_schema::ArrowError;
57//! # use futures::stream::{Stream, StreamExt};
58//! # use arrow_array::RecordBatch;
59//! # use arrow_csv::reader::Decoder;
60//! #
61//! fn decode_stream<S: Stream<Item = Bytes> + Unpin>(
62//!     mut decoder: Decoder,
63//!     mut input: S,
64//! ) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
65//!     let mut buffered = Bytes::new();
66//!     futures::stream::poll_fn(move |cx| {
67//!         loop {
68//!             if buffered.is_empty() {
69//!                 if let Some(b) = ready!(input.poll_next_unpin(cx)) {
70//!                     buffered = b;
71//!                 }
72//!                 // Note: don't break on `None` as the decoder needs
73//!                 // to be called with an empty array to delimit the
74//!                 // final record
75//!             }
76//!             let decoded = match decoder.decode(buffered.as_ref()) {
77//!                 Ok(0) => break,
78//!                 Ok(decoded) => decoded,
79//!                 Err(e) => return Poll::Ready(Some(Err(e))),
80//!             };
81//!             buffered.advance(decoded);
82//!         }
83//!
84//!         Poll::Ready(decoder.flush().transpose())
85//!     })
86//! }
87//!
88//! ```
89//!
90//! In a similar vein, it can also be used with tokio-based IO primitives
91//!
92//! ```
93//! # use std::pin::Pin;
94//! # use std::task::{Poll, ready};
95//! # use futures::Stream;
96//! # use tokio::io::AsyncBufRead;
97//! # use arrow_array::RecordBatch;
98//! # use arrow_csv::reader::Decoder;
99//! # use arrow_schema::ArrowError;
100//! fn decode_stream<R: AsyncBufRead + Unpin>(
101//!     mut decoder: Decoder,
102//!     mut reader: R,
103//! ) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
104//!     futures::stream::poll_fn(move |cx| {
105//!         loop {
106//!             let b = match ready!(Pin::new(&mut reader).poll_fill_buf(cx)) {
107//!                 Ok(b) => b,
108//!                 Err(e) => return Poll::Ready(Some(Err(e.into()))),
109//!             };
110//!             let decoded = match decoder.decode(b) {
111//!                 // Note: the decoder needs to be called with an empty
112//!                 // array to delimit the final record
113//!                 Ok(0) => break,
114//!                 Ok(decoded) => decoded,
115//!                 Err(e) => return Poll::Ready(Some(Err(e))),
116//!             };
117//!             Pin::new(&mut reader).consume(decoded);
118//!         }
119//!
120//!         Poll::Ready(decoder.flush().transpose())
121//!     })
122//! }
123//! ```
124//!
125
126mod records;
127
128use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
129use arrow_array::types::*;
130use arrow_array::*;
131use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser};
132use arrow_schema::*;
133use chrono::{TimeZone, Utc};
134use csv::StringRecord;
135use lazy_static::lazy_static;
136use regex::{Regex, RegexSet};
137use std::fmt::{self, Debug};
138use std::fs::File;
139use std::io::{BufRead, BufReader as StdBufReader, Read};
140use std::sync::Arc;
141
142use crate::map_csv_error;
143use crate::reader::records::{RecordDecoder, StringRecords};
144use arrow_array::timezone::Tz;
145
146lazy_static! {
147    /// Order should match [`InferredDataType`]
148    static ref REGEX_SET: RegexSet = RegexSet::new([
149        r"(?i)^(true)$|^(false)$(?-i)", //BOOLEAN
150        r"^-?(\d+)$", //INTEGER
151        r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", //DECIMAL
152        r"^\d{4}-\d\d-\d\d$", //DATE32
153        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", //Timestamp(Second)
154        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", //Timestamp(Millisecond)
155        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", //Timestamp(Microsecond)
156        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", //Timestamp(Nanosecond)
157    ]).unwrap();
158}
159
160/// A wrapper over `Option<Regex>` to check if the value is `NULL`.
161#[derive(Debug, Clone, Default)]
162struct NullRegex(Option<Regex>);
163
164impl NullRegex {
165    /// Returns true if the value should be considered as `NULL` according to
166    /// the provided regular expression.
167    #[inline]
168    fn is_null(&self, s: &str) -> bool {
169        match &self.0 {
170            Some(r) => r.is_match(s),
171            None => s.is_empty(),
172        }
173    }
174}
175
176#[derive(Default, Copy, Clone)]
177struct InferredDataType {
178    /// Packed booleans indicating type
179    ///
180    /// 0 - Boolean
181    /// 1 - Integer
182    /// 2 - Float64
183    /// 3 - Date32
184    /// 4 - Timestamp(Second)
185    /// 5 - Timestamp(Millisecond)
186    /// 6 - Timestamp(Microsecond)
187    /// 7 - Timestamp(Nanosecond)
188    /// 8 - Utf8
189    packed: u16,
190}
191
192impl InferredDataType {
193    /// Returns the inferred data type
194    fn get(&self) -> DataType {
195        match self.packed {
196            0 => DataType::Null,
197            1 => DataType::Boolean,
198            2 => DataType::Int64,
199            4 | 6 => DataType::Float64, // Promote Int64 to Float64
200            b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
201                // Promote to highest precision temporal type
202                8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
203                9 => DataType::Timestamp(TimeUnit::Microsecond, None),
204                10 => DataType::Timestamp(TimeUnit::Millisecond, None),
205                11 => DataType::Timestamp(TimeUnit::Second, None),
206                12 => DataType::Date32,
207                _ => unreachable!(),
208            },
209            _ => DataType::Utf8,
210        }
211    }
212
213    /// Updates the [`InferredDataType`] with the given string
214    fn update(&mut self, string: &str) {
215        self.packed |= if string.starts_with('"') {
216            1 << 8 // Utf8
217        } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
218            if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
219                // if overflow i64, fallback to utf8
220                1 << 8
221            } else {
222                1 << m
223            }
224        } else {
225            1 << 8 // Utf8
226        }
227    }
228}
229
230/// The format specification for the CSV file
231#[derive(Debug, Clone, Default)]
232pub struct Format {
233    header: bool,
234    delimiter: Option<u8>,
235    escape: Option<u8>,
236    quote: Option<u8>,
237    terminator: Option<u8>,
238    comment: Option<u8>,
239    null_regex: NullRegex,
240    truncated_rows: bool,
241}
242
243impl Format {
244    /// Specify whether the CSV file has a header, defaults to `false`
245    ///
246    /// When `true`, the first row of the CSV file is treated as a header row
247    pub fn with_header(mut self, has_header: bool) -> Self {
248        self.header = has_header;
249        self
250    }
251
252    /// Specify a custom delimiter character, defaults to comma `','`
253    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
254        self.delimiter = Some(delimiter);
255        self
256    }
257
258    /// Specify an escape character, defaults to `None`
259    pub fn with_escape(mut self, escape: u8) -> Self {
260        self.escape = Some(escape);
261        self
262    }
263
264    /// Specify a custom quote character, defaults to double quote `'"'`
265    pub fn with_quote(mut self, quote: u8) -> Self {
266        self.quote = Some(quote);
267        self
268    }
269
270    /// Specify a custom terminator character, defaults to CRLF
271    pub fn with_terminator(mut self, terminator: u8) -> Self {
272        self.terminator = Some(terminator);
273        self
274    }
275
276    /// Specify a comment character, defaults to `None`
277    ///
278    /// Lines starting with this character will be ignored
279    pub fn with_comment(mut self, comment: u8) -> Self {
280        self.comment = Some(comment);
281        self
282    }
283
284    /// Provide a regex to match null values, defaults to `^$`
285    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
286        self.null_regex = NullRegex(Some(null_regex));
287        self
288    }
289
290    /// Whether to allow truncated rows when parsing.
291    ///
292    /// By default this is set to `false` and will error if the CSV rows have different lengths.
293    /// When set to true then it will allow records with less than the expected number of columns
294    /// and fill the missing columns with nulls. If the record's schema is not nullable, then it
295    /// will still return an error.
296    pub fn with_truncated_rows(mut self, allow: bool) -> Self {
297        self.truncated_rows = allow;
298        self
299    }
300
301    /// Infer schema of CSV records from the provided `reader`
302    ///
303    /// If `max_records` is `None`, all records will be read, otherwise up to `max_records`
304    /// records are read to infer the schema
305    ///
306    /// Returns inferred schema and number of records read
307    pub fn infer_schema<R: Read>(
308        &self,
309        reader: R,
310        max_records: Option<usize>,
311    ) -> Result<(Schema, usize), ArrowError> {
312        let mut csv_reader = self.build_reader(reader);
313
314        // get or create header names
315        // when has_header is false, creates default column names with column_ prefix
316        let headers: Vec<String> = if self.header {
317            let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
318            headers.iter().map(|s| s.to_string()).collect()
319        } else {
320            let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
321            (0..*first_record_count)
322                .map(|i| format!("column_{}", i + 1))
323                .collect()
324        };
325
326        let header_length = headers.len();
327        // keep track of inferred field types
328        let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
329
330        let mut records_count = 0;
331
332        let mut record = StringRecord::new();
333        let max_records = max_records.unwrap_or(usize::MAX);
334        while records_count < max_records {
335            if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
336                break;
337            }
338            records_count += 1;
339
340            // Note since we may be looking at a sample of the data, we make the safe assumption that
341            // they could be nullable
342            for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
343                if let Some(string) = record.get(i) {
344                    if !self.null_regex.is_null(string) {
345                        column_type.update(string)
346                    }
347                }
348            }
349        }
350
351        // build schema from inference results
352        let fields: Fields = column_types
353            .iter()
354            .zip(&headers)
355            .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
356            .collect();
357
358        Ok((Schema::new(fields), records_count))
359    }
360
361    /// Build a [`csv::Reader`] for this [`Format`]
362    fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
363        let mut builder = csv::ReaderBuilder::new();
364        builder.has_headers(self.header);
365        builder.flexible(self.truncated_rows);
366
367        if let Some(c) = self.delimiter {
368            builder.delimiter(c);
369        }
370        builder.escape(self.escape);
371        if let Some(c) = self.quote {
372            builder.quote(c);
373        }
374        if let Some(t) = self.terminator {
375            builder.terminator(csv::Terminator::Any(t));
376        }
377        if let Some(comment) = self.comment {
378            builder.comment(Some(comment));
379        }
380        builder.from_reader(reader)
381    }
382
383    /// Build a [`csv_core::Reader`] for this [`Format`]
384    fn build_parser(&self) -> csv_core::Reader {
385        let mut builder = csv_core::ReaderBuilder::new();
386        builder.escape(self.escape);
387        builder.comment(self.comment);
388
389        if let Some(c) = self.delimiter {
390            builder.delimiter(c);
391        }
392        if let Some(c) = self.quote {
393            builder.quote(c);
394        }
395        if let Some(t) = self.terminator {
396            builder.terminator(csv_core::Terminator::Any(t));
397        }
398        builder.build()
399    }
400}
401
402/// Infer schema from a list of CSV files by reading through first n records
403/// with `max_read_records` controlling the maximum number of records to read.
404///
405/// Files will be read in the given order until n records have been reached.
406///
407/// If `max_read_records` is not set, all files will be read fully to infer the schema.
408pub fn infer_schema_from_files(
409    files: &[String],
410    delimiter: u8,
411    max_read_records: Option<usize>,
412    has_header: bool,
413) -> Result<Schema, ArrowError> {
414    let mut schemas = vec![];
415    let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
416    let format = Format {
417        delimiter: Some(delimiter),
418        header: has_header,
419        ..Default::default()
420    };
421
422    for fname in files.iter() {
423        let f = File::open(fname)?;
424        let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
425        if records_read == 0 {
426            continue;
427        }
428        schemas.push(schema.clone());
429        records_to_read -= records_read;
430        if records_to_read == 0 {
431            break;
432        }
433    }
434
435    Schema::try_merge(schemas)
436}
437
438// optional bounds of the reader, of the form (min line, max line).
439type Bounds = Option<(usize, usize)>;
440
441/// CSV file reader using [`std::io::BufReader`]
442pub type Reader<R> = BufReader<StdBufReader<R>>;
443
444/// CSV file reader
445pub struct BufReader<R> {
446    /// File reader
447    reader: R,
448
449    /// The decoder
450    decoder: Decoder,
451}
452
453impl<R> fmt::Debug for BufReader<R>
454where
455    R: BufRead,
456{
457    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458        f.debug_struct("Reader")
459            .field("decoder", &self.decoder)
460            .finish()
461    }
462}
463
464impl<R: Read> Reader<R> {
465    /// Returns the schema of the reader, useful for getting the schema without reading
466    /// record batches
467    pub fn schema(&self) -> SchemaRef {
468        match &self.decoder.projection {
469            Some(projection) => {
470                let fields = self.decoder.schema.fields();
471                let projected = projection.iter().map(|i| fields[*i].clone());
472                Arc::new(Schema::new(projected.collect::<Fields>()))
473            }
474            None => self.decoder.schema.clone(),
475        }
476    }
477}
478
479impl<R: BufRead> BufReader<R> {
480    fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
481        loop {
482            let buf = self.reader.fill_buf()?;
483            let decoded = self.decoder.decode(buf)?;
484            self.reader.consume(decoded);
485            // Yield if decoded no bytes or the decoder is full
486            //
487            // The capacity check avoids looping around and potentially
488            // blocking reading data in fill_buf that isn't needed
489            // to flush the next batch
490            if decoded == 0 || self.decoder.capacity() == 0 {
491                break;
492            }
493        }
494
495        self.decoder.flush()
496    }
497}
498
499impl<R: BufRead> Iterator for BufReader<R> {
500    type Item = Result<RecordBatch, ArrowError>;
501
502    fn next(&mut self) -> Option<Self::Item> {
503        self.read().transpose()
504    }
505}
506
507impl<R: BufRead> RecordBatchReader for BufReader<R> {
508    fn schema(&self) -> SchemaRef {
509        self.decoder.schema.clone()
510    }
511}
512
513/// A push-based interface for decoding CSV data from an arbitrary byte stream
514///
515/// See [`Reader`] for a higher-level interface for interface with [`Read`]
516///
517/// The push-based interface facilitates integration with sources that yield arbitrarily
518/// delimited bytes ranges, such as [`BufRead`], or a chunked byte stream received from
519/// object storage
520///
521/// ```
522/// # use std::io::BufRead;
523/// # use arrow_array::RecordBatch;
524/// # use arrow_csv::ReaderBuilder;
525/// # use arrow_schema::{ArrowError, SchemaRef};
526/// #
527/// fn read_from_csv<R: BufRead>(
528///     mut reader: R,
529///     schema: SchemaRef,
530///     batch_size: usize,
531/// ) -> Result<impl Iterator<Item = Result<RecordBatch, ArrowError>>, ArrowError> {
532///     let mut decoder = ReaderBuilder::new(schema)
533///         .with_batch_size(batch_size)
534///         .build_decoder();
535///
536///     let mut next = move || {
537///         loop {
538///             let buf = reader.fill_buf()?;
539///             let decoded = decoder.decode(buf)?;
540///             if decoded == 0 {
541///                 break;
542///             }
543///
544///             // Consume the number of bytes read
545///             reader.consume(decoded);
546///         }
547///         decoder.flush()
548///     };
549///     Ok(std::iter::from_fn(move || next().transpose()))
550/// }
551/// ```
552#[derive(Debug)]
553pub struct Decoder {
554    /// Explicit schema for the CSV file
555    schema: SchemaRef,
556
557    /// Optional projection for which columns to load (zero-based column indices)
558    projection: Option<Vec<usize>>,
559
560    /// Number of records per batch
561    batch_size: usize,
562
563    /// Rows to skip
564    to_skip: usize,
565
566    /// Current line number
567    line_number: usize,
568
569    /// End line number
570    end: usize,
571
572    /// A decoder for [`StringRecords`]
573    record_decoder: RecordDecoder,
574
575    /// Check if the string matches this pattern for `NULL`.
576    null_regex: NullRegex,
577}
578
579impl Decoder {
580    /// Decode records from `buf` returning the number of bytes read
581    ///
582    /// This method returns once `batch_size` objects have been parsed since the
583    /// last call to [`Self::flush`], or `buf` is exhausted. Any remaining bytes
584    /// should be included in the next call to [`Self::decode`]
585    ///
586    /// There is no requirement that `buf` contains a whole number of records, facilitating
587    /// integration with arbitrary byte streams, such as that yielded by [`BufRead`] or
588    /// network sources such as object storage
589    pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
590        if self.to_skip != 0 {
591            // Skip in units of `to_read` to avoid over-allocating buffers
592            let to_skip = self.to_skip.min(self.batch_size);
593            let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
594            self.to_skip -= skipped;
595            self.record_decoder.clear();
596            return Ok(bytes);
597        }
598
599        let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
600        let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
601        Ok(bytes)
602    }
603
604    /// Flushes the currently buffered data to a [`RecordBatch`]
605    ///
606    /// This should only be called after [`Self::decode`] has returned `Ok(0)`,
607    /// otherwise may return an error if part way through decoding a record
608    ///
609    /// Returns `Ok(None)` if no buffered data
610    pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
611        if self.record_decoder.is_empty() {
612            return Ok(None);
613        }
614
615        let rows = self.record_decoder.flush()?;
616        let batch = parse(
617            &rows,
618            self.schema.fields(),
619            Some(self.schema.metadata.clone()),
620            self.projection.as_ref(),
621            self.line_number,
622            &self.null_regex,
623        )?;
624        self.line_number += rows.len();
625        Ok(Some(batch))
626    }
627
628    /// Returns the number of records that can be read before requiring a call to [`Self::flush`]
629    pub fn capacity(&self) -> usize {
630        self.batch_size - self.record_decoder.len()
631    }
632}
633
634/// Parses a slice of [`StringRecords`] into a [RecordBatch]
635fn parse(
636    rows: &StringRecords<'_>,
637    fields: &Fields,
638    metadata: Option<std::collections::HashMap<String, String>>,
639    projection: Option<&Vec<usize>>,
640    line_number: usize,
641    null_regex: &NullRegex,
642) -> Result<RecordBatch, ArrowError> {
643    let projection: Vec<usize> = match projection {
644        Some(v) => v.clone(),
645        None => fields.iter().enumerate().map(|(i, _)| i).collect(),
646    };
647
648    let arrays: Result<Vec<ArrayRef>, _> = projection
649        .iter()
650        .map(|i| {
651            let i = *i;
652            let field = &fields[i];
653            match field.data_type() {
654                DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
655                DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
656                    line_number,
657                    rows,
658                    i,
659                    *precision,
660                    *scale,
661                    null_regex,
662                ),
663                DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
664                    line_number,
665                    rows,
666                    i,
667                    *precision,
668                    *scale,
669                    null_regex,
670                ),
671                DataType::Int8 => {
672                    build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
673                }
674                DataType::Int16 => {
675                    build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
676                }
677                DataType::Int32 => {
678                    build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
679                }
680                DataType::Int64 => {
681                    build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
682                }
683                DataType::UInt8 => {
684                    build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
685                }
686                DataType::UInt16 => {
687                    build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
688                }
689                DataType::UInt32 => {
690                    build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
691                }
692                DataType::UInt64 => {
693                    build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
694                }
695                DataType::Float32 => {
696                    build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
697                }
698                DataType::Float64 => {
699                    build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
700                }
701                DataType::Date32 => {
702                    build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
703                }
704                DataType::Date64 => {
705                    build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
706                }
707                DataType::Time32(TimeUnit::Second) => {
708                    build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
709                }
710                DataType::Time32(TimeUnit::Millisecond) => {
711                    build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
712                }
713                DataType::Time64(TimeUnit::Microsecond) => {
714                    build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
715                }
716                DataType::Time64(TimeUnit::Nanosecond) => {
717                    build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
718                }
719                DataType::Timestamp(TimeUnit::Second, tz) => {
720                    build_timestamp_array::<TimestampSecondType>(
721                        line_number,
722                        rows,
723                        i,
724                        tz.as_deref(),
725                        null_regex,
726                    )
727                }
728                DataType::Timestamp(TimeUnit::Millisecond, tz) => {
729                    build_timestamp_array::<TimestampMillisecondType>(
730                        line_number,
731                        rows,
732                        i,
733                        tz.as_deref(),
734                        null_regex,
735                    )
736                }
737                DataType::Timestamp(TimeUnit::Microsecond, tz) => {
738                    build_timestamp_array::<TimestampMicrosecondType>(
739                        line_number,
740                        rows,
741                        i,
742                        tz.as_deref(),
743                        null_regex,
744                    )
745                }
746                DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
747                    build_timestamp_array::<TimestampNanosecondType>(
748                        line_number,
749                        rows,
750                        i,
751                        tz.as_deref(),
752                        null_regex,
753                    )
754                }
755                DataType::Null => Ok(Arc::new({
756                    let mut builder = NullBuilder::new();
757                    builder.append_nulls(rows.len());
758                    builder.finish()
759                }) as ArrayRef),
760                DataType::Utf8 => Ok(Arc::new(
761                    rows.iter()
762                        .map(|row| {
763                            let s = row.get(i);
764                            (!null_regex.is_null(s)).then_some(s)
765                        })
766                        .collect::<StringArray>(),
767                ) as ArrayRef),
768                DataType::Utf8View => Ok(Arc::new(
769                    rows.iter()
770                        .map(|row| {
771                            let s = row.get(i);
772                            (!null_regex.is_null(s)).then_some(s)
773                        })
774                        .collect::<StringViewArray>(),
775                ) as ArrayRef),
776                DataType::Dictionary(key_type, value_type)
777                    if value_type.as_ref() == &DataType::Utf8 =>
778                {
779                    match key_type.as_ref() {
780                        DataType::Int8 => Ok(Arc::new(
781                            rows.iter()
782                                .map(|row| {
783                                    let s = row.get(i);
784                                    (!null_regex.is_null(s)).then_some(s)
785                                })
786                                .collect::<DictionaryArray<Int8Type>>(),
787                        ) as ArrayRef),
788                        DataType::Int16 => Ok(Arc::new(
789                            rows.iter()
790                                .map(|row| {
791                                    let s = row.get(i);
792                                    (!null_regex.is_null(s)).then_some(s)
793                                })
794                                .collect::<DictionaryArray<Int16Type>>(),
795                        ) as ArrayRef),
796                        DataType::Int32 => Ok(Arc::new(
797                            rows.iter()
798                                .map(|row| {
799                                    let s = row.get(i);
800                                    (!null_regex.is_null(s)).then_some(s)
801                                })
802                                .collect::<DictionaryArray<Int32Type>>(),
803                        ) as ArrayRef),
804                        DataType::Int64 => Ok(Arc::new(
805                            rows.iter()
806                                .map(|row| {
807                                    let s = row.get(i);
808                                    (!null_regex.is_null(s)).then_some(s)
809                                })
810                                .collect::<DictionaryArray<Int64Type>>(),
811                        ) as ArrayRef),
812                        DataType::UInt8 => Ok(Arc::new(
813                            rows.iter()
814                                .map(|row| {
815                                    let s = row.get(i);
816                                    (!null_regex.is_null(s)).then_some(s)
817                                })
818                                .collect::<DictionaryArray<UInt8Type>>(),
819                        ) as ArrayRef),
820                        DataType::UInt16 => Ok(Arc::new(
821                            rows.iter()
822                                .map(|row| {
823                                    let s = row.get(i);
824                                    (!null_regex.is_null(s)).then_some(s)
825                                })
826                                .collect::<DictionaryArray<UInt16Type>>(),
827                        ) as ArrayRef),
828                        DataType::UInt32 => Ok(Arc::new(
829                            rows.iter()
830                                .map(|row| {
831                                    let s = row.get(i);
832                                    (!null_regex.is_null(s)).then_some(s)
833                                })
834                                .collect::<DictionaryArray<UInt32Type>>(),
835                        ) as ArrayRef),
836                        DataType::UInt64 => Ok(Arc::new(
837                            rows.iter()
838                                .map(|row| {
839                                    let s = row.get(i);
840                                    (!null_regex.is_null(s)).then_some(s)
841                                })
842                                .collect::<DictionaryArray<UInt64Type>>(),
843                        ) as ArrayRef),
844                        _ => Err(ArrowError::ParseError(format!(
845                            "Unsupported dictionary key type {key_type:?}"
846                        ))),
847                    }
848                }
849                other => Err(ArrowError::ParseError(format!(
850                    "Unsupported data type {other:?}"
851                ))),
852            }
853        })
854        .collect();
855
856    let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
857
858    let projected_schema = Arc::new(match metadata {
859        None => Schema::new(projected_fields),
860        Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
861    });
862
863    arrays.and_then(|arr| {
864        RecordBatch::try_new_with_options(
865            projected_schema,
866            arr,
867            &RecordBatchOptions::new()
868                .with_match_field_names(true)
869                .with_row_count(Some(rows.len())),
870        )
871    })
872}
873
874fn parse_bool(string: &str) -> Option<bool> {
875    if string.eq_ignore_ascii_case("false") {
876        Some(false)
877    } else if string.eq_ignore_ascii_case("true") {
878        Some(true)
879    } else {
880        None
881    }
882}
883
884// parse the column string to an Arrow Array
885fn build_decimal_array<T: DecimalType>(
886    _line_number: usize,
887    rows: &StringRecords<'_>,
888    col_idx: usize,
889    precision: u8,
890    scale: i8,
891    null_regex: &NullRegex,
892) -> Result<ArrayRef, ArrowError> {
893    let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
894    for row in rows.iter() {
895        let s = row.get(col_idx);
896        if null_regex.is_null(s) {
897            // append null
898            decimal_builder.append_null();
899        } else {
900            let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
901            match decimal_value {
902                Ok(v) => {
903                    decimal_builder.append_value(v);
904                }
905                Err(e) => {
906                    return Err(e);
907                }
908            }
909        }
910    }
911    Ok(Arc::new(
912        decimal_builder
913            .finish()
914            .with_precision_and_scale(precision, scale)?,
915    ))
916}
917
918// parses a specific column (col_idx) into an Arrow Array.
919fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
920    line_number: usize,
921    rows: &StringRecords<'_>,
922    col_idx: usize,
923    null_regex: &NullRegex,
924) -> Result<ArrayRef, ArrowError> {
925    rows.iter()
926        .enumerate()
927        .map(|(row_index, row)| {
928            let s = row.get(col_idx);
929            if null_regex.is_null(s) {
930                return Ok(None);
931            }
932
933            match T::parse(s) {
934                Some(e) => Ok(Some(e)),
935                None => Err(ArrowError::ParseError(format!(
936                    // TODO: we should surface the underlying error here.
937                    "Error while parsing value {} for column {} at line {}",
938                    s,
939                    col_idx,
940                    line_number + row_index
941                ))),
942            }
943        })
944        .collect::<Result<PrimitiveArray<T>, ArrowError>>()
945        .map(|e| Arc::new(e) as ArrayRef)
946}
947
948fn build_timestamp_array<T: ArrowTimestampType>(
949    line_number: usize,
950    rows: &StringRecords<'_>,
951    col_idx: usize,
952    timezone: Option<&str>,
953    null_regex: &NullRegex,
954) -> Result<ArrayRef, ArrowError> {
955    Ok(Arc::new(match timezone {
956        Some(timezone) => {
957            let tz: Tz = timezone.parse()?;
958            build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
959                .with_timezone(timezone)
960        }
961        None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
962    }))
963}
964
965fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
966    line_number: usize,
967    rows: &StringRecords<'_>,
968    col_idx: usize,
969    timezone: &Tz,
970    null_regex: &NullRegex,
971) -> Result<PrimitiveArray<T>, ArrowError> {
972    rows.iter()
973        .enumerate()
974        .map(|(row_index, row)| {
975            let s = row.get(col_idx);
976            if null_regex.is_null(s) {
977                return Ok(None);
978            }
979
980            let date = string_to_datetime(timezone, s)
981                .and_then(|date| match T::UNIT {
982                    TimeUnit::Second => Ok(date.timestamp()),
983                    TimeUnit::Millisecond => Ok(date.timestamp_millis()),
984                    TimeUnit::Microsecond => Ok(date.timestamp_micros()),
985                    TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
986                        ArrowError::ParseError(format!(
987                            "{} would overflow 64-bit signed nanoseconds",
988                            date.to_rfc3339(),
989                        ))
990                    }),
991                })
992                .map_err(|e| {
993                    ArrowError::ParseError(format!(
994                        "Error parsing column {col_idx} at line {}: {}",
995                        line_number + row_index,
996                        e
997                    ))
998                })?;
999            Ok(Some(date))
1000        })
1001        .collect()
1002}
1003
1004// parses a specific column (col_idx) into an Arrow Array.
1005fn build_boolean_array(
1006    line_number: usize,
1007    rows: &StringRecords<'_>,
1008    col_idx: usize,
1009    null_regex: &NullRegex,
1010) -> Result<ArrayRef, ArrowError> {
1011    rows.iter()
1012        .enumerate()
1013        .map(|(row_index, row)| {
1014            let s = row.get(col_idx);
1015            if null_regex.is_null(s) {
1016                return Ok(None);
1017            }
1018            let parsed = parse_bool(s);
1019            match parsed {
1020                Some(e) => Ok(Some(e)),
1021                None => Err(ArrowError::ParseError(format!(
1022                    // TODO: we should surface the underlying error here.
1023                    "Error while parsing value {} for column {} at line {}",
1024                    s,
1025                    col_idx,
1026                    line_number + row_index
1027                ))),
1028            }
1029        })
1030        .collect::<Result<BooleanArray, _>>()
1031        .map(|e| Arc::new(e) as ArrayRef)
1032}
1033
1034/// CSV file reader builder
1035#[derive(Debug)]
1036pub struct ReaderBuilder {
1037    /// Schema of the CSV file
1038    schema: SchemaRef,
1039    /// Format of the CSV file
1040    format: Format,
1041    /// Batch size (number of records to load each time)
1042    ///
1043    /// The default batch size when using the `ReaderBuilder` is 1024 records
1044    batch_size: usize,
1045    /// The bounds over which to scan the reader. `None` starts from 0 and runs until EOF.
1046    bounds: Bounds,
1047    /// Optional projection for which columns to load (zero-based column indices)
1048    projection: Option<Vec<usize>>,
1049}
1050
1051impl ReaderBuilder {
1052    /// Create a new builder for configuring CSV parsing options.
1053    ///
1054    /// To convert a builder into a reader, call `ReaderBuilder::build`
1055    ///
1056    /// # Example
1057    ///
1058    /// ```
1059    /// # use arrow_csv::{Reader, ReaderBuilder};
1060    /// # use std::fs::File;
1061    /// # use std::io::Seek;
1062    /// # use std::sync::Arc;
1063    /// # use arrow_csv::reader::Format;
1064    /// #
1065    /// let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1066    /// // Infer the schema with the first 100 records
1067    /// let (schema, _) = Format::default().infer_schema(&mut file, Some(100)).unwrap();
1068    /// file.rewind().unwrap();
1069    ///
1070    /// // create a builder
1071    /// ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1072    /// ```
1073    pub fn new(schema: SchemaRef) -> ReaderBuilder {
1074        Self {
1075            schema,
1076            format: Format::default(),
1077            batch_size: 1024,
1078            bounds: None,
1079            projection: None,
1080        }
1081    }
1082
1083    /// Set whether the CSV file has a header
1084    pub fn with_header(mut self, has_header: bool) -> Self {
1085        self.format.header = has_header;
1086        self
1087    }
1088
1089    /// Overrides the [Format] of this [ReaderBuilder]
1090    pub fn with_format(mut self, format: Format) -> Self {
1091        self.format = format;
1092        self
1093    }
1094
1095    /// Set the CSV file's column delimiter as a byte character
1096    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1097        self.format.delimiter = Some(delimiter);
1098        self
1099    }
1100
1101    /// Set the given character as the CSV file's escape character
1102    pub fn with_escape(mut self, escape: u8) -> Self {
1103        self.format.escape = Some(escape);
1104        self
1105    }
1106
1107    /// Set the given character as the CSV file's quote character, by default it is double quote
1108    pub fn with_quote(mut self, quote: u8) -> Self {
1109        self.format.quote = Some(quote);
1110        self
1111    }
1112
1113    /// Provide a custom terminator character, defaults to CRLF
1114    pub fn with_terminator(mut self, terminator: u8) -> Self {
1115        self.format.terminator = Some(terminator);
1116        self
1117    }
1118
1119    /// Provide a comment character, lines starting with this character will be ignored
1120    pub fn with_comment(mut self, comment: u8) -> Self {
1121        self.format.comment = Some(comment);
1122        self
1123    }
1124
1125    /// Provide a regex to match null values, defaults to `^$`
1126    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1127        self.format.null_regex = NullRegex(Some(null_regex));
1128        self
1129    }
1130
1131    /// Set the batch size (number of records to load at one time)
1132    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1133        self.batch_size = batch_size;
1134        self
1135    }
1136
1137    /// Set the bounds over which to scan the reader.
1138    /// `start` and `end` are line numbers.
1139    pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1140        self.bounds = Some((start, end));
1141        self
1142    }
1143
1144    /// Set the reader's column projection
1145    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1146        self.projection = Some(projection);
1147        self
1148    }
1149
1150    /// Whether to allow truncated rows when parsing.
1151    ///
1152    /// By default this is set to `false` and will error if the CSV rows have different lengths.
1153    /// When set to true then it will allow records with less than the expected number of columns
1154    /// and fill the missing columns with nulls. If the record's schema is not nullable, then it
1155    /// will still return an error.
1156    pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1157        self.format.truncated_rows = allow;
1158        self
1159    }
1160
1161    /// Create a new `Reader` from a non-buffered reader
1162    ///
1163    /// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional
1164    /// buffering, as internally this method wraps `reader` in [`std::io::BufReader`]
1165    pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1166        self.build_buffered(StdBufReader::new(reader))
1167    }
1168
1169    /// Create a new `BufReader` from a buffered reader
1170    pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1171        Ok(BufReader {
1172            reader,
1173            decoder: self.build_decoder(),
1174        })
1175    }
1176
1177    /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream
1178    pub fn build_decoder(self) -> Decoder {
1179        let delimiter = self.format.build_parser();
1180        let record_decoder = RecordDecoder::new(
1181            delimiter,
1182            self.schema.fields().len(),
1183            self.format.truncated_rows,
1184        );
1185
1186        let header = self.format.header as usize;
1187
1188        let (start, end) = match self.bounds {
1189            Some((start, end)) => (start + header, end + header),
1190            None => (header, usize::MAX),
1191        };
1192
1193        Decoder {
1194            schema: self.schema,
1195            to_skip: start,
1196            record_decoder,
1197            line_number: start,
1198            end,
1199            projection: self.projection,
1200            batch_size: self.batch_size,
1201            null_regex: self.format.null_regex,
1202        }
1203    }
1204}
1205
1206#[cfg(test)]
1207mod tests {
1208    use super::*;
1209
1210    use std::io::{Cursor, Seek, SeekFrom, Write};
1211    use tempfile::NamedTempFile;
1212
1213    use arrow_array::cast::AsArray;
1214
1215    #[test]
1216    fn test_csv() {
1217        let schema = Arc::new(Schema::new(vec![
1218            Field::new("city", DataType::Utf8, false),
1219            Field::new("lat", DataType::Float64, false),
1220            Field::new("lng", DataType::Float64, false),
1221        ]));
1222
1223        let file = File::open("test/data/uk_cities.csv").unwrap();
1224        let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1225        assert_eq!(schema, csv.schema());
1226        let batch = csv.next().unwrap().unwrap();
1227        assert_eq!(37, batch.num_rows());
1228        assert_eq!(3, batch.num_columns());
1229
1230        // access data from a primitive array
1231        let lat = batch.column(1).as_primitive::<Float64Type>();
1232        assert_eq!(57.653484, lat.value(0));
1233
1234        // access data from a string array (ListArray<u8>)
1235        let city = batch.column(0).as_string::<i32>();
1236
1237        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1238    }
1239
1240    #[test]
1241    fn test_csv_schema_metadata() {
1242        let mut metadata = std::collections::HashMap::new();
1243        metadata.insert("foo".to_owned(), "bar".to_owned());
1244        let schema = Arc::new(Schema::new_with_metadata(
1245            vec![
1246                Field::new("city", DataType::Utf8, false),
1247                Field::new("lat", DataType::Float64, false),
1248                Field::new("lng", DataType::Float64, false),
1249            ],
1250            metadata.clone(),
1251        ));
1252
1253        let file = File::open("test/data/uk_cities.csv").unwrap();
1254
1255        let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1256        assert_eq!(schema, csv.schema());
1257        let batch = csv.next().unwrap().unwrap();
1258        assert_eq!(37, batch.num_rows());
1259        assert_eq!(3, batch.num_columns());
1260
1261        assert_eq!(&metadata, batch.schema().metadata());
1262    }
1263
1264    #[test]
1265    fn test_csv_reader_with_decimal() {
1266        let schema = Arc::new(Schema::new(vec![
1267            Field::new("city", DataType::Utf8, false),
1268            Field::new("lat", DataType::Decimal128(38, 6), false),
1269            Field::new("lng", DataType::Decimal256(76, 6), false),
1270        ]));
1271
1272        let file = File::open("test/data/decimal_test.csv").unwrap();
1273
1274        let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1275        let batch = csv.next().unwrap().unwrap();
1276        // access data from a primitive array
1277        let lat = batch
1278            .column(1)
1279            .as_any()
1280            .downcast_ref::<Decimal128Array>()
1281            .unwrap();
1282
1283        assert_eq!("57.653484", lat.value_as_string(0));
1284        assert_eq!("53.002666", lat.value_as_string(1));
1285        assert_eq!("52.412811", lat.value_as_string(2));
1286        assert_eq!("51.481583", lat.value_as_string(3));
1287        assert_eq!("12.123456", lat.value_as_string(4));
1288        assert_eq!("50.760000", lat.value_as_string(5));
1289        assert_eq!("0.123000", lat.value_as_string(6));
1290        assert_eq!("123.000000", lat.value_as_string(7));
1291        assert_eq!("123.000000", lat.value_as_string(8));
1292        assert_eq!("-50.760000", lat.value_as_string(9));
1293
1294        let lng = batch
1295            .column(2)
1296            .as_any()
1297            .downcast_ref::<Decimal256Array>()
1298            .unwrap();
1299
1300        assert_eq!("-3.335724", lng.value_as_string(0));
1301        assert_eq!("-2.179404", lng.value_as_string(1));
1302        assert_eq!("-1.778197", lng.value_as_string(2));
1303        assert_eq!("-3.179090", lng.value_as_string(3));
1304        assert_eq!("-3.179090", lng.value_as_string(4));
1305        assert_eq!("0.290472", lng.value_as_string(5));
1306        assert_eq!("0.290472", lng.value_as_string(6));
1307        assert_eq!("0.290472", lng.value_as_string(7));
1308        assert_eq!("0.290472", lng.value_as_string(8));
1309        assert_eq!("0.290472", lng.value_as_string(9));
1310    }
1311
1312    #[test]
1313    fn test_csv_from_buf_reader() {
1314        let schema = Schema::new(vec![
1315            Field::new("city", DataType::Utf8, false),
1316            Field::new("lat", DataType::Float64, false),
1317            Field::new("lng", DataType::Float64, false),
1318        ]);
1319
1320        let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1321        let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1322        let both_files = file_with_headers
1323            .chain(Cursor::new("\n".to_string()))
1324            .chain(file_without_headers);
1325        let mut csv = ReaderBuilder::new(Arc::new(schema))
1326            .with_header(true)
1327            .build(both_files)
1328            .unwrap();
1329        let batch = csv.next().unwrap().unwrap();
1330        assert_eq!(74, batch.num_rows());
1331        assert_eq!(3, batch.num_columns());
1332    }
1333
1334    #[test]
1335    fn test_csv_with_schema_inference() {
1336        let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1337
1338        let (schema, _) = Format::default()
1339            .with_header(true)
1340            .infer_schema(&mut file, None)
1341            .unwrap();
1342
1343        file.rewind().unwrap();
1344        let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1345
1346        let mut csv = builder.build(file).unwrap();
1347        let expected_schema = Schema::new(vec![
1348            Field::new("city", DataType::Utf8, true),
1349            Field::new("lat", DataType::Float64, true),
1350            Field::new("lng", DataType::Float64, true),
1351        ]);
1352        assert_eq!(Arc::new(expected_schema), csv.schema());
1353        let batch = csv.next().unwrap().unwrap();
1354        assert_eq!(37, batch.num_rows());
1355        assert_eq!(3, batch.num_columns());
1356
1357        // access data from a primitive array
1358        let lat = batch
1359            .column(1)
1360            .as_any()
1361            .downcast_ref::<Float64Array>()
1362            .unwrap();
1363        assert_eq!(57.653484, lat.value(0));
1364
1365        // access data from a string array (ListArray<u8>)
1366        let city = batch
1367            .column(0)
1368            .as_any()
1369            .downcast_ref::<StringArray>()
1370            .unwrap();
1371
1372        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1373    }
1374
1375    #[test]
1376    fn test_csv_with_schema_inference_no_headers() {
1377        let mut file = File::open("test/data/uk_cities.csv").unwrap();
1378
1379        let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1380        file.rewind().unwrap();
1381
1382        let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1383
1384        // csv field names should be 'column_{number}'
1385        let schema = csv.schema();
1386        assert_eq!("column_1", schema.field(0).name());
1387        assert_eq!("column_2", schema.field(1).name());
1388        assert_eq!("column_3", schema.field(2).name());
1389        let batch = csv.next().unwrap().unwrap();
1390        let batch_schema = batch.schema();
1391
1392        assert_eq!(schema, batch_schema);
1393        assert_eq!(37, batch.num_rows());
1394        assert_eq!(3, batch.num_columns());
1395
1396        // access data from a primitive array
1397        let lat = batch
1398            .column(1)
1399            .as_any()
1400            .downcast_ref::<Float64Array>()
1401            .unwrap();
1402        assert_eq!(57.653484, lat.value(0));
1403
1404        // access data from a string array (ListArray<u8>)
1405        let city = batch
1406            .column(0)
1407            .as_any()
1408            .downcast_ref::<StringArray>()
1409            .unwrap();
1410
1411        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1412    }
1413
1414    #[test]
1415    fn test_csv_builder_with_bounds() {
1416        let mut file = File::open("test/data/uk_cities.csv").unwrap();
1417
1418        // Set the bounds to the lines 0, 1 and 2.
1419        let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1420        file.rewind().unwrap();
1421        let mut csv = ReaderBuilder::new(Arc::new(schema))
1422            .with_bounds(0, 2)
1423            .build(file)
1424            .unwrap();
1425        let batch = csv.next().unwrap().unwrap();
1426
1427        // access data from a string array (ListArray<u8>)
1428        let city = batch
1429            .column(0)
1430            .as_any()
1431            .downcast_ref::<StringArray>()
1432            .unwrap();
1433
1434        // The value on line 0 is within the bounds
1435        assert_eq!("Elgin, Scotland, the UK", city.value(0));
1436
1437        // The value on line 13 is outside of the bounds. Therefore
1438        // the call to .value() will panic.
1439        let result = std::panic::catch_unwind(|| city.value(13));
1440        assert!(result.is_err());
1441    }
1442
1443    #[test]
1444    fn test_csv_with_projection() {
1445        let schema = Arc::new(Schema::new(vec![
1446            Field::new("city", DataType::Utf8, false),
1447            Field::new("lat", DataType::Float64, false),
1448            Field::new("lng", DataType::Float64, false),
1449        ]));
1450
1451        let file = File::open("test/data/uk_cities.csv").unwrap();
1452
1453        let mut csv = ReaderBuilder::new(schema)
1454            .with_projection(vec![0, 1])
1455            .build(file)
1456            .unwrap();
1457
1458        let projected_schema = Arc::new(Schema::new(vec![
1459            Field::new("city", DataType::Utf8, false),
1460            Field::new("lat", DataType::Float64, false),
1461        ]));
1462        assert_eq!(projected_schema, csv.schema());
1463        let batch = csv.next().unwrap().unwrap();
1464        assert_eq!(projected_schema, batch.schema());
1465        assert_eq!(37, batch.num_rows());
1466        assert_eq!(2, batch.num_columns());
1467    }
1468
1469    #[test]
1470    fn test_csv_with_dictionary() {
1471        let schema = Arc::new(Schema::new(vec![
1472            Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1473            Field::new("lat", DataType::Float64, false),
1474            Field::new("lng", DataType::Float64, false),
1475        ]));
1476
1477        let file = File::open("test/data/uk_cities.csv").unwrap();
1478
1479        let mut csv = ReaderBuilder::new(schema)
1480            .with_projection(vec![0, 1])
1481            .build(file)
1482            .unwrap();
1483
1484        let projected_schema = Arc::new(Schema::new(vec![
1485            Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1486            Field::new("lat", DataType::Float64, false),
1487        ]));
1488        assert_eq!(projected_schema, csv.schema());
1489        let batch = csv.next().unwrap().unwrap();
1490        assert_eq!(projected_schema, batch.schema());
1491        assert_eq!(37, batch.num_rows());
1492        assert_eq!(2, batch.num_columns());
1493
1494        let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1495        let strings = strings.as_string::<i32>();
1496
1497        assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1498        assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1499        assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1500    }
1501
1502    #[test]
1503    fn test_csv_with_nullable_dictionary() {
1504        let offset_type = vec![
1505            DataType::Int8,
1506            DataType::Int16,
1507            DataType::Int32,
1508            DataType::Int64,
1509            DataType::UInt8,
1510            DataType::UInt16,
1511            DataType::UInt32,
1512            DataType::UInt64,
1513        ];
1514        for data_type in offset_type {
1515            let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1516            let dictionary_type =
1517                DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1518            let schema = Arc::new(Schema::new(vec![
1519                Field::new("id", DataType::Utf8, false),
1520                Field::new("name", dictionary_type.clone(), true),
1521            ]));
1522
1523            let mut csv = ReaderBuilder::new(schema)
1524                .build(file.try_clone().unwrap())
1525                .unwrap();
1526
1527            let batch = csv.next().unwrap().unwrap();
1528            assert_eq!(3, batch.num_rows());
1529            assert_eq!(2, batch.num_columns());
1530
1531            let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1532            assert!(!names.is_null(2));
1533            assert!(names.is_null(1));
1534        }
1535    }
1536    #[test]
1537    fn test_nulls() {
1538        let schema = Arc::new(Schema::new(vec![
1539            Field::new("c_int", DataType::UInt64, false),
1540            Field::new("c_float", DataType::Float32, true),
1541            Field::new("c_string", DataType::Utf8, true),
1542            Field::new("c_bool", DataType::Boolean, false),
1543        ]));
1544
1545        let file = File::open("test/data/null_test.csv").unwrap();
1546
1547        let mut csv = ReaderBuilder::new(schema)
1548            .with_header(true)
1549            .build(file)
1550            .unwrap();
1551
1552        let batch = csv.next().unwrap().unwrap();
1553
1554        assert!(!batch.column(1).is_null(0));
1555        assert!(!batch.column(1).is_null(1));
1556        assert!(batch.column(1).is_null(2));
1557        assert!(!batch.column(1).is_null(3));
1558        assert!(!batch.column(1).is_null(4));
1559    }
1560
1561    #[test]
1562    fn test_init_nulls() {
1563        let schema = Arc::new(Schema::new(vec![
1564            Field::new("c_int", DataType::UInt64, true),
1565            Field::new("c_float", DataType::Float32, true),
1566            Field::new("c_string", DataType::Utf8, true),
1567            Field::new("c_bool", DataType::Boolean, true),
1568            Field::new("c_null", DataType::Null, true),
1569        ]));
1570        let file = File::open("test/data/init_null_test.csv").unwrap();
1571
1572        let mut csv = ReaderBuilder::new(schema)
1573            .with_header(true)
1574            .build(file)
1575            .unwrap();
1576
1577        let batch = csv.next().unwrap().unwrap();
1578
1579        assert!(batch.column(1).is_null(0));
1580        assert!(!batch.column(1).is_null(1));
1581        assert!(batch.column(1).is_null(2));
1582        assert!(!batch.column(1).is_null(3));
1583        assert!(!batch.column(1).is_null(4));
1584    }
1585
1586    #[test]
1587    fn test_init_nulls_with_inference() {
1588        let format = Format::default().with_header(true).with_delimiter(b',');
1589
1590        let mut file = File::open("test/data/init_null_test.csv").unwrap();
1591        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1592        file.rewind().unwrap();
1593
1594        let expected_schema = Schema::new(vec![
1595            Field::new("c_int", DataType::Int64, true),
1596            Field::new("c_float", DataType::Float64, true),
1597            Field::new("c_string", DataType::Utf8, true),
1598            Field::new("c_bool", DataType::Boolean, true),
1599            Field::new("c_null", DataType::Null, true),
1600        ]);
1601        assert_eq!(schema, expected_schema);
1602
1603        let mut csv = ReaderBuilder::new(Arc::new(schema))
1604            .with_format(format)
1605            .build(file)
1606            .unwrap();
1607
1608        let batch = csv.next().unwrap().unwrap();
1609
1610        assert!(batch.column(1).is_null(0));
1611        assert!(!batch.column(1).is_null(1));
1612        assert!(batch.column(1).is_null(2));
1613        assert!(!batch.column(1).is_null(3));
1614        assert!(!batch.column(1).is_null(4));
1615    }
1616
1617    #[test]
1618    fn test_custom_nulls() {
1619        let schema = Arc::new(Schema::new(vec![
1620            Field::new("c_int", DataType::UInt64, true),
1621            Field::new("c_float", DataType::Float32, true),
1622            Field::new("c_string", DataType::Utf8, true),
1623            Field::new("c_bool", DataType::Boolean, true),
1624        ]));
1625
1626        let file = File::open("test/data/custom_null_test.csv").unwrap();
1627
1628        let null_regex = Regex::new("^nil$").unwrap();
1629
1630        let mut csv = ReaderBuilder::new(schema)
1631            .with_header(true)
1632            .with_null_regex(null_regex)
1633            .build(file)
1634            .unwrap();
1635
1636        let batch = csv.next().unwrap().unwrap();
1637
1638        // "nil"s should be NULL
1639        assert!(batch.column(0).is_null(1));
1640        assert!(batch.column(1).is_null(2));
1641        assert!(batch.column(3).is_null(4));
1642        assert!(batch.column(2).is_null(3));
1643        assert!(!batch.column(2).is_null(4));
1644    }
1645
1646    #[test]
1647    fn test_nulls_with_inference() {
1648        let mut file = File::open("test/data/various_types.csv").unwrap();
1649        let format = Format::default().with_header(true).with_delimiter(b'|');
1650
1651        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1652        file.rewind().unwrap();
1653
1654        let builder = ReaderBuilder::new(Arc::new(schema))
1655            .with_format(format)
1656            .with_batch_size(512)
1657            .with_projection(vec![0, 1, 2, 3, 4, 5]);
1658
1659        let mut csv = builder.build(file).unwrap();
1660        let batch = csv.next().unwrap().unwrap();
1661
1662        assert_eq!(7, batch.num_rows());
1663        assert_eq!(6, batch.num_columns());
1664
1665        let schema = batch.schema();
1666
1667        assert_eq!(&DataType::Int64, schema.field(0).data_type());
1668        assert_eq!(&DataType::Float64, schema.field(1).data_type());
1669        assert_eq!(&DataType::Float64, schema.field(2).data_type());
1670        assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1671        assert_eq!(&DataType::Date32, schema.field(4).data_type());
1672        assert_eq!(
1673            &DataType::Timestamp(TimeUnit::Second, None),
1674            schema.field(5).data_type()
1675        );
1676
1677        let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1678        assert_eq!(
1679            names,
1680            vec![
1681                "c_int",
1682                "c_float",
1683                "c_string",
1684                "c_bool",
1685                "c_date",
1686                "c_datetime"
1687            ]
1688        );
1689
1690        assert!(schema.field(0).is_nullable());
1691        assert!(schema.field(1).is_nullable());
1692        assert!(schema.field(2).is_nullable());
1693        assert!(schema.field(3).is_nullable());
1694        assert!(schema.field(4).is_nullable());
1695        assert!(schema.field(5).is_nullable());
1696
1697        assert!(!batch.column(1).is_null(0));
1698        assert!(!batch.column(1).is_null(1));
1699        assert!(batch.column(1).is_null(2));
1700        assert!(!batch.column(1).is_null(3));
1701        assert!(!batch.column(1).is_null(4));
1702    }
1703
1704    #[test]
1705    fn test_custom_nulls_with_inference() {
1706        let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1707
1708        let null_regex = Regex::new("^nil$").unwrap();
1709
1710        let format = Format::default()
1711            .with_header(true)
1712            .with_null_regex(null_regex);
1713
1714        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1715        file.rewind().unwrap();
1716
1717        let expected_schema = Schema::new(vec![
1718            Field::new("c_int", DataType::Int64, true),
1719            Field::new("c_float", DataType::Float64, true),
1720            Field::new("c_string", DataType::Utf8, true),
1721            Field::new("c_bool", DataType::Boolean, true),
1722        ]);
1723
1724        assert_eq!(schema, expected_schema);
1725
1726        let builder = ReaderBuilder::new(Arc::new(schema))
1727            .with_format(format)
1728            .with_batch_size(512)
1729            .with_projection(vec![0, 1, 2, 3]);
1730
1731        let mut csv = builder.build(file).unwrap();
1732        let batch = csv.next().unwrap().unwrap();
1733
1734        assert_eq!(5, batch.num_rows());
1735        assert_eq!(4, batch.num_columns());
1736
1737        assert_eq!(batch.schema().as_ref(), &expected_schema);
1738    }
1739
1740    #[test]
1741    fn test_scientific_notation_with_inference() {
1742        let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1743        let format = Format::default().with_header(false).with_delimiter(b',');
1744
1745        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1746        file.rewind().unwrap();
1747
1748        let builder = ReaderBuilder::new(Arc::new(schema))
1749            .with_format(format)
1750            .with_batch_size(512)
1751            .with_projection(vec![0, 1]);
1752
1753        let mut csv = builder.build(file).unwrap();
1754        let batch = csv.next().unwrap().unwrap();
1755
1756        let schema = batch.schema();
1757
1758        assert_eq!(&DataType::Float64, schema.field(0).data_type());
1759    }
1760
1761    #[test]
1762    fn test_parse_invalid_csv() {
1763        let file = File::open("test/data/various_types_invalid.csv").unwrap();
1764
1765        let schema = Schema::new(vec![
1766            Field::new("c_int", DataType::UInt64, false),
1767            Field::new("c_float", DataType::Float32, false),
1768            Field::new("c_string", DataType::Utf8, false),
1769            Field::new("c_bool", DataType::Boolean, false),
1770        ]);
1771
1772        let builder = ReaderBuilder::new(Arc::new(schema))
1773            .with_header(true)
1774            .with_delimiter(b'|')
1775            .with_batch_size(512)
1776            .with_projection(vec![0, 1, 2, 3]);
1777
1778        let mut csv = builder.build(file).unwrap();
1779        match csv.next() {
1780            Some(e) => match e {
1781                Err(e) => assert_eq!(
1782                    "ParseError(\"Error while parsing value 4.x4 for column 1 at line 4\")",
1783                    format!("{e:?}")
1784                ),
1785                Ok(_) => panic!("should have failed"),
1786            },
1787            None => panic!("should have failed"),
1788        }
1789    }
1790
1791    /// Infer the data type of a record
1792    fn infer_field_schema(string: &str) -> DataType {
1793        let mut v = InferredDataType::default();
1794        v.update(string);
1795        v.get()
1796    }
1797
1798    #[test]
1799    fn test_infer_field_schema() {
1800        assert_eq!(infer_field_schema("A"), DataType::Utf8);
1801        assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1802        assert_eq!(infer_field_schema("10"), DataType::Int64);
1803        assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1804        assert_eq!(infer_field_schema(".2"), DataType::Float64);
1805        assert_eq!(infer_field_schema("2."), DataType::Float64);
1806        assert_eq!(infer_field_schema("true"), DataType::Boolean);
1807        assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1808        assert_eq!(infer_field_schema("false"), DataType::Boolean);
1809        assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1810        assert_eq!(
1811            infer_field_schema("2020-11-08T14:20:01"),
1812            DataType::Timestamp(TimeUnit::Second, None)
1813        );
1814        assert_eq!(
1815            infer_field_schema("2020-11-08 14:20:01"),
1816            DataType::Timestamp(TimeUnit::Second, None)
1817        );
1818        assert_eq!(
1819            infer_field_schema("2020-11-08 14:20:01"),
1820            DataType::Timestamp(TimeUnit::Second, None)
1821        );
1822        assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1823        assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1824        assert_eq!(
1825            infer_field_schema("2021-12-19 13:12:30.921"),
1826            DataType::Timestamp(TimeUnit::Millisecond, None)
1827        );
1828        assert_eq!(
1829            infer_field_schema("2021-12-19T13:12:30.123456789"),
1830            DataType::Timestamp(TimeUnit::Nanosecond, None)
1831        );
1832        assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1833        assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1834    }
1835
1836    #[test]
1837    fn parse_date32() {
1838        assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1839        assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1840        assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1841    }
1842
1843    #[test]
1844    fn parse_time() {
1845        assert_eq!(
1846            Time64NanosecondType::parse("12:10:01.123456789 AM"),
1847            Some(601_123_456_789)
1848        );
1849        assert_eq!(
1850            Time64MicrosecondType::parse("12:10:01.123456 am"),
1851            Some(601_123_456)
1852        );
1853        assert_eq!(
1854            Time32MillisecondType::parse("2:10:01.12 PM"),
1855            Some(51_001_120)
1856        );
1857        assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
1858    }
1859
1860    #[test]
1861    fn parse_date64() {
1862        assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
1863        assert_eq!(
1864            Date64Type::parse("2018-11-13T17:11:10").unwrap(),
1865            1542129070000
1866        );
1867        assert_eq!(
1868            Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
1869            1542129070011
1870        );
1871        assert_eq!(
1872            Date64Type::parse("1900-02-28T12:34:56").unwrap(),
1873            -2203932304000
1874        );
1875        assert_eq!(
1876            Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
1877            -2203932304000
1878        );
1879        assert_eq!(
1880            Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
1881            -2203932304000 - (30 * 60 * 1000)
1882        );
1883    }
1884
1885    fn test_parse_timestamp_impl<T: ArrowTimestampType>(
1886        timezone: Option<Arc<str>>,
1887        expected: &[i64],
1888    ) {
1889        let csv = [
1890            "1970-01-01T00:00:00",
1891            "1970-01-01T00:00:00Z",
1892            "1970-01-01T00:00:00+02:00",
1893        ]
1894        .join("\n");
1895        let schema = Arc::new(Schema::new(vec![Field::new(
1896            "field",
1897            DataType::Timestamp(T::UNIT, timezone.clone()),
1898            true,
1899        )]));
1900
1901        let mut decoder = ReaderBuilder::new(schema).build_decoder();
1902
1903        let decoded = decoder.decode(csv.as_bytes()).unwrap();
1904        assert_eq!(decoded, csv.len());
1905        decoder.decode(&[]).unwrap();
1906
1907        let batch = decoder.flush().unwrap().unwrap();
1908        assert_eq!(batch.num_columns(), 1);
1909        assert_eq!(batch.num_rows(), 3);
1910        let col = batch.column(0).as_primitive::<T>();
1911        assert_eq!(col.values(), expected);
1912        assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
1913    }
1914
1915    #[test]
1916    fn test_parse_timestamp() {
1917        test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
1918        test_parse_timestamp_impl::<TimestampNanosecondType>(
1919            Some("+00:00".into()),
1920            &[0, 0, -7_200_000_000_000],
1921        );
1922        test_parse_timestamp_impl::<TimestampNanosecondType>(
1923            Some("-05:00".into()),
1924            &[18_000_000_000_000, 0, -7_200_000_000_000],
1925        );
1926        test_parse_timestamp_impl::<TimestampMicrosecondType>(
1927            Some("-03".into()),
1928            &[10_800_000_000, 0, -7_200_000_000],
1929        );
1930        test_parse_timestamp_impl::<TimestampMillisecondType>(
1931            Some("-03".into()),
1932            &[10_800_000, 0, -7_200_000],
1933        );
1934        test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
1935    }
1936
1937    #[test]
1938    fn test_infer_schema_from_multiple_files() {
1939        let mut csv1 = NamedTempFile::new().unwrap();
1940        let mut csv2 = NamedTempFile::new().unwrap();
1941        let csv3 = NamedTempFile::new().unwrap(); // empty csv file should be skipped
1942        let mut csv4 = NamedTempFile::new().unwrap();
1943        writeln!(csv1, "c1,c2,c3").unwrap();
1944        writeln!(csv1, "1,\"foo\",0.5").unwrap();
1945        writeln!(csv1, "3,\"bar\",1").unwrap();
1946        writeln!(csv1, "3,\"bar\",2e-06").unwrap();
1947        // reading csv2 will set c2 to optional
1948        writeln!(csv2, "c1,c2,c3,c4").unwrap();
1949        writeln!(csv2, "10,,3.14,true").unwrap();
1950        // reading csv4 will set c3 to optional
1951        writeln!(csv4, "c1,c2,c3").unwrap();
1952        writeln!(csv4, "10,\"foo\",").unwrap();
1953
1954        let schema = infer_schema_from_files(
1955            &[
1956                csv3.path().to_str().unwrap().to_string(),
1957                csv1.path().to_str().unwrap().to_string(),
1958                csv2.path().to_str().unwrap().to_string(),
1959                csv4.path().to_str().unwrap().to_string(),
1960            ],
1961            b',',
1962            Some(4), // only csv1 and csv2 should be read
1963            true,
1964        )
1965        .unwrap();
1966
1967        assert_eq!(schema.fields().len(), 4);
1968        assert!(schema.field(0).is_nullable());
1969        assert!(schema.field(1).is_nullable());
1970        assert!(schema.field(2).is_nullable());
1971        assert!(schema.field(3).is_nullable());
1972
1973        assert_eq!(&DataType::Int64, schema.field(0).data_type());
1974        assert_eq!(&DataType::Utf8, schema.field(1).data_type());
1975        assert_eq!(&DataType::Float64, schema.field(2).data_type());
1976        assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1977    }
1978
1979    #[test]
1980    fn test_bounded() {
1981        let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
1982        let data = [
1983            vec!["0"],
1984            vec!["1"],
1985            vec!["2"],
1986            vec!["3"],
1987            vec!["4"],
1988            vec!["5"],
1989            vec!["6"],
1990        ];
1991
1992        let data = data
1993            .iter()
1994            .map(|x| x.join(","))
1995            .collect::<Vec<_>>()
1996            .join("\n");
1997        let data = data.as_bytes();
1998
1999        let reader = std::io::Cursor::new(data);
2000
2001        let mut csv = ReaderBuilder::new(Arc::new(schema))
2002            .with_batch_size(2)
2003            .with_projection(vec![0])
2004            .with_bounds(2, 6)
2005            .build_buffered(reader)
2006            .unwrap();
2007
2008        let batch = csv.next().unwrap().unwrap();
2009        let a = batch.column(0);
2010        let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2011        assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2012
2013        let batch = csv.next().unwrap().unwrap();
2014        let a = batch.column(0);
2015        let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2016        assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2017
2018        assert!(csv.next().is_none());
2019    }
2020
2021    #[test]
2022    fn test_empty_projection() {
2023        let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2024        let data = [vec!["0"], vec!["1"]];
2025
2026        let data = data
2027            .iter()
2028            .map(|x| x.join(","))
2029            .collect::<Vec<_>>()
2030            .join("\n");
2031
2032        let mut csv = ReaderBuilder::new(Arc::new(schema))
2033            .with_batch_size(2)
2034            .with_projection(vec![])
2035            .build_buffered(Cursor::new(data.as_bytes()))
2036            .unwrap();
2037
2038        let batch = csv.next().unwrap().unwrap();
2039        assert_eq!(batch.columns().len(), 0);
2040        assert_eq!(batch.num_rows(), 2);
2041
2042        assert!(csv.next().is_none());
2043    }
2044
2045    #[test]
2046    fn test_parsing_bool() {
2047        // Encode the expected behavior of boolean parsing
2048        assert_eq!(Some(true), parse_bool("true"));
2049        assert_eq!(Some(true), parse_bool("tRUe"));
2050        assert_eq!(Some(true), parse_bool("True"));
2051        assert_eq!(Some(true), parse_bool("TRUE"));
2052        assert_eq!(None, parse_bool("t"));
2053        assert_eq!(None, parse_bool("T"));
2054        assert_eq!(None, parse_bool(""));
2055
2056        assert_eq!(Some(false), parse_bool("false"));
2057        assert_eq!(Some(false), parse_bool("fALse"));
2058        assert_eq!(Some(false), parse_bool("False"));
2059        assert_eq!(Some(false), parse_bool("FALSE"));
2060        assert_eq!(None, parse_bool("f"));
2061        assert_eq!(None, parse_bool("F"));
2062        assert_eq!(None, parse_bool(""));
2063    }
2064
2065    #[test]
2066    fn test_parsing_float() {
2067        assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2068        assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2069        assert_eq!(Some(12.0), Float64Type::parse("12"));
2070        assert_eq!(Some(0.0), Float64Type::parse("0"));
2071        assert_eq!(Some(2.0), Float64Type::parse("2."));
2072        assert_eq!(Some(0.2), Float64Type::parse(".2"));
2073        assert!(Float64Type::parse("nan").unwrap().is_nan());
2074        assert!(Float64Type::parse("NaN").unwrap().is_nan());
2075        assert!(Float64Type::parse("inf").unwrap().is_infinite());
2076        assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2077        assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2078        assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2079        assert_eq!(None, Float64Type::parse(""));
2080        assert_eq!(None, Float64Type::parse("dd"));
2081        assert_eq!(None, Float64Type::parse("12.34.56"));
2082    }
2083
2084    #[test]
2085    fn test_non_std_quote() {
2086        let schema = Schema::new(vec![
2087            Field::new("text1", DataType::Utf8, false),
2088            Field::new("text2", DataType::Utf8, false),
2089        ]);
2090        let builder = ReaderBuilder::new(Arc::new(schema))
2091            .with_header(false)
2092            .with_quote(b'~'); // default is ", change to ~
2093
2094        let mut csv_text = Vec::new();
2095        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2096        for index in 0..10 {
2097            let text1 = format!("id{index:}");
2098            let text2 = format!("value{index:}");
2099            csv_writer
2100                .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2101                .unwrap();
2102        }
2103        let mut csv_reader = std::io::Cursor::new(&csv_text);
2104        let mut reader = builder.build(&mut csv_reader).unwrap();
2105        let batch = reader.next().unwrap().unwrap();
2106        let col0 = batch.column(0);
2107        assert_eq!(col0.len(), 10);
2108        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2109        assert_eq!(col0_arr.value(0), "id0");
2110        let col1 = batch.column(1);
2111        assert_eq!(col1.len(), 10);
2112        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2113        assert_eq!(col1_arr.value(5), "value5");
2114    }
2115
2116    #[test]
2117    fn test_non_std_escape() {
2118        let schema = Schema::new(vec![
2119            Field::new("text1", DataType::Utf8, false),
2120            Field::new("text2", DataType::Utf8, false),
2121        ]);
2122        let builder = ReaderBuilder::new(Arc::new(schema))
2123            .with_header(false)
2124            .with_escape(b'\\'); // default is None, change to \
2125
2126        let mut csv_text = Vec::new();
2127        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2128        for index in 0..10 {
2129            let text1 = format!("id{index:}");
2130            let text2 = format!("value\\\"{index:}");
2131            csv_writer
2132                .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2133                .unwrap();
2134        }
2135        let mut csv_reader = std::io::Cursor::new(&csv_text);
2136        let mut reader = builder.build(&mut csv_reader).unwrap();
2137        let batch = reader.next().unwrap().unwrap();
2138        let col0 = batch.column(0);
2139        assert_eq!(col0.len(), 10);
2140        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2141        assert_eq!(col0_arr.value(0), "id0");
2142        let col1 = batch.column(1);
2143        assert_eq!(col1.len(), 10);
2144        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2145        assert_eq!(col1_arr.value(5), "value\"5");
2146    }
2147
2148    #[test]
2149    fn test_non_std_terminator() {
2150        let schema = Schema::new(vec![
2151            Field::new("text1", DataType::Utf8, false),
2152            Field::new("text2", DataType::Utf8, false),
2153        ]);
2154        let builder = ReaderBuilder::new(Arc::new(schema))
2155            .with_header(false)
2156            .with_terminator(b'\n'); // default is CRLF, change to LF
2157
2158        let mut csv_text = Vec::new();
2159        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2160        for index in 0..10 {
2161            let text1 = format!("id{index:}");
2162            let text2 = format!("value{index:}");
2163            csv_writer
2164                .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2165                .unwrap();
2166        }
2167        let mut csv_reader = std::io::Cursor::new(&csv_text);
2168        let mut reader = builder.build(&mut csv_reader).unwrap();
2169        let batch = reader.next().unwrap().unwrap();
2170        let col0 = batch.column(0);
2171        assert_eq!(col0.len(), 10);
2172        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2173        assert_eq!(col0_arr.value(0), "id0");
2174        let col1 = batch.column(1);
2175        assert_eq!(col1.len(), 10);
2176        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2177        assert_eq!(col1_arr.value(5), "value5");
2178    }
2179
2180    #[test]
2181    fn test_header_bounds() {
2182        let csv = "a,b\na,b\na,b\na,b\na,b\n";
2183        let tests = [
2184            (None, false, 5),
2185            (None, true, 4),
2186            (Some((0, 4)), false, 4),
2187            (Some((1, 4)), false, 3),
2188            (Some((0, 4)), true, 4),
2189            (Some((1, 4)), true, 3),
2190        ];
2191        let schema = Arc::new(Schema::new(vec![
2192            Field::new("a", DataType::Utf8, false),
2193            Field::new("a", DataType::Utf8, false),
2194        ]));
2195
2196        for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2197            let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2198            if let Some((start, end)) = bounds {
2199                reader = reader.with_bounds(start, end);
2200            }
2201            let b = reader
2202                .build_buffered(Cursor::new(csv.as_bytes()))
2203                .unwrap()
2204                .next()
2205                .unwrap()
2206                .unwrap();
2207            assert_eq!(b.num_rows(), expected, "{idx}");
2208        }
2209    }
2210
2211    #[test]
2212    fn test_null_boolean() {
2213        let csv = "true,false\nFalse,True\n,True\nFalse,";
2214        let schema = Arc::new(Schema::new(vec![
2215            Field::new("a", DataType::Boolean, true),
2216            Field::new("a", DataType::Boolean, true),
2217        ]));
2218
2219        let b = ReaderBuilder::new(schema)
2220            .build_buffered(Cursor::new(csv.as_bytes()))
2221            .unwrap()
2222            .next()
2223            .unwrap()
2224            .unwrap();
2225
2226        assert_eq!(b.num_rows(), 4);
2227        assert_eq!(b.num_columns(), 2);
2228
2229        let c = b.column(0).as_boolean();
2230        assert_eq!(c.null_count(), 1);
2231        assert!(c.value(0));
2232        assert!(!c.value(1));
2233        assert!(c.is_null(2));
2234        assert!(!c.value(3));
2235
2236        let c = b.column(1).as_boolean();
2237        assert_eq!(c.null_count(), 1);
2238        assert!(!c.value(0));
2239        assert!(c.value(1));
2240        assert!(c.value(2));
2241        assert!(c.is_null(3));
2242    }
2243
2244    #[test]
2245    fn test_truncated_rows() {
2246        let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2247        let schema = Arc::new(Schema::new(vec![
2248            Field::new("a", DataType::Int32, true),
2249            Field::new("b", DataType::Int32, true),
2250            Field::new("c", DataType::Int32, true),
2251        ]));
2252
2253        let reader = ReaderBuilder::new(schema.clone())
2254            .with_header(true)
2255            .with_truncated_rows(true)
2256            .build(Cursor::new(data))
2257            .unwrap();
2258
2259        let batches = reader.collect::<Result<Vec<_>, _>>();
2260        assert!(batches.is_ok());
2261        let batch = batches.unwrap().into_iter().next().unwrap();
2262        // Empty rows are skipped by the underlying csv parser
2263        assert_eq!(batch.num_rows(), 3);
2264
2265        let reader = ReaderBuilder::new(schema.clone())
2266            .with_header(true)
2267            .with_truncated_rows(false)
2268            .build(Cursor::new(data))
2269            .unwrap();
2270
2271        let batches = reader.collect::<Result<Vec<_>, _>>();
2272        assert!(match batches {
2273            Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2274            _ => false,
2275        });
2276    }
2277
2278    #[test]
2279    fn test_truncated_rows_csv() {
2280        let file = File::open("test/data/truncated_rows.csv").unwrap();
2281        let schema = Arc::new(Schema::new(vec![
2282            Field::new("Name", DataType::Utf8, true),
2283            Field::new("Age", DataType::UInt32, true),
2284            Field::new("Occupation", DataType::Utf8, true),
2285            Field::new("DOB", DataType::Date32, true),
2286        ]));
2287        let reader = ReaderBuilder::new(schema.clone())
2288            .with_header(true)
2289            .with_batch_size(24)
2290            .with_truncated_rows(true);
2291        let csv = reader.build(file).unwrap();
2292        let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2293
2294        assert_eq!(batches.len(), 1);
2295        let batch = &batches[0];
2296        assert_eq!(batch.num_rows(), 6);
2297        assert_eq!(batch.num_columns(), 4);
2298        let name = batch
2299            .column(0)
2300            .as_any()
2301            .downcast_ref::<StringArray>()
2302            .unwrap();
2303        let age = batch
2304            .column(1)
2305            .as_any()
2306            .downcast_ref::<UInt32Array>()
2307            .unwrap();
2308        let occupation = batch
2309            .column(2)
2310            .as_any()
2311            .downcast_ref::<StringArray>()
2312            .unwrap();
2313        let dob = batch
2314            .column(3)
2315            .as_any()
2316            .downcast_ref::<Date32Array>()
2317            .unwrap();
2318
2319        assert_eq!(name.value(0), "A1");
2320        assert_eq!(name.value(1), "B2");
2321        assert!(name.is_null(2));
2322        assert_eq!(name.value(3), "C3");
2323        assert_eq!(name.value(4), "D4");
2324        assert_eq!(name.value(5), "E5");
2325
2326        assert_eq!(age.value(0), 34);
2327        assert_eq!(age.value(1), 29);
2328        assert!(age.is_null(2));
2329        assert_eq!(age.value(3), 45);
2330        assert!(age.is_null(4));
2331        assert_eq!(age.value(5), 31);
2332
2333        assert_eq!(occupation.value(0), "Engineer");
2334        assert_eq!(occupation.value(1), "Doctor");
2335        assert!(occupation.is_null(2));
2336        assert_eq!(occupation.value(3), "Artist");
2337        assert!(occupation.is_null(4));
2338        assert!(occupation.is_null(5));
2339
2340        assert_eq!(dob.value(0), 5675);
2341        assert!(dob.is_null(1));
2342        assert!(dob.is_null(2));
2343        assert_eq!(dob.value(3), -1858);
2344        assert!(dob.is_null(4));
2345        assert!(dob.is_null(5));
2346    }
2347
2348    #[test]
2349    fn test_truncated_rows_not_nullable_error() {
2350        let data = "a,b,c\n1,2,3\n4,5";
2351        let schema = Arc::new(Schema::new(vec![
2352            Field::new("a", DataType::Int32, false),
2353            Field::new("b", DataType::Int32, false),
2354            Field::new("c", DataType::Int32, false),
2355        ]));
2356
2357        let reader = ReaderBuilder::new(schema.clone())
2358            .with_header(true)
2359            .with_truncated_rows(true)
2360            .build(Cursor::new(data))
2361            .unwrap();
2362
2363        let batches = reader.collect::<Result<Vec<_>, _>>();
2364        assert!(match batches {
2365            Err(ArrowError::InvalidArgumentError(e)) =>
2366                e.to_string().contains("contains null values"),
2367            _ => false,
2368        });
2369    }
2370
2371    #[test]
2372    fn test_buffered() {
2373        let tests = [
2374            ("test/data/uk_cities.csv", false, 37),
2375            ("test/data/various_types.csv", true, 7),
2376            ("test/data/decimal_test.csv", false, 10),
2377        ];
2378
2379        for (path, has_header, expected_rows) in tests {
2380            let (schema, _) = Format::default()
2381                .infer_schema(File::open(path).unwrap(), None)
2382                .unwrap();
2383            let schema = Arc::new(schema);
2384
2385            for batch_size in [1, 4] {
2386                for capacity in [1, 3, 7, 100] {
2387                    let reader = ReaderBuilder::new(schema.clone())
2388                        .with_batch_size(batch_size)
2389                        .with_header(has_header)
2390                        .build(File::open(path).unwrap())
2391                        .unwrap();
2392
2393                    let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2394
2395                    assert_eq!(
2396                        expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2397                        expected_rows
2398                    );
2399
2400                    let buffered =
2401                        std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2402
2403                    let reader = ReaderBuilder::new(schema.clone())
2404                        .with_batch_size(batch_size)
2405                        .with_header(has_header)
2406                        .build_buffered(buffered)
2407                        .unwrap();
2408
2409                    let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2410                    assert_eq!(expected, actual)
2411                }
2412            }
2413        }
2414    }
2415
2416    fn err_test(csv: &[u8], expected: &str) {
2417        fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2418            let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2419            let b = ReaderBuilder::new(schema)
2420                .with_batch_size(2)
2421                .build_buffered(buffer)
2422                .unwrap();
2423            let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2424            assert_eq!(err, expected)
2425        }
2426
2427        let schema_utf8 = Arc::new(Schema::new(vec![
2428            Field::new("text1", DataType::Utf8, true),
2429            Field::new("text2", DataType::Utf8, true),
2430        ]));
2431        err_test_with_schema(csv, expected, schema_utf8);
2432
2433        let schema_utf8view = Arc::new(Schema::new(vec![
2434            Field::new("text1", DataType::Utf8View, true),
2435            Field::new("text2", DataType::Utf8View, true),
2436        ]));
2437        err_test_with_schema(csv, expected, schema_utf8view);
2438    }
2439
2440    #[test]
2441    fn test_invalid_utf8() {
2442        err_test(
2443            b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2444            "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2445        );
2446
2447        err_test(
2448            b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2449            "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2450        );
2451
2452        err_test(
2453            b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2454            "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2455        );
2456
2457        err_test(
2458            b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2459            "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2460        );
2461    }
2462
2463    struct InstrumentedRead<R> {
2464        r: R,
2465        fill_count: usize,
2466        fill_sizes: Vec<usize>,
2467    }
2468
2469    impl<R> InstrumentedRead<R> {
2470        fn new(r: R) -> Self {
2471            Self {
2472                r,
2473                fill_count: 0,
2474                fill_sizes: vec![],
2475            }
2476        }
2477    }
2478
2479    impl<R: Seek> Seek for InstrumentedRead<R> {
2480        fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2481            self.r.seek(pos)
2482        }
2483    }
2484
2485    impl<R: BufRead> Read for InstrumentedRead<R> {
2486        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2487            self.r.read(buf)
2488        }
2489    }
2490
2491    impl<R: BufRead> BufRead for InstrumentedRead<R> {
2492        fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2493            self.fill_count += 1;
2494            let buf = self.r.fill_buf()?;
2495            self.fill_sizes.push(buf.len());
2496            Ok(buf)
2497        }
2498
2499        fn consume(&mut self, amt: usize) {
2500            self.r.consume(amt)
2501        }
2502    }
2503
2504    #[test]
2505    fn test_io() {
2506        let schema = Arc::new(Schema::new(vec![
2507            Field::new("a", DataType::Utf8, false),
2508            Field::new("b", DataType::Utf8, false),
2509        ]));
2510        let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2511        let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2512        let reader = ReaderBuilder::new(schema)
2513            .with_batch_size(3)
2514            .build_buffered(&mut read)
2515            .unwrap();
2516
2517        let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2518        assert_eq!(batches.len(), 2);
2519        assert_eq!(batches[0].num_rows(), 3);
2520        assert_eq!(batches[1].num_rows(), 1);
2521
2522        // Expect 4 calls to fill_buf
2523        // 1. Read first 3 rows
2524        // 2. Read final row
2525        // 3. Delimit and flush final row
2526        // 4. Iterator finished
2527        assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2528        assert_eq!(read.fill_count, 4);
2529    }
2530
2531    #[test]
2532    fn test_inference() {
2533        let cases: &[(&[&str], DataType)] = &[
2534            (&[], DataType::Null),
2535            (&["false", "12"], DataType::Utf8),
2536            (&["12", "cupcakes"], DataType::Utf8),
2537            (&["12", "12.4"], DataType::Float64),
2538            (&["14050", "24332"], DataType::Int64),
2539            (&["14050.0", "true"], DataType::Utf8),
2540            (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2541            (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2542            (
2543                &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2544                DataType::Timestamp(TimeUnit::Second, None),
2545            ),
2546            (&["2020-03-19", "2020-03-20"], DataType::Date32),
2547            (
2548                &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2549                DataType::Timestamp(TimeUnit::Second, None),
2550            ),
2551            (
2552                &[
2553                    "2020-03-19",
2554                    "2020-03-19 02:00:00",
2555                    "2020-03-19 00:00:00.000",
2556                ],
2557                DataType::Timestamp(TimeUnit::Millisecond, None),
2558            ),
2559            (
2560                &[
2561                    "2020-03-19",
2562                    "2020-03-19 02:00:00",
2563                    "2020-03-19 00:00:00.000000",
2564                ],
2565                DataType::Timestamp(TimeUnit::Microsecond, None),
2566            ),
2567            (
2568                &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2569                DataType::Timestamp(TimeUnit::Second, None),
2570            ),
2571            (
2572                &[
2573                    "2020-03-19",
2574                    "2020-03-19 02:00:00+02:00",
2575                    "2020-03-19 02:00:00Z",
2576                    "2020-03-19 02:00:00.12Z",
2577                ],
2578                DataType::Timestamp(TimeUnit::Millisecond, None),
2579            ),
2580            (
2581                &[
2582                    "2020-03-19",
2583                    "2020-03-19 02:00:00.000000000",
2584                    "2020-03-19 00:00:00.000000",
2585                ],
2586                DataType::Timestamp(TimeUnit::Nanosecond, None),
2587            ),
2588        ];
2589
2590        for (values, expected) in cases {
2591            let mut t = InferredDataType::default();
2592            for v in *values {
2593                t.update(v)
2594            }
2595            assert_eq!(&t.get(), expected, "{values:?}")
2596        }
2597    }
2598
2599    #[test]
2600    fn test_record_length_mismatch() {
2601        let csv = "\
2602        a,b,c\n\
2603        1,2,3\n\
2604        4,5\n\
2605        6,7,8";
2606        let mut read = Cursor::new(csv.as_bytes());
2607        let result = Format::default()
2608            .with_header(true)
2609            .infer_schema(&mut read, None);
2610        assert!(result.is_err());
2611        // Include line number in the error message to help locate and fix the issue
2612        assert_eq!(result.err().unwrap().to_string(), "Csv error: Encountered unequal lengths between records on CSV file. Expected 2 records, found 3 records at line 3");
2613    }
2614
2615    #[test]
2616    fn test_comment() {
2617        let schema = Schema::new(vec![
2618            Field::new("a", DataType::Int8, false),
2619            Field::new("b", DataType::Int8, false),
2620        ]);
2621
2622        let csv = "# comment1 \n1,2\n#comment2\n11,22";
2623        let mut read = Cursor::new(csv.as_bytes());
2624        let reader = ReaderBuilder::new(Arc::new(schema))
2625            .with_comment(b'#')
2626            .build(&mut read)
2627            .unwrap();
2628
2629        let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2630        assert_eq!(batches.len(), 1);
2631        let b = batches.first().unwrap();
2632        assert_eq!(b.num_columns(), 2);
2633        assert_eq!(
2634            b.column(0)
2635                .as_any()
2636                .downcast_ref::<Int8Array>()
2637                .unwrap()
2638                .values(),
2639            &vec![1, 11]
2640        );
2641        assert_eq!(
2642            b.column(1)
2643                .as_any()
2644                .downcast_ref::<Int8Array>()
2645                .unwrap()
2646                .values(),
2647            &vec![2, 22]
2648        );
2649    }
2650
2651    #[test]
2652    fn test_parse_string_view_single_column() {
2653        let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2654        let schema = Arc::new(Schema::new(vec![Field::new(
2655            "c1",
2656            DataType::Utf8View,
2657            true,
2658        )]));
2659
2660        let mut decoder = ReaderBuilder::new(schema).build_decoder();
2661
2662        let decoded = decoder.decode(csv.as_bytes()).unwrap();
2663        assert_eq!(decoded, csv.len());
2664        decoder.decode(&[]).unwrap();
2665
2666        let batch = decoder.flush().unwrap().unwrap();
2667        assert_eq!(batch.num_columns(), 1);
2668        assert_eq!(batch.num_rows(), 3);
2669        let col = batch.column(0).as_string_view();
2670        assert_eq!(col.data_type(), &DataType::Utf8View);
2671        assert_eq!(col.value(0), "foo");
2672        assert_eq!(col.value(1), "something_cannot_be_inlined");
2673        assert_eq!(col.value(2), "foobar");
2674    }
2675
2676    #[test]
2677    fn test_parse_string_view_multi_column() {
2678        let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2679        let schema = Arc::new(Schema::new(vec![
2680            Field::new("c1", DataType::Utf8View, true),
2681            Field::new("c2", DataType::Utf8View, true),
2682        ]));
2683
2684        let mut decoder = ReaderBuilder::new(schema).build_decoder();
2685
2686        let decoded = decoder.decode(csv.as_bytes()).unwrap();
2687        assert_eq!(decoded, csv.len());
2688        decoder.decode(&[]).unwrap();
2689
2690        let batch = decoder.flush().unwrap().unwrap();
2691        assert_eq!(batch.num_columns(), 2);
2692        assert_eq!(batch.num_rows(), 3);
2693        let c1 = batch.column(0).as_string_view();
2694        let c2 = batch.column(1).as_string_view();
2695        assert_eq!(c1.data_type(), &DataType::Utf8View);
2696        assert_eq!(c2.data_type(), &DataType::Utf8View);
2697
2698        assert!(!c1.is_null(0));
2699        assert!(c1.is_null(1));
2700        assert!(!c1.is_null(2));
2701        assert_eq!(c1.value(0), "foo");
2702        assert_eq!(c1.value(2), "foobarfoobar");
2703
2704        assert!(c2.is_null(0));
2705        assert!(!c2.is_null(1));
2706        assert!(!c2.is_null(2));
2707        assert_eq!(c2.value(1), "something_cannot_be_inlined");
2708        assert_eq!(c2.value(2), "bar");
2709    }
2710}