simple_dns/dns/
packet.rs

1use std::{
2    collections::HashMap,
3    io::{Cursor, Seek, Write},
4};
5
6use super::{Header, PacketFlag, Question, ResourceRecord, WireFormat, OPCODE};
7use crate::{bytes_buffer::BytesBuffer, rdata::OPT, RCODE};
8
9/// Represents a DNS message packet
10///
11/// When working with EDNS packets, use [Packet::opt] and [Packet::opt_mut] to add or access [OPT] packet information
12#[derive(Debug, Clone)]
13pub struct Packet<'a> {
14    /// Packet header
15    header: Header<'a>,
16    /// Questions section
17    pub questions: Vec<Question<'a>>,
18    /// Answers section
19    pub answers: Vec<ResourceRecord<'a>>,
20    /// Name servers section
21    pub name_servers: Vec<ResourceRecord<'a>>,
22    /// Aditional records section.  
23    /// DO NOT use this field to add OPT record, use [`Packet::opt_mut`] instead
24    pub additional_records: Vec<ResourceRecord<'a>>,
25}
26
27impl<'a> Packet<'a> {
28    /// Creates a new empty packet with a query header
29    pub fn new_query(id: u16) -> Self {
30        Self {
31            header: Header::new_query(id),
32            questions: Vec::new(),
33            answers: Vec::new(),
34            name_servers: Vec::new(),
35            additional_records: Vec::new(),
36        }
37    }
38
39    /// Creates a new empty packet with a reply header
40    pub fn new_reply(id: u16) -> Self {
41        Self {
42            header: Header::new_reply(id, OPCODE::StandardQuery),
43            questions: Vec::new(),
44            answers: Vec::new(),
45            name_servers: Vec::new(),
46            additional_records: Vec::new(),
47        }
48    }
49
50    /// Get packet id
51    pub fn id(&self) -> u16 {
52        self.header.id
53    }
54
55    /// Set packet id
56    pub fn set_id(&mut self, id: u16) {
57        self.header.id = id;
58    }
59
60    /// Set flags in the packet
61    pub fn set_flags(&mut self, flags: PacketFlag) {
62        self.header.set_flags(flags);
63    }
64
65    /// Remove flags present in the packet
66    pub fn remove_flags(&mut self, flags: PacketFlag) {
67        self.header.remove_flags(flags)
68    }
69
70    /// Check if the packet has flags set
71    pub fn has_flags(&self, flags: PacketFlag) -> bool {
72        self.header.has_flags(flags)
73    }
74
75    /// Get this packet [RCODE] information
76    pub fn rcode(&self) -> RCODE {
77        self.header.response_code
78    }
79
80    /// Get a mutable reference for  this packet [RCODE] information
81    /// Warning, if the [RCODE] value is greater than 15 (4 bits), you MUST provide an [OPT]
82    /// resource record through the [Packet::opt_mut] function
83    pub fn rcode_mut(&mut self) -> &mut RCODE {
84        &mut self.header.response_code
85    }
86
87    /// Get this packet [OPCODE] information
88    pub fn opcode(&self) -> OPCODE {
89        self.header.opcode
90    }
91
92    /// Get a mutable reference for this packet [OPCODE] information
93    pub fn opcode_mut(&mut self) -> &mut OPCODE {
94        &mut self.header.opcode
95    }
96
97    /// Get the [OPT] resource record for this packet, if present
98    pub fn opt(&self) -> Option<&OPT<'a>> {
99        self.header.opt.as_ref()
100    }
101
102    /// Get a mutable reference for this packet [OPT] resource record.  
103    pub fn opt_mut(&mut self) -> &mut Option<OPT<'a>> {
104        &mut self.header.opt
105    }
106
107    /// Changes this packet into a reply packet by replacing its header
108    pub fn into_reply(mut self) -> Self {
109        self.header = Header::new_reply(self.header.id, self.header.opcode);
110        self
111    }
112
113    /// Parses a packet from a slice of bytes
114    pub fn parse(data: &'a [u8]) -> crate::Result<Self> {
115        let mut data = BytesBuffer::new(data);
116        let mut header = Header::parse(&mut data)?;
117
118        let questions = Self::parse_section(&mut data, header.questions)?;
119        let answers = Self::parse_section(&mut data, header.answers)?;
120        let name_servers = Self::parse_section(&mut data, header.name_servers)?;
121        let mut additional_records: Vec<ResourceRecord> =
122            Self::parse_section(&mut data, header.additional_records)?;
123
124        header.extract_info_from_opt_rr(
125            additional_records
126                .iter()
127                .position(|rr| rr.rdata.type_code() == crate::TYPE::OPT)
128                .map(|i| additional_records.remove(i)),
129        );
130
131        Ok(Self {
132            header,
133            questions,
134            answers,
135            name_servers,
136            additional_records,
137        })
138    }
139
140    fn parse_section<T: WireFormat<'a>>(
141        data: &mut BytesBuffer<'a>,
142        items_count: u16,
143    ) -> crate::Result<Vec<T>> {
144        let mut section_items = Vec::with_capacity(items_count as usize);
145
146        for _ in 0..items_count {
147            section_items.push(T::parse(data)?);
148        }
149
150        Ok(section_items)
151    }
152
153    /// Creates a new [Vec`<u8>`](`Vec<T>`) and write the contents of this package in wire format
154    ///
155    /// This call will allocate a `Vec<u8>` of 900 bytes, which is enough for a jumbo UDP packet
156    pub fn build_bytes_vec(&self) -> crate::Result<Vec<u8>> {
157        let mut out = Cursor::new(Vec::with_capacity(900));
158
159        self.write_to(&mut out)?;
160
161        Ok(out.into_inner())
162    }
163
164    /// Creates a new [Vec`<u8>`](`Vec<T>`) and write the contents of this package in wire format
165    /// with compression enabled
166    ///
167    /// This call will allocate a `Vec<u8>` of 900 bytes, which is enough for a jumbo UDP packet
168    pub fn build_bytes_vec_compressed(&self) -> crate::Result<Vec<u8>> {
169        let mut out = Cursor::new(Vec::with_capacity(900));
170        self.write_compressed_to(&mut out)?;
171
172        Ok(out.into_inner())
173    }
174
175    /// Write the contents of this package in wire format into the provided writer
176    pub fn write_to<T: Write>(&self, out: &mut T) -> crate::Result<()> {
177        self.write_header(out)?;
178
179        for e in &self.questions {
180            e.write_to(out)?;
181        }
182        for e in &self.answers {
183            e.write_to(out)?;
184        }
185        for e in &self.name_servers {
186            e.write_to(out)?;
187        }
188
189        if let Some(rr) = self.header.opt_rr() {
190            rr.write_to(out)?;
191        }
192
193        for e in &self.additional_records {
194            e.write_to(out)?;
195        }
196
197        out.flush()?;
198        Ok(())
199    }
200
201    /// Write the contents of this package in wire format with enabled compression into the provided writer
202    pub fn write_compressed_to<T: Write + Seek>(&self, out: &mut T) -> crate::Result<()> {
203        self.write_header(out)?;
204
205        let mut name_refs = HashMap::new();
206        for e in &self.questions {
207            e.write_compressed_to(out, &mut name_refs)?;
208        }
209        for e in &self.answers {
210            e.write_compressed_to(out, &mut name_refs)?;
211        }
212        for e in &self.name_servers {
213            e.write_compressed_to(out, &mut name_refs)?;
214        }
215
216        if let Some(rr) = self.header.opt_rr() {
217            rr.write_to(out)?;
218        }
219
220        for e in &self.additional_records {
221            e.write_compressed_to(out, &mut name_refs)?;
222        }
223        out.flush()?;
224
225        Ok(())
226    }
227
228    fn write_header<T: Write>(&self, out: &mut T) -> crate::Result<()> {
229        self.header.write_to(
230            out,
231            self.questions.len() as u16,
232            self.answers.len() as u16,
233            self.name_servers.len() as u16,
234            self.additional_records.len() as u16 + u16::from(self.header.opt.is_some()),
235        )
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use crate::{dns::CLASS, dns::TYPE, SimpleDnsError};
242
243    use super::*;
244    use std::convert::TryInto;
245
246    #[test]
247    fn parse_without_data_should_not_panic() {
248        assert!(matches!(
249            Packet::parse(&[]),
250            Err(SimpleDnsError::InsufficientData)
251        ));
252    }
253
254    #[test]
255    fn build_query_correct() {
256        let mut query = Packet::new_query(1);
257        query.questions.push(Question::new(
258            "_srv._udp.local".try_into().unwrap(),
259            TYPE::TXT.into(),
260            CLASS::IN.into(),
261            false,
262        ));
263        query.questions.push(Question::new(
264            "_srv2._udp.local".try_into().unwrap(),
265            TYPE::TXT.into(),
266            CLASS::IN.into(),
267            false,
268        ));
269
270        let query = query.build_bytes_vec().unwrap();
271
272        let parsed = Packet::parse(&query);
273        assert!(parsed.is_ok());
274
275        let parsed = parsed.unwrap();
276        assert_eq!(2, parsed.questions.len());
277        assert_eq!("_srv._udp.local", parsed.questions[0].qname.to_string());
278        assert_eq!("_srv2._udp.local", parsed.questions[1].qname.to_string());
279    }
280}