1use std::{
2 collections::HashMap,
3 io::{Cursor, Seek, Write},
4};
5
6use super::{Header, PacketFlag, Question, ResourceRecord, WireFormat, OPCODE};
7use crate::{bytes_buffer::BytesBuffer, rdata::OPT, RCODE};
8
9#[derive(Debug, Clone)]
13pub struct Packet<'a> {
14 header: Header<'a>,
16 pub questions: Vec<Question<'a>>,
18 pub answers: Vec<ResourceRecord<'a>>,
20 pub name_servers: Vec<ResourceRecord<'a>>,
22 pub additional_records: Vec<ResourceRecord<'a>>,
25}
26
27impl<'a> Packet<'a> {
28 pub fn new_query(id: u16) -> Self {
30 Self {
31 header: Header::new_query(id),
32 questions: Vec::new(),
33 answers: Vec::new(),
34 name_servers: Vec::new(),
35 additional_records: Vec::new(),
36 }
37 }
38
39 pub fn new_reply(id: u16) -> Self {
41 Self {
42 header: Header::new_reply(id, OPCODE::StandardQuery),
43 questions: Vec::new(),
44 answers: Vec::new(),
45 name_servers: Vec::new(),
46 additional_records: Vec::new(),
47 }
48 }
49
50 pub fn id(&self) -> u16 {
52 self.header.id
53 }
54
55 pub fn set_id(&mut self, id: u16) {
57 self.header.id = id;
58 }
59
60 pub fn set_flags(&mut self, flags: PacketFlag) {
62 self.header.set_flags(flags);
63 }
64
65 pub fn remove_flags(&mut self, flags: PacketFlag) {
67 self.header.remove_flags(flags)
68 }
69
70 pub fn has_flags(&self, flags: PacketFlag) -> bool {
72 self.header.has_flags(flags)
73 }
74
75 pub fn rcode(&self) -> RCODE {
77 self.header.response_code
78 }
79
80 pub fn rcode_mut(&mut self) -> &mut RCODE {
84 &mut self.header.response_code
85 }
86
87 pub fn opcode(&self) -> OPCODE {
89 self.header.opcode
90 }
91
92 pub fn opcode_mut(&mut self) -> &mut OPCODE {
94 &mut self.header.opcode
95 }
96
97 pub fn opt(&self) -> Option<&OPT<'a>> {
99 self.header.opt.as_ref()
100 }
101
102 pub fn opt_mut(&mut self) -> &mut Option<OPT<'a>> {
104 &mut self.header.opt
105 }
106
107 pub fn into_reply(mut self) -> Self {
109 self.header = Header::new_reply(self.header.id, self.header.opcode);
110 self
111 }
112
113 pub fn parse(data: &'a [u8]) -> crate::Result<Self> {
115 let mut data = BytesBuffer::new(data);
116 let mut header = Header::parse(&mut data)?;
117
118 let questions = Self::parse_section(&mut data, header.questions)?;
119 let answers = Self::parse_section(&mut data, header.answers)?;
120 let name_servers = Self::parse_section(&mut data, header.name_servers)?;
121 let mut additional_records: Vec<ResourceRecord> =
122 Self::parse_section(&mut data, header.additional_records)?;
123
124 header.extract_info_from_opt_rr(
125 additional_records
126 .iter()
127 .position(|rr| rr.rdata.type_code() == crate::TYPE::OPT)
128 .map(|i| additional_records.remove(i)),
129 );
130
131 Ok(Self {
132 header,
133 questions,
134 answers,
135 name_servers,
136 additional_records,
137 })
138 }
139
140 fn parse_section<T: WireFormat<'a>>(
141 data: &mut BytesBuffer<'a>,
142 items_count: u16,
143 ) -> crate::Result<Vec<T>> {
144 let mut section_items = Vec::with_capacity(items_count as usize);
145
146 for _ in 0..items_count {
147 section_items.push(T::parse(data)?);
148 }
149
150 Ok(section_items)
151 }
152
153 pub fn build_bytes_vec(&self) -> crate::Result<Vec<u8>> {
157 let mut out = Cursor::new(Vec::with_capacity(900));
158
159 self.write_to(&mut out)?;
160
161 Ok(out.into_inner())
162 }
163
164 pub fn build_bytes_vec_compressed(&self) -> crate::Result<Vec<u8>> {
169 let mut out = Cursor::new(Vec::with_capacity(900));
170 self.write_compressed_to(&mut out)?;
171
172 Ok(out.into_inner())
173 }
174
175 pub fn write_to<T: Write>(&self, out: &mut T) -> crate::Result<()> {
177 self.write_header(out)?;
178
179 for e in &self.questions {
180 e.write_to(out)?;
181 }
182 for e in &self.answers {
183 e.write_to(out)?;
184 }
185 for e in &self.name_servers {
186 e.write_to(out)?;
187 }
188
189 if let Some(rr) = self.header.opt_rr() {
190 rr.write_to(out)?;
191 }
192
193 for e in &self.additional_records {
194 e.write_to(out)?;
195 }
196
197 out.flush()?;
198 Ok(())
199 }
200
201 pub fn write_compressed_to<T: Write + Seek>(&self, out: &mut T) -> crate::Result<()> {
203 self.write_header(out)?;
204
205 let mut name_refs = HashMap::new();
206 for e in &self.questions {
207 e.write_compressed_to(out, &mut name_refs)?;
208 }
209 for e in &self.answers {
210 e.write_compressed_to(out, &mut name_refs)?;
211 }
212 for e in &self.name_servers {
213 e.write_compressed_to(out, &mut name_refs)?;
214 }
215
216 if let Some(rr) = self.header.opt_rr() {
217 rr.write_to(out)?;
218 }
219
220 for e in &self.additional_records {
221 e.write_compressed_to(out, &mut name_refs)?;
222 }
223 out.flush()?;
224
225 Ok(())
226 }
227
228 fn write_header<T: Write>(&self, out: &mut T) -> crate::Result<()> {
229 self.header.write_to(
230 out,
231 self.questions.len() as u16,
232 self.answers.len() as u16,
233 self.name_servers.len() as u16,
234 self.additional_records.len() as u16 + u16::from(self.header.opt.is_some()),
235 )
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use crate::{dns::CLASS, dns::TYPE, SimpleDnsError};
242
243 use super::*;
244 use std::convert::TryInto;
245
246 #[test]
247 fn parse_without_data_should_not_panic() {
248 assert!(matches!(
249 Packet::parse(&[]),
250 Err(SimpleDnsError::InsufficientData)
251 ));
252 }
253
254 #[test]
255 fn build_query_correct() {
256 let mut query = Packet::new_query(1);
257 query.questions.push(Question::new(
258 "_srv._udp.local".try_into().unwrap(),
259 TYPE::TXT.into(),
260 CLASS::IN.into(),
261 false,
262 ));
263 query.questions.push(Question::new(
264 "_srv2._udp.local".try_into().unwrap(),
265 TYPE::TXT.into(),
266 CLASS::IN.into(),
267 false,
268 ));
269
270 let query = query.build_bytes_vec().unwrap();
271
272 let parsed = Packet::parse(&query);
273 assert!(parsed.is_ok());
274
275 let parsed = parsed.unwrap();
276 assert_eq!(2, parsed.questions.len());
277 assert_eq!("_srv._udp.local", parsed.questions[0].qname.to_string());
278 assert_eq!("_srv2._udp.local", parsed.questions[1].qname.to_string());
279 }
280}