hickory_proto/serialize/binary/
encoder.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use core::marker::PhantomData;
9
10use alloc::vec::Vec;
11
12use crate::{
13    error::{ProtoErrorKind, ProtoResult},
14    op::Header,
15};
16
17use super::BinEncodable;
18
19// this is private to make sure there is no accidental access to the inner buffer.
20mod private {
21    use alloc::vec::Vec;
22
23    use crate::error::{ProtoErrorKind, ProtoResult};
24
25    /// A wrapper for a buffer that guarantees writes never exceed a defined set of bytes
26    pub(super) struct MaximalBuf<'a> {
27        max_size: usize,
28        buffer: &'a mut Vec<u8>,
29    }
30
31    impl<'a> MaximalBuf<'a> {
32        pub(super) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
33            MaximalBuf {
34                max_size: max_size as usize,
35                buffer,
36            }
37        }
38
39        /// Sets the maximum size to enforce
40        pub(super) fn set_max_size(&mut self, max: u16) {
41            self.max_size = max as usize;
42        }
43
44        pub(super) fn write(&mut self, offset: usize, data: &[u8]) -> ProtoResult<()> {
45            debug_assert!(offset <= self.buffer.len());
46            if offset + data.len() > self.max_size {
47                return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
48            }
49
50            if offset == self.buffer.len() {
51                self.buffer.extend(data);
52                return Ok(());
53            }
54
55            let end = offset + data.len();
56            if end > self.buffer.len() {
57                self.buffer.resize(end, 0);
58            }
59
60            self.buffer[offset..end].copy_from_slice(data);
61            Ok(())
62        }
63
64        pub(super) fn reserve(&mut self, offset: usize, len: usize) -> ProtoResult<()> {
65            let end = offset + len;
66            if end > self.max_size {
67                return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
68            }
69
70            self.buffer.resize(end, 0);
71            Ok(())
72        }
73
74        /// truncates are always safe
75        pub(super) fn truncate(&mut self, len: usize) {
76            self.buffer.truncate(len)
77        }
78
79        /// returns the length of the underlying buffer
80        pub(super) fn len(&self) -> usize {
81            self.buffer.len()
82        }
83
84        /// Immutable reads are always safe
85        pub(super) fn buffer(&'a self) -> &'a [u8] {
86            self.buffer as &'a [u8]
87        }
88
89        /// Returns a reference to the internal buffer
90        pub(super) fn into_bytes(self) -> &'a Vec<u8> {
91            self.buffer
92        }
93    }
94}
95
96/// Encode DNS messages and resource record types.
97pub struct BinEncoder<'a> {
98    offset: usize,
99    buffer: private::MaximalBuf<'a>,
100    /// start of label pointers with their labels in fully decompressed form for easy comparison, smallvec here?
101    name_pointers: Vec<(usize, Vec<u8>)>,
102    mode: EncodeMode,
103    canonical_names: bool,
104}
105
106impl<'a> BinEncoder<'a> {
107    /// Create a new encoder with the Vec to fill
108    pub fn new(buf: &'a mut Vec<u8>) -> Self {
109        Self::with_offset(buf, 0, EncodeMode::Normal)
110    }
111
112    /// Specify the mode for encoding
113    ///
114    /// # Arguments
115    ///
116    /// * `mode` - In Signing mode, canonical forms of all data are encoded, otherwise format matches the source form
117    pub fn with_mode(buf: &'a mut Vec<u8>, mode: EncodeMode) -> Self {
118        Self::with_offset(buf, 0, mode)
119    }
120
121    /// Begins the encoder at the given offset
122    ///
123    /// This is used for pointers. If this encoder is starting at some point further in
124    ///  the sequence of bytes, for the proper offset of the pointer, the offset accounts for that
125    ///  by using the offset to add to the pointer location being written.
126    ///
127    /// # Arguments
128    ///
129    /// * `offset` - index at which to start writing into the buffer
130    pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32, mode: EncodeMode) -> Self {
131        if buf.capacity() < 512 {
132            let reserve = 512 - buf.capacity();
133            buf.reserve(reserve);
134        }
135
136        BinEncoder {
137            offset: offset as usize,
138            // TODO: add max_size to signature
139            buffer: private::MaximalBuf::new(u16::MAX, buf),
140            name_pointers: Vec::new(),
141            mode,
142            canonical_names: false,
143        }
144    }
145
146    // TODO: move to constructor (kept for backward compatibility)
147    /// Sets the maximum size of the buffer
148    ///
149    /// DNS message lens must be smaller than u16::max_value due to hard limits in the protocol
150    ///
151    /// *this method will move to the constructor in a future release*
152    pub fn set_max_size(&mut self, max: u16) {
153        self.buffer.set_max_size(max);
154    }
155
156    /// Returns a reference to the internal buffer
157    pub fn into_bytes(self) -> &'a Vec<u8> {
158        self.buffer.into_bytes()
159    }
160
161    /// Returns the length of the buffer
162    pub fn len(&self) -> usize {
163        self.buffer.len()
164    }
165
166    /// Returns `true` if the buffer is empty
167    pub fn is_empty(&self) -> bool {
168        self.buffer.buffer().is_empty()
169    }
170
171    /// Returns the current offset into the buffer
172    pub fn offset(&self) -> usize {
173        self.offset
174    }
175
176    /// sets the current offset to the new offset
177    pub fn set_offset(&mut self, offset: usize) {
178        self.offset = offset;
179    }
180
181    /// Returns the current Encoding mode
182    pub fn mode(&self) -> EncodeMode {
183        self.mode
184    }
185
186    /// If set to true, then names will be written into the buffer in canonical form
187    pub fn set_canonical_names(&mut self, canonical_names: bool) {
188        self.canonical_names = canonical_names;
189    }
190
191    /// Returns true if then encoder is writing in canonical form
192    pub fn is_canonical_names(&self) -> bool {
193        self.canonical_names
194    }
195
196    /// Emit all names in canonical form, useful for <https://tools.ietf.org/html/rfc3597>
197    pub fn with_canonical_names<F: FnOnce(&mut Self) -> ProtoResult<()>>(
198        &mut self,
199        f: F,
200    ) -> ProtoResult<()> {
201        let was_canonical = self.is_canonical_names();
202        self.set_canonical_names(true);
203
204        let res = f(self);
205        self.set_canonical_names(was_canonical);
206
207        res
208    }
209
210    // TODO: deprecate this...
211    /// Reserve specified additional length in the internal buffer.
212    pub fn reserve(&mut self, _additional: usize) -> ProtoResult<()> {
213        Ok(())
214    }
215
216    /// trims to the current offset
217    pub fn trim(&mut self) {
218        let offset = self.offset;
219        self.buffer.truncate(offset);
220        self.name_pointers.retain(|&(start, _)| start < offset);
221    }
222
223    // /// returns an error if the maximum buffer size would be exceeded with the addition number of elements
224    // ///
225    // /// and reserves the additional space in the buffer
226    // fn enforce_size(&mut self, additional: usize) -> ProtoResult<()> {
227    //     if (self.buffer.len() + additional) > self.max_size {
228    //         Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into())
229    //     } else {
230    //         self.reserve(additional);
231    //         Ok(())
232    //     }
233    // }
234
235    /// borrow a slice from the encoder
236    pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
237        assert!(start < self.offset);
238        assert!(end <= self.buffer.len());
239        &self.buffer.buffer()[start..end]
240    }
241
242    /// Stores a label pointer to an already written label
243    ///
244    /// The location is the current position in the buffer
245    ///  implicitly, it is expected that the name will be written to the stream after the current index.
246    pub fn store_label_pointer(&mut self, start: usize, end: usize) {
247        assert!(start <= (u16::MAX as usize));
248        assert!(end <= (u16::MAX as usize));
249        assert!(start <= end);
250        if self.offset < 0x3FFF_usize {
251            self.name_pointers
252                .push((start, self.slice_of(start, end).to_vec())); // the next char will be at the len() location
253        }
254    }
255
256    /// Looks up the index of an already written label
257    pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
258        let search = self.slice_of(start, end);
259
260        for (match_start, matcher) in &self.name_pointers {
261            if matcher.as_slice() == search {
262                assert!(match_start <= &(u16::MAX as usize));
263                return Some(*match_start as u16);
264            }
265        }
266
267        None
268    }
269
270    /// Emit one byte into the buffer
271    pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
272        self.buffer.write(self.offset, &[b])?;
273        self.offset += 1;
274        Ok(())
275    }
276
277    /// matches description from above.
278    ///
279    /// ```
280    /// use hickory_proto::serialize::binary::BinEncoder;
281    ///
282    /// let mut bytes: Vec<u8> = Vec::new();
283    /// {
284    ///   let mut encoder: BinEncoder = BinEncoder::new(&mut bytes);
285    ///   encoder.emit_character_data("abc");
286    /// }
287    /// assert_eq!(bytes, vec![3,b'a',b'b',b'c']);
288    /// ```
289    pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
290        let char_bytes = char_data.as_ref();
291        if char_bytes.len() > 255 {
292            return Err(ProtoErrorKind::CharacterDataTooLong {
293                max: 255,
294                len: char_bytes.len(),
295            }
296            .into());
297        }
298
299        self.emit_character_data_unrestricted(char_data)
300    }
301
302    /// Emit character data of unrestricted length
303    ///
304    /// Although character strings are typically restricted to being no longer than 255 characters,
305    /// some modern standards allow longer strings to be encoded.
306    pub fn emit_character_data_unrestricted<S: AsRef<[u8]>>(&mut self, data: S) -> ProtoResult<()> {
307        // first the length is written
308        let data = data.as_ref();
309        self.emit(data.len() as u8)?;
310        self.write_slice(data)
311    }
312
313    /// Emit one byte into the buffer
314    pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
315        self.emit(data)
316    }
317
318    /// Writes a u16 in network byte order to the buffer
319    pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
320        self.write_slice(&data.to_be_bytes())
321    }
322
323    /// Writes an i32 in network byte order to the buffer
324    pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
325        self.write_slice(&data.to_be_bytes())
326    }
327
328    /// Writes an u32 in network byte order to the buffer
329    pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
330        self.write_slice(&data.to_be_bytes())
331    }
332
333    fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
334        self.buffer.write(self.offset, data)?;
335        self.offset += data.len();
336        Ok(())
337    }
338
339    /// Writes the byte slice to the stream
340    pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
341        self.write_slice(data)
342    }
343
344    /// Emits all the elements of an Iterator to the encoder
345    pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
346        &mut self,
347        mut iter: I,
348    ) -> ProtoResult<usize> {
349        self.emit_iter(&mut iter)
350    }
351
352    // TODO: dedup with above emit_all
353    /// Emits all the elements of an Iterator to the encoder
354    pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
355    where
356        'e: 'r,
357        I: Iterator<Item = &'r &'e E>,
358        E: 'r + 'e + BinEncodable,
359    {
360        let mut iter = iter.cloned();
361        self.emit_iter(&mut iter)
362    }
363
364    /// emits all items in the iterator, return the number emitted
365    #[allow(clippy::needless_return)]
366    pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
367        &mut self,
368        iter: &mut I,
369    ) -> ProtoResult<usize> {
370        let mut count = 0;
371        for i in iter {
372            let rollback = self.set_rollback();
373            i.emit(self).map_err(|e| {
374                if let ProtoErrorKind::MaxBufferSizeExceeded(_) = e.kind() {
375                    rollback.rollback(self);
376                    return ProtoErrorKind::NotAllRecordsWritten { count }.into();
377                } else {
378                    return e;
379                }
380            })?;
381            count += 1;
382        }
383        Ok(count)
384    }
385
386    /// capture a location to write back to
387    pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
388        let index = self.offset;
389        let len = T::size_of();
390
391        // resize the buffer
392        self.buffer.reserve(self.offset, len)?;
393
394        // update the offset
395        self.offset += len;
396
397        Ok(Place {
398            start_index: index,
399            phantom: PhantomData,
400        })
401    }
402
403    /// calculates the length of data written since the place was creating
404    pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
405        (self.offset - place.start_index) - place.size_of()
406    }
407
408    /// write back to a previously captured location
409    pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
410        // preserve current index
411        let current_index = self.offset;
412
413        // reset the current index back to place before writing
414        //   this is an assert because it's programming error for it to be wrong.
415        assert!(place.start_index < current_index);
416        self.offset = place.start_index;
417
418        // emit the data to be written at this place
419        let emit_result = data.emit(self);
420
421        // double check that the current number of bytes were written
422        //   this is an assert because it's programming error for it to be wrong.
423        assert!((self.offset - place.start_index) == place.size_of());
424
425        // reset to original location
426        self.offset = current_index;
427
428        emit_result
429    }
430
431    fn set_rollback(&self) -> Rollback {
432        Rollback {
433            rollback_index: self.offset(),
434        }
435    }
436}
437
438/// A trait to return the size of a type as it will be encoded in DNS
439///
440/// it does not necessarily equal `core::mem::size_of`, though it might, especially for primitives
441pub trait EncodedSize: BinEncodable {
442    /// Return the size in bytes of the
443    fn size_of() -> usize;
444}
445
446impl EncodedSize for u16 {
447    fn size_of() -> usize {
448        2
449    }
450}
451
452impl EncodedSize for Header {
453    fn size_of() -> usize {
454        Self::len()
455    }
456}
457
458#[derive(Debug)]
459#[must_use = "data must be written back to the place"]
460pub struct Place<T: EncodedSize> {
461    start_index: usize,
462    phantom: PhantomData<T>,
463}
464
465impl<T: EncodedSize> Place<T> {
466    pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
467        encoder.emit_at(self, data)
468    }
469
470    pub fn size_of(&self) -> usize {
471        T::size_of()
472    }
473}
474
475/// A type representing a rollback point in a stream
476pub(crate) struct Rollback {
477    rollback_index: usize,
478}
479
480impl Rollback {
481    pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
482        encoder.set_offset(self.rollback_index)
483    }
484}
485
486/// In the Verify mode there maybe some things which are encoded differently, e.g. SIG0 records
487///  should not be included in the additional count and not in the encoded data when in Verify
488#[derive(Copy, Clone, Eq, PartialEq)]
489pub enum EncodeMode {
490    /// In signing mode records are written in canonical form
491    Signing,
492    /// Write records in standard format
493    Normal,
494}
495
496#[cfg(test)]
497mod tests {
498    use core::str::FromStr;
499
500    use super::*;
501    use crate::{
502        op::{Message, Query},
503        rr::{
504            RData, Record, RecordType,
505            rdata::{CNAME, SRV},
506        },
507        serialize::binary::BinDecodable,
508    };
509    use crate::{rr::Name, serialize::binary::BinDecoder};
510
511    #[test]
512    fn test_label_compression_regression() {
513        // https://github.com/hickory-dns/hickory-dns/issues/339
514        /*
515        ;; QUESTION SECTION:
516        ;bluedot.is.autonavi.com.gds.alibabadns.com. IN AAAA
517
518        ;; AUTHORITY SECTION:
519        gds.alibabadns.com.     1799    IN      SOA     gdsns1.alibabadns.com. none. 2015080610 1800 600 3600 360
520        */
521        let data = vec![
522            154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
523            115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
524            97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
525            0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
526            110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
527            0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
528        ];
529
530        let msg = Message::from_vec(&data).unwrap();
531        msg.to_bytes().unwrap();
532    }
533
534    #[test]
535    fn test_size_of() {
536        assert_eq!(u16::size_of(), 2);
537    }
538
539    #[test]
540    fn test_place() {
541        let mut buf = vec![];
542        {
543            let mut encoder = BinEncoder::new(&mut buf);
544            let place = encoder.place::<u16>().unwrap();
545            assert_eq!(place.size_of(), 2);
546            assert_eq!(encoder.len_since_place(&place), 0);
547
548            encoder.emit(42_u8).expect("failed 0");
549            assert_eq!(encoder.len_since_place(&place), 1);
550
551            encoder.emit(48_u8).expect("failed 1");
552            assert_eq!(encoder.len_since_place(&place), 2);
553
554            place
555                .replace(&mut encoder, 4_u16)
556                .expect("failed to replace");
557            drop(encoder);
558        }
559
560        assert_eq!(buf.len(), 4);
561
562        let mut decoder = BinDecoder::new(&buf);
563        let written = decoder.read_u16().expect("cound not read u16").unverified();
564
565        assert_eq!(written, 4);
566    }
567
568    #[test]
569    fn test_max_size() {
570        let mut buf = vec![];
571        let mut encoder = BinEncoder::new(&mut buf);
572
573        encoder.set_max_size(5);
574        encoder.emit(0).expect("failed to write");
575        encoder.emit(1).expect("failed to write");
576        encoder.emit(2).expect("failed to write");
577        encoder.emit(3).expect("failed to write");
578        encoder.emit(4).expect("failed to write");
579        let error = encoder.emit(5).unwrap_err();
580
581        match error.kind() {
582            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
583            _ => panic!(),
584        }
585    }
586
587    #[test]
588    fn test_max_size_0() {
589        let mut buf = vec![];
590        let mut encoder = BinEncoder::new(&mut buf);
591
592        encoder.set_max_size(0);
593        let error = encoder.emit(0).unwrap_err();
594
595        match error.kind() {
596            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
597            _ => panic!(),
598        }
599    }
600
601    #[test]
602    fn test_max_size_place() {
603        let mut buf = vec![];
604        let mut encoder = BinEncoder::new(&mut buf);
605
606        encoder.set_max_size(2);
607        let place = encoder.place::<u16>().expect("place failed");
608        place.replace(&mut encoder, 16).expect("placeback failed");
609
610        let error = encoder.place::<u16>().unwrap_err();
611
612        match error.kind() {
613            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
614            _ => panic!(),
615        }
616    }
617
618    #[test]
619    fn test_target_compression() {
620        let mut msg = Message::new();
621        msg.add_query(Query::query(
622            Name::from_str("www.google.com.").unwrap(),
623            RecordType::A,
624        ))
625        .add_answer(Record::from_rdata(
626            Name::from_str("www.google.com.").unwrap(),
627            0,
628            RData::SRV(SRV::new(
629                0,
630                0,
631                0,
632                Name::from_str("www.compressme.com.").unwrap(),
633            )),
634        ))
635        .add_additional(Record::from_rdata(
636            Name::from_str("www.google.com.").unwrap(),
637            0,
638            RData::SRV(SRV::new(
639                0,
640                0,
641                0,
642                Name::from_str("www.compressme.com.").unwrap(),
643            )),
644        ))
645        // name here should use compressed label from target in previous records
646        .add_answer(Record::from_rdata(
647            Name::from_str("www.compressme.com.").unwrap(),
648            0,
649            RData::CNAME(CNAME(Name::from_str("www.foo.com.").unwrap())),
650        ));
651
652        let bytes = msg.to_vec().unwrap();
653        // label is compressed pointing to target, would be 145 otherwise
654        assert_eq!(bytes.len(), 130);
655        // check re-serializing
656        assert!(Message::from_vec(&bytes).is_ok());
657    }
658
659    #[test]
660    fn test_fuzzed() {
661        const MESSAGE: &[u8] = include_bytes!("../../../tests/test-data/fuzz-long.rdata");
662        let msg = Message::from_bytes(MESSAGE).unwrap();
663        msg.to_bytes().unwrap();
664    }
665}