1use crate::{
3 encode::{encode, encode_ref, encode_to_vec},
4 schema::Schema,
5 ser::Serializer,
6 types::Value,
7 AvroResult, Codec, Error,
8};
9use rand::random;
10use serde::Serialize;
11use std::{collections::HashMap, io::Write};
12
13const DEFAULT_BLOCK_SIZE: usize = 16000;
14const AVRO_OBJECT_HEADER: &[u8] = b"Obj\x01";
15
16#[derive(typed_builder::TypedBuilder)]
18pub struct Writer<'a, W> {
19 schema: &'a Schema,
20 writer: W,
21 #[builder(default = Codec::Null)]
22 codec: Codec,
23 #[builder(default = DEFAULT_BLOCK_SIZE)]
24 block_size: usize,
25 #[builder(default = Vec::with_capacity(block_size), setter(skip))]
26 buffer: Vec<u8>,
27 #[builder(default, setter(skip))]
28 serializer: Serializer,
29 #[builder(default = 0, setter(skip))]
30 num_values: usize,
31 #[builder(default = std::iter::repeat_with(random).take(16).collect(), setter(skip))]
32 marker: Vec<u8>,
33 #[builder(default = false, setter(skip))]
34 has_header: bool,
35}
36
37impl<'a, W: Write> Writer<'a, W> {
38 pub fn new(schema: &'a Schema, writer: W) -> Self {
42 Self::builder().schema(schema).writer(writer).build()
43 }
44
45 pub fn with_codec(schema: &'a Schema, writer: W, codec: Codec) -> Self {
48 Self::builder()
49 .schema(schema)
50 .writer(writer)
51 .codec(codec)
52 .build()
53 }
54
55 pub fn schema(&self) -> &'a Schema {
57 self.schema
58 }
59
60 pub fn append<T: Into<Value>>(&mut self, value: T) -> AvroResult<usize> {
69 let n = if !self.has_header {
70 let header = self.header()?;
71 let n = self.append_bytes(header.as_ref())?;
72 self.has_header = true;
73 n
74 } else {
75 0
76 };
77
78 let avro = value.into();
79 write_value_ref(self.schema, &avro, &mut self.buffer)?;
80
81 self.num_values += 1;
82
83 if self.buffer.len() >= self.block_size {
84 return self.flush().map(|b| b + n);
85 }
86
87 Ok(n)
88 }
89
90 pub fn append_value_ref(&mut self, value: &Value) -> AvroResult<usize> {
98 let n = if !self.has_header {
99 let header = self.header()?;
100 let n = self.append_bytes(header.as_ref())?;
101 self.has_header = true;
102 n
103 } else {
104 0
105 };
106
107 write_value_ref(self.schema, value, &mut self.buffer)?;
108
109 self.num_values += 1;
110
111 if self.buffer.len() >= self.block_size {
112 return self.flush().map(|b| b + n);
113 }
114
115 Ok(n)
116 }
117
118 pub fn append_ser<S: Serialize>(&mut self, value: S) -> AvroResult<usize> {
128 let avro_value = value.serialize(&mut self.serializer)?;
129 self.append(avro_value)
130 }
131
132 pub fn extend<I, T: Into<Value>>(&mut self, values: I) -> AvroResult<usize>
140 where
141 I: IntoIterator<Item = T>,
142 {
143 let mut num_bytes = 0;
158 for value in values {
159 num_bytes += self.append(value)?;
160 }
161 num_bytes += self.flush()?;
162
163 Ok(num_bytes)
164 }
165
166 pub fn extend_ser<I, T: Serialize>(&mut self, values: I) -> AvroResult<usize>
175 where
176 I: IntoIterator<Item = T>,
177 {
178 let mut num_bytes = 0;
193 for value in values {
194 num_bytes += self.append_ser(value)?;
195 }
196 num_bytes += self.flush()?;
197
198 Ok(num_bytes)
199 }
200
201 pub fn extend_from_slice(&mut self, values: &[Value]) -> AvroResult<usize> {
209 let mut num_bytes = 0;
210 for value in values {
211 num_bytes += self.append_value_ref(value)?;
212 }
213 num_bytes += self.flush()?;
214
215 Ok(num_bytes)
216 }
217
218 pub fn flush(&mut self) -> AvroResult<usize> {
223 if self.num_values == 0 {
224 return Ok(0);
225 }
226
227 self.codec.compress(&mut self.buffer)?;
228
229 let num_values = self.num_values;
230 let stream_len = self.buffer.len();
231
232 let num_bytes = self.append_raw(&num_values.into(), &Schema::Long)?
233 + self.append_raw(&stream_len.into(), &Schema::Long)?
234 + self
235 .writer
236 .write(self.buffer.as_ref())
237 .map_err(Error::WriteBytes)?
238 + self.append_marker()?;
239
240 self.buffer.clear();
241 self.num_values = 0;
242
243 Ok(num_bytes)
244 }
245
246 pub fn into_inner(mut self) -> AvroResult<W> {
251 self.flush()?;
252 Ok(self.writer)
253 }
254
255 fn append_marker(&mut self) -> AvroResult<usize> {
257 self.writer.write(&self.marker).map_err(Error::WriteMarker)
260 }
261
262 fn append_raw(&mut self, value: &Value, schema: &Schema) -> AvroResult<usize> {
264 self.append_bytes(encode_to_vec(&value, schema).as_ref())
265 }
266
267 fn append_bytes(&mut self, bytes: &[u8]) -> AvroResult<usize> {
269 self.writer.write(bytes).map_err(Error::WriteBytes)
270 }
271
272 fn header(&self) -> Result<Vec<u8>, Error> {
274 let schema_bytes = serde_json::to_string(self.schema)
275 .map_err(Error::ConvertJsonToString)?
276 .into_bytes();
277
278 let mut metadata = HashMap::with_capacity(2);
279 metadata.insert("avro.schema", Value::Bytes(schema_bytes));
280 metadata.insert("avro.codec", self.codec.into());
281
282 let mut header = Vec::new();
283 header.extend_from_slice(AVRO_OBJECT_HEADER);
284 encode(
285 &metadata.into(),
286 &Schema::Map(Box::new(Schema::Bytes)),
287 &mut header,
288 );
289 header.extend_from_slice(&self.marker);
290
291 Ok(header)
292 }
293}
294
295fn write_avro_datum<T: Into<Value>>(
301 schema: &Schema,
302 value: T,
303 buffer: &mut Vec<u8>,
304) -> Result<(), Error> {
305 let avro = value.into();
306 if !avro.validate(schema) {
307 return Err(Error::Validation);
308 }
309 encode(&avro, schema, buffer);
310 Ok(())
311}
312
313fn write_value_ref(schema: &Schema, value: &Value, buffer: &mut Vec<u8>) -> AvroResult<()> {
314 if !value.validate(schema) {
315 return Err(Error::Validation);
316 }
317 encode_ref(value, schema, buffer);
318 Ok(())
319}
320
321pub fn to_avro_datum<T: Into<Value>>(schema: &Schema, value: T) -> AvroResult<Vec<u8>> {
328 let mut buffer = Vec::new();
329 write_avro_datum(schema, value, &mut buffer)?;
330 Ok(buffer)
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::{
337 decimal::Decimal,
338 duration::{Days, Duration, Millis, Months},
339 schema::Name,
340 types::Record,
341 util::zig_i64,
342 };
343 use serde::{Deserialize, Serialize};
344
345 const AVRO_OBJECT_HEADER_LEN: usize = AVRO_OBJECT_HEADER.len();
346
347 const SCHEMA: &str = r#"
348 {
349 "type": "record",
350 "name": "test",
351 "fields": [
352 {
353 "name": "a",
354 "type": "long",
355 "default": 42
356 },
357 {
358 "name": "b",
359 "type": "string"
360 }
361 ]
362 }
363 "#;
364 const UNION_SCHEMA: &str = r#"["null", "long"]"#;
365
366 #[test]
367 fn test_to_avro_datum() {
368 let schema = Schema::parse_str(SCHEMA).unwrap();
369 let mut record = Record::new(&schema).unwrap();
370 record.put("a", 27i64);
371 record.put("b", "foo");
372
373 let mut expected = Vec::new();
374 zig_i64(27, &mut expected);
375 zig_i64(3, &mut expected);
376 expected.extend(vec![b'f', b'o', b'o'].into_iter());
377
378 assert_eq!(to_avro_datum(&schema, record).unwrap(), expected);
379 }
380
381 #[test]
382 fn test_union_not_null() {
383 let schema = Schema::parse_str(UNION_SCHEMA).unwrap();
384 let union = Value::Union(Box::new(Value::Long(3)));
385
386 let mut expected = Vec::new();
387 zig_i64(1, &mut expected);
388 zig_i64(3, &mut expected);
389
390 assert_eq!(to_avro_datum(&schema, union).unwrap(), expected);
391 }
392
393 #[test]
394 fn test_union_null() {
395 let schema = Schema::parse_str(UNION_SCHEMA).unwrap();
396 let union = Value::Union(Box::new(Value::Null));
397
398 let mut expected = Vec::new();
399 zig_i64(0, &mut expected);
400
401 assert_eq!(to_avro_datum(&schema, union).unwrap(), expected);
402 }
403
404 type TestResult<T> = Result<T, Box<dyn std::error::Error>>;
405
406 fn logical_type_test<T: Into<Value> + Clone>(
407 schema_str: &'static str,
408
409 expected_schema: &Schema,
410 value: Value,
411
412 raw_schema: &Schema,
413 raw_value: T,
414 ) -> TestResult<()> {
415 let schema = Schema::parse_str(schema_str)?;
416 assert_eq!(&schema, expected_schema);
417 let ser = to_avro_datum(&schema, value.clone())?;
419 let raw_ser = to_avro_datum(&raw_schema, raw_value)?;
420 assert_eq!(ser, raw_ser);
421
422 let mut r = ser.as_slice();
424 let de = crate::from_avro_datum(&schema, &mut r, None).unwrap();
425 assert_eq!(de, value);
426 Ok(())
427 }
428
429 #[test]
430 fn date() -> TestResult<()> {
431 logical_type_test(
432 r#"{"type": "int", "logicalType": "date"}"#,
433 &Schema::Date,
434 Value::Date(1_i32),
435 &Schema::Int,
436 1_i32,
437 )
438 }
439
440 #[test]
441 fn time_millis() -> TestResult<()> {
442 logical_type_test(
443 r#"{"type": "int", "logicalType": "time-millis"}"#,
444 &Schema::TimeMillis,
445 Value::TimeMillis(1_i32),
446 &Schema::Int,
447 1_i32,
448 )
449 }
450
451 #[test]
452 fn time_micros() -> TestResult<()> {
453 logical_type_test(
454 r#"{"type": "long", "logicalType": "time-micros"}"#,
455 &Schema::TimeMicros,
456 Value::TimeMicros(1_i64),
457 &Schema::Long,
458 1_i64,
459 )
460 }
461
462 #[test]
463 fn timestamp_millis() -> TestResult<()> {
464 logical_type_test(
465 r#"{"type": "long", "logicalType": "timestamp-millis"}"#,
466 &Schema::TimestampMillis,
467 Value::TimestampMillis(1_i64),
468 &Schema::Long,
469 1_i64,
470 )
471 }
472
473 #[test]
474 fn timestamp_micros() -> TestResult<()> {
475 logical_type_test(
476 r#"{"type": "long", "logicalType": "timestamp-micros"}"#,
477 &Schema::TimestampMicros,
478 Value::TimestampMicros(1_i64),
479 &Schema::Long,
480 1_i64,
481 )
482 }
483
484 #[test]
485 fn decimal_fixed() -> TestResult<()> {
486 let size = 30;
487 let inner = Schema::Fixed {
488 name: Name::new("decimal"),
489 size,
490 };
491 let value = vec![0u8; size];
492 logical_type_test(
493 r#"{"type": {"type": "fixed", "size": 30, "name": "decimal"}, "logicalType": "decimal", "precision": 20, "scale": 5}"#,
494 &Schema::Decimal {
495 precision: 20,
496 scale: 5,
497 inner: Box::new(inner.clone()),
498 },
499 Value::Decimal(Decimal::from(value.clone())),
500 &inner,
501 Value::Fixed(size, value),
502 )
503 }
504
505 #[test]
506 fn decimal_bytes() -> TestResult<()> {
507 let inner = Schema::Bytes;
508 let value = vec![0u8; 10];
509 logical_type_test(
510 r#"{"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 3}"#,
511 &Schema::Decimal {
512 precision: 4,
513 scale: 3,
514 inner: Box::new(inner.clone()),
515 },
516 Value::Decimal(Decimal::from(value.clone())),
517 &inner,
518 value,
519 )
520 }
521
522 #[test]
523 fn duration() -> TestResult<()> {
524 let inner = Schema::Fixed {
525 name: Name::new("duration"),
526 size: 12,
527 };
528 let value = Value::Duration(Duration::new(
529 Months::new(256),
530 Days::new(512),
531 Millis::new(1024),
532 ));
533 logical_type_test(
534 r#"{"type": {"type": "fixed", "name": "duration", "size": 12}, "logicalType": "duration"}"#,
535 &Schema::Duration,
536 value,
537 &inner,
538 Value::Fixed(12, vec![0, 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0]),
539 )
540 }
541
542 #[test]
543 fn test_writer_append() {
544 let schema = Schema::parse_str(SCHEMA).unwrap();
545 let mut writer = Writer::new(&schema, Vec::new());
546
547 let mut record = Record::new(&schema).unwrap();
548 record.put("a", 27i64);
549 record.put("b", "foo");
550
551 let n1 = writer.append(record.clone()).unwrap();
552 let n2 = writer.append(record.clone()).unwrap();
553 let n3 = writer.flush().unwrap();
554 let result = writer.into_inner().unwrap();
555
556 assert_eq!(n1 + n2 + n3, result.len());
557
558 let mut data = Vec::new();
559 zig_i64(27, &mut data);
560 zig_i64(3, &mut data);
561 data.extend(b"foo");
562 data.extend(data.clone());
563
564 assert_eq!(&result[..AVRO_OBJECT_HEADER_LEN], AVRO_OBJECT_HEADER);
566 let last_data_byte = result.len() - 16;
568 assert_eq!(
569 &result[last_data_byte - data.len()..last_data_byte],
570 data.as_slice()
571 );
572 }
573
574 #[test]
575 fn test_writer_extend() {
576 let schema = Schema::parse_str(SCHEMA).unwrap();
577 let mut writer = Writer::new(&schema, Vec::new());
578
579 let mut record = Record::new(&schema).unwrap();
580 record.put("a", 27i64);
581 record.put("b", "foo");
582 let record_copy = record.clone();
583 let records = vec![record, record_copy];
584
585 let n1 = writer.extend(records.into_iter()).unwrap();
586 let n2 = writer.flush().unwrap();
587 let result = writer.into_inner().unwrap();
588
589 assert_eq!(n1 + n2, result.len());
590
591 let mut data = Vec::new();
592 zig_i64(27, &mut data);
593 zig_i64(3, &mut data);
594 data.extend(b"foo");
595 data.extend(data.clone());
596
597 assert_eq!(&result[..AVRO_OBJECT_HEADER_LEN], AVRO_OBJECT_HEADER);
599 let last_data_byte = result.len() - 16;
601 assert_eq!(
602 &result[last_data_byte - data.len()..last_data_byte],
603 data.as_slice()
604 );
605 }
606
607 #[derive(Debug, Clone, Deserialize, Serialize)]
608 struct TestSerdeSerialize {
609 a: i64,
610 b: String,
611 }
612
613 #[test]
614 fn test_writer_append_ser() {
615 let schema = Schema::parse_str(SCHEMA).unwrap();
616 let mut writer = Writer::new(&schema, Vec::new());
617
618 let record = TestSerdeSerialize {
619 a: 27,
620 b: "foo".to_owned(),
621 };
622
623 let n1 = writer.append_ser(record).unwrap();
624 let n2 = writer.flush().unwrap();
625 let result = writer.into_inner().unwrap();
626
627 assert_eq!(n1 + n2, result.len());
628
629 let mut data = Vec::new();
630 zig_i64(27, &mut data);
631 zig_i64(3, &mut data);
632 data.extend(b"foo");
633
634 assert_eq!(&result[..AVRO_OBJECT_HEADER_LEN], AVRO_OBJECT_HEADER);
636 let last_data_byte = result.len() - 16;
638 assert_eq!(
639 &result[last_data_byte - data.len()..last_data_byte],
640 data.as_slice()
641 );
642 }
643
644 #[test]
645 fn test_writer_extend_ser() {
646 let schema = Schema::parse_str(SCHEMA).unwrap();
647 let mut writer = Writer::new(&schema, Vec::new());
648
649 let record = TestSerdeSerialize {
650 a: 27,
651 b: "foo".to_owned(),
652 };
653 let record_copy = record.clone();
654 let records = vec![record, record_copy];
655
656 let n1 = writer.extend_ser(records.into_iter()).unwrap();
657 let n2 = writer.flush().unwrap();
658 let result = writer.into_inner().unwrap();
659
660 assert_eq!(n1 + n2, result.len());
661
662 let mut data = Vec::new();
663 zig_i64(27, &mut data);
664 zig_i64(3, &mut data);
665 data.extend(b"foo");
666 data.extend(data.clone());
667
668 assert_eq!(&result[..AVRO_OBJECT_HEADER_LEN], AVRO_OBJECT_HEADER);
670 let last_data_byte = result.len() - 16;
672 assert_eq!(
673 &result[last_data_byte - data.len()..last_data_byte],
674 data.as_slice()
675 );
676 }
677
678 fn make_writer_with_codec(schema: &Schema) -> Writer<'_, Vec<u8>> {
679 Writer::with_codec(schema, Vec::new(), Codec::Deflate)
680 }
681
682 fn make_writer_with_builder(schema: &Schema) -> Writer<'_, Vec<u8>> {
683 Writer::builder()
684 .writer(Vec::new())
685 .schema(schema)
686 .codec(Codec::Deflate)
687 .block_size(100)
688 .build()
689 }
690
691 fn check_writer(mut writer: Writer<'_, Vec<u8>>, schema: &Schema) {
692 let mut record = Record::new(schema).unwrap();
693 record.put("a", 27i64);
694 record.put("b", "foo");
695
696 let n1 = writer.append(record.clone()).unwrap();
697 let n2 = writer.append(record.clone()).unwrap();
698 let n3 = writer.flush().unwrap();
699 let result = writer.into_inner().unwrap();
700
701 assert_eq!(n1 + n2 + n3, result.len());
702
703 let mut data = Vec::new();
704 zig_i64(27, &mut data);
705 zig_i64(3, &mut data);
706 data.extend(b"foo");
707 data.extend(data.clone());
708 Codec::Deflate.compress(&mut data).unwrap();
709
710 assert_eq!(&result[..AVRO_OBJECT_HEADER_LEN], AVRO_OBJECT_HEADER);
712 let last_data_byte = result.len() - 16;
714 assert_eq!(
715 &result[last_data_byte - data.len()..last_data_byte],
716 data.as_slice()
717 );
718 }
719
720 #[test]
721 fn test_writer_with_codec() {
722 let schema = Schema::parse_str(SCHEMA).unwrap();
723 let writer = make_writer_with_codec(&schema);
724 check_writer(writer, &schema);
725 }
726
727 #[test]
728 fn test_writer_with_builder() {
729 let schema = Schema::parse_str(SCHEMA).unwrap();
730 let writer = make_writer_with_builder(&schema);
731 check_writer(writer, &schema);
732 }
733
734 #[test]
735 fn test_logical_writer() {
736 const LOGICAL_TYPE_SCHEMA: &str = r#"
737 {
738 "type": "record",
739 "name": "logical_type_test",
740 "fields": [
741 {
742 "name": "a",
743 "type": [
744 "null",
745 {
746 "type": "long",
747 "logicalType": "timestamp-micros"
748 }
749 ]
750 }
751 ]
752 }
753 "#;
754 let codec = Codec::Deflate;
755 let schema = Schema::parse_str(LOGICAL_TYPE_SCHEMA).unwrap();
756 let mut writer = Writer::builder()
757 .schema(&schema)
758 .codec(codec)
759 .writer(Vec::new())
760 .build();
761
762 let mut record1 = Record::new(&schema).unwrap();
763 record1.put(
764 "a",
765 Value::Union(Box::new(Value::TimestampMicros(1234_i64))),
766 );
767
768 let mut record2 = Record::new(&schema).unwrap();
769 record2.put("a", Value::Union(Box::new(Value::Null)));
770
771 let n1 = writer.append(record1).unwrap();
772 let n2 = writer.append(record2).unwrap();
773 let n3 = writer.flush().unwrap();
774 let result = writer.into_inner().unwrap();
775
776 assert_eq!(n1 + n2 + n3, result.len());
777
778 let mut data = Vec::new();
779 zig_i64(1, &mut data);
781 zig_i64(1234, &mut data);
782
783 zig_i64(0, &mut data);
785 codec.compress(&mut data).unwrap();
786
787 assert_eq!(&result[..AVRO_OBJECT_HEADER_LEN], AVRO_OBJECT_HEADER);
789 let last_data_byte = result.len() - 16;
791 assert_eq!(
792 &result[last_data_byte - data.len()..last_data_byte],
793 data.as_slice()
794 );
795 }
796}