1use arrow_array::*;
67use arrow_cast::display::*;
68use arrow_schema::*;
69use csv::ByteRecord;
70use std::io::Write;
71
72use crate::map_csv_error;
73const DEFAULT_NULL_VALUE: &str = "";
74
75#[derive(Debug)]
77pub struct Writer<W: Write> {
78 writer: csv::Writer<W>,
80 has_headers: bool,
82 date_format: Option<String>,
84 datetime_format: Option<String>,
86 timestamp_format: Option<String>,
88 timestamp_tz_format: Option<String>,
90 time_format: Option<String>,
92 beginning: bool,
94 null_value: Option<String>,
96}
97
98impl<W: Write> Writer<W> {
99 pub fn new(writer: W) -> Self {
101 let delimiter = b',';
102 WriterBuilder::new().with_delimiter(delimiter).build(writer)
103 }
104
105 pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
107 let num_columns = batch.num_columns();
108 if self.beginning {
109 if self.has_headers {
110 let mut headers: Vec<String> = Vec::with_capacity(num_columns);
111 batch
112 .schema()
113 .fields()
114 .iter()
115 .for_each(|field| headers.push(field.name().to_string()));
116 self.writer
117 .write_record(&headers[..])
118 .map_err(map_csv_error)?;
119 }
120 self.beginning = false;
121 }
122
123 let options = FormatOptions::default()
124 .with_null(self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE))
125 .with_date_format(self.date_format.as_deref())
126 .with_datetime_format(self.datetime_format.as_deref())
127 .with_timestamp_format(self.timestamp_format.as_deref())
128 .with_timestamp_tz_format(self.timestamp_tz_format.as_deref())
129 .with_time_format(self.time_format.as_deref());
130
131 let converters = batch
132 .columns()
133 .iter()
134 .map(|a| {
135 if a.data_type().is_nested() {
136 Err(ArrowError::CsvError(format!(
137 "Nested type {} is not supported in CSV",
138 a.data_type()
139 )))
140 } else {
141 ArrayFormatter::try_new(a.as_ref(), &options)
142 }
143 })
144 .collect::<Result<Vec<_>, ArrowError>>()?;
145
146 let mut buffer = String::with_capacity(1024);
147 let mut byte_record = ByteRecord::with_capacity(1024, converters.len());
148
149 for row_idx in 0..batch.num_rows() {
150 byte_record.clear();
151 for (col_idx, converter) in converters.iter().enumerate() {
152 buffer.clear();
153 converter.value(row_idx).write(&mut buffer).map_err(|e| {
154 ArrowError::CsvError(format!(
155 "Error processing row {}, col {}: {e}",
156 row_idx + 1,
157 col_idx + 1
158 ))
159 })?;
160 byte_record.push_field(buffer.as_bytes());
161 }
162
163 self.writer
164 .write_byte_record(&byte_record)
165 .map_err(map_csv_error)?;
166 }
167 self.writer.flush()?;
168
169 Ok(())
170 }
171
172 pub fn into_inner(self) -> W {
174 self.writer.into_inner().unwrap()
176 }
177}
178
179impl<W: Write> RecordBatchWriter for Writer<W> {
180 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
181 self.write(batch)
182 }
183
184 fn close(self) -> Result<(), ArrowError> {
185 Ok(())
186 }
187}
188
189#[derive(Clone, Debug)]
191pub struct WriterBuilder {
192 delimiter: u8,
194 has_header: bool,
196 quote: u8,
198 escape: u8,
200 double_quote: bool,
202 date_format: Option<String>,
204 datetime_format: Option<String>,
206 timestamp_format: Option<String>,
208 timestamp_tz_format: Option<String>,
210 time_format: Option<String>,
212 null_value: Option<String>,
214}
215
216impl Default for WriterBuilder {
217 fn default() -> Self {
218 WriterBuilder {
219 delimiter: b',',
220 has_header: true,
221 quote: b'"',
222 escape: b'\\',
223 double_quote: true,
224 date_format: None,
225 datetime_format: None,
226 timestamp_format: None,
227 timestamp_tz_format: None,
228 time_format: None,
229 null_value: None,
230 }
231 }
232}
233
234impl WriterBuilder {
235 pub fn new() -> Self {
256 Self::default()
257 }
258
259 pub fn with_header(mut self, header: bool) -> Self {
261 self.has_header = header;
262 self
263 }
264
265 pub fn header(&self) -> bool {
267 self.has_header
268 }
269
270 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
272 self.delimiter = delimiter;
273 self
274 }
275
276 pub fn delimiter(&self) -> u8 {
278 self.delimiter
279 }
280
281 pub fn with_quote(mut self, quote: u8) -> Self {
283 self.quote = quote;
284 self
285 }
286
287 pub fn quote(&self) -> u8 {
289 self.quote
290 }
291
292 pub fn with_escape(mut self, escape: u8) -> Self {
300 self.escape = escape;
301 self
302 }
303
304 pub fn escape(&self) -> u8 {
306 self.escape
307 }
308
309 pub fn with_double_quote(mut self, double_quote: bool) -> Self {
317 self.double_quote = double_quote;
318 self
319 }
320
321 pub fn double_quote(&self) -> bool {
323 self.double_quote
324 }
325
326 pub fn with_date_format(mut self, format: String) -> Self {
328 self.date_format = Some(format);
329 self
330 }
331
332 pub fn date_format(&self) -> Option<&str> {
334 self.date_format.as_deref()
335 }
336
337 pub fn with_datetime_format(mut self, format: String) -> Self {
339 self.datetime_format = Some(format);
340 self
341 }
342
343 pub fn datetime_format(&self) -> Option<&str> {
345 self.datetime_format.as_deref()
346 }
347
348 pub fn with_time_format(mut self, format: String) -> Self {
350 self.time_format = Some(format);
351 self
352 }
353
354 pub fn time_format(&self) -> Option<&str> {
356 self.time_format.as_deref()
357 }
358
359 pub fn with_timestamp_format(mut self, format: String) -> Self {
361 self.timestamp_format = Some(format);
362 self
363 }
364
365 pub fn timestamp_format(&self) -> Option<&str> {
367 self.timestamp_format.as_deref()
368 }
369
370 pub fn with_timestamp_tz_format(mut self, tz_format: String) -> Self {
372 self.timestamp_tz_format = Some(tz_format);
373 self
374 }
375
376 pub fn timestamp_tz_format(&self) -> Option<&str> {
378 self.timestamp_tz_format.as_deref()
379 }
380
381 pub fn with_null(mut self, null_value: String) -> Self {
383 self.null_value = Some(null_value);
384 self
385 }
386
387 pub fn null(&self) -> &str {
389 self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE)
390 }
391
392 pub fn build<W: Write>(self, writer: W) -> Writer<W> {
394 let mut builder = csv::WriterBuilder::new();
395 let writer = builder
396 .delimiter(self.delimiter)
397 .quote(self.quote)
398 .double_quote(self.double_quote)
399 .escape(self.escape)
400 .from_writer(writer);
401 Writer {
402 writer,
403 beginning: true,
404 has_headers: self.has_header,
405 date_format: self.date_format,
406 datetime_format: self.datetime_format,
407 time_format: self.time_format,
408 timestamp_format: self.timestamp_format,
409 timestamp_tz_format: self.timestamp_tz_format,
410 null_value: self.null_value,
411 }
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 use crate::ReaderBuilder;
420 use arrow_array::builder::{
421 BinaryBuilder, Decimal128Builder, Decimal256Builder, FixedSizeBinaryBuilder,
422 LargeBinaryBuilder,
423 };
424 use arrow_array::types::*;
425 use arrow_buffer::i256;
426 use core::str;
427 use std::io::{Cursor, Read, Seek};
428 use std::sync::Arc;
429
430 #[test]
431 fn test_write_csv() {
432 let schema = Schema::new(vec![
433 Field::new("c1", DataType::Utf8, false),
434 Field::new("c2", DataType::Float64, true),
435 Field::new("c3", DataType::UInt32, false),
436 Field::new("c4", DataType::Boolean, true),
437 Field::new("c5", DataType::Timestamp(TimeUnit::Millisecond, None), true),
438 Field::new("c6", DataType::Time32(TimeUnit::Second), false),
439 Field::new_dictionary("c7", DataType::Int32, DataType::Utf8, false),
440 ]);
441
442 let c1 = StringArray::from(vec![
443 "Lorem ipsum dolor sit amet",
444 "consectetur adipiscing elit",
445 "sed do eiusmod tempor",
446 ]);
447 let c2 =
448 PrimitiveArray::<Float64Type>::from(vec![Some(123.564532), None, Some(-556132.25)]);
449 let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
450 let c4 = BooleanArray::from(vec![Some(true), Some(false), None]);
451 let c5 =
452 TimestampMillisecondArray::from(vec![None, Some(1555584887378), Some(1555555555555)]);
453 let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]);
454 let c7: DictionaryArray<Int32Type> =
455 vec!["cupcakes", "cupcakes", "foo"].into_iter().collect();
456
457 let batch = RecordBatch::try_new(
458 Arc::new(schema),
459 vec![
460 Arc::new(c1),
461 Arc::new(c2),
462 Arc::new(c3),
463 Arc::new(c4),
464 Arc::new(c5),
465 Arc::new(c6),
466 Arc::new(c7),
467 ],
468 )
469 .unwrap();
470
471 let mut file = tempfile::tempfile().unwrap();
472
473 let mut writer = Writer::new(&mut file);
474 let batches = vec![&batch, &batch];
475 for batch in batches {
476 writer.write(batch).unwrap();
477 }
478 drop(writer);
479
480 file.rewind().unwrap();
482 let mut buffer: Vec<u8> = vec![];
483 file.read_to_end(&mut buffer).unwrap();
484
485 let expected = r#"c1,c2,c3,c4,c5,c6,c7
486Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes
487consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes
488sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo
489Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes
490consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes
491sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo
492"#;
493 assert_eq!(expected, str::from_utf8(&buffer).unwrap());
494 }
495
496 #[test]
497 fn test_write_csv_decimal() {
498 let schema = Schema::new(vec![
499 Field::new("c1", DataType::Decimal128(38, 6), true),
500 Field::new("c2", DataType::Decimal256(76, 6), true),
501 ]);
502
503 let mut c1_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 6));
504 c1_builder.extend(vec![Some(-3335724), Some(2179404), None, Some(290472)]);
505 let c1 = c1_builder.finish();
506
507 let mut c2_builder = Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 6));
508 c2_builder.extend(vec![
509 Some(i256::from_i128(-3335724)),
510 Some(i256::from_i128(2179404)),
511 None,
512 Some(i256::from_i128(290472)),
513 ]);
514 let c2 = c2_builder.finish();
515
516 let batch =
517 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap();
518
519 let mut file = tempfile::tempfile().unwrap();
520
521 let mut writer = Writer::new(&mut file);
522 let batches = vec![&batch, &batch];
523 for batch in batches {
524 writer.write(batch).unwrap();
525 }
526 drop(writer);
527
528 file.rewind().unwrap();
530 let mut buffer: Vec<u8> = vec![];
531 file.read_to_end(&mut buffer).unwrap();
532
533 let expected = r#"c1,c2
534-3.335724,-3.335724
5352.179404,2.179404
536,
5370.290472,0.290472
538-3.335724,-3.335724
5392.179404,2.179404
540,
5410.290472,0.290472
542"#;
543 assert_eq!(expected, str::from_utf8(&buffer).unwrap());
544 }
545
546 #[test]
547 fn test_write_csv_custom_options() {
548 let schema = Schema::new(vec![
549 Field::new("c1", DataType::Utf8, false),
550 Field::new("c2", DataType::Float64, true),
551 Field::new("c3", DataType::UInt32, false),
552 Field::new("c4", DataType::Boolean, true),
553 Field::new("c6", DataType::Time32(TimeUnit::Second), false),
554 ]);
555
556 let c1 = StringArray::from(vec![
557 "Lorem ipsum \ndolor sit amet",
558 "consectetur \"adipiscing\" elit",
559 "sed do eiusmod tempor",
560 ]);
561 let c2 =
562 PrimitiveArray::<Float64Type>::from(vec![Some(123.564532), None, Some(-556132.25)]);
563 let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
564 let c4 = BooleanArray::from(vec![Some(true), Some(false), None]);
565 let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]);
566
567 let batch = RecordBatch::try_new(
568 Arc::new(schema),
569 vec![
570 Arc::new(c1),
571 Arc::new(c2),
572 Arc::new(c3),
573 Arc::new(c4),
574 Arc::new(c6),
575 ],
576 )
577 .unwrap();
578
579 let mut file = tempfile::tempfile().unwrap();
580
581 let builder = WriterBuilder::new()
582 .with_header(false)
583 .with_delimiter(b'|')
584 .with_quote(b'\'')
585 .with_null("NULL".to_string())
586 .with_time_format("%r".to_string());
587 let mut writer = builder.build(&mut file);
588 let batches = vec![&batch];
589 for batch in batches {
590 writer.write(batch).unwrap();
591 }
592 drop(writer);
593
594 file.rewind().unwrap();
596 let mut buffer: Vec<u8> = vec![];
597 file.read_to_end(&mut buffer).unwrap();
598
599 assert_eq!(
600 "'Lorem ipsum \ndolor sit amet'|123.564532|3|true|12:20:34 AM\nconsectetur \"adipiscing\" elit|NULL|2|false|06:51:20 AM\nsed do eiusmod tempor|-556132.25|1|NULL|11:46:03 PM\n"
601 .to_string(),
602 String::from_utf8(buffer).unwrap()
603 );
604
605 let mut file = tempfile::tempfile().unwrap();
606
607 let builder = WriterBuilder::new()
608 .with_header(true)
609 .with_double_quote(false)
610 .with_escape(b'$');
611 let mut writer = builder.build(&mut file);
612 let batches = vec![&batch];
613 for batch in batches {
614 writer.write(batch).unwrap();
615 }
616 drop(writer);
617
618 file.rewind().unwrap();
619 let mut buffer: Vec<u8> = vec![];
620 file.read_to_end(&mut buffer).unwrap();
621
622 assert_eq!(
623 "c1,c2,c3,c4,c6\n\"Lorem ipsum \ndolor sit amet\",123.564532,3,true,00:20:34\n\"consectetur $\"adipiscing$\" elit\",,2,false,06:51:20\nsed do eiusmod tempor,-556132.25,1,,23:46:03\n"
624 .to_string(),
625 String::from_utf8(buffer).unwrap()
626 );
627 }
628
629 #[test]
630 fn test_conversion_consistency() {
631 let schema = Schema::new(vec![
634 Field::new("c1", DataType::Date32, false),
635 Field::new("c2", DataType::Date64, false),
636 Field::new("c3", DataType::Timestamp(TimeUnit::Nanosecond, None), false),
637 ]);
638
639 let nanoseconds = vec![
640 1599566300000000000,
641 1599566200000000000,
642 1599566100000000000,
643 ];
644 let c1 = Date32Array::from(vec![3, 2, 1]);
645 let c2 = Date64Array::from(vec![3, 2, 1]);
646 let c3 = TimestampNanosecondArray::from(nanoseconds.clone());
647
648 let batch = RecordBatch::try_new(
649 Arc::new(schema.clone()),
650 vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)],
651 )
652 .unwrap();
653
654 let builder = WriterBuilder::new().with_header(false);
655
656 let mut buf: Cursor<Vec<u8>> = Default::default();
657 {
659 let mut writer = builder.build(&mut buf);
660 writer.write(&batch).unwrap();
661 }
662 buf.set_position(0);
663
664 let mut reader = ReaderBuilder::new(Arc::new(schema))
665 .with_batch_size(3)
666 .build_buffered(buf)
667 .unwrap();
668
669 let rb = reader.next().unwrap().unwrap();
670 let c1 = rb.column(0).as_any().downcast_ref::<Date32Array>().unwrap();
671 let c2 = rb.column(1).as_any().downcast_ref::<Date64Array>().unwrap();
672 let c3 = rb
673 .column(2)
674 .as_any()
675 .downcast_ref::<TimestampNanosecondArray>()
676 .unwrap();
677
678 let actual = c1.into_iter().collect::<Vec<_>>();
679 let expected = vec![Some(3), Some(2), Some(1)];
680 assert_eq!(actual, expected);
681 let actual = c2.into_iter().collect::<Vec<_>>();
682 let expected = vec![Some(3), Some(2), Some(1)];
683 assert_eq!(actual, expected);
684 let actual = c3.into_iter().collect::<Vec<_>>();
685 let expected = nanoseconds.into_iter().map(Some).collect::<Vec<_>>();
686 assert_eq!(actual, expected);
687 }
688
689 #[test]
690 fn test_write_csv_invalid_cast() {
691 let schema = Schema::new(vec![
692 Field::new("c0", DataType::UInt32, false),
693 Field::new("c1", DataType::Date64, false),
694 ]);
695
696 let c0 = UInt32Array::from(vec![Some(123), Some(234)]);
697 let c1 = Date64Array::from(vec![Some(1926632005177), Some(1926632005177685347)]);
698 let batch =
699 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c0), Arc::new(c1)]).unwrap();
700
701 let mut file = tempfile::tempfile().unwrap();
702 let mut writer = Writer::new(&mut file);
703 let batches = vec![&batch, &batch];
704
705 for batch in batches {
706 let err = writer.write(batch).unwrap_err().to_string();
707 assert_eq!(err, "Csv error: Error processing row 2, col 2: Cast error: Failed to convert 1926632005177685347 to temporal for Date64")
708 }
709 drop(writer);
710 }
711
712 #[test]
713 fn test_write_csv_using_rfc3339() {
714 let schema = Schema::new(vec![
715 Field::new(
716 "c1",
717 DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())),
718 true,
719 ),
720 Field::new("c2", DataType::Timestamp(TimeUnit::Millisecond, None), true),
721 Field::new("c3", DataType::Date32, false),
722 Field::new("c4", DataType::Time32(TimeUnit::Second), false),
723 ]);
724
725 let c1 = TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)])
726 .with_timezone("+00:00".to_string());
727 let c2 = TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)]);
728 let c3 = Date32Array::from(vec![3, 2]);
729 let c4 = Time32SecondArray::from(vec![1234, 24680]);
730
731 let batch = RecordBatch::try_new(
732 Arc::new(schema),
733 vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
734 )
735 .unwrap();
736
737 let mut file = tempfile::tempfile().unwrap();
738
739 let builder = WriterBuilder::new();
740 let mut writer = builder.build(&mut file);
741 let batches = vec![&batch];
742 for batch in batches {
743 writer.write(batch).unwrap();
744 }
745 drop(writer);
746
747 file.rewind().unwrap();
748 let mut buffer: Vec<u8> = vec![];
749 file.read_to_end(&mut buffer).unwrap();
750
751 assert_eq!(
752 "c1,c2,c3,c4
7532019-04-18T10:54:47.378Z,2019-04-18T10:54:47.378,1970-01-04,00:20:34
7542021-10-30T06:59:07Z,2021-10-30T06:59:07,1970-01-03,06:51:20\n",
755 String::from_utf8(buffer).unwrap()
756 );
757 }
758
759 #[test]
760 fn test_write_csv_tz_format() {
761 let schema = Schema::new(vec![
762 Field::new(
763 "c1",
764 DataType::Timestamp(TimeUnit::Millisecond, Some("+02:00".into())),
765 true,
766 ),
767 Field::new(
768 "c2",
769 DataType::Timestamp(TimeUnit::Second, Some("+04:00".into())),
770 true,
771 ),
772 ]);
773 let c1 = TimestampMillisecondArray::from(vec![Some(1_000), Some(2_000)])
774 .with_timezone("+02:00".to_string());
775 let c2 = TimestampSecondArray::from(vec![Some(1_000_000), None])
776 .with_timezone("+04:00".to_string());
777 let batch =
778 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap();
779
780 let mut file = tempfile::tempfile().unwrap();
781 let mut writer = WriterBuilder::new()
782 .with_timestamp_tz_format("%M:%H".to_string())
783 .build(&mut file);
784 writer.write(&batch).unwrap();
785
786 drop(writer);
787 file.rewind().unwrap();
788 let mut buffer: Vec<u8> = vec![];
789 file.read_to_end(&mut buffer).unwrap();
790
791 assert_eq!(
792 "c1,c2\n00:02,46:17\n00:02,\n",
793 String::from_utf8(buffer).unwrap()
794 );
795 }
796
797 #[test]
798 fn test_write_csv_binary() {
799 let fixed_size = 8;
800 let schema = SchemaRef::new(Schema::new(vec![
801 Field::new("c1", DataType::Binary, true),
802 Field::new("c2", DataType::FixedSizeBinary(fixed_size), true),
803 Field::new("c3", DataType::LargeBinary, true),
804 ]));
805 let mut c1_builder = BinaryBuilder::new();
806 c1_builder.append_value(b"Homer");
807 c1_builder.append_value(b"Bart");
808 c1_builder.append_null();
809 c1_builder.append_value(b"Ned");
810 let mut c2_builder = FixedSizeBinaryBuilder::new(fixed_size);
811 c2_builder.append_value(b"Simpson ").unwrap();
812 c2_builder.append_value(b"Simpson ").unwrap();
813 c2_builder.append_null();
814 c2_builder.append_value(b"Flanders").unwrap();
815 let mut c3_builder = LargeBinaryBuilder::new();
816 c3_builder.append_null();
817 c3_builder.append_null();
818 c3_builder.append_value(b"Comic Book Guy");
819 c3_builder.append_null();
820
821 let batch = RecordBatch::try_new(
822 schema,
823 vec![
824 Arc::new(c1_builder.finish()) as ArrayRef,
825 Arc::new(c2_builder.finish()) as ArrayRef,
826 Arc::new(c3_builder.finish()) as ArrayRef,
827 ],
828 )
829 .unwrap();
830
831 let mut buf = Vec::new();
832 let builder = WriterBuilder::new();
833 let mut writer = builder.build(&mut buf);
834 writer.write(&batch).unwrap();
835 drop(writer);
836 assert_eq!(
837 "\
838 c1,c2,c3\n\
839 486f6d6572,53696d70736f6e20,\n\
840 42617274,53696d70736f6e20,\n\
841 ,,436f6d696320426f6f6b20477579\n\
842 4e6564,466c616e64657273,\n\
843 ",
844 String::from_utf8(buf).unwrap()
845 );
846 }
847}