1use core::marker::PhantomData;
9
10use alloc::vec::Vec;
11
12use crate::{
13 error::{ProtoErrorKind, ProtoResult},
14 op::Header,
15};
16
17use super::BinEncodable;
18
19mod private {
21 use alloc::vec::Vec;
22
23 use crate::error::{ProtoErrorKind, ProtoResult};
24
25 pub(super) struct MaximalBuf<'a> {
27 max_size: usize,
28 buffer: &'a mut Vec<u8>,
29 }
30
31 impl<'a> MaximalBuf<'a> {
32 pub(super) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
33 MaximalBuf {
34 max_size: max_size as usize,
35 buffer,
36 }
37 }
38
39 pub(super) fn set_max_size(&mut self, max: u16) {
41 self.max_size = max as usize;
42 }
43
44 pub(super) fn write(&mut self, offset: usize, data: &[u8]) -> ProtoResult<()> {
45 debug_assert!(offset <= self.buffer.len());
46 if offset + data.len() > self.max_size {
47 return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
48 }
49
50 if offset == self.buffer.len() {
51 self.buffer.extend(data);
52 return Ok(());
53 }
54
55 let end = offset + data.len();
56 if end > self.buffer.len() {
57 self.buffer.resize(end, 0);
58 }
59
60 self.buffer[offset..end].copy_from_slice(data);
61 Ok(())
62 }
63
64 pub(super) fn reserve(&mut self, offset: usize, len: usize) -> ProtoResult<()> {
65 let end = offset + len;
66 if end > self.max_size {
67 return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
68 }
69
70 self.buffer.resize(end, 0);
71 Ok(())
72 }
73
74 pub(super) fn truncate(&mut self, len: usize) {
76 self.buffer.truncate(len)
77 }
78
79 pub(super) fn len(&self) -> usize {
81 self.buffer.len()
82 }
83
84 pub(super) fn buffer(&'a self) -> &'a [u8] {
86 self.buffer as &'a [u8]
87 }
88
89 pub(super) fn into_bytes(self) -> &'a Vec<u8> {
91 self.buffer
92 }
93 }
94}
95
96pub struct BinEncoder<'a> {
98 offset: usize,
99 buffer: private::MaximalBuf<'a>,
100 name_pointers: Vec<(usize, Vec<u8>)>,
102 mode: EncodeMode,
103 canonical_names: bool,
104}
105
106impl<'a> BinEncoder<'a> {
107 pub fn new(buf: &'a mut Vec<u8>) -> Self {
109 Self::with_offset(buf, 0, EncodeMode::Normal)
110 }
111
112 pub fn with_mode(buf: &'a mut Vec<u8>, mode: EncodeMode) -> Self {
118 Self::with_offset(buf, 0, mode)
119 }
120
121 pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32, mode: EncodeMode) -> Self {
131 if buf.capacity() < 512 {
132 let reserve = 512 - buf.capacity();
133 buf.reserve(reserve);
134 }
135
136 BinEncoder {
137 offset: offset as usize,
138 buffer: private::MaximalBuf::new(u16::MAX, buf),
140 name_pointers: Vec::new(),
141 mode,
142 canonical_names: false,
143 }
144 }
145
146 pub fn set_max_size(&mut self, max: u16) {
153 self.buffer.set_max_size(max);
154 }
155
156 pub fn into_bytes(self) -> &'a Vec<u8> {
158 self.buffer.into_bytes()
159 }
160
161 pub fn len(&self) -> usize {
163 self.buffer.len()
164 }
165
166 pub fn is_empty(&self) -> bool {
168 self.buffer.buffer().is_empty()
169 }
170
171 pub fn offset(&self) -> usize {
173 self.offset
174 }
175
176 pub fn set_offset(&mut self, offset: usize) {
178 self.offset = offset;
179 }
180
181 pub fn mode(&self) -> EncodeMode {
183 self.mode
184 }
185
186 pub fn set_canonical_names(&mut self, canonical_names: bool) {
188 self.canonical_names = canonical_names;
189 }
190
191 pub fn is_canonical_names(&self) -> bool {
193 self.canonical_names
194 }
195
196 pub fn with_canonical_names<F: FnOnce(&mut Self) -> ProtoResult<()>>(
198 &mut self,
199 f: F,
200 ) -> ProtoResult<()> {
201 let was_canonical = self.is_canonical_names();
202 self.set_canonical_names(true);
203
204 let res = f(self);
205 self.set_canonical_names(was_canonical);
206
207 res
208 }
209
210 pub fn reserve(&mut self, _additional: usize) -> ProtoResult<()> {
213 Ok(())
214 }
215
216 pub fn trim(&mut self) {
218 let offset = self.offset;
219 self.buffer.truncate(offset);
220 self.name_pointers.retain(|&(start, _)| start < offset);
221 }
222
223 pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
237 assert!(start < self.offset);
238 assert!(end <= self.buffer.len());
239 &self.buffer.buffer()[start..end]
240 }
241
242 pub fn store_label_pointer(&mut self, start: usize, end: usize) {
247 assert!(start <= (u16::MAX as usize));
248 assert!(end <= (u16::MAX as usize));
249 assert!(start <= end);
250 if self.offset < 0x3FFF_usize {
251 self.name_pointers
252 .push((start, self.slice_of(start, end).to_vec())); }
254 }
255
256 pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
258 let search = self.slice_of(start, end);
259
260 for (match_start, matcher) in &self.name_pointers {
261 if matcher.as_slice() == search {
262 assert!(match_start <= &(u16::MAX as usize));
263 return Some(*match_start as u16);
264 }
265 }
266
267 None
268 }
269
270 pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
272 self.buffer.write(self.offset, &[b])?;
273 self.offset += 1;
274 Ok(())
275 }
276
277 pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
290 let char_bytes = char_data.as_ref();
291 if char_bytes.len() > 255 {
292 return Err(ProtoErrorKind::CharacterDataTooLong {
293 max: 255,
294 len: char_bytes.len(),
295 }
296 .into());
297 }
298
299 self.emit_character_data_unrestricted(char_data)
300 }
301
302 pub fn emit_character_data_unrestricted<S: AsRef<[u8]>>(&mut self, data: S) -> ProtoResult<()> {
307 let data = data.as_ref();
309 self.emit(data.len() as u8)?;
310 self.write_slice(data)
311 }
312
313 pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
315 self.emit(data)
316 }
317
318 pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
320 self.write_slice(&data.to_be_bytes())
321 }
322
323 pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
325 self.write_slice(&data.to_be_bytes())
326 }
327
328 pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
330 self.write_slice(&data.to_be_bytes())
331 }
332
333 fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
334 self.buffer.write(self.offset, data)?;
335 self.offset += data.len();
336 Ok(())
337 }
338
339 pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
341 self.write_slice(data)
342 }
343
344 pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
346 &mut self,
347 mut iter: I,
348 ) -> ProtoResult<usize> {
349 self.emit_iter(&mut iter)
350 }
351
352 pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
355 where
356 'e: 'r,
357 I: Iterator<Item = &'r &'e E>,
358 E: 'r + 'e + BinEncodable,
359 {
360 let mut iter = iter.cloned();
361 self.emit_iter(&mut iter)
362 }
363
364 #[allow(clippy::needless_return)]
366 pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
367 &mut self,
368 iter: &mut I,
369 ) -> ProtoResult<usize> {
370 let mut count = 0;
371 for i in iter {
372 let rollback = self.set_rollback();
373 i.emit(self).map_err(|e| {
374 if let ProtoErrorKind::MaxBufferSizeExceeded(_) = e.kind() {
375 rollback.rollback(self);
376 return ProtoErrorKind::NotAllRecordsWritten { count }.into();
377 } else {
378 return e;
379 }
380 })?;
381 count += 1;
382 }
383 Ok(count)
384 }
385
386 pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
388 let index = self.offset;
389 let len = T::size_of();
390
391 self.buffer.reserve(self.offset, len)?;
393
394 self.offset += len;
396
397 Ok(Place {
398 start_index: index,
399 phantom: PhantomData,
400 })
401 }
402
403 pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
405 (self.offset - place.start_index) - place.size_of()
406 }
407
408 pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
410 let current_index = self.offset;
412
413 assert!(place.start_index < current_index);
416 self.offset = place.start_index;
417
418 let emit_result = data.emit(self);
420
421 assert!((self.offset - place.start_index) == place.size_of());
424
425 self.offset = current_index;
427
428 emit_result
429 }
430
431 fn set_rollback(&self) -> Rollback {
432 Rollback {
433 rollback_index: self.offset(),
434 }
435 }
436}
437
438pub trait EncodedSize: BinEncodable {
442 fn size_of() -> usize;
444}
445
446impl EncodedSize for u16 {
447 fn size_of() -> usize {
448 2
449 }
450}
451
452impl EncodedSize for Header {
453 fn size_of() -> usize {
454 Self::len()
455 }
456}
457
458#[derive(Debug)]
459#[must_use = "data must be written back to the place"]
460pub struct Place<T: EncodedSize> {
461 start_index: usize,
462 phantom: PhantomData<T>,
463}
464
465impl<T: EncodedSize> Place<T> {
466 pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
467 encoder.emit_at(self, data)
468 }
469
470 pub fn size_of(&self) -> usize {
471 T::size_of()
472 }
473}
474
475pub(crate) struct Rollback {
477 rollback_index: usize,
478}
479
480impl Rollback {
481 pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
482 encoder.set_offset(self.rollback_index)
483 }
484}
485
486#[derive(Copy, Clone, Eq, PartialEq)]
489pub enum EncodeMode {
490 Signing,
492 Normal,
494}
495
496#[cfg(test)]
497mod tests {
498 use core::str::FromStr;
499
500 use super::*;
501 use crate::{
502 op::{Message, Query},
503 rr::{
504 RData, Record, RecordType,
505 rdata::{CNAME, SRV},
506 },
507 serialize::binary::BinDecodable,
508 };
509 use crate::{rr::Name, serialize::binary::BinDecoder};
510
511 #[test]
512 fn test_label_compression_regression() {
513 let data = vec![
522 154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
523 115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
524 97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
525 0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
526 110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
527 0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
528 ];
529
530 let msg = Message::from_vec(&data).unwrap();
531 msg.to_bytes().unwrap();
532 }
533
534 #[test]
535 fn test_size_of() {
536 assert_eq!(u16::size_of(), 2);
537 }
538
539 #[test]
540 fn test_place() {
541 let mut buf = vec![];
542 {
543 let mut encoder = BinEncoder::new(&mut buf);
544 let place = encoder.place::<u16>().unwrap();
545 assert_eq!(place.size_of(), 2);
546 assert_eq!(encoder.len_since_place(&place), 0);
547
548 encoder.emit(42_u8).expect("failed 0");
549 assert_eq!(encoder.len_since_place(&place), 1);
550
551 encoder.emit(48_u8).expect("failed 1");
552 assert_eq!(encoder.len_since_place(&place), 2);
553
554 place
555 .replace(&mut encoder, 4_u16)
556 .expect("failed to replace");
557 drop(encoder);
558 }
559
560 assert_eq!(buf.len(), 4);
561
562 let mut decoder = BinDecoder::new(&buf);
563 let written = decoder.read_u16().expect("cound not read u16").unverified();
564
565 assert_eq!(written, 4);
566 }
567
568 #[test]
569 fn test_max_size() {
570 let mut buf = vec![];
571 let mut encoder = BinEncoder::new(&mut buf);
572
573 encoder.set_max_size(5);
574 encoder.emit(0).expect("failed to write");
575 encoder.emit(1).expect("failed to write");
576 encoder.emit(2).expect("failed to write");
577 encoder.emit(3).expect("failed to write");
578 encoder.emit(4).expect("failed to write");
579 let error = encoder.emit(5).unwrap_err();
580
581 match error.kind() {
582 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
583 _ => panic!(),
584 }
585 }
586
587 #[test]
588 fn test_max_size_0() {
589 let mut buf = vec![];
590 let mut encoder = BinEncoder::new(&mut buf);
591
592 encoder.set_max_size(0);
593 let error = encoder.emit(0).unwrap_err();
594
595 match error.kind() {
596 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
597 _ => panic!(),
598 }
599 }
600
601 #[test]
602 fn test_max_size_place() {
603 let mut buf = vec![];
604 let mut encoder = BinEncoder::new(&mut buf);
605
606 encoder.set_max_size(2);
607 let place = encoder.place::<u16>().expect("place failed");
608 place.replace(&mut encoder, 16).expect("placeback failed");
609
610 let error = encoder.place::<u16>().unwrap_err();
611
612 match error.kind() {
613 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
614 _ => panic!(),
615 }
616 }
617
618 #[test]
619 fn test_target_compression() {
620 let mut msg = Message::new();
621 msg.add_query(Query::query(
622 Name::from_str("www.google.com.").unwrap(),
623 RecordType::A,
624 ))
625 .add_answer(Record::from_rdata(
626 Name::from_str("www.google.com.").unwrap(),
627 0,
628 RData::SRV(SRV::new(
629 0,
630 0,
631 0,
632 Name::from_str("www.compressme.com.").unwrap(),
633 )),
634 ))
635 .add_additional(Record::from_rdata(
636 Name::from_str("www.google.com.").unwrap(),
637 0,
638 RData::SRV(SRV::new(
639 0,
640 0,
641 0,
642 Name::from_str("www.compressme.com.").unwrap(),
643 )),
644 ))
645 .add_answer(Record::from_rdata(
647 Name::from_str("www.compressme.com.").unwrap(),
648 0,
649 RData::CNAME(CNAME(Name::from_str("www.foo.com.").unwrap())),
650 ));
651
652 let bytes = msg.to_vec().unwrap();
653 assert_eq!(bytes.len(), 130);
655 assert!(Message::from_vec(&bytes).is_ok());
657 }
658
659 #[test]
660 fn test_fuzzed() {
661 const MESSAGE: &[u8] = include_bytes!("../../../tests/test-data/fuzz-long.rdata");
662 let msg = Message::from_bytes(MESSAGE).unwrap();
663 msg.to_bytes().unwrap();
664 }
665}