1use std::{borrow::Cow, collections::HashMap, convert::TryFrom, fmt::Display, hash::Hash};
2
3use crate::bytes_buffer::BytesBuffer;
4
5use super::{WireFormat, MAX_LABEL_LENGTH, MAX_NAME_LENGTH};
6
7const POINTER_MASK: u8 = 0b1100_0000;
8const POINTER_MASK_U16: u16 = 0b1100_0000_0000_0000;
9
10#[derive(Eq, Clone)]
25pub struct Name<'a> {
26 labels: Vec<Label<'a>>,
27}
28
29impl<'a> Name<'a> {
30 pub fn new(name: &'a str) -> crate::Result<Self> {
32 let labels = LabelsIter::new(name.as_bytes())
33 .map(Label::new)
34 .collect::<Result<Vec<Label>, _>>()?;
35
36 let name = Self { labels };
37
38 if name.len() > MAX_NAME_LENGTH {
39 Err(crate::SimpleDnsError::InvalidServiceName)
40 } else {
41 Ok(name)
42 }
43 }
44
45 pub fn new_unchecked(name: &'a str) -> Self {
47 let labels = LabelsIter::new(name.as_bytes())
48 .map(Label::new_unchecked)
49 .collect();
50
51 Self { labels }
52 }
53
54 pub fn new_with_labels(labels: &[Label<'a>]) -> Self {
58 Self {
59 labels: labels.to_vec(),
60 }
61 }
62
63 pub fn is_link_local(&self) -> bool {
65 match self.iter().last() {
66 Some(label) => b"local".eq_ignore_ascii_case(&label.data),
67 None => false,
68 }
69 }
70
71 pub fn iter(&'a self) -> std::slice::Iter<'a, Label<'a>> {
73 self.labels.iter()
74 }
75
76 pub fn is_subdomain_of(&self, other: &Name) -> bool {
78 self.labels.len() > other.labels.len()
79 && other
80 .iter()
81 .rev()
82 .zip(self.iter().rev())
83 .all(|(o, s)| *o == *s)
84 }
85
86 pub fn without(&self, domain: &Name) -> Option<Name> {
101 if self.is_subdomain_of(domain) {
102 let labels = self.labels[..self.labels.len() - domain.labels.len()].to_vec();
103
104 Some(Name { labels })
105 } else {
106 None
107 }
108 }
109
110 pub fn into_owned<'b>(self) -> Name<'b> {
112 Name {
113 labels: self.labels.into_iter().map(|l| l.into_owned()).collect(),
114 }
115 }
116
117 pub fn get_labels(&'_ self) -> &'_ [Label<'_>] {
119 &self.labels[..]
120 }
121
122 fn plain_append<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
123 for label in self.iter() {
124 out.write_all(&[label.len() as u8])?;
125 out.write_all(&label.data)?;
126 }
127
128 out.write_all(&[0])?;
129 Ok(())
130 }
131
132 fn compress_append<T: std::io::Write + std::io::Seek>(
133 &'a self,
134 out: &mut T,
135 name_refs: &mut HashMap<&'a [Label<'a>], usize>,
136 ) -> crate::Result<()> {
137 for (i, label) in self.iter().enumerate() {
138 match name_refs.entry(&self.labels[i..]) {
139 std::collections::hash_map::Entry::Occupied(e) => {
140 let p = *e.get() as u16;
141 out.write_all(&(p | POINTER_MASK_U16).to_be_bytes())?;
142
143 return Ok(());
144 }
145 std::collections::hash_map::Entry::Vacant(e) => {
146 e.insert(out.stream_position()? as usize);
147 out.write_all(&[label.len() as u8])?;
148 out.write_all(&label.data)?;
149 }
150 }
151 }
152
153 out.write_all(&[0])?;
154 Ok(())
155 }
156
157 pub fn is_valid(&self) -> bool {
159 self.labels.iter().all(|label| label.is_valid())
160 }
161
162 pub fn as_bytes(&self) -> impl Iterator<Item = &[u8]> {
164 self.labels.iter().map(|label| label.as_ref())
165 }
166}
167
168impl<'a> WireFormat<'a> for Name<'a> {
169 const MINIMUM_LEN: usize = 1;
170
171 fn parse(data: &mut BytesBuffer<'a>) -> crate::Result<Self>
172 where
173 Self: Sized,
174 {
175 fn parse_labels<'a>(
178 data: &mut BytesBuffer<'a>,
179 name_len: &mut usize,
180 labels: &mut Vec<Label<'a>>,
181 ) -> crate::Result<Option<usize>> {
182 loop {
183 match data.get_u8()? {
184 0 => break Ok(None),
185 len if len & POINTER_MASK == POINTER_MASK => {
186 let mut pointer = len as u16;
187 pointer <<= 8;
188 pointer += data.get_u8()? as u16;
189 pointer &= !POINTER_MASK_U16;
190
191 break Ok(Some(pointer as usize));
192 }
193 len => {
194 *name_len += 1 + len as usize;
195
196 if *name_len >= MAX_NAME_LENGTH {
198 return Err(crate::SimpleDnsError::InvalidDnsPacket);
199 }
200
201 if len as usize > MAX_LABEL_LENGTH {
202 return Err(crate::SimpleDnsError::InvalidServiceLabel);
203 }
204
205 labels.push(Label::new_unchecked(data.get_slice(len as usize)?));
208 }
209 }
210 }
211 }
212
213 let mut labels = Vec::new();
214 let mut name_len = 0usize;
215
216 let mut pointer = parse_labels(data, &mut name_len, &mut labels)?;
217
218 let mut data = data.clone();
219 while let Some(p) = pointer {
220 data = data.new_at(p)?;
224 pointer = parse_labels(&mut data, &mut name_len, &mut labels)?;
225 }
226
227 Ok(Self { labels })
228 }
229
230 fn write_to<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
231 self.plain_append(out)
232 }
233
234 fn write_compressed_to<T: std::io::Write + std::io::Seek>(
235 &'a self,
236 out: &mut T,
237 name_refs: &mut HashMap<&'a [Label<'a>], usize>,
238 ) -> crate::Result<()> {
239 self.compress_append(out, name_refs)
240 }
241
242 fn len(&self) -> usize {
243 self.labels
244 .iter()
245 .map(|label| label.len() + 1)
246 .sum::<usize>()
247 + Self::MINIMUM_LEN
248 }
249}
250
251impl<'a> TryFrom<&'a str> for Name<'a> {
252 type Error = crate::SimpleDnsError;
253
254 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
255 Name::new(value)
256 }
257}
258
259impl<'a> From<&'a [Label<'a>]> for Name<'a> {
260 fn from(labels: &'a [Label<'a>]) -> Self {
261 Name::new_with_labels(labels)
262 }
263}
264
265impl<'a, const N: usize> From<[Label<'a>; N]> for Name<'a> {
266 fn from(labels: [Label<'a>; N]) -> Self {
267 Name::new_with_labels(&labels)
268 }
269}
270
271impl Display for Name<'_> {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 for (i, label) in self.iter().enumerate() {
274 if i != 0 {
275 f.write_str(".")?;
276 }
277
278 f.write_fmt(format_args!("{}", label))?;
279 }
280
281 Ok(())
282 }
283}
284
285impl std::fmt::Debug for Name<'_> {
286 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287 f.debug_tuple("Name")
288 .field(&format!("{}", self))
289 .field(&format!("{}", self.len()))
290 .finish()
291 }
292}
293
294impl PartialEq for Name<'_> {
295 fn eq(&self, other: &Self) -> bool {
296 self.labels == other.labels
297 }
298}
299
300impl Hash for Name<'_> {
301 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
302 self.labels.hash(state);
303 }
304}
305
306struct LabelsIter<'a> {
308 bytes: &'a [u8],
309 current: usize,
310}
311
312impl<'a> LabelsIter<'a> {
313 fn new(bytes: &'a [u8]) -> Self {
314 Self { bytes, current: 0 }
315 }
316}
317
318impl<'a> Iterator for LabelsIter<'a> {
319 type Item = Cow<'a, [u8]>;
320
321 fn next(&mut self) -> Option<Self::Item> {
322 for i in self.current..self.bytes.len() {
323 if self.bytes[i] == b'.' {
324 let current = std::mem::replace(&mut self.current, i + 1);
325 if i - current == 0 {
326 continue;
327 }
328 return Some(self.bytes[current..i].into());
329 }
330 }
331
332 if self.current < self.bytes.len() {
333 let current = std::mem::replace(&mut self.current, self.bytes.len());
334 Some(self.bytes[current..].into())
335 } else {
336 None
337 }
338 }
339}
340
341#[derive(Eq, PartialEq, Hash, Clone)]
353pub struct Label<'a> {
354 data: Cow<'a, [u8]>,
355}
356
357impl<'a> Label<'a> {
358 pub fn new<T: Into<Cow<'a, [u8]>>>(data: T) -> crate::Result<Self> {
360 let label = Self::new_unchecked(data);
361 if !label.is_valid() {
362 return Err(crate::SimpleDnsError::InvalidServiceLabel);
363 }
364
365 Ok(label)
366 }
367
368 pub fn new_unchecked<T: Into<Cow<'a, [u8]>>>(data: T) -> Self {
371 Self { data: data.into() }
372 }
373
374 pub fn len(&self) -> usize {
376 self.data.len()
377 }
378
379 pub fn is_empty(&self) -> bool {
381 self.data.is_empty()
382 }
383
384 pub fn into_owned<'b>(self) -> Label<'b> {
386 Label {
387 data: self.data.into_owned().into(),
388 }
389 }
390
391 pub fn is_valid(&self) -> bool {
393 if self.data.is_empty() || self.data.len() > MAX_LABEL_LENGTH {
394 return false;
395 }
396
397 if let Some(first) = self.data.first() {
398 if !first.is_ascii_alphanumeric() && *first != b'_' {
399 return false;
400 }
401 }
402
403 if !self
404 .data
405 .iter()
406 .skip(1)
407 .all(|c| c.is_ascii_alphanumeric() || *c == b'-' || *c == b'_')
408 {
409 return false;
410 }
411
412 if let Some(last) = self.data.last() {
413 if !last.is_ascii_alphanumeric() {
414 return false;
415 }
416 }
417
418 true
419 }
420}
421
422impl Display for Label<'_> {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 let s = std::string::String::from_utf8_lossy(&self.data);
425 f.write_str(&s)
426 }
427}
428
429impl std::fmt::Debug for Label<'_> {
430 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431 f.debug_struct("Label")
432 .field("data", &self.to_string())
433 .finish()
434 }
435}
436
437impl AsRef<[u8]> for Label<'_> {
438 fn as_ref(&self) -> &[u8] {
439 self.data.as_ref()
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use std::io::Cursor;
446 use std::{collections::hash_map::DefaultHasher, hash::Hasher};
447
448 use super::*;
449 use crate::SimpleDnsError;
450
451 #[test]
452 fn construct_valid_names() {
453 assert!(Name::new("some").is_ok());
454 assert!(Name::new("some.local").is_ok());
455 assert!(Name::new("some.local.").is_ok());
456 assert!(Name::new("some-dash.local.").is_ok());
457 assert!(Name::new("_sync_miss._tcp.local").is_ok());
458 assert!(Name::new("1sync_miss._tcp.local").is_ok());
459
460 assert_eq!(Name::new_unchecked("\u{1F600}.local.").labels.len(), 2);
461 }
462
463 #[test]
464 fn label_validate() {
465 assert!(Name::new("\u{1F600}.local.").is_err());
466 assert!(Name::new("@.local.").is_err());
467 assert!(Name::new("\\.local.").is_err());
468 }
469
470 #[test]
471 fn is_link_local() {
472 assert!(!Name::new("some.example.com").unwrap().is_link_local());
473 assert!(Name::new("some.example.local.").unwrap().is_link_local());
474 }
475
476 #[test]
477 fn parse_without_compression() {
478 let mut data = BytesBuffer::new(
479 b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\x01F\x03ISI\x04ARPA\x00\x04ARPA\x00",
480 );
481 data.advance(3).unwrap();
482 let name = Name::parse(&mut data).unwrap();
483 assert_eq!("F.ISI.ARPA", name.to_string());
484
485 let name = Name::parse(&mut data).unwrap();
486 assert_eq!("FOO.F.ISI.ARPA", name.to_string());
487 }
488
489 #[test]
490 fn parse_with_compression() {
491 let mut data = BytesBuffer::new(b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03\x07INVALID\xc0\x1b" );
492 data.advance(3).unwrap();
493
494 let name = Name::parse(&mut data).unwrap();
495 assert_eq!("F.ISI.ARPA", name.to_string());
496
497 let name = Name::parse(&mut data).unwrap();
498 assert_eq!("FOO.F.ISI.ARPA", name.to_string());
499
500 let name = Name::parse(&mut data).unwrap();
501 assert_eq!("BAR.F.ISI.ARPA", name.to_string());
502
503 assert!(Name::parse(&mut data).is_err());
504 }
505
506 #[test]
507 fn parse_handle_circular_pointers() {
508 let mut data = BytesBuffer::new(&[249, 0, 37, 1, 1, 139, 192, 6, 1, 1, 1, 139, 192, 6]);
509 data.advance(12).unwrap();
510
511 assert_eq!(
512 Name::parse(&mut data),
513 Err(SimpleDnsError::InvalidDnsPacket)
514 );
515 }
516
517 #[test]
518 fn test_write() {
519 let mut bytes = Cursor::new(Vec::with_capacity(30));
520 Name::new_unchecked("_srv._udp.local")
521 .write_to(&mut bytes)
522 .unwrap();
523
524 assert_eq!(b"\x04_srv\x04_udp\x05local\x00", &bytes.get_ref()[..]);
525
526 let mut bytes = Cursor::new(Vec::with_capacity(30));
527 Name::new_unchecked("_srv._udp.local2.")
528 .write_to(&mut bytes)
529 .unwrap();
530
531 assert_eq!(b"\x04_srv\x04_udp\x06local2\x00", &bytes.get_ref()[..]);
532 }
533
534 #[test]
535 fn root_name_should_generate_no_labels() {
536 assert_eq!(Name::new_unchecked("").labels.len(), 0);
537 assert_eq!(Name::new_unchecked(".").labels.len(), 0);
538 }
539
540 #[test]
541 fn dot_sequence_should_generate_no_labels() {
542 assert_eq!(Name::new_unchecked(".....").labels.len(), 0);
543 assert_eq!(Name::new_unchecked("example.....com").labels.len(), 2);
544 }
545
546 #[test]
547 fn root_name_should_write_zero() {
548 let mut bytes = Cursor::new(Vec::with_capacity(30));
549 Name::new_unchecked(".").write_to(&mut bytes).unwrap();
550
551 assert_eq!(b"\x00", &bytes.get_ref()[..]);
552 }
553
554 #[test]
555 fn append_to_vec_with_compression() {
556 let mut buf = Cursor::new(vec![0, 0, 0]);
557 buf.set_position(3);
558
559 let mut name_refs = HashMap::new();
560
561 let f_isi_arpa = Name::new_unchecked("F.ISI.ARPA");
562 f_isi_arpa
563 .write_compressed_to(&mut buf, &mut name_refs)
564 .expect("failed to add F.ISI.ARPA");
565 let foo_f_isi_arpa = Name::new_unchecked("FOO.F.ISI.ARPA");
566 foo_f_isi_arpa
567 .write_compressed_to(&mut buf, &mut name_refs)
568 .expect("failed to add FOO.F.ISI.ARPA");
569
570 Name::new_unchecked("BAR.F.ISI.ARPA")
571 .write_compressed_to(&mut buf, &mut name_refs)
572 .expect("failed to add FOO.F.ISI.ARPA");
573
574 let data = b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03";
575 assert_eq!(data[..], buf.get_ref()[..]);
576 }
577
578 #[test]
579 fn append_to_vec_with_compression_mult_names() {
580 let mut buf = Cursor::new(vec![]);
581 let mut name_refs = HashMap::new();
582
583 let isi_arpa = Name::new_unchecked("ISI.ARPA");
584 isi_arpa
585 .write_compressed_to(&mut buf, &mut name_refs)
586 .expect("failed to add ISI.ARPA");
587
588 let f_isi_arpa = Name::new_unchecked("F.ISI.ARPA");
589 f_isi_arpa
590 .write_compressed_to(&mut buf, &mut name_refs)
591 .expect("failed to add F.ISI.ARPA");
592 let foo_f_isi_arpa = Name::new_unchecked("FOO.F.ISI.ARPA");
593 foo_f_isi_arpa
594 .write_compressed_to(&mut buf, &mut name_refs)
595 .expect("failed to add F.ISI.ARPA");
596 Name::new_unchecked("BAR.F.ISI.ARPA")
597 .write_compressed_to(&mut buf, &mut name_refs)
598 .expect("failed to add F.ISI.ARPA");
599
600 let expected = b"\x03ISI\x04ARPA\x00\x01F\xc0\x00\x03FOO\xc0\x0a\x03BAR\xc0\x0a";
601 assert_eq!(expected[..], buf.get_ref()[..]);
602
603 let mut data = BytesBuffer::new(buf.get_ref());
604
605 let first = Name::parse(&mut data).unwrap();
606 assert_eq!("ISI.ARPA", first.to_string());
607 let second = Name::parse(&mut data).unwrap();
608 assert_eq!("F.ISI.ARPA", second.to_string());
609 let third = Name::parse(&mut data).unwrap();
610 assert_eq!("FOO.F.ISI.ARPA", third.to_string());
611 let fourth = Name::parse(&mut data).unwrap();
612 assert_eq!("BAR.F.ISI.ARPA", fourth.to_string());
613 }
614
615 #[test]
616 fn ensure_different_domains_are_not_compressed() {
617 let mut buf = Cursor::new(vec![]);
618 let mut name_refs = HashMap::new();
619
620 let foo_bar_baz = Name::new_unchecked("FOO.BAR.BAZ");
621 foo_bar_baz
622 .write_compressed_to(&mut buf, &mut name_refs)
623 .expect("failed to add FOO.BAR.BAZ");
624
625 let foo_bar_buz = Name::new_unchecked("FOO.BAR.BUZ");
626 foo_bar_buz
627 .write_compressed_to(&mut buf, &mut name_refs)
628 .expect("failed to add FOO.BAR.BUZ");
629
630 Name::new_unchecked("FOO.BAR")
631 .write_compressed_to(&mut buf, &mut name_refs)
632 .expect("failed to add FOO.BAR");
633
634 let expected = b"\x03FOO\x03BAR\x03BAZ\x00\x03FOO\x03BAR\x03BUZ\x00\x03FOO\x03BAR\x00";
635 assert_eq!(expected[..], buf.get_ref()[..]);
636 }
637
638 #[test]
639 fn eq_other_name() -> Result<(), SimpleDnsError> {
640 assert_eq!(Name::new("example.com")?, Name::new("example.com")?);
641 assert_ne!(Name::new("some.example.com")?, Name::new("example.com")?);
642 assert_ne!(Name::new("example.co")?, Name::new("example.com")?);
643 assert_ne!(Name::new("example.com.org")?, Name::new("example.com")?);
644
645 let mut data =
646 BytesBuffer::new(b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03");
647 data.advance(3)?;
648 assert_eq!(Name::new("F.ISI.ARPA")?, Name::parse(&mut data)?);
649 assert_eq!(Name::new("FOO.F.ISI.ARPA")?, Name::parse(&mut data)?);
650 Ok(())
651 }
652
653 #[test]
654 fn len() -> crate::Result<()> {
655 let mut bytes = Cursor::new(Vec::new());
656 let name_one = Name::new_unchecked("ex.com.");
657 name_one.write_to(&mut bytes)?;
658
659 assert_eq!(8, bytes.get_ref().len());
660 assert_eq!(bytes.get_ref().len(), name_one.len());
661 assert_eq!(
662 8,
663 Name::parse(&mut BytesBuffer::new(bytes.get_ref()))?.len()
664 );
665
666 let mut name_refs = HashMap::new();
667 let mut bytes = Cursor::new(Vec::new());
668 name_one.write_compressed_to(&mut bytes, &mut name_refs)?;
669 name_one.write_compressed_to(&mut bytes, &mut name_refs)?;
670
671 assert_eq!(10, bytes.get_ref().len());
672 Ok(())
673 }
674
675 #[test]
676 fn hash() -> crate::Result<()> {
677 let mut data =
678 BytesBuffer::new(b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03");
679 data.advance(3)?;
680
681 assert_eq!(
682 get_hash(&Name::new("F.ISI.ARPA")?),
683 get_hash(&Name::parse(&mut data)?)
684 );
685
686 assert_eq!(
687 get_hash(&Name::new("FOO.F.ISI.ARPA")?),
688 get_hash(&Name::parse(&mut data)?)
689 );
690
691 Ok(())
692 }
693
694 fn get_hash(name: &Name) -> u64 {
695 let mut hasher = DefaultHasher::default();
696 name.hash(&mut hasher);
697 hasher.finish()
698 }
699
700 #[test]
701 fn is_subdomain_of() {
702 assert!(Name::new_unchecked("sub.example.com")
703 .is_subdomain_of(&Name::new_unchecked("example.com")));
704
705 assert!(!Name::new_unchecked("example.com")
706 .is_subdomain_of(&Name::new_unchecked("example.com")));
707
708 assert!(Name::new_unchecked("foo.sub.example.com")
709 .is_subdomain_of(&Name::new_unchecked("example.com")));
710
711 assert!(!Name::new_unchecked("example.com")
712 .is_subdomain_of(&Name::new_unchecked("example.xom")));
713
714 assert!(!Name::new_unchecked("domain.com")
715 .is_subdomain_of(&Name::new_unchecked("other.domain")));
716
717 assert!(!Name::new_unchecked("domain.com")
718 .is_subdomain_of(&Name::new_unchecked("domain.com.br")));
719 }
720
721 #[test]
722 fn subtract_domain() {
723 let domain = Name::new_unchecked("_srv3._tcp.local");
724 assert_eq!(
725 Name::new_unchecked("a._srv3._tcp.local")
726 .without(&domain)
727 .unwrap()
728 .to_string(),
729 "a"
730 );
731
732 assert!(Name::new_unchecked("unrelated").without(&domain).is_none(),);
733
734 assert_eq!(
735 Name::new_unchecked("some.longer.domain._srv3._tcp.local")
736 .without(&domain)
737 .unwrap()
738 .to_string(),
739 "some.longer.domain"
740 );
741 }
742
743 #[test]
744 fn display_invalid_label() {
745 let input = b"invalid\xF0\x90\x80label";
746 let label = Label::new_unchecked(input);
747
748 assert_eq!(label.to_string(), "invalid�label");
749 }
750}