hickory_proto/op/
message.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Basic protocol message for DNS
9
10use std::{fmt, iter, mem, ops::Deref, sync::Arc};
11
12use tracing::{debug, warn};
13
14use crate::{
15    error::*,
16    op::{Edns, Header, MessageType, OpCode, Query, ResponseCode},
17    rr::{Record, RecordType},
18    serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, EncodeMode},
19    xfer::DnsResponse,
20};
21
22/// The basic request and response data structure, used for all DNS protocols.
23///
24/// [RFC 1035, DOMAIN NAMES - IMPLEMENTATION AND SPECIFICATION, November 1987](https://tools.ietf.org/html/rfc1035)
25///
26/// ```text
27/// 4.1. Format
28///
29/// All communications inside of the domain protocol are carried in a single
30/// format called a message.  The top level format of message is divided
31/// into 5 sections (some of which are empty in certain cases) shown below:
32///
33///     +--------------------------+
34///     |        Header            |
35///     +--------------------------+
36///     |  Question / Zone         | the question for the name server
37///     +--------------------------+
38///     |   Answer  / Prerequisite | RRs answering the question
39///     +--------------------------+
40///     | Authority / Update       | RRs pointing toward an authority
41///     +--------------------------+
42///     |      Additional          | RRs holding additional information
43///     +--------------------------+
44///
45/// The header section is always present.  The header includes fields that
46/// specify which of the remaining sections are present, and also specify
47/// whether the message is a query or a response, a standard query or some
48/// other opcode, etc.
49///
50/// The names of the sections after the header are derived from their use in
51/// standard queries.  The question section contains fields that describe a
52/// question to a name server.  These fields are a query type (QTYPE), a
53/// query class (QCLASS), and a query domain name (QNAME).  The last three
54/// sections have the same format: a possibly empty list of concatenated
55/// resource records (RRs).  The answer section contains RRs that answer the
56/// question; the authority section contains RRs that point toward an
57/// authoritative name server; the additional records section contains RRs
58/// which relate to the query, but are not strictly answers for the
59/// question.
60/// ```
61///
62/// By default Message is a Query. Use the Message::as_update() to create and update, or
63///  Message::new_update()
64#[derive(Clone, Debug, PartialEq, Eq, Default)]
65pub struct Message {
66    header: Header,
67    queries: Vec<Query>,
68    answers: Vec<Record>,
69    name_servers: Vec<Record>,
70    additionals: Vec<Record>,
71    signature: Vec<Record>,
72    edns: Option<Edns>,
73}
74
75/// Returns a new Header with accurate counts for each Message section
76pub fn update_header_counts(
77    current_header: &Header,
78    is_truncated: bool,
79    counts: HeaderCounts,
80) -> Header {
81    assert!(counts.query_count <= u16::MAX as usize);
82    assert!(counts.answer_count <= u16::MAX as usize);
83    assert!(counts.nameserver_count <= u16::MAX as usize);
84    assert!(counts.additional_count <= u16::MAX as usize);
85
86    // TODO: should the function just take by value?
87    let mut header = *current_header;
88    header
89        .set_query_count(counts.query_count as u16)
90        .set_answer_count(counts.answer_count as u16)
91        .set_name_server_count(counts.nameserver_count as u16)
92        .set_additional_count(counts.additional_count as u16)
93        .set_truncated(is_truncated);
94
95    header
96}
97
98/// Tracks the counts of the records in the Message.
99///
100/// This is only used internally during serialization.
101#[derive(Clone, Copy, Debug)]
102pub struct HeaderCounts {
103    /// The number of queries in the Message
104    pub query_count: usize,
105    /// The number of answers in the Message
106    pub answer_count: usize,
107    /// The number of nameservers or authorities in the Message
108    pub nameserver_count: usize,
109    /// The number of additional records in the Message
110    pub additional_count: usize,
111}
112
113impl Message {
114    /// Returns a new "empty" Message
115    pub fn new() -> Self {
116        Self {
117            header: Header::new(),
118            queries: Vec::new(),
119            answers: Vec::new(),
120            name_servers: Vec::new(),
121            additionals: Vec::new(),
122            signature: Vec::new(),
123            edns: None,
124        }
125    }
126
127    /// Returns a Message constructed with error details to return to a client
128    ///
129    /// # Arguments
130    ///
131    /// * `id` - message id should match the request message id
132    /// * `op_code` - operation of the request
133    /// * `response_code` - the error code for the response
134    pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Self {
135        let mut message = Self::new();
136        message
137            .set_message_type(MessageType::Response)
138            .set_id(id)
139            .set_response_code(response_code)
140            .set_op_code(op_code);
141
142        message
143    }
144
145    /// Truncates a Message, this blindly removes all response fields and sets truncated to `true`
146    pub fn truncate(&self) -> Self {
147        let mut truncated = self.clone();
148        truncated.set_truncated(true);
149        // drops additional/answer/queries so len is 0
150        truncated.take_additionals();
151        truncated.take_answers();
152        truncated.take_queries();
153
154        // TODO, perhaps just quickly add a few response records here? that we know would fit?
155        truncated
156    }
157
158    /// Sets the `Header` with provided
159    pub fn set_header(&mut self, header: Header) -> &mut Self {
160        self.header = header;
161        self
162    }
163
164    /// see `Header::set_id`
165    pub fn set_id(&mut self, id: u16) -> &mut Self {
166        self.header.set_id(id);
167        self
168    }
169
170    /// see `Header::set_message_type`
171    pub fn set_message_type(&mut self, message_type: MessageType) -> &mut Self {
172        self.header.set_message_type(message_type);
173        self
174    }
175
176    /// see `Header::set_op_code`
177    pub fn set_op_code(&mut self, op_code: OpCode) -> &mut Self {
178        self.header.set_op_code(op_code);
179        self
180    }
181
182    /// see `Header::set_authoritative`
183    pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
184        self.header.set_authoritative(authoritative);
185        self
186    }
187
188    /// see `Header::set_truncated`
189    pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
190        self.header.set_truncated(truncated);
191        self
192    }
193
194    /// see `Header::set_recursion_desired`
195    pub fn set_recursion_desired(&mut self, recursion_desired: bool) -> &mut Self {
196        self.header.set_recursion_desired(recursion_desired);
197        self
198    }
199
200    /// see `Header::set_recursion_available`
201    pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
202        self.header.set_recursion_available(recursion_available);
203        self
204    }
205
206    /// see `Header::set_authentic_data`
207    pub fn set_authentic_data(&mut self, authentic_data: bool) -> &mut Self {
208        self.header.set_authentic_data(authentic_data);
209        self
210    }
211
212    /// see `Header::set_checking_disabled`
213    pub fn set_checking_disabled(&mut self, checking_disabled: bool) -> &mut Self {
214        self.header.set_checking_disabled(checking_disabled);
215        self
216    }
217
218    /// see `Header::set_response_code`
219    pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
220        self.header.set_response_code(response_code);
221        self
222    }
223
224    /// Add a query to the Message, either the query response from the server, or the request Query.
225    pub fn add_query(&mut self, query: Query) -> &mut Self {
226        self.queries.push(query);
227        self
228    }
229
230    /// Adds an iterator over a set of Queries to be added to the message
231    pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
232    where
233        Q: IntoIterator<Item = Query, IntoIter = I>,
234        I: Iterator<Item = Query>,
235    {
236        for query in queries {
237            self.add_query(query);
238        }
239
240        self
241    }
242
243    /// Add an answer to the Message
244    pub fn add_answer(&mut self, record: Record) -> &mut Self {
245        self.answers.push(record);
246        self
247    }
248
249    /// Add all the records from the iterator to the answers section of the Message
250    pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
251    where
252        R: IntoIterator<Item = Record, IntoIter = I>,
253        I: Iterator<Item = Record>,
254    {
255        for record in records {
256            self.add_answer(record);
257        }
258
259        self
260    }
261
262    /// Sets the answers to the specified set of Records.
263    ///
264    /// # Panics
265    ///
266    /// Will panic if answer records are already associated to the message.
267    pub fn insert_answers(&mut self, records: Vec<Record>) {
268        assert!(self.answers.is_empty());
269        self.answers = records;
270    }
271
272    /// Add a name server record to the Message
273    pub fn add_name_server(&mut self, record: Record) -> &mut Self {
274        self.name_servers.push(record);
275        self
276    }
277
278    /// Add all the records in the Iterator to the name server section of the message
279    pub fn add_name_servers<R, I>(&mut self, records: R) -> &mut Self
280    where
281        R: IntoIterator<Item = Record, IntoIter = I>,
282        I: Iterator<Item = Record>,
283    {
284        for record in records {
285            self.add_name_server(record);
286        }
287
288        self
289    }
290
291    /// Sets the name_servers to the specified set of Records.
292    ///
293    /// # Panics
294    ///
295    /// Will panic if name_servers records are already associated to the message.
296    pub fn insert_name_servers(&mut self, records: Vec<Record>) {
297        assert!(self.name_servers.is_empty());
298        self.name_servers = records;
299    }
300
301    /// Add an additional Record to the message
302    pub fn add_additional(&mut self, record: Record) -> &mut Self {
303        self.additionals.push(record);
304        self
305    }
306
307    /// Add all the records from the iterator to the additionals section of the Message
308    pub fn add_additionals<R, I>(&mut self, records: R) -> &mut Self
309    where
310        R: IntoIterator<Item = Record, IntoIter = I>,
311        I: Iterator<Item = Record>,
312    {
313        for record in records {
314            self.add_additional(record);
315        }
316
317        self
318    }
319
320    /// Sets the additional to the specified set of Records.
321    ///
322    /// # Panics
323    ///
324    /// Will panic if additional records are already associated to the message.
325    pub fn insert_additionals(&mut self, records: Vec<Record>) {
326        assert!(self.additionals.is_empty());
327        self.additionals = records;
328    }
329
330    /// Add the EDNS section to the Message
331    pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
332        self.edns = Some(edns);
333        self
334    }
335
336    /// Add a SIG0 record, i.e. sign this message
337    ///
338    /// This must be used only after all records have been associated. Generally this will be handled by the client and not need to be used directly
339    #[cfg(feature = "dnssec")]
340    #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))]
341    pub fn add_sig0(&mut self, record: Record) -> &mut Self {
342        assert_eq!(RecordType::SIG, record.record_type());
343        self.signature.push(record);
344        self
345    }
346
347    /// Add a TSIG record, i.e. authenticate this message
348    ///
349    /// This must be used only after all records have been associated. Generally this will be handled by the client and not need to be used directly
350    #[cfg(feature = "dnssec")]
351    #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))]
352    pub fn add_tsig(&mut self, record: Record) -> &mut Self {
353        assert_eq!(RecordType::TSIG, record.record_type());
354        self.signature.push(record);
355        self
356    }
357
358    /// Gets the header of the Message
359    pub fn header(&self) -> &Header {
360        &self.header
361    }
362
363    /// see `Header::id()`
364    pub fn id(&self) -> u16 {
365        self.header.id()
366    }
367
368    /// see `Header::message_type()`
369    pub fn message_type(&self) -> MessageType {
370        self.header.message_type()
371    }
372
373    /// see `Header::op_code()`
374    pub fn op_code(&self) -> OpCode {
375        self.header.op_code()
376    }
377
378    /// see `Header::authoritative()`
379    pub fn authoritative(&self) -> bool {
380        self.header.authoritative()
381    }
382
383    /// see `Header::truncated()`
384    pub fn truncated(&self) -> bool {
385        self.header.truncated()
386    }
387
388    /// see `Header::recursion_desired()`
389    pub fn recursion_desired(&self) -> bool {
390        self.header.recursion_desired()
391    }
392
393    /// see `Header::recursion_available()`
394    pub fn recursion_available(&self) -> bool {
395        self.header.recursion_available()
396    }
397
398    /// see `Header::authentic_data()`
399    pub fn authentic_data(&self) -> bool {
400        self.header.authentic_data()
401    }
402
403    /// see `Header::checking_disabled()`
404    pub fn checking_disabled(&self) -> bool {
405        self.header.checking_disabled()
406    }
407
408    /// # Return value
409    ///
410    /// The `ResponseCode`, if this is an EDNS message then this will join the section from the OPT
411    ///  record to create the EDNS `ResponseCode`
412    pub fn response_code(&self) -> ResponseCode {
413        self.header.response_code()
414    }
415
416    /// Returns the query from this Message.
417    ///
418    /// In almost all cases, a Message will only contain one query. This is a convenience function to get the single query.
419    /// See the alternative `queries*` methods for the raw set of queries in the Message
420    pub fn query(&self) -> Option<&Query> {
421        self.queries.first()
422    }
423
424    /// ```text
425    /// Question        Carries the query name and other query parameters.
426    /// ```
427    pub fn queries(&self) -> &[Query] {
428        &self.queries
429    }
430
431    /// Provides mutable access to `queries`
432    pub fn queries_mut(&mut self) -> &mut Vec<Query> {
433        &mut self.queries
434    }
435
436    /// Removes all the answers from the Message
437    pub fn take_queries(&mut self) -> Vec<Query> {
438        mem::take(&mut self.queries)
439    }
440
441    /// ```text
442    /// Answer          Carries RRs which directly answer the query.
443    /// ```
444    pub fn answers(&self) -> &[Record] {
445        &self.answers
446    }
447
448    /// Provides mutable access to `answers`
449    pub fn answers_mut(&mut self) -> &mut Vec<Record> {
450        &mut self.answers
451    }
452
453    /// Removes all the answers from the Message
454    pub fn take_answers(&mut self) -> Vec<Record> {
455        mem::take(&mut self.answers)
456    }
457
458    /// ```text
459    /// Authority       Carries RRs which describe other authoritative servers.
460    ///                 May optionally carry the SOA RR for the authoritative
461    ///                 data in the answer section.
462    /// ```
463    pub fn name_servers(&self) -> &[Record] {
464        &self.name_servers
465    }
466
467    /// Provides mutable access to `name_servers`
468    pub fn name_servers_mut(&mut self) -> &mut Vec<Record> {
469        &mut self.name_servers
470    }
471
472    /// Remove the name servers from the Message
473    pub fn take_name_servers(&mut self) -> Vec<Record> {
474        mem::take(&mut self.name_servers)
475    }
476
477    /// ```text
478    /// Additional      Carries RRs which may be helpful in using the RRs in the
479    ///                 other sections.
480    /// ```
481    pub fn additionals(&self) -> &[Record] {
482        &self.additionals
483    }
484
485    /// Provides mutable access to `additionals`
486    pub fn additionals_mut(&mut self) -> &mut Vec<Record> {
487        &mut self.additionals
488    }
489
490    /// Remove the additional Records from the Message
491    pub fn take_additionals(&mut self) -> Vec<Record> {
492        mem::take(&mut self.additionals)
493    }
494
495    /// All sections chained
496    pub fn all_sections(&self) -> impl Iterator<Item = &Record> {
497        self.answers
498            .iter()
499            .chain(self.name_servers().iter())
500            .chain(self.additionals.iter())
501    }
502
503    /// [RFC 6891, EDNS(0) Extensions, April 2013](https://tools.ietf.org/html/rfc6891#section-6.1.1)
504    ///
505    /// ```text
506    /// 6.1.1.  Basic Elements
507    ///
508    ///  An OPT pseudo-RR (sometimes called a meta-RR) MAY be added to the
509    ///  additional data section of a request.
510    ///
511    ///  The OPT RR has RR type 41.
512    ///
513    ///  If an OPT record is present in a received request, compliant
514    ///  responders MUST include an OPT record in their respective responses.
515    ///
516    ///  An OPT record does not carry any DNS data.  It is used only to
517    ///  contain control information pertaining to the question-and-answer
518    ///  sequence of a specific transaction.  OPT RRs MUST NOT be cached,
519    ///  forwarded, or stored in or loaded from Zone Files.
520    ///
521    ///  The OPT RR MAY be placed anywhere within the additional data section.
522    ///  When an OPT RR is included within any DNS message, it MUST be the
523    ///  only OPT RR in that message.  If a query message with more than one
524    ///  OPT RR is received, a FORMERR (RCODE=1) MUST be returned.  The
525    ///  placement flexibility for the OPT RR does not override the need for
526    ///  the TSIG or SIG(0) RRs to be the last in the additional section
527    ///  whenever they are present.
528    /// ```
529    /// # Return value
530    ///
531    /// Optionally returns a reference to EDNS section
532    #[deprecated(note = "Please use `extensions()`")]
533    pub fn edns(&self) -> Option<&Edns> {
534        self.edns.as_ref()
535    }
536
537    /// Optionally returns mutable reference to EDNS section
538    #[deprecated(
539        note = "Please use `extensions_mut()`. You can chain `.get_or_insert_with(Edns::new)` to recover original behavior of adding Edns if not present"
540    )]
541    pub fn edns_mut(&mut self) -> &mut Edns {
542        if self.edns.is_none() {
543            self.set_edns(Edns::new());
544        }
545        self.edns.as_mut().unwrap()
546    }
547
548    /// Returns reference of Edns section
549    pub fn extensions(&self) -> &Option<Edns> {
550        &self.edns
551    }
552
553    /// Returns mutable reference of Edns section
554    pub fn extensions_mut(&mut self) -> &mut Option<Edns> {
555        &mut self.edns
556    }
557
558    /// # Return value
559    ///
560    /// the max payload value as it's defined in the EDNS section.
561    pub fn max_payload(&self) -> u16 {
562        let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
563        if max_size < 512 {
564            512
565        } else {
566            max_size
567        }
568    }
569
570    /// # Return value
571    ///
572    /// the version as defined in the EDNS record
573    pub fn version(&self) -> u8 {
574        self.edns.as_ref().map_or(0, Edns::version)
575    }
576
577    /// [RFC 2535, Domain Name System Security Extensions, March 1999](https://tools.ietf.org/html/rfc2535#section-4)
578    ///
579    /// ```text
580    /// A DNS request may be optionally signed by including one or more SIGs
581    ///  at the end of the query. Such SIGs are identified by having a "type
582    ///  covered" field of zero. They sign the preceding DNS request message
583    ///  including DNS header but not including the IP header or any request
584    ///  SIGs at the end and before the request RR counts have been adjusted
585    ///  for the inclusions of any request SIG(s).
586    /// ```
587    ///
588    /// # Return value
589    ///
590    /// The sig0 and tsig, i.e. signed record, for verifying the sending and package integrity
591    // comportment change: can now return TSIG instead of SIG0. Maybe should get deprecated in
592    // favor of signature() which have more correct naming ?
593    pub fn sig0(&self) -> &[Record] {
594        &self.signature
595    }
596
597    /// [RFC 2535, Domain Name System Security Extensions, March 1999](https://tools.ietf.org/html/rfc2535#section-4)
598    ///
599    /// ```text
600    /// A DNS request may be optionally signed by including one or more SIGs
601    ///  at the end of the query. Such SIGs are identified by having a "type
602    ///  covered" field of zero. They sign the preceding DNS request message
603    ///  including DNS header but not including the IP header or any request
604    ///  SIGs at the end and before the request RR counts have been adjusted
605    ///  for the inclusions of any request SIG(s).
606    /// ```
607    ///
608    /// # Return value
609    ///
610    /// The sig0 and tsig, i.e. signed record, for verifying the sending and package integrity
611    pub fn signature(&self) -> &[Record] {
612        &self.signature
613    }
614
615    /// Remove signatures from the Message
616    pub fn take_signature(&mut self) -> Vec<Record> {
617        mem::take(&mut self.signature)
618    }
619
620    // TODO: only necessary in tests, should it be removed?
621    /// this is necessary to match the counts in the header from the record sections
622    ///  this happens implicitly on write_to, so no need to call before write_to
623    #[cfg(test)]
624    pub fn update_counts(&mut self) -> &mut Self {
625        self.header = update_header_counts(
626            &self.header,
627            self.truncated(),
628            HeaderCounts {
629                query_count: self.queries.len(),
630                answer_count: self.answers.len(),
631                nameserver_count: self.name_servers.len(),
632                additional_count: self.additionals.len(),
633            },
634        );
635        self
636    }
637
638    /// Attempts to read the specified number of `Query`s
639    pub fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<Query>> {
640        let mut queries = Vec::with_capacity(count);
641        for _ in 0..count {
642            queries.push(Query::read(decoder)?);
643        }
644        Ok(queries)
645    }
646
647    /// Attempts to read the specified number of records
648    ///
649    /// # Returns
650    ///
651    /// This returns a tuple of first standard Records, then a possibly associated Edns, and then finally any optionally associated SIG0 and TSIG records.
652    #[cfg_attr(not(feature = "dnssec"), allow(unused_mut))]
653    pub fn read_records(
654        decoder: &mut BinDecoder<'_>,
655        count: usize,
656        is_additional: bool,
657    ) -> ProtoResult<(Vec<Record>, Option<Edns>, Vec<Record>)> {
658        let mut records: Vec<Record> = Vec::with_capacity(count);
659        let mut edns: Option<Edns> = None;
660        let mut sigs: Vec<Record> = Vec::with_capacity(if is_additional { 1 } else { 0 });
661
662        // sig0 must be last, once this is set, disable.
663        let mut saw_sig0 = false;
664        // tsig must be last, once this is set, disable.
665        let mut saw_tsig = false;
666        for _ in 0..count {
667            let record = Record::read(decoder)?;
668            if saw_tsig {
669                return Err("tsig must be final resource record".into());
670            } // TSIG must be last and multiple TSIG records are not allowed
671            if !is_additional {
672                if saw_sig0 {
673                    return Err("sig0 must be final resource record".into());
674                } // SIG0 must be last
675                records.push(record)
676            } else {
677                match record.record_type() {
678                    #[cfg(feature = "dnssec")]
679                    RecordType::SIG => {
680                        saw_sig0 = true;
681                        sigs.push(record);
682                    }
683                    #[cfg(feature = "dnssec")]
684                    RecordType::TSIG => {
685                        if saw_sig0 {
686                            return Err("sig0 must be final resource record".into());
687                        } // SIG0 must be last
688                        saw_tsig = true;
689                        sigs.push(record);
690                    }
691                    RecordType::OPT => {
692                        if saw_sig0 {
693                            return Err("sig0 must be final resource record".into());
694                        } // SIG0 must be last
695                        if edns.is_some() {
696                            return Err("more than one edns record present".into());
697                        }
698                        edns = Some((&record).into());
699                    }
700                    _ => {
701                        if saw_sig0 {
702                            return Err("sig0 must be final resource record".into());
703                        } // SIG0 must be last
704                        records.push(record);
705                    }
706                }
707            }
708        }
709
710        Ok((records, edns, sigs))
711    }
712
713    /// Decodes a message from the buffer.
714    pub fn from_vec(buffer: &[u8]) -> ProtoResult<Self> {
715        let mut decoder = BinDecoder::new(buffer);
716        Self::read(&mut decoder)
717    }
718
719    /// Encodes the Message into a buffer
720    pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
721        // TODO: this feels like the right place to verify the max packet size of the message,
722        //  will need to update the header for truncation and the lengths if we send less than the
723        //  full response. This needs to conform with the EDNS settings of the server...
724        let mut buffer = Vec::with_capacity(512);
725        {
726            let mut encoder = BinEncoder::new(&mut buffer);
727            self.emit(&mut encoder)?;
728        }
729
730        Ok(buffer)
731    }
732
733    /// Finalize the message prior to sending.
734    ///
735    /// Subsequent to calling this, the Message should not change.
736    #[allow(clippy::match_single_binding)]
737    pub fn finalize<MF: MessageFinalizer>(
738        &mut self,
739        finalizer: &MF,
740        inception_time: u32,
741    ) -> ProtoResult<Option<MessageVerifier>> {
742        debug!("finalizing message: {:?}", self);
743        let (finals, verifier): (Vec<Record>, Option<MessageVerifier>) =
744            finalizer.finalize_message(self, inception_time)?;
745
746        // append all records to message
747        for fin in finals {
748            match fin.record_type() {
749                // SIG0's are special, and come at the very end of the message
750                #[cfg(feature = "dnssec")]
751                RecordType::SIG => self.add_sig0(fin),
752                #[cfg(feature = "dnssec")]
753                RecordType::TSIG => self.add_tsig(fin),
754                _ => self.add_additional(fin),
755            };
756        }
757
758        Ok(verifier)
759    }
760
761    /// Consumes `Message` and returns into components
762    pub fn into_parts(self) -> MessageParts {
763        self.into()
764    }
765}
766
767/// Consumes `Message` giving public access to fields in `Message` so they can be
768/// destructured and taken by value
769/// ```rust
770/// use hickory_proto::op::{Message, MessageParts};
771///
772///  let msg = Message::new();
773///  let MessageParts { queries, .. } = msg.into_parts();
774/// ```
775#[derive(Clone, Debug, PartialEq, Eq, Default)]
776pub struct MessageParts {
777    /// message header
778    pub header: Header,
779    /// message queries
780    pub queries: Vec<Query>,
781    /// message answers
782    pub answers: Vec<Record>,
783    /// message name_servers
784    pub name_servers: Vec<Record>,
785    /// message additional records
786    pub additionals: Vec<Record>,
787    /// sig0 or tsig
788    // this can now contains TSIG too. It should probably be renamed to reflect that, but it's a
789    // breaking change
790    pub sig0: Vec<Record>,
791    /// optional edns records
792    pub edns: Option<Edns>,
793}
794
795impl From<Message> for MessageParts {
796    fn from(msg: Message) -> Self {
797        let Message {
798            header,
799            queries,
800            answers,
801            name_servers,
802            additionals,
803            signature,
804            edns,
805        } = msg;
806        Self {
807            header,
808            queries,
809            answers,
810            name_servers,
811            additionals,
812            sig0: signature,
813            edns,
814        }
815    }
816}
817
818impl From<MessageParts> for Message {
819    fn from(msg: MessageParts) -> Self {
820        let MessageParts {
821            header,
822            queries,
823            answers,
824            name_servers,
825            additionals,
826            sig0,
827            edns,
828        } = msg;
829        Self {
830            header,
831            queries,
832            answers,
833            name_servers,
834            additionals,
835            signature: sig0,
836            edns,
837        }
838    }
839}
840
841impl Deref for Message {
842    type Target = Header;
843
844    fn deref(&self) -> &Self::Target {
845        &self.header
846    }
847}
848
849/// Alias for a function verifying if a message is properly signed
850pub type MessageVerifier = Box<dyn FnMut(&[u8]) -> ProtoResult<DnsResponse> + Send>;
851
852/// A trait for performing final amendments to a Message before it is sent.
853///
854/// An example of this is a SIG0 signer, which needs the final form of the message,
855///  but then needs to attach additional data to the body of the message.
856pub trait MessageFinalizer: Send + Sync + 'static {
857    /// The message taken in should be processed and then return [`Record`]s which should be
858    ///  appended to the additional section of the message.
859    ///
860    /// # Arguments
861    ///
862    /// * `message` - message to process
863    /// * `current_time` - the current time as specified by the system, it's not recommended to read the current time as that makes testing complicated.
864    ///
865    /// # Return
866    ///
867    /// A vector to append to the additionals section of the message, sorted in the order as they should appear in the message.
868    fn finalize_message(
869        &self,
870        message: &Message,
871        current_time: u32,
872    ) -> ProtoResult<(Vec<Record>, Option<MessageVerifier>)>;
873
874    /// Return whether the message requires further processing before being sent
875    /// By default, returns true for AXFR and IXFR queries, and Update and Notify messages
876    fn should_finalize_message(&self, message: &Message) -> bool {
877        [OpCode::Update, OpCode::Notify].contains(&message.op_code())
878            || message
879                .queries()
880                .iter()
881                .any(|q| [RecordType::AXFR, RecordType::IXFR].contains(&q.query_type()))
882    }
883}
884
885/// A MessageFinalizer which does nothing
886///
887/// *WARNING* This should only be used in None context, it will panic in all cases where finalize is called.
888#[derive(Clone, Copy, Debug)]
889pub struct NoopMessageFinalizer;
890
891impl NoopMessageFinalizer {
892    /// Always returns None
893    pub fn new() -> Option<Arc<Self>> {
894        None
895    }
896}
897
898impl MessageFinalizer for NoopMessageFinalizer {
899    fn finalize_message(
900        &self,
901        _: &Message,
902        _: u32,
903    ) -> ProtoResult<(Vec<Record>, Option<MessageVerifier>)> {
904        panic!("Misused NoopMessageFinalizer, None should be used instead")
905    }
906
907    fn should_finalize_message(&self, _: &Message) -> bool {
908        true
909    }
910}
911
912/// Returns the count written and a boolean if it was truncated
913pub fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(usize, bool)> {
914    result.map(|count| (count, false)).or_else(|e| {
915        if let ProtoErrorKind::NotAllRecordsWritten { count } = e.kind() {
916            return Ok((*count, true));
917        }
918
919        Err(e)
920    })
921}
922
923/// A trait that defines types which can be emitted as a set, with the associated count returned.
924pub trait EmitAndCount {
925    /// Emit self to the encoder and return the count of items
926    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize>;
927}
928
929impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
930    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
931        encoder.emit_all(self)
932    }
933}
934
935/// Emits the different sections of a message properly
936///
937/// # Return
938///
939/// In the case of a successful emit, the final header (updated counts, etc) is returned for help with logging, etc.
940#[allow(clippy::too_many_arguments)]
941pub fn emit_message_parts<Q, A, N, D>(
942    header: &Header,
943    queries: &mut Q,
944    answers: &mut A,
945    name_servers: &mut N,
946    additionals: &mut D,
947    edns: Option<&Edns>,
948    signature: &[Record],
949    encoder: &mut BinEncoder<'_>,
950) -> ProtoResult<Header>
951where
952    Q: EmitAndCount,
953    A: EmitAndCount,
954    N: EmitAndCount,
955    D: EmitAndCount,
956{
957    let include_signature: bool = encoder.mode() != EncodeMode::Signing;
958    let place = encoder.place::<Header>()?;
959
960    let query_count = queries.emit(encoder)?;
961    // TODO: need to do something on max records
962    //  return offset of last emitted record.
963    let answer_count = count_was_truncated(answers.emit(encoder))?;
964    let nameserver_count = count_was_truncated(name_servers.emit(encoder))?;
965    let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
966
967    if let Some(mut edns) = edns.cloned() {
968        // need to commit the error code
969        edns.set_rcode_high(header.response_code().high());
970
971        let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(&edns))))?;
972        additional_count.0 += count.0;
973        additional_count.1 |= count.1;
974    } else if header.response_code().high() > 0 {
975        warn!(
976            "response code: {} for request: {} requires EDNS but none available",
977            header.response_code(),
978            header.id()
979        );
980    }
981
982    // this is a little hacky, but if we are Verifying a signature, i.e. the original Message
983    //  then the SIG0 records should not be encoded and the edns record (if it exists) is already
984    //  part of the additionals section.
985    if include_signature {
986        let count = count_was_truncated(encoder.emit_all(signature.iter()))?;
987        additional_count.0 += count.0;
988        additional_count.1 |= count.1;
989    }
990
991    let counts = HeaderCounts {
992        query_count,
993        answer_count: answer_count.0,
994        nameserver_count: nameserver_count.0,
995        additional_count: additional_count.0,
996    };
997    let was_truncated =
998        header.truncated() || answer_count.1 || nameserver_count.1 || additional_count.1;
999
1000    let final_header = update_header_counts(header, was_truncated, counts);
1001    place.replace(encoder, final_header)?;
1002    Ok(final_header)
1003}
1004
1005impl BinEncodable for Message {
1006    fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
1007        emit_message_parts(
1008            &self.header,
1009            &mut self.queries.iter(),
1010            &mut self.answers.iter(),
1011            &mut self.name_servers.iter(),
1012            &mut self.additionals.iter(),
1013            self.edns.as_ref(),
1014            &self.signature,
1015            encoder,
1016        )?;
1017
1018        Ok(())
1019    }
1020}
1021
1022impl<'r> BinDecodable<'r> for Message {
1023    fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
1024        let mut header = Header::read(decoder)?;
1025
1026        // TODO: return just header, and in the case of the rest of message getting an error.
1027        //  this could improve error detection while decoding.
1028
1029        // get the questions
1030        let count = header.query_count() as usize;
1031        let mut queries = Vec::with_capacity(count);
1032        for _ in 0..count {
1033            queries.push(Query::read(decoder)?);
1034        }
1035
1036        // get all counts before header moves
1037        let answer_count = header.answer_count() as usize;
1038        let name_server_count = header.name_server_count() as usize;
1039        let additional_count = header.additional_count() as usize;
1040
1041        let (answers, _, _) = Self::read_records(decoder, answer_count, false)?;
1042        let (name_servers, _, _) = Self::read_records(decoder, name_server_count, false)?;
1043        let (additionals, edns, signature) = Self::read_records(decoder, additional_count, true)?;
1044
1045        // need to grab error code from EDNS (which might have a higher value)
1046        if let Some(edns) = &edns {
1047            let high_response_code = edns.rcode_high();
1048            header.merge_response_code(high_response_code);
1049        }
1050
1051        Ok(Self {
1052            header,
1053            queries,
1054            answers,
1055            name_servers,
1056            additionals,
1057            signature,
1058            edns,
1059        })
1060    }
1061}
1062
1063impl fmt::Display for Message {
1064    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1065        let write_query = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1066            for d in slice {
1067                writeln!(f, ";; {d}")?;
1068            }
1069
1070            Ok(())
1071        };
1072
1073        let write_slice = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1074            for d in slice {
1075                writeln!(f, "{d}")?;
1076            }
1077
1078            Ok(())
1079        };
1080
1081        writeln!(f, "; header {header}", header = self.header())?;
1082
1083        if let Some(edns) = self.extensions() {
1084            writeln!(f, "; edns {edns}")?;
1085        }
1086
1087        writeln!(f, "; query")?;
1088        write_query(self.queries(), f)?;
1089
1090        if self.header().message_type() == MessageType::Response
1091            || self.header().op_code() == OpCode::Update
1092        {
1093            writeln!(f, "; answers {}", self.answer_count())?;
1094            write_slice(self.answers(), f)?;
1095            writeln!(f, "; nameservers {}", self.name_server_count())?;
1096            write_slice(self.name_servers(), f)?;
1097            writeln!(f, "; additionals {}", self.additional_count())?;
1098            write_slice(self.additionals(), f)?;
1099        }
1100
1101        Ok(())
1102    }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108
1109    #[test]
1110    fn test_emit_and_read_header() {
1111        let mut message = Message::new();
1112        message
1113            .set_id(10)
1114            .set_message_type(MessageType::Response)
1115            .set_op_code(OpCode::Update)
1116            .set_authoritative(true)
1117            .set_truncated(false)
1118            .set_recursion_desired(true)
1119            .set_recursion_available(true)
1120            .set_response_code(ResponseCode::ServFail);
1121
1122        test_emit_and_read(message);
1123    }
1124
1125    #[test]
1126    fn test_emit_and_read_query() {
1127        let mut message = Message::new();
1128        message
1129            .set_id(10)
1130            .set_message_type(MessageType::Response)
1131            .set_op_code(OpCode::Update)
1132            .set_authoritative(true)
1133            .set_truncated(true)
1134            .set_recursion_desired(true)
1135            .set_recursion_available(true)
1136            .set_response_code(ResponseCode::ServFail)
1137            .add_query(Query::new())
1138            .update_counts(); // we're not testing the query parsing, just message
1139
1140        test_emit_and_read(message);
1141    }
1142
1143    #[test]
1144    fn test_emit_and_read_records() {
1145        let mut message = Message::new();
1146        message
1147            .set_id(10)
1148            .set_message_type(MessageType::Response)
1149            .set_op_code(OpCode::Update)
1150            .set_authoritative(true)
1151            .set_truncated(true)
1152            .set_recursion_desired(true)
1153            .set_recursion_available(true)
1154            .set_authentic_data(true)
1155            .set_checking_disabled(true)
1156            .set_response_code(ResponseCode::ServFail);
1157
1158        message.add_answer(Record::new());
1159        message.add_name_server(Record::new());
1160        message.add_additional(Record::new());
1161        message.update_counts(); // needed for the comparison...
1162
1163        test_emit_and_read(message);
1164    }
1165
1166    #[cfg(test)]
1167    fn test_emit_and_read(message: Message) {
1168        let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1169        {
1170            let mut encoder = BinEncoder::new(&mut byte_vec);
1171            message.emit(&mut encoder).unwrap();
1172        }
1173
1174        let mut decoder = BinDecoder::new(&byte_vec);
1175        let got = Message::read(&mut decoder).unwrap();
1176
1177        assert_eq!(got, message);
1178    }
1179
1180    #[test]
1181    fn test_legit_message() {
1182        #[rustfmt::skip]
1183        let buf: Vec<u8> = vec![
1184            0x10, 0x00, 0x81,
1185            0x80, // id = 4096, response, op=query, recursion_desired, recursion_available, no_error
1186            0x00, 0x01, 0x00, 0x01, // 1 query, 1 answer,
1187            0x00, 0x00, 0x00, 0x00, // 0 nameservers, 0 additional record
1188            0x03, b'w', b'w', b'w', // query --- www.example.com
1189            0x07, b'e', b'x', b'a', //
1190            b'm', b'p', b'l', b'e', //
1191            0x03, b'c', b'o', b'm', //
1192            0x00,                   // 0 = endname
1193            0x00, 0x01, 0x00, 0x01, // ReordType = A, Class = IN
1194            0xC0, 0x0C,             // name pointer to www.example.com
1195            0x00, 0x01, 0x00, 0x01, // RecordType = A, Class = IN
1196            0x00, 0x00, 0x00, 0x02, // TTL = 2 seconds
1197            0x00, 0x04,             // record length = 4 (ipv4 address)
1198            0x5D, 0xB8, 0xD7, 0x0E, // address = 93.184.215.14
1199        ];
1200
1201        let mut decoder = BinDecoder::new(&buf);
1202        let message = Message::read(&mut decoder).unwrap();
1203
1204        assert_eq!(message.id(), 4096);
1205
1206        let mut buf: Vec<u8> = Vec::with_capacity(512);
1207        {
1208            let mut encoder = BinEncoder::new(&mut buf);
1209            message.emit(&mut encoder).unwrap();
1210        }
1211
1212        let mut decoder = BinDecoder::new(&buf);
1213        let message = Message::read(&mut decoder).unwrap();
1214
1215        assert_eq!(message.id(), 4096);
1216    }
1217
1218    #[test]
1219    fn rdata_zero_roundtrip() {
1220        let buf = &[
1221            160, 160, 0, 13, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1222        ];
1223
1224        assert!(Message::from_bytes(buf).is_err());
1225    }
1226
1227    #[test]
1228    fn nsec_deserialization() {
1229        const CRASHING_MESSAGE: &[u8] = &[
1230            0, 0, 132, 0, 0, 0, 0, 1, 0, 0, 0, 1, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100,
1231            52, 50, 52, 45, 52, 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55,
1232            56, 48, 102, 50, 98, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4,
1233            192, 168, 1, 17, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100, 52, 50, 52, 45, 52,
1234            102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55, 56, 48, 102, 50, 98,
1235            5, 108, 111, 99, 97, 108, 0, 0, 47, 128, 1, 0, 0, 0, 120, 0, 5, 192, 70, 0, 1, 64,
1236        ];
1237
1238        Message::from_vec(CRASHING_MESSAGE).expect("failed to parse message");
1239    }
1240}