1use crate::{bytes_buffer::BytesBuffer, QCLASS, QTYPE};
2
3use super::{name::Label, rdata::RData, Name, WireFormat, CLASS, TYPE};
4use core::fmt::Debug;
5use std::{collections::HashMap, convert::TryInto, hash::Hash};
6
7mod flag {
8 pub const CACHE_FLUSH: u16 = 0b1000_0000_0000_0000;
9}
10#[derive(Debug, Eq, Clone)]
12pub struct ResourceRecord<'a> {
13 pub name: Name<'a>,
15 pub class: CLASS,
17 pub ttl: u32,
20 pub rdata: RData<'a>,
22
23 pub cache_flush: bool,
25}
26
27impl<'a> ResourceRecord<'a> {
28 pub fn new(name: Name<'a>, class: CLASS, ttl: u32, rdata: RData<'a>) -> Self {
30 Self {
31 name,
32 class,
33 ttl,
34 rdata,
35 cache_flush: false,
36 }
37 }
38
39 pub fn with_cache_flush(mut self, cache_flush: bool) -> Self {
41 self.cache_flush = cache_flush;
42 self
43 }
44
45 pub fn to_cache_flush_record(&self) -> Self {
47 self.clone().with_cache_flush(true)
48 }
49
50 pub fn match_qclass(&self, qclass: QCLASS) -> bool {
52 match qclass {
53 QCLASS::CLASS(class) => class == self.class,
54 QCLASS::ANY => true,
55 }
56 }
57
58 pub fn match_qtype(&self, qtype: QTYPE) -> bool {
60 let type_code = self.rdata.type_code();
61 match qtype {
62 QTYPE::ANY => true,
63 QTYPE::IXFR => false,
64 QTYPE::AXFR => true, QTYPE::MAILB => type_code == TYPE::MR || type_code == TYPE::MB || type_code == TYPE::MG,
66 QTYPE::MAILA => type_code == TYPE::MX,
67 QTYPE::TYPE(ty) => ty == type_code,
68 }
69 }
70
71 pub fn into_owned<'b>(self) -> ResourceRecord<'b> {
73 ResourceRecord {
74 name: self.name.into_owned(),
75 class: self.class,
76 ttl: self.ttl,
77 rdata: self.rdata.into_owned(),
78 cache_flush: self.cache_flush,
79 }
80 }
81
82 fn write_common<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
83 out.write_all(&u16::from(self.rdata.type_code()).to_be_bytes())?;
84
85 if let RData::OPT(ref opt) = self.rdata {
86 out.write_all(&opt.udp_packet_size.to_be_bytes())?;
87 } else {
88 let class = if self.cache_flush {
89 ((self.class as u16) | flag::CACHE_FLUSH).to_be_bytes()
90 } else {
91 (self.class as u16).to_be_bytes()
92 };
93
94 out.write_all(&class)?;
95 }
96
97 out.write_all(&self.ttl.to_be_bytes())
98 .map_err(crate::SimpleDnsError::from)
99 }
100}
101
102impl<'a> WireFormat<'a> for ResourceRecord<'a> {
103 const MINIMUM_LEN: usize = 10;
104
105 fn parse(data: &mut BytesBuffer<'a>) -> crate::Result<Self>
107 where
108 Self: Sized,
109 {
110 let name = Name::parse(data)?;
111
112 let class_value = data.peek_u16_in(2)?;
113 let ttl = data.peek_u32_in(4)?;
114
115 let rdata = RData::parse(data)?;
116
117 if rdata.type_code() == TYPE::OPT {
118 Ok(Self {
119 name,
120 class: CLASS::IN,
121 ttl,
122 rdata,
123 cache_flush: false,
124 })
125 } else {
126 let cache_flush = class_value & flag::CACHE_FLUSH == flag::CACHE_FLUSH;
127 let class = (class_value & !flag::CACHE_FLUSH).try_into()?;
128
129 Ok(Self {
130 name,
131 class,
132 ttl,
133 rdata,
134 cache_flush,
135 })
136 }
137 }
138
139 fn len(&self) -> usize {
140 self.name.len() + self.rdata.len() + Self::MINIMUM_LEN
141 }
142
143 fn write_to<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
144 self.name.write_to(out)?;
145 self.write_common(out)?;
146 out.write_all(&(self.rdata.len() as u16).to_be_bytes())?;
147 self.rdata.write_to(out)
148 }
149
150 fn write_compressed_to<T: std::io::Write + std::io::Seek>(
151 &'a self,
152 out: &mut T,
153 name_refs: &mut HashMap<&'a [Label<'a>], usize>,
154 ) -> crate::Result<()> {
155 self.name.write_compressed_to(out, name_refs)?;
156 self.write_common(out)?;
157
158 let len_position = out.stream_position()?;
159 out.write_all(&[0, 0])?;
160
161 self.rdata.write_compressed_to(out, name_refs)?;
162 let end = out.stream_position()?;
163
164 out.seek(std::io::SeekFrom::Start(len_position))?;
165 out.write_all(&((end - len_position - 2) as u16).to_be_bytes())?;
166 out.seek(std::io::SeekFrom::End(0))?;
167 Ok(())
168 }
169}
170
171impl Hash for ResourceRecord<'_> {
172 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
173 self.name.hash(state);
174 self.class.hash(state);
175 self.rdata.hash(state);
176 }
177}
178
179impl PartialEq for ResourceRecord<'_> {
180 fn eq(&self, other: &Self) -> bool {
181 self.name == other.name && self.class == other.class && self.rdata == other.rdata
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use std::{
188 collections::hash_map::DefaultHasher,
189 hash::{Hash, Hasher},
190 io::Cursor,
191 };
192
193 use crate::{dns::rdata::NULL, rdata::TXT};
194
195 use super::*;
196
197 #[test]
198 fn test_parse() {
199 let bytes = b"\x04_srv\x04_udp\x05local\x00\x00\x01\x00\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff";
200 let rr = ResourceRecord::parse(&mut BytesBuffer::new(bytes)).unwrap();
201
202 assert_eq!("_srv._udp.local", rr.name.to_string());
203 assert_eq!(CLASS::IN, rr.class);
204 assert_eq!(10, rr.ttl);
205 assert_eq!(4, rr.rdata.len());
206 assert!(!rr.cache_flush);
207
208 match rr.rdata {
209 RData::A(a) => assert_eq!(4294967295, a.address),
210 _ => panic!("invalid rdata"),
211 }
212 }
213
214 #[test]
215 fn test_empty_rdata() {
216 let rr = ResourceRecord {
217 class: CLASS::NONE,
218 name: "_srv._udp.local".try_into().unwrap(),
219 ttl: 0,
220 rdata: RData::Empty(TYPE::A),
221 cache_flush: false,
222 };
223
224 assert_eq!(rr.rdata.type_code(), TYPE::A);
225 assert_eq!(rr.rdata.len(), 0);
226
227 let mut data = Vec::new();
228 rr.write_to(&mut data).expect("failed to write");
229
230 let parsed_rr =
231 ResourceRecord::parse(&mut BytesBuffer::new(&data)).expect("failed to parse");
232 assert_eq!(parsed_rr.rdata.type_code(), TYPE::A);
233 assert_eq!(parsed_rr.rdata.len(), 0);
234 assert!(matches!(parsed_rr.rdata, RData::Empty(TYPE::A)));
235 }
236
237 #[test]
238 fn test_cache_flush_parse() {
239 let bytes = b"\x04_srv\x04_udp\x05local\x00\x00\x01\x80\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff";
240 let rr = ResourceRecord::parse(&mut BytesBuffer::new(bytes)).unwrap();
241
242 assert_eq!(CLASS::IN, rr.class);
243 assert!(rr.cache_flush);
244 }
245
246 #[test]
247 fn test_write() {
248 let mut out = Cursor::new(Vec::new());
249 let rdata = [255u8; 4];
250
251 let rr = ResourceRecord {
252 class: CLASS::IN,
253 name: "_srv._udp.local".try_into().unwrap(),
254 ttl: 10,
255 rdata: RData::NULL(0, NULL::new(&rdata).unwrap()),
256 cache_flush: false,
257 };
258
259 assert!(rr.write_to(&mut out).is_ok());
260 assert_eq!(
261 b"\x04_srv\x04_udp\x05local\x00\x00\x00\x00\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff",
262 &out.get_ref()[..]
263 );
264 assert_eq!(out.get_ref().len(), rr.len());
265 }
266
267 #[test]
268 fn test_append_to_vec_cache_flush() {
269 let mut out = Cursor::new(Vec::new());
270 let rdata = [255u8; 4];
271
272 let rr = ResourceRecord {
273 class: CLASS::IN,
274 name: "_srv._udp.local".try_into().unwrap(),
275 ttl: 10,
276 rdata: RData::NULL(0, NULL::new(&rdata).unwrap()),
277 cache_flush: true,
278 };
279
280 assert!(rr.write_to(&mut out).is_ok());
281 assert_eq!(
282 b"\x04_srv\x04_udp\x05local\x00\x00\x00\x80\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff",
283 &out.get_ref()[..]
284 );
285 assert_eq!(out.get_ref().len(), rr.len());
286 }
287
288 #[test]
289 fn test_match_qclass() {
290 let rr = ResourceRecord {
291 class: CLASS::IN,
292 name: "_srv._udp.local".try_into().unwrap(),
293 ttl: 10,
294 rdata: RData::NULL(0, NULL::new(&[255u8; 4]).unwrap()),
295 cache_flush: false,
296 };
297
298 assert!(rr.match_qclass(QCLASS::ANY));
299 assert!(rr.match_qclass(CLASS::IN.into()));
300 assert!(!rr.match_qclass(CLASS::CS.into()));
301 }
302
303 #[test]
304 fn test_match_qtype() {
305 let rr = ResourceRecord {
306 class: CLASS::IN,
307 name: "_srv._udp.local".try_into().unwrap(),
308 ttl: 10,
309 rdata: RData::A(crate::rdata::A { address: 0 }),
310 cache_flush: false,
311 };
312
313 assert!(rr.match_qtype(QTYPE::ANY));
314 assert!(rr.match_qtype(TYPE::A.into()));
315 assert!(!rr.match_qtype(TYPE::WKS.into()));
316 }
317
318 #[test]
319 fn test_eq() {
320 let a = ResourceRecord::new(
321 Name::new_unchecked("_srv.local"),
322 CLASS::IN,
323 10,
324 RData::TXT(TXT::new().with_string("text").unwrap()),
325 );
326 let b = ResourceRecord::new(
327 Name::new_unchecked("_srv.local"),
328 CLASS::IN,
329 10,
330 RData::TXT(TXT::new().with_string("text").unwrap()),
331 );
332
333 assert_eq!(a, b);
334 assert_eq!(get_hash(&a), get_hash(&b));
335 }
336
337 #[test]
338 fn test_hash_ignore_ttl() {
339 let a = ResourceRecord::new(
340 Name::new_unchecked("_srv.local"),
341 CLASS::IN,
342 10,
343 RData::TXT(TXT::new().with_string("text").unwrap()),
344 );
345 let mut b = ResourceRecord::new(
346 Name::new_unchecked("_srv.local"),
347 CLASS::IN,
348 10,
349 RData::TXT(TXT::new().with_string("text").unwrap()),
350 );
351
352 assert_eq!(get_hash(&a), get_hash(&b));
353 b.ttl = 50;
354
355 assert_eq!(get_hash(&a), get_hash(&b));
356 }
357
358 fn get_hash(rr: &ResourceRecord) -> u64 {
359 let mut hasher = DefaultHasher::default();
360 rr.hash(&mut hasher);
361 hasher.finish()
362 }
363
364 #[test]
365 fn parse_sample_files() -> Result<(), Box<dyn std::error::Error>> {
366 for file_path in std::fs::read_dir("samples/zonefile")? {
367 let bytes = std::fs::read(file_path?.path())?;
368 let mut data = BytesBuffer::new(&bytes);
369 while data.has_remaining() {
370 crate::ResourceRecord::parse(&mut data)?;
371 }
372 }
373
374 Ok(())
375 }
376}