simple_dns/dns/
question.rs1use std::{collections::HashMap, convert::TryFrom};
2
3use crate::bytes_buffer::BytesBuffer;
4
5use super::{name::Label, Name, WireFormat, QCLASS, QTYPE};
6
7#[derive(Debug, Clone)]
9pub struct Question<'a> {
10 pub qname: Name<'a>,
12 pub qtype: QTYPE,
14 pub qclass: QCLASS,
16 pub unicast_response: bool,
19}
20
21impl<'a> Question<'a> {
22 pub fn new(qname: Name<'a>, qtype: QTYPE, qclass: QCLASS, unicast_response: bool) -> Self {
24 Self {
25 qname,
26 qtype,
27 qclass,
28 unicast_response,
29 }
30 }
31
32 pub fn into_owned<'b>(self) -> Question<'b> {
34 Question {
35 qname: self.qname.into_owned(),
36 qtype: self.qtype,
37 qclass: self.qclass,
38 unicast_response: self.unicast_response,
39 }
40 }
41
42 fn write_common<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
43 let qclass: u16 = match self.unicast_response {
44 true => Into::<u16>::into(self.qclass) | 0x8000,
45 false => self.qclass.into(),
46 };
47
48 out.write_all(&Into::<u16>::into(self.qtype).to_be_bytes())?;
49 out.write_all(&qclass.to_be_bytes())
50 .map_err(crate::SimpleDnsError::from)
51 }
52}
53
54impl<'a> WireFormat<'a> for Question<'a> {
55 const MINIMUM_LEN: usize = 4;
56
57 fn parse(data: &mut BytesBuffer<'a>) -> crate::Result<Self> {
59 let qname = Name::parse(data)?;
60
61 let qtype = data.get_u16()?;
62 let qclass = data.get_u16()?;
63
64 Ok(Self {
65 qname,
66 qtype: QTYPE::try_from(qtype)?,
67 qclass: QCLASS::try_from(qclass & 0x7FFF)?,
68 unicast_response: qclass & 0x8000 == 0x8000,
69 })
70 }
71
72 fn len(&self) -> usize {
73 self.qname.len() + Self::MINIMUM_LEN
74 }
75
76 fn write_to<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
77 self.qname.write_to(out)?;
78 self.write_common(out)
79 }
80
81 fn write_compressed_to<T: std::io::Write + std::io::Seek>(
82 &'a self,
83 out: &mut T,
84 name_refs: &mut HashMap<&'a [Label<'a>], usize>,
85 ) -> crate::Result<()> {
86 self.qname.write_compressed_to(out, name_refs)?;
87 self.write_common(out)
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use crate::{CLASS, TYPE};
94
95 use super::*;
96 use std::convert::TryInto;
97
98 #[test]
99 fn parse_question() {
100 let mut bytes = BytesBuffer::new(b"\x00\x00\x04_srv\x04_udp\x05local\x00\x00\x10\x00\x01");
101 bytes.advance(2).unwrap();
102 let question = Question::parse(&mut bytes);
103
104 assert!(question.is_ok());
105 let question = question.unwrap();
106
107 assert_eq!(QCLASS::CLASS(CLASS::IN), question.qclass);
108 assert_eq!(QTYPE::TYPE(TYPE::TXT), question.qtype);
109 assert!(!question.unicast_response);
110 }
111
112 #[test]
113 fn append_to_vec() {
114 let question = Question::new(
115 "_srv._udp.local".try_into().unwrap(),
116 TYPE::TXT.into(),
117 CLASS::IN.into(),
118 false,
119 );
120 let mut bytes = Vec::new();
121 question.write_to(&mut bytes).unwrap();
122
123 assert_eq!(b"\x04_srv\x04_udp\x05local\x00\x00\x10\x00\x01", &bytes[..]);
124 assert_eq!(bytes.len(), question.len());
125 }
126
127 #[test]
128 fn unicast_response() {
129 let mut bytes = Vec::new();
130 Question::new(
131 "x.local".try_into().unwrap(),
132 TYPE::TXT.into(),
133 CLASS::IN.into(),
134 true,
135 )
136 .write_to(&mut bytes)
137 .unwrap();
138 let parsed = Question::parse(&mut BytesBuffer::new(&bytes)).unwrap();
139
140 assert!(parsed.unicast_response);
141 }
142}