hickory_proto/serialize/binary/
encoder.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::marker::PhantomData;
9
10use crate::{
11    error::{ProtoErrorKind, ProtoResult},
12    op::Header,
13};
14
15use super::BinEncodable;
16
17// this is private to make sure there is no accidental access to the inner buffer.
18mod private {
19    use crate::error::{ProtoErrorKind, ProtoResult};
20
21    /// A wrapper for a buffer that guarantees writes never exceed a defined set of bytes
22    pub(super) struct MaximalBuf<'a> {
23        max_size: usize,
24        buffer: &'a mut Vec<u8>,
25    }
26
27    impl<'a> MaximalBuf<'a> {
28        pub(super) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
29            MaximalBuf {
30                max_size: max_size as usize,
31                buffer,
32            }
33        }
34
35        /// Sets the maximum size to enforce
36        pub(super) fn set_max_size(&mut self, max: u16) {
37            self.max_size = max as usize;
38        }
39
40        pub(super) fn write(&mut self, offset: usize, data: &[u8]) -> ProtoResult<()> {
41            debug_assert!(offset <= self.buffer.len());
42            if offset + data.len() > self.max_size {
43                return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
44            }
45
46            if offset == self.buffer.len() {
47                self.buffer.extend(data);
48                return Ok(());
49            }
50
51            let end = offset + data.len();
52            if end > self.buffer.len() {
53                self.buffer.resize(end, 0);
54            }
55
56            self.buffer[offset..end].copy_from_slice(data);
57            Ok(())
58        }
59
60        pub(super) fn reserve(&mut self, offset: usize, len: usize) -> ProtoResult<()> {
61            let end = offset + len;
62            if end > self.max_size {
63                return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
64            }
65
66            self.buffer.resize(end, 0);
67            Ok(())
68        }
69
70        /// truncates are always safe
71        pub(super) fn truncate(&mut self, len: usize) {
72            self.buffer.truncate(len)
73        }
74
75        /// returns the length of the underlying buffer
76        pub(super) fn len(&self) -> usize {
77            self.buffer.len()
78        }
79
80        /// Immutable reads are always safe
81        pub(super) fn buffer(&'a self) -> &'a [u8] {
82            self.buffer as &'a [u8]
83        }
84
85        /// Returns a reference to the internal buffer
86        pub(super) fn into_bytes(self) -> &'a Vec<u8> {
87            self.buffer
88        }
89    }
90}
91
92/// Encode DNS messages and resource record types.
93pub struct BinEncoder<'a> {
94    offset: usize,
95    buffer: private::MaximalBuf<'a>,
96    /// start of label pointers with their labels in fully decompressed form for easy comparison, smallvec here?
97    name_pointers: Vec<(usize, Vec<u8>)>,
98    mode: EncodeMode,
99    canonical_names: bool,
100}
101
102impl<'a> BinEncoder<'a> {
103    /// Create a new encoder with the Vec to fill
104    pub fn new(buf: &'a mut Vec<u8>) -> Self {
105        Self::with_offset(buf, 0, EncodeMode::Normal)
106    }
107
108    /// Specify the mode for encoding
109    ///
110    /// # Arguments
111    ///
112    /// * `mode` - In Signing mode, canonical forms of all data are encoded, otherwise format matches the source form
113    pub fn with_mode(buf: &'a mut Vec<u8>, mode: EncodeMode) -> Self {
114        Self::with_offset(buf, 0, mode)
115    }
116
117    /// Begins the encoder at the given offset
118    ///
119    /// This is used for pointers. If this encoder is starting at some point further in
120    ///  the sequence of bytes, for the proper offset of the pointer, the offset accounts for that
121    ///  by using the offset to add to the pointer location being written.
122    ///
123    /// # Arguments
124    ///
125    /// * `offset` - index at which to start writing into the buffer
126    pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32, mode: EncodeMode) -> Self {
127        if buf.capacity() < 512 {
128            let reserve = 512 - buf.capacity();
129            buf.reserve(reserve);
130        }
131
132        BinEncoder {
133            offset: offset as usize,
134            // TODO: add max_size to signature
135            buffer: private::MaximalBuf::new(u16::MAX, buf),
136            name_pointers: Vec::new(),
137            mode,
138            canonical_names: false,
139        }
140    }
141
142    // TODO: move to constructor (kept for backward compatibility)
143    /// Sets the maximum size of the buffer
144    ///
145    /// DNS message lens must be smaller than u16::max_value due to hard limits in the protocol
146    ///
147    /// *this method will move to the constructor in a future release*
148    pub fn set_max_size(&mut self, max: u16) {
149        self.buffer.set_max_size(max);
150    }
151
152    /// Returns a reference to the internal buffer
153    pub fn into_bytes(self) -> &'a Vec<u8> {
154        self.buffer.into_bytes()
155    }
156
157    /// Returns the length of the buffer
158    pub fn len(&self) -> usize {
159        self.buffer.len()
160    }
161
162    /// Returns `true` if the buffer is empty
163    pub fn is_empty(&self) -> bool {
164        self.buffer.buffer().is_empty()
165    }
166
167    /// Returns the current offset into the buffer
168    pub fn offset(&self) -> usize {
169        self.offset
170    }
171
172    /// sets the current offset to the new offset
173    pub fn set_offset(&mut self, offset: usize) {
174        self.offset = offset;
175    }
176
177    /// Returns the current Encoding mode
178    pub fn mode(&self) -> EncodeMode {
179        self.mode
180    }
181
182    /// If set to true, then names will be written into the buffer in canonical form
183    pub fn set_canonical_names(&mut self, canonical_names: bool) {
184        self.canonical_names = canonical_names;
185    }
186
187    /// Returns true if then encoder is writing in canonical form
188    pub fn is_canonical_names(&self) -> bool {
189        self.canonical_names
190    }
191
192    /// Emit all names in canonical form, useful for <https://tools.ietf.org/html/rfc3597>
193    pub fn with_canonical_names<F: FnOnce(&mut Self) -> ProtoResult<()>>(
194        &mut self,
195        f: F,
196    ) -> ProtoResult<()> {
197        let was_canonical = self.is_canonical_names();
198        self.set_canonical_names(true);
199
200        let res = f(self);
201        self.set_canonical_names(was_canonical);
202
203        res
204    }
205
206    // TODO: deprecate this...
207    /// Reserve specified additional length in the internal buffer.
208    pub fn reserve(&mut self, _additional: usize) -> ProtoResult<()> {
209        Ok(())
210    }
211
212    /// trims to the current offset
213    pub fn trim(&mut self) {
214        let offset = self.offset;
215        self.buffer.truncate(offset);
216        self.name_pointers.retain(|&(start, _)| start < offset);
217    }
218
219    // /// returns an error if the maximum buffer size would be exceeded with the addition number of elements
220    // ///
221    // /// and reserves the additional space in the buffer
222    // fn enforce_size(&mut self, additional: usize) -> ProtoResult<()> {
223    //     if (self.buffer.len() + additional) > self.max_size {
224    //         Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into())
225    //     } else {
226    //         self.reserve(additional);
227    //         Ok(())
228    //     }
229    // }
230
231    /// borrow a slice from the encoder
232    pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
233        assert!(start < self.offset);
234        assert!(end <= self.buffer.len());
235        &self.buffer.buffer()[start..end]
236    }
237
238    /// Stores a label pointer to an already written label
239    ///
240    /// The location is the current position in the buffer
241    ///  implicitly, it is expected that the name will be written to the stream after the current index.
242    pub fn store_label_pointer(&mut self, start: usize, end: usize) {
243        assert!(start <= (u16::MAX as usize));
244        assert!(end <= (u16::MAX as usize));
245        assert!(start <= end);
246        if self.offset < 0x3FFF_usize {
247            self.name_pointers
248                .push((start, self.slice_of(start, end).to_vec())); // the next char will be at the len() location
249        }
250    }
251
252    /// Looks up the index of an already written label
253    pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
254        let search = self.slice_of(start, end);
255
256        for (match_start, matcher) in &self.name_pointers {
257            if matcher.as_slice() == search {
258                assert!(match_start <= &(u16::MAX as usize));
259                return Some(*match_start as u16);
260            }
261        }
262
263        None
264    }
265
266    /// Emit one byte into the buffer
267    pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
268        self.buffer.write(self.offset, &[b])?;
269        self.offset += 1;
270        Ok(())
271    }
272
273    /// matches description from above.
274    ///
275    /// ```
276    /// use hickory_proto::serialize::binary::BinEncoder;
277    ///
278    /// let mut bytes: Vec<u8> = Vec::new();
279    /// {
280    ///   let mut encoder: BinEncoder = BinEncoder::new(&mut bytes);
281    ///   encoder.emit_character_data("abc");
282    /// }
283    /// assert_eq!(bytes, vec![3,b'a',b'b',b'c']);
284    /// ```
285    pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
286        let char_bytes = char_data.as_ref();
287        if char_bytes.len() > 255 {
288            return Err(ProtoErrorKind::CharacterDataTooLong {
289                max: 255,
290                len: char_bytes.len(),
291            }
292            .into());
293        }
294
295        self.emit_character_data_unrestricted(char_data)
296    }
297
298    /// Emit character data of unrestricted length
299    ///
300    /// Although character strings are typically restricted to being no longer than 255 characters,
301    /// some modern standards allow longer strings to be encoded.
302    pub fn emit_character_data_unrestricted<S: AsRef<[u8]>>(&mut self, data: S) -> ProtoResult<()> {
303        // first the length is written
304        let data = data.as_ref();
305        self.emit(data.len() as u8)?;
306        self.write_slice(data)
307    }
308
309    /// Emit one byte into the buffer
310    pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
311        self.emit(data)
312    }
313
314    /// Writes a u16 in network byte order to the buffer
315    pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
316        self.write_slice(&data.to_be_bytes())
317    }
318
319    /// Writes an i32 in network byte order to the buffer
320    pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
321        self.write_slice(&data.to_be_bytes())
322    }
323
324    /// Writes an u32 in network byte order to the buffer
325    pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
326        self.write_slice(&data.to_be_bytes())
327    }
328
329    fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
330        self.buffer.write(self.offset, data)?;
331        self.offset += data.len();
332        Ok(())
333    }
334
335    /// Writes the byte slice to the stream
336    pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
337        self.write_slice(data)
338    }
339
340    /// Emits all the elements of an Iterator to the encoder
341    pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
342        &mut self,
343        mut iter: I,
344    ) -> ProtoResult<usize> {
345        self.emit_iter(&mut iter)
346    }
347
348    // TODO: dedup with above emit_all
349    /// Emits all the elements of an Iterator to the encoder
350    pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
351    where
352        'e: 'r,
353        I: Iterator<Item = &'r &'e E>,
354        E: 'r + 'e + BinEncodable,
355    {
356        let mut iter = iter.cloned();
357        self.emit_iter(&mut iter)
358    }
359
360    /// emits all items in the iterator, return the number emitted
361    #[allow(clippy::needless_return)]
362    pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
363        &mut self,
364        iter: &mut I,
365    ) -> ProtoResult<usize> {
366        let mut count = 0;
367        for i in iter {
368            let rollback = self.set_rollback();
369            i.emit(self).map_err(|e| {
370                if let ProtoErrorKind::MaxBufferSizeExceeded(_) = e.kind() {
371                    rollback.rollback(self);
372                    return ProtoErrorKind::NotAllRecordsWritten { count }.into();
373                } else {
374                    return e;
375                }
376            })?;
377            count += 1;
378        }
379        Ok(count)
380    }
381
382    /// capture a location to write back to
383    pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
384        let index = self.offset;
385        let len = T::size_of();
386
387        // resize the buffer
388        self.buffer.reserve(self.offset, len)?;
389
390        // update the offset
391        self.offset += len;
392
393        Ok(Place {
394            start_index: index,
395            phantom: PhantomData,
396        })
397    }
398
399    /// calculates the length of data written since the place was creating
400    pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
401        (self.offset - place.start_index) - place.size_of()
402    }
403
404    /// write back to a previously captured location
405    pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
406        // preserve current index
407        let current_index = self.offset;
408
409        // reset the current index back to place before writing
410        //   this is an assert because it's programming error for it to be wrong.
411        assert!(place.start_index < current_index);
412        self.offset = place.start_index;
413
414        // emit the data to be written at this place
415        let emit_result = data.emit(self);
416
417        // double check that the current number of bytes were written
418        //   this is an assert because it's programming error for it to be wrong.
419        assert!((self.offset - place.start_index) == place.size_of());
420
421        // reset to original location
422        self.offset = current_index;
423
424        emit_result
425    }
426
427    fn set_rollback(&self) -> Rollback {
428        Rollback {
429            rollback_index: self.offset(),
430        }
431    }
432}
433
434/// A trait to return the size of a type as it will be encoded in DNS
435///
436/// it does not necessarily equal `std::mem::size_of`, though it might, especially for primitives
437pub trait EncodedSize: BinEncodable {
438    /// Return the size in bytes of the
439    fn size_of() -> usize;
440}
441
442impl EncodedSize for u16 {
443    fn size_of() -> usize {
444        2
445    }
446}
447
448impl EncodedSize for Header {
449    fn size_of() -> usize {
450        Self::len()
451    }
452}
453
454#[derive(Debug)]
455#[must_use = "data must be written back to the place"]
456pub struct Place<T: EncodedSize> {
457    start_index: usize,
458    phantom: PhantomData<T>,
459}
460
461impl<T: EncodedSize> Place<T> {
462    pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
463        encoder.emit_at(self, data)
464    }
465
466    pub fn size_of(&self) -> usize {
467        T::size_of()
468    }
469}
470
471/// A type representing a rollback point in a stream
472pub(crate) struct Rollback {
473    rollback_index: usize,
474}
475
476impl Rollback {
477    pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
478        encoder.set_offset(self.rollback_index)
479    }
480}
481
482/// In the Verify mode there maybe some things which are encoded differently, e.g. SIG0 records
483///  should not be included in the additional count and not in the encoded data when in Verify
484#[derive(Copy, Clone, Eq, PartialEq)]
485pub enum EncodeMode {
486    /// In signing mode records are written in canonical form
487    Signing,
488    /// Write records in standard format
489    Normal,
490}
491
492#[cfg(test)]
493mod tests {
494    use std::str::FromStr;
495
496    use super::*;
497    use crate::{
498        op::{Message, Query},
499        rr::{
500            rdata::{CNAME, SRV},
501            RData, Record, RecordType,
502        },
503        serialize::binary::BinDecodable,
504    };
505    use crate::{rr::Name, serialize::binary::BinDecoder};
506
507    #[test]
508    fn test_label_compression_regression() {
509        // https://github.com/hickory-dns/hickory-dns/issues/339
510        /*
511        ;; QUESTION SECTION:
512        ;bluedot.is.autonavi.com.gds.alibabadns.com. IN AAAA
513
514        ;; AUTHORITY SECTION:
515        gds.alibabadns.com.     1799    IN      SOA     gdsns1.alibabadns.com. none. 2015080610 1800 600 3600 360
516        */
517        let data: Vec<u8> = vec![
518            154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
519            115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
520            97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
521            0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
522            110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
523            0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
524        ];
525
526        let msg = Message::from_vec(&data).unwrap();
527        msg.to_bytes().unwrap();
528    }
529
530    #[test]
531    fn test_size_of() {
532        assert_eq!(u16::size_of(), 2);
533    }
534
535    #[test]
536    fn test_place() {
537        let mut buf = vec![];
538        {
539            let mut encoder = BinEncoder::new(&mut buf);
540            let place = encoder.place::<u16>().unwrap();
541            assert_eq!(place.size_of(), 2);
542            assert_eq!(encoder.len_since_place(&place), 0);
543
544            encoder.emit(42_u8).expect("failed 0");
545            assert_eq!(encoder.len_since_place(&place), 1);
546
547            encoder.emit(48_u8).expect("failed 1");
548            assert_eq!(encoder.len_since_place(&place), 2);
549
550            place
551                .replace(&mut encoder, 4_u16)
552                .expect("failed to replace");
553            drop(encoder);
554        }
555
556        assert_eq!(buf.len(), 4);
557
558        let mut decoder = BinDecoder::new(&buf);
559        let written = decoder.read_u16().expect("cound not read u16").unverified();
560
561        assert_eq!(written, 4);
562    }
563
564    #[test]
565    fn test_max_size() {
566        let mut buf = vec![];
567        let mut encoder = BinEncoder::new(&mut buf);
568
569        encoder.set_max_size(5);
570        encoder.emit(0).expect("failed to write");
571        encoder.emit(1).expect("failed to write");
572        encoder.emit(2).expect("failed to write");
573        encoder.emit(3).expect("failed to write");
574        encoder.emit(4).expect("failed to write");
575        let error = encoder.emit(5).unwrap_err();
576
577        match *error.kind() {
578            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
579            _ => panic!(),
580        }
581    }
582
583    #[test]
584    fn test_max_size_0() {
585        let mut buf = vec![];
586        let mut encoder = BinEncoder::new(&mut buf);
587
588        encoder.set_max_size(0);
589        let error = encoder.emit(0).unwrap_err();
590
591        match *error.kind() {
592            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
593            _ => panic!(),
594        }
595    }
596
597    #[test]
598    fn test_max_size_place() {
599        let mut buf = vec![];
600        let mut encoder = BinEncoder::new(&mut buf);
601
602        encoder.set_max_size(2);
603        let place = encoder.place::<u16>().expect("place failed");
604        place.replace(&mut encoder, 16).expect("placeback failed");
605
606        let error = encoder.place::<u16>().unwrap_err();
607
608        match *error.kind() {
609            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
610            _ => panic!(),
611        }
612    }
613
614    #[test]
615    fn test_target_compression() {
616        let mut msg = Message::new();
617        msg.add_query(Query::query(
618            Name::from_str("www.google.com.").unwrap(),
619            RecordType::A,
620        ))
621        .add_answer(Record::from_rdata(
622            Name::from_str("www.google.com.").unwrap(),
623            0,
624            RData::SRV(SRV::new(
625                0,
626                0,
627                0,
628                Name::from_str("www.compressme.com").unwrap(),
629            )),
630        ))
631        .add_additional(Record::from_rdata(
632            Name::from_str("www.google.com.").unwrap(),
633            0,
634            RData::SRV(SRV::new(
635                0,
636                0,
637                0,
638                Name::from_str("www.compressme.com").unwrap(),
639            )),
640        ))
641        // name here should use compressed label from target in previous records
642        .add_answer(Record::from_rdata(
643            Name::from_str("www.compressme.com").unwrap(),
644            0,
645            RData::CNAME(CNAME(Name::from_str("www.foo.com").unwrap())),
646        ));
647
648        let bytes = msg.to_vec().unwrap();
649        // label is compressed pointing to target, would be 145 otherwise
650        assert_eq!(bytes.len(), 130);
651        // check re-serializing
652        assert!(Message::from_vec(&bytes).is_ok());
653    }
654
655    #[test]
656    fn test_fuzzed() {
657        const MESSAGE: &[u8] = include_bytes!("../../../tests/test-data/fuzz-long.rdata");
658        let msg = Message::from_bytes(MESSAGE).unwrap();
659        msg.to_bytes().unwrap();
660    }
661}