1use alloc::{boxed::Box, fmt, vec::Vec};
11use core::{iter, mem, ops::Deref};
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use tracing::{debug, warn};
16
17use crate::{
18 error::*,
19 op::{Edns, Header, MessageType, OpCode, Query, ResponseCode},
20 rr::{Record, RecordType},
21 serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, EncodeMode},
22 xfer::DnsResponse,
23};
24
25#[derive(Clone, Debug, PartialEq, Eq, Default)]
68#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
69pub struct Message {
70 header: Header,
71 queries: Vec<Query>,
72 answers: Vec<Record>,
73 name_servers: Vec<Record>,
74 additionals: Vec<Record>,
75 signature: Vec<Record>,
76 edns: Option<Edns>,
77}
78
79pub fn update_header_counts(
81 current_header: &Header,
82 is_truncated: bool,
83 counts: HeaderCounts,
84) -> Header {
85 assert!(counts.query_count <= u16::MAX as usize);
86 assert!(counts.answer_count <= u16::MAX as usize);
87 assert!(counts.nameserver_count <= u16::MAX as usize);
88 assert!(counts.additional_count <= u16::MAX as usize);
89
90 let mut header = *current_header;
92 header
93 .set_query_count(counts.query_count as u16)
94 .set_answer_count(counts.answer_count as u16)
95 .set_name_server_count(counts.nameserver_count as u16)
96 .set_additional_count(counts.additional_count as u16)
97 .set_truncated(is_truncated);
98
99 header
100}
101
102#[derive(Clone, Copy, Debug)]
106pub struct HeaderCounts {
107 pub query_count: usize,
109 pub answer_count: usize,
111 pub nameserver_count: usize,
113 pub additional_count: usize,
115}
116
117impl Message {
118 pub fn new() -> Self {
120 Self {
121 header: Header::new(),
122 queries: Vec::new(),
123 answers: Vec::new(),
124 name_servers: Vec::new(),
125 additionals: Vec::new(),
126 signature: Vec::new(),
127 edns: None,
128 }
129 }
130
131 pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Self {
139 let mut message = Self::new();
140 message
141 .set_message_type(MessageType::Response)
142 .set_id(id)
143 .set_response_code(response_code)
144 .set_op_code(op_code);
145
146 message
147 }
148
149 pub fn truncate(&self) -> Self {
151 let mut header = self.header;
153 header.set_truncated(true);
154 header
155 .set_additional_count(0)
156 .set_answer_count(0)
157 .set_name_server_count(0);
158
159 let mut msg = Self::new();
160 msg.add_queries(self.queries().iter().cloned());
163 if let Some(edns) = self.extensions().clone() {
164 msg.set_edns(edns);
165 }
166 msg.set_header(header);
168
169 msg
171 }
172
173 pub fn set_header(&mut self, header: Header) -> &mut Self {
175 self.header = header;
176 self
177 }
178
179 pub fn set_id(&mut self, id: u16) -> &mut Self {
181 self.header.set_id(id);
182 self
183 }
184
185 pub fn set_message_type(&mut self, message_type: MessageType) -> &mut Self {
187 self.header.set_message_type(message_type);
188 self
189 }
190
191 pub fn set_op_code(&mut self, op_code: OpCode) -> &mut Self {
193 self.header.set_op_code(op_code);
194 self
195 }
196
197 pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
199 self.header.set_authoritative(authoritative);
200 self
201 }
202
203 pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
205 self.header.set_truncated(truncated);
206 self
207 }
208
209 pub fn set_recursion_desired(&mut self, recursion_desired: bool) -> &mut Self {
211 self.header.set_recursion_desired(recursion_desired);
212 self
213 }
214
215 pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
217 self.header.set_recursion_available(recursion_available);
218 self
219 }
220
221 pub fn set_authentic_data(&mut self, authentic_data: bool) -> &mut Self {
223 self.header.set_authentic_data(authentic_data);
224 self
225 }
226
227 pub fn set_checking_disabled(&mut self, checking_disabled: bool) -> &mut Self {
229 self.header.set_checking_disabled(checking_disabled);
230 self
231 }
232
233 pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
235 self.header.set_response_code(response_code);
236 self
237 }
238
239 pub fn set_query_count(&mut self, query_count: u16) -> &mut Self {
244 self.header.set_query_count(query_count);
245 self
246 }
247
248 pub fn set_answer_count(&mut self, answer_count: u16) -> &mut Self {
253 self.header.set_answer_count(answer_count);
254 self
255 }
256
257 pub fn set_name_server_count(&mut self, name_server_count: u16) -> &mut Self {
262 self.header.set_name_server_count(name_server_count);
263 self
264 }
265
266 pub fn set_additional_count(&mut self, additional_count: u16) -> &mut Self {
271 self.header.set_additional_count(additional_count);
272 self
273 }
274
275 pub fn add_query(&mut self, query: Query) -> &mut Self {
277 self.queries.push(query);
278 self
279 }
280
281 pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
283 where
284 Q: IntoIterator<Item = Query, IntoIter = I>,
285 I: Iterator<Item = Query>,
286 {
287 for query in queries {
288 self.add_query(query);
289 }
290
291 self
292 }
293
294 pub fn add_answer(&mut self, record: Record) -> &mut Self {
296 self.answers.push(record);
297 self
298 }
299
300 pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
302 where
303 R: IntoIterator<Item = Record, IntoIter = I>,
304 I: Iterator<Item = Record>,
305 {
306 for record in records {
307 self.add_answer(record);
308 }
309
310 self
311 }
312
313 pub fn insert_answers(&mut self, records: Vec<Record>) {
319 assert!(self.answers.is_empty());
320 self.answers = records;
321 }
322
323 pub fn add_name_server(&mut self, record: Record) -> &mut Self {
325 self.name_servers.push(record);
326 self
327 }
328
329 pub fn add_name_servers<R, I>(&mut self, records: R) -> &mut Self
331 where
332 R: IntoIterator<Item = Record, IntoIter = I>,
333 I: Iterator<Item = Record>,
334 {
335 for record in records {
336 self.add_name_server(record);
337 }
338
339 self
340 }
341
342 pub fn insert_name_servers(&mut self, records: Vec<Record>) {
348 assert!(self.name_servers.is_empty());
349 self.name_servers = records;
350 }
351
352 pub fn add_additional(&mut self, record: Record) -> &mut Self {
354 self.additionals.push(record);
355 self
356 }
357
358 pub fn add_additionals<R, I>(&mut self, records: R) -> &mut Self
360 where
361 R: IntoIterator<Item = Record, IntoIter = I>,
362 I: Iterator<Item = Record>,
363 {
364 for record in records {
365 self.add_additional(record);
366 }
367
368 self
369 }
370
371 pub fn insert_additionals(&mut self, records: Vec<Record>) {
377 assert!(self.additionals.is_empty());
378 self.additionals = records;
379 }
380
381 pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
383 self.edns = Some(edns);
384 self
385 }
386
387 #[cfg(feature = "__dnssec")]
391 pub fn add_sig0(&mut self, record: Record) -> &mut Self {
392 assert_eq!(RecordType::SIG, record.record_type());
393 self.signature.push(record);
394 self
395 }
396
397 #[cfg(feature = "__dnssec")]
401 pub fn add_tsig(&mut self, record: Record) -> &mut Self {
402 assert_eq!(RecordType::TSIG, record.record_type());
403 self.signature.push(record);
404 self
405 }
406
407 pub fn header(&self) -> &Header {
409 &self.header
410 }
411
412 pub fn id(&self) -> u16 {
414 self.header.id()
415 }
416
417 pub fn message_type(&self) -> MessageType {
419 self.header.message_type()
420 }
421
422 pub fn op_code(&self) -> OpCode {
424 self.header.op_code()
425 }
426
427 pub fn authoritative(&self) -> bool {
429 self.header.authoritative()
430 }
431
432 pub fn truncated(&self) -> bool {
434 self.header.truncated()
435 }
436
437 pub fn recursion_desired(&self) -> bool {
439 self.header.recursion_desired()
440 }
441
442 pub fn recursion_available(&self) -> bool {
444 self.header.recursion_available()
445 }
446
447 pub fn authentic_data(&self) -> bool {
449 self.header.authentic_data()
450 }
451
452 pub fn checking_disabled(&self) -> bool {
454 self.header.checking_disabled()
455 }
456
457 pub fn response_code(&self) -> ResponseCode {
462 self.header.response_code()
463 }
464
465 pub fn query(&self) -> Option<&Query> {
470 self.queries.first()
471 }
472
473 pub fn queries(&self) -> &[Query] {
477 &self.queries
478 }
479
480 pub fn queries_mut(&mut self) -> &mut Vec<Query> {
482 &mut self.queries
483 }
484
485 pub fn take_queries(&mut self) -> Vec<Query> {
487 mem::take(&mut self.queries)
488 }
489
490 pub fn answers(&self) -> &[Record] {
494 &self.answers
495 }
496
497 pub fn answers_mut(&mut self) -> &mut Vec<Record> {
499 &mut self.answers
500 }
501
502 pub fn take_answers(&mut self) -> Vec<Record> {
504 mem::take(&mut self.answers)
505 }
506
507 pub fn name_servers(&self) -> &[Record] {
513 &self.name_servers
514 }
515
516 pub fn name_servers_mut(&mut self) -> &mut Vec<Record> {
518 &mut self.name_servers
519 }
520
521 pub fn take_name_servers(&mut self) -> Vec<Record> {
523 mem::take(&mut self.name_servers)
524 }
525
526 pub fn additionals(&self) -> &[Record] {
531 &self.additionals
532 }
533
534 pub fn additionals_mut(&mut self) -> &mut Vec<Record> {
536 &mut self.additionals
537 }
538
539 pub fn take_additionals(&mut self) -> Vec<Record> {
541 mem::take(&mut self.additionals)
542 }
543
544 pub fn all_sections(&self) -> impl Iterator<Item = &Record> {
546 self.answers
547 .iter()
548 .chain(self.name_servers().iter())
549 .chain(self.additionals.iter())
550 }
551
552 #[deprecated(note = "Please use `extensions()`")]
582 pub fn edns(&self) -> Option<&Edns> {
583 self.edns.as_ref()
584 }
585
586 #[deprecated(
588 note = "Please use `extensions_mut()`. You can chain `.get_or_insert_with(Edns::new)` to recover original behavior of adding Edns if not present"
589 )]
590 pub fn edns_mut(&mut self) -> &mut Edns {
591 if self.edns.is_none() {
592 self.set_edns(Edns::new());
593 }
594 self.edns.as_mut().unwrap()
595 }
596
597 pub fn extensions(&self) -> &Option<Edns> {
599 &self.edns
600 }
601
602 pub fn extensions_mut(&mut self) -> &mut Option<Edns> {
604 &mut self.edns
605 }
606
607 pub fn max_payload(&self) -> u16 {
611 let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
612 if max_size < 512 { 512 } else { max_size }
613 }
614
615 pub fn version(&self) -> u8 {
619 self.edns.as_ref().map_or(0, Edns::version)
620 }
621
622 pub fn sig0(&self) -> &[Record] {
639 &self.signature
640 }
641
642 pub fn signature(&self) -> &[Record] {
657 &self.signature
658 }
659
660 pub fn take_signature(&mut self) -> Vec<Record> {
662 mem::take(&mut self.signature)
663 }
664
665 #[cfg(test)]
669 pub fn update_counts(&mut self) -> &mut Self {
670 self.header = update_header_counts(
671 &self.header,
672 self.truncated(),
673 HeaderCounts {
674 query_count: self.queries.len(),
675 answer_count: self.answers.len(),
676 nameserver_count: self.name_servers.len(),
677 additional_count: self.additionals.len(),
678 },
679 );
680 self
681 }
682
683 pub fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<Query>> {
685 let mut queries = Vec::with_capacity(count);
686 for _ in 0..count {
687 queries.push(Query::read(decoder)?);
688 }
689 Ok(queries)
690 }
691
692 #[cfg_attr(not(feature = "__dnssec"), allow(unused_mut))]
698 pub fn read_records(
699 decoder: &mut BinDecoder<'_>,
700 count: usize,
701 is_additional: bool,
702 ) -> ProtoResult<(Vec<Record>, Option<Edns>, Vec<Record>)> {
703 let mut records: Vec<Record> = Vec::with_capacity(count);
704 let mut edns: Option<Edns> = None;
705 let mut sigs: Vec<Record> = Vec::with_capacity(if is_additional { 1 } else { 0 });
706
707 let mut saw_sig0 = false;
709 let mut saw_tsig = false;
711 for _ in 0..count {
712 let record = Record::read(decoder)?;
713 if saw_tsig {
714 return Err("tsig must be final resource record".into());
715 } if !is_additional {
717 if saw_sig0 {
718 return Err("sig0 must be final resource record".into());
719 } records.push(record)
721 } else {
722 match record.record_type() {
723 #[cfg(feature = "__dnssec")]
724 RecordType::SIG => {
725 saw_sig0 = true;
726 sigs.push(record);
727 }
728 #[cfg(feature = "__dnssec")]
729 RecordType::TSIG => {
730 if saw_sig0 {
731 return Err("sig0 must be final resource record".into());
732 } saw_tsig = true;
734 sigs.push(record);
735 }
736 RecordType::OPT => {
737 if saw_sig0 {
738 return Err("sig0 must be final resource record".into());
739 } if edns.is_some() {
741 return Err("more than one edns record present".into());
742 }
743 edns = Some((&record).into());
744 }
745 _ => {
746 if saw_sig0 {
747 return Err("sig0 must be final resource record".into());
748 } records.push(record);
750 }
751 }
752 }
753 }
754
755 Ok((records, edns, sigs))
756 }
757
758 pub fn from_vec(buffer: &[u8]) -> ProtoResult<Self> {
760 let mut decoder = BinDecoder::new(buffer);
761 Self::read(&mut decoder)
762 }
763
764 pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
766 let mut buffer = Vec::with_capacity(512);
770 {
771 let mut encoder = BinEncoder::new(&mut buffer);
772 self.emit(&mut encoder)?;
773 }
774
775 Ok(buffer)
776 }
777
778 #[allow(clippy::match_single_binding)]
782 pub fn finalize(
783 &mut self,
784 finalizer: &dyn MessageFinalizer,
785 inception_time: u32,
786 ) -> ProtoResult<Option<MessageVerifier>> {
787 debug!("finalizing message: {:?}", self);
788 let (finals, verifier): (Vec<Record>, Option<MessageVerifier>) =
789 finalizer.finalize_message(self, inception_time)?;
790
791 for fin in finals {
793 match fin.record_type() {
794 #[cfg(feature = "__dnssec")]
796 RecordType::SIG => self.add_sig0(fin),
797 #[cfg(feature = "__dnssec")]
798 RecordType::TSIG => self.add_tsig(fin),
799 _ => self.add_additional(fin),
800 };
801 }
802
803 Ok(verifier)
804 }
805
806 pub fn into_parts(self) -> MessageParts {
808 self.into()
809 }
810}
811
812#[derive(Clone, Debug, PartialEq, Eq, Default)]
821pub struct MessageParts {
822 pub header: Header,
824 pub queries: Vec<Query>,
826 pub answers: Vec<Record>,
828 pub name_servers: Vec<Record>,
830 pub additionals: Vec<Record>,
832 pub sig0: Vec<Record>,
836 pub edns: Option<Edns>,
838}
839
840impl From<Message> for MessageParts {
841 fn from(msg: Message) -> Self {
842 let Message {
843 header,
844 queries,
845 answers,
846 name_servers,
847 additionals,
848 signature,
849 edns,
850 } = msg;
851 Self {
852 header,
853 queries,
854 answers,
855 name_servers,
856 additionals,
857 sig0: signature,
858 edns,
859 }
860 }
861}
862
863impl From<MessageParts> for Message {
864 fn from(msg: MessageParts) -> Self {
865 let MessageParts {
866 header,
867 queries,
868 answers,
869 name_servers,
870 additionals,
871 sig0,
872 edns,
873 } = msg;
874 Self {
875 header,
876 queries,
877 answers,
878 name_servers,
879 additionals,
880 signature: sig0,
881 edns,
882 }
883 }
884}
885
886impl Deref for Message {
887 type Target = Header;
888
889 fn deref(&self) -> &Self::Target {
890 &self.header
891 }
892}
893
894pub type MessageVerifier = Box<dyn FnMut(&[u8]) -> ProtoResult<DnsResponse> + Send>;
896
897pub trait MessageFinalizer: Send + Sync + 'static {
902 fn finalize_message(
914 &self,
915 message: &Message,
916 current_time: u32,
917 ) -> ProtoResult<(Vec<Record>, Option<MessageVerifier>)>;
918
919 fn should_finalize_message(&self, message: &Message) -> bool {
922 [OpCode::Update, OpCode::Notify].contains(&message.op_code())
923 || message
924 .queries()
925 .iter()
926 .any(|q| [RecordType::AXFR, RecordType::IXFR].contains(&q.query_type()))
927 }
928}
929
930pub fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(usize, bool)> {
932 result.map(|count| (count, false)).or_else(|e| {
933 if let ProtoErrorKind::NotAllRecordsWritten { count } = e.kind() {
934 return Ok((*count, true));
935 }
936
937 Err(e)
938 })
939}
940
941pub trait EmitAndCount {
943 fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize>;
945}
946
947impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
948 fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
949 encoder.emit_all(self)
950 }
951}
952
953#[allow(clippy::too_many_arguments)]
959pub fn emit_message_parts<Q, A, N, D>(
960 header: &Header,
961 queries: &mut Q,
962 answers: &mut A,
963 name_servers: &mut N,
964 additionals: &mut D,
965 edns: Option<&Edns>,
966 signature: &[Record],
967 encoder: &mut BinEncoder<'_>,
968) -> ProtoResult<Header>
969where
970 Q: EmitAndCount,
971 A: EmitAndCount,
972 N: EmitAndCount,
973 D: EmitAndCount,
974{
975 let include_signature: bool = encoder.mode() != EncodeMode::Signing;
976 let place = encoder.place::<Header>()?;
977
978 let query_count = queries.emit(encoder)?;
979 let answer_count = count_was_truncated(answers.emit(encoder))?;
982 let nameserver_count = count_was_truncated(name_servers.emit(encoder))?;
983 let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
984
985 if let Some(mut edns) = edns.cloned() {
986 edns.set_rcode_high(header.response_code().high());
988
989 let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(&edns))))?;
990 additional_count.0 += count.0;
991 additional_count.1 |= count.1;
992 } else if header.response_code().high() > 0 {
993 warn!(
994 "response code: {} for request: {} requires EDNS but none available",
995 header.response_code(),
996 header.id()
997 );
998 }
999
1000 if include_signature {
1004 let count = count_was_truncated(encoder.emit_all(signature.iter()))?;
1005 additional_count.0 += count.0;
1006 additional_count.1 |= count.1;
1007 }
1008
1009 let counts = HeaderCounts {
1010 query_count,
1011 answer_count: answer_count.0,
1012 nameserver_count: nameserver_count.0,
1013 additional_count: additional_count.0,
1014 };
1015 let was_truncated =
1016 header.truncated() || answer_count.1 || nameserver_count.1 || additional_count.1;
1017
1018 let final_header = update_header_counts(header, was_truncated, counts);
1019 place.replace(encoder, final_header)?;
1020 Ok(final_header)
1021}
1022
1023impl BinEncodable for Message {
1024 fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
1025 emit_message_parts(
1026 &self.header,
1027 &mut self.queries.iter(),
1028 &mut self.answers.iter(),
1029 &mut self.name_servers.iter(),
1030 &mut self.additionals.iter(),
1031 self.edns.as_ref(),
1032 &self.signature,
1033 encoder,
1034 )?;
1035
1036 Ok(())
1037 }
1038}
1039
1040impl<'r> BinDecodable<'r> for Message {
1041 fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
1042 let mut header = Header::read(decoder)?;
1043
1044 let count = header.query_count() as usize;
1049 let mut queries = Vec::with_capacity(count);
1050 for _ in 0..count {
1051 queries.push(Query::read(decoder)?);
1052 }
1053
1054 let answer_count = header.answer_count() as usize;
1056 let name_server_count = header.name_server_count() as usize;
1057 let additional_count = header.additional_count() as usize;
1058
1059 let (answers, _, _) = Self::read_records(decoder, answer_count, false)?;
1060 let (name_servers, _, _) = Self::read_records(decoder, name_server_count, false)?;
1061 let (additionals, edns, signature) = Self::read_records(decoder, additional_count, true)?;
1062
1063 if let Some(edns) = &edns {
1065 let high_response_code = edns.rcode_high();
1066 header.merge_response_code(high_response_code);
1067 }
1068
1069 Ok(Self {
1070 header,
1071 queries,
1072 answers,
1073 name_servers,
1074 additionals,
1075 signature,
1076 edns,
1077 })
1078 }
1079}
1080
1081impl fmt::Display for Message {
1082 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1083 let write_query = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1084 for d in slice {
1085 writeln!(f, ";; {d}")?;
1086 }
1087
1088 Ok(())
1089 };
1090
1091 let write_slice = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1092 for d in slice {
1093 writeln!(f, "{d}")?;
1094 }
1095
1096 Ok(())
1097 };
1098
1099 writeln!(f, "; header {header}", header = self.header())?;
1100
1101 if let Some(edns) = self.extensions() {
1102 writeln!(f, "; edns {edns}")?;
1103 }
1104
1105 writeln!(f, "; query")?;
1106 write_query(self.queries(), f)?;
1107
1108 if self.header().message_type() == MessageType::Response
1109 || self.header().op_code() == OpCode::Update
1110 {
1111 writeln!(f, "; answers {}", self.answer_count())?;
1112 write_slice(self.answers(), f)?;
1113 writeln!(f, "; nameservers {}", self.name_server_count())?;
1114 write_slice(self.name_servers(), f)?;
1115 writeln!(f, "; additionals {}", self.additional_count())?;
1116 write_slice(self.additionals(), f)?;
1117 }
1118
1119 Ok(())
1120 }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125 use super::*;
1126
1127 #[test]
1128 fn test_emit_and_read_header() {
1129 let mut message = Message::new();
1130 message
1131 .set_id(10)
1132 .set_message_type(MessageType::Response)
1133 .set_op_code(OpCode::Update)
1134 .set_authoritative(true)
1135 .set_truncated(false)
1136 .set_recursion_desired(true)
1137 .set_recursion_available(true)
1138 .set_response_code(ResponseCode::ServFail);
1139
1140 test_emit_and_read(message);
1141 }
1142
1143 #[test]
1144 fn test_emit_and_read_query() {
1145 let mut message = Message::new();
1146 message
1147 .set_id(10)
1148 .set_message_type(MessageType::Response)
1149 .set_op_code(OpCode::Update)
1150 .set_authoritative(true)
1151 .set_truncated(true)
1152 .set_recursion_desired(true)
1153 .set_recursion_available(true)
1154 .set_response_code(ResponseCode::ServFail)
1155 .add_query(Query::new())
1156 .update_counts(); test_emit_and_read(message);
1159 }
1160
1161 #[test]
1162 fn test_emit_and_read_records() {
1163 let mut message = Message::new();
1164 message
1165 .set_id(10)
1166 .set_message_type(MessageType::Response)
1167 .set_op_code(OpCode::Update)
1168 .set_authoritative(true)
1169 .set_truncated(true)
1170 .set_recursion_desired(true)
1171 .set_recursion_available(true)
1172 .set_authentic_data(true)
1173 .set_checking_disabled(true)
1174 .set_response_code(ResponseCode::ServFail);
1175
1176 message.add_answer(Record::stub());
1177 message.add_name_server(Record::stub());
1178 message.add_additional(Record::stub());
1179 message.update_counts();
1180
1181 test_emit_and_read(message);
1182 }
1183
1184 #[cfg(test)]
1185 fn test_emit_and_read(message: Message) {
1186 let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1187 {
1188 let mut encoder = BinEncoder::new(&mut byte_vec);
1189 message.emit(&mut encoder).unwrap();
1190 }
1191
1192 let mut decoder = BinDecoder::new(&byte_vec);
1193 let got = Message::read(&mut decoder).unwrap();
1194
1195 assert_eq!(got, message);
1196 }
1197
1198 #[test]
1199 fn test_header_counts_correction_after_emit_read() {
1200 let mut message = Message::new();
1201
1202 message
1203 .set_id(10)
1204 .set_message_type(MessageType::Response)
1205 .set_op_code(OpCode::Update)
1206 .set_authoritative(true)
1207 .set_truncated(true)
1208 .set_recursion_desired(true)
1209 .set_recursion_available(true)
1210 .set_authentic_data(true)
1211 .set_checking_disabled(true)
1212 .set_response_code(ResponseCode::ServFail);
1213
1214 message.add_answer(Record::stub());
1215 message.add_name_server(Record::stub());
1216 message.add_additional(Record::stub());
1217
1218 message.set_query_count(1);
1222 message.set_answer_count(5);
1223 message.set_name_server_count(5);
1224 let got = get_message_after_emitting_and_reading(message);
1227
1228 assert_eq!(got.query_count(), 0);
1230 assert_eq!(got.answer_count(), 1);
1231 assert_eq!(got.name_server_count(), 1);
1232 assert_eq!(got.additional_count(), 1);
1233 }
1234
1235 #[cfg(test)]
1236 fn get_message_after_emitting_and_reading(message: Message) -> Message {
1237 let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1238 {
1239 let mut encoder = BinEncoder::new(&mut byte_vec);
1240 message.emit(&mut encoder).unwrap();
1241 }
1242
1243 let mut decoder = BinDecoder::new(&byte_vec);
1244
1245 Message::read(&mut decoder).unwrap()
1246 }
1247
1248 #[test]
1249 fn test_legit_message() {
1250 #[rustfmt::skip]
1251 let buf: Vec<u8> = vec![
1252 0x10, 0x00, 0x81,
1253 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x03, b'w', b'w', b'w', 0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 0x03, b'c', b'o', b'm', 0x00, 0x00, 0x01, 0x00, 0x01, 0xC0, 0x0C, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x04, 0x5D, 0xB8, 0xD7, 0x0E, ];
1268
1269 let mut decoder = BinDecoder::new(&buf);
1270 let message = Message::read(&mut decoder).unwrap();
1271
1272 assert_eq!(message.id(), 4_096);
1273
1274 let mut buf: Vec<u8> = Vec::with_capacity(512);
1275 {
1276 let mut encoder = BinEncoder::new(&mut buf);
1277 message.emit(&mut encoder).unwrap();
1278 }
1279
1280 let mut decoder = BinDecoder::new(&buf);
1281 let message = Message::read(&mut decoder).unwrap();
1282
1283 assert_eq!(message.id(), 4_096);
1284 }
1285
1286 #[test]
1287 fn rdata_zero_roundtrip() {
1288 let buf = &[
1289 160, 160, 0, 13, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1290 ];
1291
1292 assert!(Message::from_bytes(buf).is_err());
1293 }
1294
1295 #[test]
1296 fn nsec_deserialization() {
1297 const CRASHING_MESSAGE: &[u8] = &[
1298 0, 0, 132, 0, 0, 0, 0, 1, 0, 0, 0, 1, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100,
1299 52, 50, 52, 45, 52, 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55,
1300 56, 48, 102, 50, 98, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4,
1301 192, 168, 1, 17, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100, 52, 50, 52, 45, 52,
1302 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55, 56, 48, 102, 50, 98,
1303 5, 108, 111, 99, 97, 108, 0, 0, 47, 128, 1, 0, 0, 0, 120, 0, 5, 192, 70, 0, 1, 64,
1304 ];
1305
1306 Message::from_vec(CRASHING_MESSAGE).expect("failed to parse message");
1307 }
1308}