1use std::marker::PhantomData;
9
10use crate::{
11 error::{ProtoErrorKind, ProtoResult},
12 op::Header,
13};
14
15use super::BinEncodable;
16
17mod private {
19 use crate::error::{ProtoErrorKind, ProtoResult};
20
21 pub(super) struct MaximalBuf<'a> {
23 max_size: usize,
24 buffer: &'a mut Vec<u8>,
25 }
26
27 impl<'a> MaximalBuf<'a> {
28 pub(super) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
29 MaximalBuf {
30 max_size: max_size as usize,
31 buffer,
32 }
33 }
34
35 pub(super) fn set_max_size(&mut self, max: u16) {
37 self.max_size = max as usize;
38 }
39
40 pub(super) fn write(&mut self, offset: usize, data: &[u8]) -> ProtoResult<()> {
41 debug_assert!(offset <= self.buffer.len());
42 if offset + data.len() > self.max_size {
43 return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
44 }
45
46 if offset == self.buffer.len() {
47 self.buffer.extend(data);
48 return Ok(());
49 }
50
51 let end = offset + data.len();
52 if end > self.buffer.len() {
53 self.buffer.resize(end, 0);
54 }
55
56 self.buffer[offset..end].copy_from_slice(data);
57 Ok(())
58 }
59
60 pub(super) fn reserve(&mut self, offset: usize, len: usize) -> ProtoResult<()> {
61 let end = offset + len;
62 if end > self.max_size {
63 return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
64 }
65
66 self.buffer.resize(end, 0);
67 Ok(())
68 }
69
70 pub(super) fn truncate(&mut self, len: usize) {
72 self.buffer.truncate(len)
73 }
74
75 pub(super) fn len(&self) -> usize {
77 self.buffer.len()
78 }
79
80 pub(super) fn buffer(&'a self) -> &'a [u8] {
82 self.buffer as &'a [u8]
83 }
84
85 pub(super) fn into_bytes(self) -> &'a Vec<u8> {
87 self.buffer
88 }
89 }
90}
91
92pub struct BinEncoder<'a> {
94 offset: usize,
95 buffer: private::MaximalBuf<'a>,
96 name_pointers: Vec<(usize, Vec<u8>)>,
98 mode: EncodeMode,
99 canonical_names: bool,
100}
101
102impl<'a> BinEncoder<'a> {
103 pub fn new(buf: &'a mut Vec<u8>) -> Self {
105 Self::with_offset(buf, 0, EncodeMode::Normal)
106 }
107
108 pub fn with_mode(buf: &'a mut Vec<u8>, mode: EncodeMode) -> Self {
114 Self::with_offset(buf, 0, mode)
115 }
116
117 pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32, mode: EncodeMode) -> Self {
127 if buf.capacity() < 512 {
128 let reserve = 512 - buf.capacity();
129 buf.reserve(reserve);
130 }
131
132 BinEncoder {
133 offset: offset as usize,
134 buffer: private::MaximalBuf::new(u16::MAX, buf),
136 name_pointers: Vec::new(),
137 mode,
138 canonical_names: false,
139 }
140 }
141
142 pub fn set_max_size(&mut self, max: u16) {
149 self.buffer.set_max_size(max);
150 }
151
152 pub fn into_bytes(self) -> &'a Vec<u8> {
154 self.buffer.into_bytes()
155 }
156
157 pub fn len(&self) -> usize {
159 self.buffer.len()
160 }
161
162 pub fn is_empty(&self) -> bool {
164 self.buffer.buffer().is_empty()
165 }
166
167 pub fn offset(&self) -> usize {
169 self.offset
170 }
171
172 pub fn set_offset(&mut self, offset: usize) {
174 self.offset = offset;
175 }
176
177 pub fn mode(&self) -> EncodeMode {
179 self.mode
180 }
181
182 pub fn set_canonical_names(&mut self, canonical_names: bool) {
184 self.canonical_names = canonical_names;
185 }
186
187 pub fn is_canonical_names(&self) -> bool {
189 self.canonical_names
190 }
191
192 pub fn with_canonical_names<F: FnOnce(&mut Self) -> ProtoResult<()>>(
194 &mut self,
195 f: F,
196 ) -> ProtoResult<()> {
197 let was_canonical = self.is_canonical_names();
198 self.set_canonical_names(true);
199
200 let res = f(self);
201 self.set_canonical_names(was_canonical);
202
203 res
204 }
205
206 pub fn reserve(&mut self, _additional: usize) -> ProtoResult<()> {
209 Ok(())
210 }
211
212 pub fn trim(&mut self) {
214 let offset = self.offset;
215 self.buffer.truncate(offset);
216 self.name_pointers.retain(|&(start, _)| start < offset);
217 }
218
219 pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
233 assert!(start < self.offset);
234 assert!(end <= self.buffer.len());
235 &self.buffer.buffer()[start..end]
236 }
237
238 pub fn store_label_pointer(&mut self, start: usize, end: usize) {
243 assert!(start <= (u16::MAX as usize));
244 assert!(end <= (u16::MAX as usize));
245 assert!(start <= end);
246 if self.offset < 0x3FFF_usize {
247 self.name_pointers
248 .push((start, self.slice_of(start, end).to_vec())); }
250 }
251
252 pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
254 let search = self.slice_of(start, end);
255
256 for (match_start, matcher) in &self.name_pointers {
257 if matcher.as_slice() == search {
258 assert!(match_start <= &(u16::MAX as usize));
259 return Some(*match_start as u16);
260 }
261 }
262
263 None
264 }
265
266 pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
268 self.buffer.write(self.offset, &[b])?;
269 self.offset += 1;
270 Ok(())
271 }
272
273 pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
286 let char_bytes = char_data.as_ref();
287 if char_bytes.len() > 255 {
288 return Err(ProtoErrorKind::CharacterDataTooLong {
289 max: 255,
290 len: char_bytes.len(),
291 }
292 .into());
293 }
294
295 self.emit_character_data_unrestricted(char_data)
296 }
297
298 pub fn emit_character_data_unrestricted<S: AsRef<[u8]>>(&mut self, data: S) -> ProtoResult<()> {
303 let data = data.as_ref();
305 self.emit(data.len() as u8)?;
306 self.write_slice(data)
307 }
308
309 pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
311 self.emit(data)
312 }
313
314 pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
316 self.write_slice(&data.to_be_bytes())
317 }
318
319 pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
321 self.write_slice(&data.to_be_bytes())
322 }
323
324 pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
326 self.write_slice(&data.to_be_bytes())
327 }
328
329 fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
330 self.buffer.write(self.offset, data)?;
331 self.offset += data.len();
332 Ok(())
333 }
334
335 pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
337 self.write_slice(data)
338 }
339
340 pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
342 &mut self,
343 mut iter: I,
344 ) -> ProtoResult<usize> {
345 self.emit_iter(&mut iter)
346 }
347
348 pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
351 where
352 'e: 'r,
353 I: Iterator<Item = &'r &'e E>,
354 E: 'r + 'e + BinEncodable,
355 {
356 let mut iter = iter.cloned();
357 self.emit_iter(&mut iter)
358 }
359
360 #[allow(clippy::needless_return)]
362 pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
363 &mut self,
364 iter: &mut I,
365 ) -> ProtoResult<usize> {
366 let mut count = 0;
367 for i in iter {
368 let rollback = self.set_rollback();
369 i.emit(self).map_err(|e| {
370 if let ProtoErrorKind::MaxBufferSizeExceeded(_) = e.kind() {
371 rollback.rollback(self);
372 return ProtoErrorKind::NotAllRecordsWritten { count }.into();
373 } else {
374 return e;
375 }
376 })?;
377 count += 1;
378 }
379 Ok(count)
380 }
381
382 pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
384 let index = self.offset;
385 let len = T::size_of();
386
387 self.buffer.reserve(self.offset, len)?;
389
390 self.offset += len;
392
393 Ok(Place {
394 start_index: index,
395 phantom: PhantomData,
396 })
397 }
398
399 pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
401 (self.offset - place.start_index) - place.size_of()
402 }
403
404 pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
406 let current_index = self.offset;
408
409 assert!(place.start_index < current_index);
412 self.offset = place.start_index;
413
414 let emit_result = data.emit(self);
416
417 assert!((self.offset - place.start_index) == place.size_of());
420
421 self.offset = current_index;
423
424 emit_result
425 }
426
427 fn set_rollback(&self) -> Rollback {
428 Rollback {
429 rollback_index: self.offset(),
430 }
431 }
432}
433
434pub trait EncodedSize: BinEncodable {
438 fn size_of() -> usize;
440}
441
442impl EncodedSize for u16 {
443 fn size_of() -> usize {
444 2
445 }
446}
447
448impl EncodedSize for Header {
449 fn size_of() -> usize {
450 Self::len()
451 }
452}
453
454#[derive(Debug)]
455#[must_use = "data must be written back to the place"]
456pub struct Place<T: EncodedSize> {
457 start_index: usize,
458 phantom: PhantomData<T>,
459}
460
461impl<T: EncodedSize> Place<T> {
462 pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
463 encoder.emit_at(self, data)
464 }
465
466 pub fn size_of(&self) -> usize {
467 T::size_of()
468 }
469}
470
471pub(crate) struct Rollback {
473 rollback_index: usize,
474}
475
476impl Rollback {
477 pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
478 encoder.set_offset(self.rollback_index)
479 }
480}
481
482#[derive(Copy, Clone, Eq, PartialEq)]
485pub enum EncodeMode {
486 Signing,
488 Normal,
490}
491
492#[cfg(test)]
493mod tests {
494 use std::str::FromStr;
495
496 use super::*;
497 use crate::{
498 op::{Message, Query},
499 rr::{
500 rdata::{CNAME, SRV},
501 RData, Record, RecordType,
502 },
503 serialize::binary::BinDecodable,
504 };
505 use crate::{rr::Name, serialize::binary::BinDecoder};
506
507 #[test]
508 fn test_label_compression_regression() {
509 let data: Vec<u8> = vec![
518 154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
519 115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
520 97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
521 0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
522 110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
523 0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
524 ];
525
526 let msg = Message::from_vec(&data).unwrap();
527 msg.to_bytes().unwrap();
528 }
529
530 #[test]
531 fn test_size_of() {
532 assert_eq!(u16::size_of(), 2);
533 }
534
535 #[test]
536 fn test_place() {
537 let mut buf = vec![];
538 {
539 let mut encoder = BinEncoder::new(&mut buf);
540 let place = encoder.place::<u16>().unwrap();
541 assert_eq!(place.size_of(), 2);
542 assert_eq!(encoder.len_since_place(&place), 0);
543
544 encoder.emit(42_u8).expect("failed 0");
545 assert_eq!(encoder.len_since_place(&place), 1);
546
547 encoder.emit(48_u8).expect("failed 1");
548 assert_eq!(encoder.len_since_place(&place), 2);
549
550 place
551 .replace(&mut encoder, 4_u16)
552 .expect("failed to replace");
553 drop(encoder);
554 }
555
556 assert_eq!(buf.len(), 4);
557
558 let mut decoder = BinDecoder::new(&buf);
559 let written = decoder.read_u16().expect("cound not read u16").unverified();
560
561 assert_eq!(written, 4);
562 }
563
564 #[test]
565 fn test_max_size() {
566 let mut buf = vec![];
567 let mut encoder = BinEncoder::new(&mut buf);
568
569 encoder.set_max_size(5);
570 encoder.emit(0).expect("failed to write");
571 encoder.emit(1).expect("failed to write");
572 encoder.emit(2).expect("failed to write");
573 encoder.emit(3).expect("failed to write");
574 encoder.emit(4).expect("failed to write");
575 let error = encoder.emit(5).unwrap_err();
576
577 match *error.kind() {
578 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
579 _ => panic!(),
580 }
581 }
582
583 #[test]
584 fn test_max_size_0() {
585 let mut buf = vec![];
586 let mut encoder = BinEncoder::new(&mut buf);
587
588 encoder.set_max_size(0);
589 let error = encoder.emit(0).unwrap_err();
590
591 match *error.kind() {
592 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
593 _ => panic!(),
594 }
595 }
596
597 #[test]
598 fn test_max_size_place() {
599 let mut buf = vec![];
600 let mut encoder = BinEncoder::new(&mut buf);
601
602 encoder.set_max_size(2);
603 let place = encoder.place::<u16>().expect("place failed");
604 place.replace(&mut encoder, 16).expect("placeback failed");
605
606 let error = encoder.place::<u16>().unwrap_err();
607
608 match *error.kind() {
609 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
610 _ => panic!(),
611 }
612 }
613
614 #[test]
615 fn test_target_compression() {
616 let mut msg = Message::new();
617 msg.add_query(Query::query(
618 Name::from_str("www.google.com.").unwrap(),
619 RecordType::A,
620 ))
621 .add_answer(Record::from_rdata(
622 Name::from_str("www.google.com.").unwrap(),
623 0,
624 RData::SRV(SRV::new(
625 0,
626 0,
627 0,
628 Name::from_str("www.compressme.com").unwrap(),
629 )),
630 ))
631 .add_additional(Record::from_rdata(
632 Name::from_str("www.google.com.").unwrap(),
633 0,
634 RData::SRV(SRV::new(
635 0,
636 0,
637 0,
638 Name::from_str("www.compressme.com").unwrap(),
639 )),
640 ))
641 .add_answer(Record::from_rdata(
643 Name::from_str("www.compressme.com").unwrap(),
644 0,
645 RData::CNAME(CNAME(Name::from_str("www.foo.com").unwrap())),
646 ));
647
648 let bytes = msg.to_vec().unwrap();
649 assert_eq!(bytes.len(), 130);
651 assert!(Message::from_vec(&bytes).is_ok());
653 }
654
655 #[test]
656 fn test_fuzzed() {
657 const MESSAGE: &[u8] = include_bytes!("../../../tests/test-data/fuzz-long.rdata");
658 let msg = Message::from_bytes(MESSAGE).unwrap();
659 msg.to_bytes().unwrap();
660 }
661}