hickory_proto/op/
message.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Basic protocol message for DNS
9
10use alloc::{boxed::Box, fmt, vec::Vec};
11use core::{iter, mem, ops::Deref};
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use tracing::{debug, warn};
16
17use crate::{
18    error::*,
19    op::{Edns, Header, MessageType, OpCode, Query, ResponseCode},
20    rr::{Record, RecordType},
21    serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, EncodeMode},
22    xfer::DnsResponse,
23};
24
25/// The basic request and response data structure, used for all DNS protocols.
26///
27/// [RFC 1035, DOMAIN NAMES - IMPLEMENTATION AND SPECIFICATION, November 1987](https://tools.ietf.org/html/rfc1035)
28///
29/// ```text
30/// 4.1. Format
31///
32/// All communications inside of the domain protocol are carried in a single
33/// format called a message.  The top level format of message is divided
34/// into 5 sections (some of which are empty in certain cases) shown below:
35///
36///     +--------------------------+
37///     |        Header            |
38///     +--------------------------+
39///     |  Question / Zone         | the question for the name server
40///     +--------------------------+
41///     |   Answer  / Prerequisite | RRs answering the question
42///     +--------------------------+
43///     | Authority / Update       | RRs pointing toward an authority
44///     +--------------------------+
45///     |      Additional          | RRs holding additional information
46///     +--------------------------+
47///
48/// The header section is always present.  The header includes fields that
49/// specify which of the remaining sections are present, and also specify
50/// whether the message is a query or a response, a standard query or some
51/// other opcode, etc.
52///
53/// The names of the sections after the header are derived from their use in
54/// standard queries.  The question section contains fields that describe a
55/// question to a name server.  These fields are a query type (QTYPE), a
56/// query class (QCLASS), and a query domain name (QNAME).  The last three
57/// sections have the same format: a possibly empty list of concatenated
58/// resource records (RRs).  The answer section contains RRs that answer the
59/// question; the authority section contains RRs that point toward an
60/// authoritative name server; the additional records section contains RRs
61/// which relate to the query, but are not strictly answers for the
62/// question.
63/// ```
64///
65/// By default Message is a Query. Use the Message::as_update() to create and update, or
66///  Message::new_update()
67#[derive(Clone, Debug, PartialEq, Eq, Default)]
68#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
69pub struct Message {
70    header: Header,
71    queries: Vec<Query>,
72    answers: Vec<Record>,
73    name_servers: Vec<Record>,
74    additionals: Vec<Record>,
75    signature: Vec<Record>,
76    edns: Option<Edns>,
77}
78
79/// Returns a new Header with accurate counts for each Message section
80pub fn update_header_counts(
81    current_header: &Header,
82    is_truncated: bool,
83    counts: HeaderCounts,
84) -> Header {
85    assert!(counts.query_count <= u16::MAX as usize);
86    assert!(counts.answer_count <= u16::MAX as usize);
87    assert!(counts.nameserver_count <= u16::MAX as usize);
88    assert!(counts.additional_count <= u16::MAX as usize);
89
90    // TODO: should the function just take by value?
91    let mut header = *current_header;
92    header
93        .set_query_count(counts.query_count as u16)
94        .set_answer_count(counts.answer_count as u16)
95        .set_name_server_count(counts.nameserver_count as u16)
96        .set_additional_count(counts.additional_count as u16)
97        .set_truncated(is_truncated);
98
99    header
100}
101
102/// Tracks the counts of the records in the Message.
103///
104/// This is only used internally during serialization.
105#[derive(Clone, Copy, Debug)]
106pub struct HeaderCounts {
107    /// The number of queries in the Message
108    pub query_count: usize,
109    /// The number of answers in the Message
110    pub answer_count: usize,
111    /// The number of nameservers or authorities in the Message
112    pub nameserver_count: usize,
113    /// The number of additional records in the Message
114    pub additional_count: usize,
115}
116
117impl Message {
118    /// Returns a new "empty" Message
119    pub fn new() -> Self {
120        Self {
121            header: Header::new(),
122            queries: Vec::new(),
123            answers: Vec::new(),
124            name_servers: Vec::new(),
125            additionals: Vec::new(),
126            signature: Vec::new(),
127            edns: None,
128        }
129    }
130
131    /// Returns a Message constructed with error details to return to a client
132    ///
133    /// # Arguments
134    ///
135    /// * `id` - message id should match the request message id
136    /// * `op_code` - operation of the request
137    /// * `response_code` - the error code for the response
138    pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Self {
139        let mut message = Self::new();
140        message
141            .set_message_type(MessageType::Response)
142            .set_id(id)
143            .set_response_code(response_code)
144            .set_op_code(op_code);
145
146        message
147    }
148
149    /// Truncates a Message, this blindly removes all response fields and sets truncated to `true`
150    pub fn truncate(&self) -> Self {
151        // copy header
152        let mut header = self.header;
153        header.set_truncated(true);
154        header
155            .set_additional_count(0)
156            .set_answer_count(0)
157            .set_name_server_count(0);
158
159        let mut msg = Self::new();
160        // drops additional/answer/nameservers/signature
161        // adds query/OPT
162        msg.add_queries(self.queries().iter().cloned());
163        if let Some(edns) = self.extensions().clone() {
164            msg.set_edns(edns);
165        }
166        // set header
167        msg.set_header(header);
168
169        // TODO, perhaps just quickly add a few response records here? that we know would fit?
170        msg
171    }
172
173    /// Sets the `Header` with provided
174    pub fn set_header(&mut self, header: Header) -> &mut Self {
175        self.header = header;
176        self
177    }
178
179    /// see `Header::set_id`
180    pub fn set_id(&mut self, id: u16) -> &mut Self {
181        self.header.set_id(id);
182        self
183    }
184
185    /// see `Header::set_message_type`
186    pub fn set_message_type(&mut self, message_type: MessageType) -> &mut Self {
187        self.header.set_message_type(message_type);
188        self
189    }
190
191    /// see `Header::set_op_code`
192    pub fn set_op_code(&mut self, op_code: OpCode) -> &mut Self {
193        self.header.set_op_code(op_code);
194        self
195    }
196
197    /// see `Header::set_authoritative`
198    pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
199        self.header.set_authoritative(authoritative);
200        self
201    }
202
203    /// see `Header::set_truncated`
204    pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
205        self.header.set_truncated(truncated);
206        self
207    }
208
209    /// see `Header::set_recursion_desired`
210    pub fn set_recursion_desired(&mut self, recursion_desired: bool) -> &mut Self {
211        self.header.set_recursion_desired(recursion_desired);
212        self
213    }
214
215    /// see `Header::set_recursion_available`
216    pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
217        self.header.set_recursion_available(recursion_available);
218        self
219    }
220
221    /// see `Header::set_authentic_data`
222    pub fn set_authentic_data(&mut self, authentic_data: bool) -> &mut Self {
223        self.header.set_authentic_data(authentic_data);
224        self
225    }
226
227    /// see `Header::set_checking_disabled`
228    pub fn set_checking_disabled(&mut self, checking_disabled: bool) -> &mut Self {
229        self.header.set_checking_disabled(checking_disabled);
230        self
231    }
232
233    /// see `Header::set_response_code`
234    pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
235        self.header.set_response_code(response_code);
236        self
237    }
238
239    /// see `Header::set_query_count`
240    ///
241    /// this count will be ignored during serialization,
242    /// where the length of the associated records will be used instead.
243    pub fn set_query_count(&mut self, query_count: u16) -> &mut Self {
244        self.header.set_query_count(query_count);
245        self
246    }
247
248    /// see `Header::set_answer_count`
249    ///
250    /// this count will be ignored during serialization,
251    /// where the length of the associated records will be used instead.
252    pub fn set_answer_count(&mut self, answer_count: u16) -> &mut Self {
253        self.header.set_answer_count(answer_count);
254        self
255    }
256
257    /// see `Header::set_name_server_count`
258    ///
259    /// this count will be ignored during serialization,
260    /// where the length of the associated records will be used instead.
261    pub fn set_name_server_count(&mut self, name_server_count: u16) -> &mut Self {
262        self.header.set_name_server_count(name_server_count);
263        self
264    }
265
266    /// see `Header::set_additional_count`
267    ///
268    /// this count will be ignored during serialization,
269    /// where the length of the associated records will be used instead.
270    pub fn set_additional_count(&mut self, additional_count: u16) -> &mut Self {
271        self.header.set_additional_count(additional_count);
272        self
273    }
274
275    /// Add a query to the Message, either the query response from the server, or the request Query.
276    pub fn add_query(&mut self, query: Query) -> &mut Self {
277        self.queries.push(query);
278        self
279    }
280
281    /// Adds an iterator over a set of Queries to be added to the message
282    pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
283    where
284        Q: IntoIterator<Item = Query, IntoIter = I>,
285        I: Iterator<Item = Query>,
286    {
287        for query in queries {
288            self.add_query(query);
289        }
290
291        self
292    }
293
294    /// Add an answer to the Message
295    pub fn add_answer(&mut self, record: Record) -> &mut Self {
296        self.answers.push(record);
297        self
298    }
299
300    /// Add all the records from the iterator to the answers section of the Message
301    pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
302    where
303        R: IntoIterator<Item = Record, IntoIter = I>,
304        I: Iterator<Item = Record>,
305    {
306        for record in records {
307            self.add_answer(record);
308        }
309
310        self
311    }
312
313    /// Sets the answers to the specified set of Records.
314    ///
315    /// # Panics
316    ///
317    /// Will panic if answer records are already associated to the message.
318    pub fn insert_answers(&mut self, records: Vec<Record>) {
319        assert!(self.answers.is_empty());
320        self.answers = records;
321    }
322
323    /// Add a name server record to the Message
324    pub fn add_name_server(&mut self, record: Record) -> &mut Self {
325        self.name_servers.push(record);
326        self
327    }
328
329    /// Add all the records in the Iterator to the name server section of the message
330    pub fn add_name_servers<R, I>(&mut self, records: R) -> &mut Self
331    where
332        R: IntoIterator<Item = Record, IntoIter = I>,
333        I: Iterator<Item = Record>,
334    {
335        for record in records {
336            self.add_name_server(record);
337        }
338
339        self
340    }
341
342    /// Sets the name_servers to the specified set of Records.
343    ///
344    /// # Panics
345    ///
346    /// Will panic if name_servers records are already associated to the message.
347    pub fn insert_name_servers(&mut self, records: Vec<Record>) {
348        assert!(self.name_servers.is_empty());
349        self.name_servers = records;
350    }
351
352    /// Add an additional Record to the message
353    pub fn add_additional(&mut self, record: Record) -> &mut Self {
354        self.additionals.push(record);
355        self
356    }
357
358    /// Add all the records from the iterator to the additionals section of the Message
359    pub fn add_additionals<R, I>(&mut self, records: R) -> &mut Self
360    where
361        R: IntoIterator<Item = Record, IntoIter = I>,
362        I: Iterator<Item = Record>,
363    {
364        for record in records {
365            self.add_additional(record);
366        }
367
368        self
369    }
370
371    /// Sets the additional to the specified set of Records.
372    ///
373    /// # Panics
374    ///
375    /// Will panic if additional records are already associated to the message.
376    pub fn insert_additionals(&mut self, records: Vec<Record>) {
377        assert!(self.additionals.is_empty());
378        self.additionals = records;
379    }
380
381    /// Add the EDNS section to the Message
382    pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
383        self.edns = Some(edns);
384        self
385    }
386
387    /// Add a SIG0 record, i.e. sign this message
388    ///
389    /// This must be used only after all records have been associated. Generally this will be handled by the client and not need to be used directly
390    #[cfg(feature = "__dnssec")]
391    pub fn add_sig0(&mut self, record: Record) -> &mut Self {
392        assert_eq!(RecordType::SIG, record.record_type());
393        self.signature.push(record);
394        self
395    }
396
397    /// Add a TSIG record, i.e. authenticate this message
398    ///
399    /// This must be used only after all records have been associated. Generally this will be handled by the client and not need to be used directly
400    #[cfg(feature = "__dnssec")]
401    pub fn add_tsig(&mut self, record: Record) -> &mut Self {
402        assert_eq!(RecordType::TSIG, record.record_type());
403        self.signature.push(record);
404        self
405    }
406
407    /// Gets the header of the Message
408    pub fn header(&self) -> &Header {
409        &self.header
410    }
411
412    /// see `Header::id()`
413    pub fn id(&self) -> u16 {
414        self.header.id()
415    }
416
417    /// see `Header::message_type()`
418    pub fn message_type(&self) -> MessageType {
419        self.header.message_type()
420    }
421
422    /// see `Header::op_code()`
423    pub fn op_code(&self) -> OpCode {
424        self.header.op_code()
425    }
426
427    /// see `Header::authoritative()`
428    pub fn authoritative(&self) -> bool {
429        self.header.authoritative()
430    }
431
432    /// see `Header::truncated()`
433    pub fn truncated(&self) -> bool {
434        self.header.truncated()
435    }
436
437    /// see `Header::recursion_desired()`
438    pub fn recursion_desired(&self) -> bool {
439        self.header.recursion_desired()
440    }
441
442    /// see `Header::recursion_available()`
443    pub fn recursion_available(&self) -> bool {
444        self.header.recursion_available()
445    }
446
447    /// see `Header::authentic_data()`
448    pub fn authentic_data(&self) -> bool {
449        self.header.authentic_data()
450    }
451
452    /// see `Header::checking_disabled()`
453    pub fn checking_disabled(&self) -> bool {
454        self.header.checking_disabled()
455    }
456
457    /// # Return value
458    ///
459    /// The `ResponseCode`, if this is an EDNS message then this will join the section from the OPT
460    ///  record to create the EDNS `ResponseCode`
461    pub fn response_code(&self) -> ResponseCode {
462        self.header.response_code()
463    }
464
465    /// Returns the query from this Message.
466    ///
467    /// In almost all cases, a Message will only contain one query. This is a convenience function to get the single query.
468    /// See the alternative `queries*` methods for the raw set of queries in the Message
469    pub fn query(&self) -> Option<&Query> {
470        self.queries.first()
471    }
472
473    /// ```text
474    /// Question        Carries the query name and other query parameters.
475    /// ```
476    pub fn queries(&self) -> &[Query] {
477        &self.queries
478    }
479
480    /// Provides mutable access to `queries`
481    pub fn queries_mut(&mut self) -> &mut Vec<Query> {
482        &mut self.queries
483    }
484
485    /// Removes all the answers from the Message
486    pub fn take_queries(&mut self) -> Vec<Query> {
487        mem::take(&mut self.queries)
488    }
489
490    /// ```text
491    /// Answer          Carries RRs which directly answer the query.
492    /// ```
493    pub fn answers(&self) -> &[Record] {
494        &self.answers
495    }
496
497    /// Provides mutable access to `answers`
498    pub fn answers_mut(&mut self) -> &mut Vec<Record> {
499        &mut self.answers
500    }
501
502    /// Removes all the answers from the Message
503    pub fn take_answers(&mut self) -> Vec<Record> {
504        mem::take(&mut self.answers)
505    }
506
507    /// ```text
508    /// Authority       Carries RRs which describe other authoritative servers.
509    ///                 May optionally carry the SOA RR for the authoritative
510    ///                 data in the answer section.
511    /// ```
512    pub fn name_servers(&self) -> &[Record] {
513        &self.name_servers
514    }
515
516    /// Provides mutable access to `name_servers`
517    pub fn name_servers_mut(&mut self) -> &mut Vec<Record> {
518        &mut self.name_servers
519    }
520
521    /// Remove the name servers from the Message
522    pub fn take_name_servers(&mut self) -> Vec<Record> {
523        mem::take(&mut self.name_servers)
524    }
525
526    /// ```text
527    /// Additional      Carries RRs which may be helpful in using the RRs in the
528    ///                 other sections.
529    /// ```
530    pub fn additionals(&self) -> &[Record] {
531        &self.additionals
532    }
533
534    /// Provides mutable access to `additionals`
535    pub fn additionals_mut(&mut self) -> &mut Vec<Record> {
536        &mut self.additionals
537    }
538
539    /// Remove the additional Records from the Message
540    pub fn take_additionals(&mut self) -> Vec<Record> {
541        mem::take(&mut self.additionals)
542    }
543
544    /// All sections chained
545    pub fn all_sections(&self) -> impl Iterator<Item = &Record> {
546        self.answers
547            .iter()
548            .chain(self.name_servers().iter())
549            .chain(self.additionals.iter())
550    }
551
552    /// [RFC 6891, EDNS(0) Extensions, April 2013](https://tools.ietf.org/html/rfc6891#section-6.1.1)
553    ///
554    /// ```text
555    /// 6.1.1.  Basic Elements
556    ///
557    ///  An OPT pseudo-RR (sometimes called a meta-RR) MAY be added to the
558    ///  additional data section of a request.
559    ///
560    ///  The OPT RR has RR type 41.
561    ///
562    ///  If an OPT record is present in a received request, compliant
563    ///  responders MUST include an OPT record in their respective responses.
564    ///
565    ///  An OPT record does not carry any DNS data.  It is used only to
566    ///  contain control information pertaining to the question-and-answer
567    ///  sequence of a specific transaction.  OPT RRs MUST NOT be cached,
568    ///  forwarded, or stored in or loaded from Zone Files.
569    ///
570    ///  The OPT RR MAY be placed anywhere within the additional data section.
571    ///  When an OPT RR is included within any DNS message, it MUST be the
572    ///  only OPT RR in that message.  If a query message with more than one
573    ///  OPT RR is received, a FORMERR (RCODE=1) MUST be returned.  The
574    ///  placement flexibility for the OPT RR does not override the need for
575    ///  the TSIG or SIG(0) RRs to be the last in the additional section
576    ///  whenever they are present.
577    /// ```
578    /// # Return value
579    ///
580    /// Optionally returns a reference to EDNS section
581    #[deprecated(note = "Please use `extensions()`")]
582    pub fn edns(&self) -> Option<&Edns> {
583        self.edns.as_ref()
584    }
585
586    /// Optionally returns mutable reference to EDNS section
587    #[deprecated(
588        note = "Please use `extensions_mut()`. You can chain `.get_or_insert_with(Edns::new)` to recover original behavior of adding Edns if not present"
589    )]
590    pub fn edns_mut(&mut self) -> &mut Edns {
591        if self.edns.is_none() {
592            self.set_edns(Edns::new());
593        }
594        self.edns.as_mut().unwrap()
595    }
596
597    /// Returns reference of Edns section
598    pub fn extensions(&self) -> &Option<Edns> {
599        &self.edns
600    }
601
602    /// Returns mutable reference of Edns section
603    pub fn extensions_mut(&mut self) -> &mut Option<Edns> {
604        &mut self.edns
605    }
606
607    /// # Return value
608    ///
609    /// the max payload value as it's defined in the EDNS section.
610    pub fn max_payload(&self) -> u16 {
611        let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
612        if max_size < 512 { 512 } else { max_size }
613    }
614
615    /// # Return value
616    ///
617    /// the version as defined in the EDNS record
618    pub fn version(&self) -> u8 {
619        self.edns.as_ref().map_or(0, Edns::version)
620    }
621
622    /// [RFC 2535, Domain Name System Security Extensions, March 1999](https://tools.ietf.org/html/rfc2535#section-4)
623    ///
624    /// ```text
625    /// A DNS request may be optionally signed by including one or more SIGs
626    ///  at the end of the query. Such SIGs are identified by having a "type
627    ///  covered" field of zero. They sign the preceding DNS request message
628    ///  including DNS header but not including the IP header or any request
629    ///  SIGs at the end and before the request RR counts have been adjusted
630    ///  for the inclusions of any request SIG(s).
631    /// ```
632    ///
633    /// # Return value
634    ///
635    /// The sig0 and tsig, i.e. signed record, for verifying the sending and package integrity
636    // comportment change: can now return TSIG instead of SIG0. Maybe should get deprecated in
637    // favor of signature() which have more correct naming ?
638    pub fn sig0(&self) -> &[Record] {
639        &self.signature
640    }
641
642    /// [RFC 2535, Domain Name System Security Extensions, March 1999](https://tools.ietf.org/html/rfc2535#section-4)
643    ///
644    /// ```text
645    /// A DNS request may be optionally signed by including one or more SIGs
646    ///  at the end of the query. Such SIGs are identified by having a "type
647    ///  covered" field of zero. They sign the preceding DNS request message
648    ///  including DNS header but not including the IP header or any request
649    ///  SIGs at the end and before the request RR counts have been adjusted
650    ///  for the inclusions of any request SIG(s).
651    /// ```
652    ///
653    /// # Return value
654    ///
655    /// The sig0 and tsig, i.e. signed record, for verifying the sending and package integrity
656    pub fn signature(&self) -> &[Record] {
657        &self.signature
658    }
659
660    /// Remove signatures from the Message
661    pub fn take_signature(&mut self) -> Vec<Record> {
662        mem::take(&mut self.signature)
663    }
664
665    // TODO: only necessary in tests, should it be removed?
666    /// this is necessary to match the counts in the header from the record sections
667    ///  this happens implicitly on write_to, so no need to call before write_to
668    #[cfg(test)]
669    pub fn update_counts(&mut self) -> &mut Self {
670        self.header = update_header_counts(
671            &self.header,
672            self.truncated(),
673            HeaderCounts {
674                query_count: self.queries.len(),
675                answer_count: self.answers.len(),
676                nameserver_count: self.name_servers.len(),
677                additional_count: self.additionals.len(),
678            },
679        );
680        self
681    }
682
683    /// Attempts to read the specified number of `Query`s
684    pub fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<Query>> {
685        let mut queries = Vec::with_capacity(count);
686        for _ in 0..count {
687            queries.push(Query::read(decoder)?);
688        }
689        Ok(queries)
690    }
691
692    /// Attempts to read the specified number of records
693    ///
694    /// # Returns
695    ///
696    /// This returns a tuple of first standard Records, then a possibly associated Edns, and then finally any optionally associated SIG0 and TSIG records.
697    #[cfg_attr(not(feature = "__dnssec"), allow(unused_mut))]
698    pub fn read_records(
699        decoder: &mut BinDecoder<'_>,
700        count: usize,
701        is_additional: bool,
702    ) -> ProtoResult<(Vec<Record>, Option<Edns>, Vec<Record>)> {
703        let mut records: Vec<Record> = Vec::with_capacity(count);
704        let mut edns: Option<Edns> = None;
705        let mut sigs: Vec<Record> = Vec::with_capacity(if is_additional { 1 } else { 0 });
706
707        // sig0 must be last, once this is set, disable.
708        let mut saw_sig0 = false;
709        // tsig must be last, once this is set, disable.
710        let mut saw_tsig = false;
711        for _ in 0..count {
712            let record = Record::read(decoder)?;
713            if saw_tsig {
714                return Err("tsig must be final resource record".into());
715            } // TSIG must be last and multiple TSIG records are not allowed
716            if !is_additional {
717                if saw_sig0 {
718                    return Err("sig0 must be final resource record".into());
719                } // SIG0 must be last
720                records.push(record)
721            } else {
722                match record.record_type() {
723                    #[cfg(feature = "__dnssec")]
724                    RecordType::SIG => {
725                        saw_sig0 = true;
726                        sigs.push(record);
727                    }
728                    #[cfg(feature = "__dnssec")]
729                    RecordType::TSIG => {
730                        if saw_sig0 {
731                            return Err("sig0 must be final resource record".into());
732                        } // SIG0 must be last
733                        saw_tsig = true;
734                        sigs.push(record);
735                    }
736                    RecordType::OPT => {
737                        if saw_sig0 {
738                            return Err("sig0 must be final resource record".into());
739                        } // SIG0 must be last
740                        if edns.is_some() {
741                            return Err("more than one edns record present".into());
742                        }
743                        edns = Some((&record).into());
744                    }
745                    _ => {
746                        if saw_sig0 {
747                            return Err("sig0 must be final resource record".into());
748                        } // SIG0 must be last
749                        records.push(record);
750                    }
751                }
752            }
753        }
754
755        Ok((records, edns, sigs))
756    }
757
758    /// Decodes a message from the buffer.
759    pub fn from_vec(buffer: &[u8]) -> ProtoResult<Self> {
760        let mut decoder = BinDecoder::new(buffer);
761        Self::read(&mut decoder)
762    }
763
764    /// Encodes the Message into a buffer
765    pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
766        // TODO: this feels like the right place to verify the max packet size of the message,
767        //  will need to update the header for truncation and the lengths if we send less than the
768        //  full response. This needs to conform with the EDNS settings of the server...
769        let mut buffer = Vec::with_capacity(512);
770        {
771            let mut encoder = BinEncoder::new(&mut buffer);
772            self.emit(&mut encoder)?;
773        }
774
775        Ok(buffer)
776    }
777
778    /// Finalize the message prior to sending.
779    ///
780    /// Subsequent to calling this, the Message should not change.
781    #[allow(clippy::match_single_binding)]
782    pub fn finalize(
783        &mut self,
784        finalizer: &dyn MessageFinalizer,
785        inception_time: u32,
786    ) -> ProtoResult<Option<MessageVerifier>> {
787        debug!("finalizing message: {:?}", self);
788        let (finals, verifier): (Vec<Record>, Option<MessageVerifier>) =
789            finalizer.finalize_message(self, inception_time)?;
790
791        // append all records to message
792        for fin in finals {
793            match fin.record_type() {
794                // SIG0's are special, and come at the very end of the message
795                #[cfg(feature = "__dnssec")]
796                RecordType::SIG => self.add_sig0(fin),
797                #[cfg(feature = "__dnssec")]
798                RecordType::TSIG => self.add_tsig(fin),
799                _ => self.add_additional(fin),
800            };
801        }
802
803        Ok(verifier)
804    }
805
806    /// Consumes `Message` and returns into components
807    pub fn into_parts(self) -> MessageParts {
808        self.into()
809    }
810}
811
812/// Consumes `Message` giving public access to fields in `Message` so they can be
813/// destructured and taken by value
814/// ```rust
815/// use hickory_proto::op::{Message, MessageParts};
816///
817///  let msg = Message::new();
818///  let MessageParts { queries, .. } = msg.into_parts();
819/// ```
820#[derive(Clone, Debug, PartialEq, Eq, Default)]
821pub struct MessageParts {
822    /// message header
823    pub header: Header,
824    /// message queries
825    pub queries: Vec<Query>,
826    /// message answers
827    pub answers: Vec<Record>,
828    /// message name_servers
829    pub name_servers: Vec<Record>,
830    /// message additional records
831    pub additionals: Vec<Record>,
832    /// sig0 or tsig
833    // this can now contains TSIG too. It should probably be renamed to reflect that, but it's a
834    // breaking change
835    pub sig0: Vec<Record>,
836    /// optional edns records
837    pub edns: Option<Edns>,
838}
839
840impl From<Message> for MessageParts {
841    fn from(msg: Message) -> Self {
842        let Message {
843            header,
844            queries,
845            answers,
846            name_servers,
847            additionals,
848            signature,
849            edns,
850        } = msg;
851        Self {
852            header,
853            queries,
854            answers,
855            name_servers,
856            additionals,
857            sig0: signature,
858            edns,
859        }
860    }
861}
862
863impl From<MessageParts> for Message {
864    fn from(msg: MessageParts) -> Self {
865        let MessageParts {
866            header,
867            queries,
868            answers,
869            name_servers,
870            additionals,
871            sig0,
872            edns,
873        } = msg;
874        Self {
875            header,
876            queries,
877            answers,
878            name_servers,
879            additionals,
880            signature: sig0,
881            edns,
882        }
883    }
884}
885
886impl Deref for Message {
887    type Target = Header;
888
889    fn deref(&self) -> &Self::Target {
890        &self.header
891    }
892}
893
894/// Alias for a function verifying if a message is properly signed
895pub type MessageVerifier = Box<dyn FnMut(&[u8]) -> ProtoResult<DnsResponse> + Send>;
896
897/// A trait for performing final amendments to a Message before it is sent.
898///
899/// An example of this is a SIG0 signer, which needs the final form of the message,
900///  but then needs to attach additional data to the body of the message.
901pub trait MessageFinalizer: Send + Sync + 'static {
902    /// The message taken in should be processed and then return [`Record`]s which should be
903    ///  appended to the additional section of the message.
904    ///
905    /// # Arguments
906    ///
907    /// * `message` - message to process
908    /// * `current_time` - the current time as specified by the system, it's not recommended to read the current time as that makes testing complicated.
909    ///
910    /// # Return
911    ///
912    /// A vector to append to the additionals section of the message, sorted in the order as they should appear in the message.
913    fn finalize_message(
914        &self,
915        message: &Message,
916        current_time: u32,
917    ) -> ProtoResult<(Vec<Record>, Option<MessageVerifier>)>;
918
919    /// Return whether the message requires further processing before being sent
920    /// By default, returns true for AXFR and IXFR queries, and Update and Notify messages
921    fn should_finalize_message(&self, message: &Message) -> bool {
922        [OpCode::Update, OpCode::Notify].contains(&message.op_code())
923            || message
924                .queries()
925                .iter()
926                .any(|q| [RecordType::AXFR, RecordType::IXFR].contains(&q.query_type()))
927    }
928}
929
930/// Returns the count written and a boolean if it was truncated
931pub fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(usize, bool)> {
932    result.map(|count| (count, false)).or_else(|e| {
933        if let ProtoErrorKind::NotAllRecordsWritten { count } = e.kind() {
934            return Ok((*count, true));
935        }
936
937        Err(e)
938    })
939}
940
941/// A trait that defines types which can be emitted as a set, with the associated count returned.
942pub trait EmitAndCount {
943    /// Emit self to the encoder and return the count of items
944    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize>;
945}
946
947impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
948    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
949        encoder.emit_all(self)
950    }
951}
952
953/// Emits the different sections of a message properly
954///
955/// # Return
956///
957/// In the case of a successful emit, the final header (updated counts, etc) is returned for help with logging, etc.
958#[allow(clippy::too_many_arguments)]
959pub fn emit_message_parts<Q, A, N, D>(
960    header: &Header,
961    queries: &mut Q,
962    answers: &mut A,
963    name_servers: &mut N,
964    additionals: &mut D,
965    edns: Option<&Edns>,
966    signature: &[Record],
967    encoder: &mut BinEncoder<'_>,
968) -> ProtoResult<Header>
969where
970    Q: EmitAndCount,
971    A: EmitAndCount,
972    N: EmitAndCount,
973    D: EmitAndCount,
974{
975    let include_signature: bool = encoder.mode() != EncodeMode::Signing;
976    let place = encoder.place::<Header>()?;
977
978    let query_count = queries.emit(encoder)?;
979    // TODO: need to do something on max records
980    //  return offset of last emitted record.
981    let answer_count = count_was_truncated(answers.emit(encoder))?;
982    let nameserver_count = count_was_truncated(name_servers.emit(encoder))?;
983    let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
984
985    if let Some(mut edns) = edns.cloned() {
986        // need to commit the error code
987        edns.set_rcode_high(header.response_code().high());
988
989        let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(&edns))))?;
990        additional_count.0 += count.0;
991        additional_count.1 |= count.1;
992    } else if header.response_code().high() > 0 {
993        warn!(
994            "response code: {} for request: {} requires EDNS but none available",
995            header.response_code(),
996            header.id()
997        );
998    }
999
1000    // this is a little hacky, but if we are Verifying a signature, i.e. the original Message
1001    //  then the SIG0 records should not be encoded and the edns record (if it exists) is already
1002    //  part of the additionals section.
1003    if include_signature {
1004        let count = count_was_truncated(encoder.emit_all(signature.iter()))?;
1005        additional_count.0 += count.0;
1006        additional_count.1 |= count.1;
1007    }
1008
1009    let counts = HeaderCounts {
1010        query_count,
1011        answer_count: answer_count.0,
1012        nameserver_count: nameserver_count.0,
1013        additional_count: additional_count.0,
1014    };
1015    let was_truncated =
1016        header.truncated() || answer_count.1 || nameserver_count.1 || additional_count.1;
1017
1018    let final_header = update_header_counts(header, was_truncated, counts);
1019    place.replace(encoder, final_header)?;
1020    Ok(final_header)
1021}
1022
1023impl BinEncodable for Message {
1024    fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
1025        emit_message_parts(
1026            &self.header,
1027            &mut self.queries.iter(),
1028            &mut self.answers.iter(),
1029            &mut self.name_servers.iter(),
1030            &mut self.additionals.iter(),
1031            self.edns.as_ref(),
1032            &self.signature,
1033            encoder,
1034        )?;
1035
1036        Ok(())
1037    }
1038}
1039
1040impl<'r> BinDecodable<'r> for Message {
1041    fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
1042        let mut header = Header::read(decoder)?;
1043
1044        // TODO: return just header, and in the case of the rest of message getting an error.
1045        //  this could improve error detection while decoding.
1046
1047        // get the questions
1048        let count = header.query_count() as usize;
1049        let mut queries = Vec::with_capacity(count);
1050        for _ in 0..count {
1051            queries.push(Query::read(decoder)?);
1052        }
1053
1054        // get all counts before header moves
1055        let answer_count = header.answer_count() as usize;
1056        let name_server_count = header.name_server_count() as usize;
1057        let additional_count = header.additional_count() as usize;
1058
1059        let (answers, _, _) = Self::read_records(decoder, answer_count, false)?;
1060        let (name_servers, _, _) = Self::read_records(decoder, name_server_count, false)?;
1061        let (additionals, edns, signature) = Self::read_records(decoder, additional_count, true)?;
1062
1063        // need to grab error code from EDNS (which might have a higher value)
1064        if let Some(edns) = &edns {
1065            let high_response_code = edns.rcode_high();
1066            header.merge_response_code(high_response_code);
1067        }
1068
1069        Ok(Self {
1070            header,
1071            queries,
1072            answers,
1073            name_servers,
1074            additionals,
1075            signature,
1076            edns,
1077        })
1078    }
1079}
1080
1081impl fmt::Display for Message {
1082    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1083        let write_query = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1084            for d in slice {
1085                writeln!(f, ";; {d}")?;
1086            }
1087
1088            Ok(())
1089        };
1090
1091        let write_slice = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1092            for d in slice {
1093                writeln!(f, "{d}")?;
1094            }
1095
1096            Ok(())
1097        };
1098
1099        writeln!(f, "; header {header}", header = self.header())?;
1100
1101        if let Some(edns) = self.extensions() {
1102            writeln!(f, "; edns {edns}")?;
1103        }
1104
1105        writeln!(f, "; query")?;
1106        write_query(self.queries(), f)?;
1107
1108        if self.header().message_type() == MessageType::Response
1109            || self.header().op_code() == OpCode::Update
1110        {
1111            writeln!(f, "; answers {}", self.answer_count())?;
1112            write_slice(self.answers(), f)?;
1113            writeln!(f, "; nameservers {}", self.name_server_count())?;
1114            write_slice(self.name_servers(), f)?;
1115            writeln!(f, "; additionals {}", self.additional_count())?;
1116            write_slice(self.additionals(), f)?;
1117        }
1118
1119        Ok(())
1120    }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125    use super::*;
1126
1127    #[test]
1128    fn test_emit_and_read_header() {
1129        let mut message = Message::new();
1130        message
1131            .set_id(10)
1132            .set_message_type(MessageType::Response)
1133            .set_op_code(OpCode::Update)
1134            .set_authoritative(true)
1135            .set_truncated(false)
1136            .set_recursion_desired(true)
1137            .set_recursion_available(true)
1138            .set_response_code(ResponseCode::ServFail);
1139
1140        test_emit_and_read(message);
1141    }
1142
1143    #[test]
1144    fn test_emit_and_read_query() {
1145        let mut message = Message::new();
1146        message
1147            .set_id(10)
1148            .set_message_type(MessageType::Response)
1149            .set_op_code(OpCode::Update)
1150            .set_authoritative(true)
1151            .set_truncated(true)
1152            .set_recursion_desired(true)
1153            .set_recursion_available(true)
1154            .set_response_code(ResponseCode::ServFail)
1155            .add_query(Query::new())
1156            .update_counts(); // we're not testing the query parsing, just message
1157
1158        test_emit_and_read(message);
1159    }
1160
1161    #[test]
1162    fn test_emit_and_read_records() {
1163        let mut message = Message::new();
1164        message
1165            .set_id(10)
1166            .set_message_type(MessageType::Response)
1167            .set_op_code(OpCode::Update)
1168            .set_authoritative(true)
1169            .set_truncated(true)
1170            .set_recursion_desired(true)
1171            .set_recursion_available(true)
1172            .set_authentic_data(true)
1173            .set_checking_disabled(true)
1174            .set_response_code(ResponseCode::ServFail);
1175
1176        message.add_answer(Record::stub());
1177        message.add_name_server(Record::stub());
1178        message.add_additional(Record::stub());
1179        message.update_counts();
1180
1181        test_emit_and_read(message);
1182    }
1183
1184    #[cfg(test)]
1185    fn test_emit_and_read(message: Message) {
1186        let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1187        {
1188            let mut encoder = BinEncoder::new(&mut byte_vec);
1189            message.emit(&mut encoder).unwrap();
1190        }
1191
1192        let mut decoder = BinDecoder::new(&byte_vec);
1193        let got = Message::read(&mut decoder).unwrap();
1194
1195        assert_eq!(got, message);
1196    }
1197
1198    #[test]
1199    fn test_header_counts_correction_after_emit_read() {
1200        let mut message = Message::new();
1201
1202        message
1203            .set_id(10)
1204            .set_message_type(MessageType::Response)
1205            .set_op_code(OpCode::Update)
1206            .set_authoritative(true)
1207            .set_truncated(true)
1208            .set_recursion_desired(true)
1209            .set_recursion_available(true)
1210            .set_authentic_data(true)
1211            .set_checking_disabled(true)
1212            .set_response_code(ResponseCode::ServFail);
1213
1214        message.add_answer(Record::stub());
1215        message.add_name_server(Record::stub());
1216        message.add_additional(Record::stub());
1217
1218        // at here, we don't call update_counts and we even set wrong count,
1219        // because we are trying to test whether the counts in the header
1220        // are correct after the message is emitted and read.
1221        message.set_query_count(1);
1222        message.set_answer_count(5);
1223        message.set_name_server_count(5);
1224        // message.set_additional_count(1);
1225
1226        let got = get_message_after_emitting_and_reading(message);
1227
1228        // make comparison
1229        assert_eq!(got.query_count(), 0);
1230        assert_eq!(got.answer_count(), 1);
1231        assert_eq!(got.name_server_count(), 1);
1232        assert_eq!(got.additional_count(), 1);
1233    }
1234
1235    #[cfg(test)]
1236    fn get_message_after_emitting_and_reading(message: Message) -> Message {
1237        let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1238        {
1239            let mut encoder = BinEncoder::new(&mut byte_vec);
1240            message.emit(&mut encoder).unwrap();
1241        }
1242
1243        let mut decoder = BinDecoder::new(&byte_vec);
1244
1245        Message::read(&mut decoder).unwrap()
1246    }
1247
1248    #[test]
1249    fn test_legit_message() {
1250        #[rustfmt::skip]
1251        let buf: Vec<u8> = vec![
1252            0x10, 0x00, 0x81,
1253            0x80, // id = 4096, response, op=query, recursion_desired, recursion_available, no_error
1254            0x00, 0x01, 0x00, 0x01, // 1 query, 1 answer,
1255            0x00, 0x00, 0x00, 0x00, // 0 nameservers, 0 additional record
1256            0x03, b'w', b'w', b'w', // query --- www.example.com
1257            0x07, b'e', b'x', b'a', //
1258            b'm', b'p', b'l', b'e', //
1259            0x03, b'c', b'o', b'm', //
1260            0x00,                   // 0 = endname
1261            0x00, 0x01, 0x00, 0x01, // RecordType = A, Class = IN
1262            0xC0, 0x0C,             // name pointer to www.example.com
1263            0x00, 0x01, 0x00, 0x01, // RecordType = A, Class = IN
1264            0x00, 0x00, 0x00, 0x02, // TTL = 2 seconds
1265            0x00, 0x04,             // record length = 4 (ipv4 address)
1266            0x5D, 0xB8, 0xD7, 0x0E, // address = 93.184.215.14
1267        ];
1268
1269        let mut decoder = BinDecoder::new(&buf);
1270        let message = Message::read(&mut decoder).unwrap();
1271
1272        assert_eq!(message.id(), 4_096);
1273
1274        let mut buf: Vec<u8> = Vec::with_capacity(512);
1275        {
1276            let mut encoder = BinEncoder::new(&mut buf);
1277            message.emit(&mut encoder).unwrap();
1278        }
1279
1280        let mut decoder = BinDecoder::new(&buf);
1281        let message = Message::read(&mut decoder).unwrap();
1282
1283        assert_eq!(message.id(), 4_096);
1284    }
1285
1286    #[test]
1287    fn rdata_zero_roundtrip() {
1288        let buf = &[
1289            160, 160, 0, 13, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1290        ];
1291
1292        assert!(Message::from_bytes(buf).is_err());
1293    }
1294
1295    #[test]
1296    fn nsec_deserialization() {
1297        const CRASHING_MESSAGE: &[u8] = &[
1298            0, 0, 132, 0, 0, 0, 0, 1, 0, 0, 0, 1, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100,
1299            52, 50, 52, 45, 52, 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55,
1300            56, 48, 102, 50, 98, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4,
1301            192, 168, 1, 17, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100, 52, 50, 52, 45, 52,
1302            102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55, 56, 48, 102, 50, 98,
1303            5, 108, 111, 99, 97, 108, 0, 0, 47, 128, 1, 0, 0, 0, 120, 0, 5, 192, 70, 0, 1, 64,
1304        ];
1305
1306        Message::from_vec(CRASHING_MESSAGE).expect("failed to parse message");
1307    }
1308}