use std::iter;
use std::mem;
use std::ops::Deref;
use std::sync::Arc;
use log::debug;
use super::{Edns, Header, MessageType, OpCode, Query, ResponseCode};
use crate::error::*;
use crate::rr::{Record, RecordType};
use crate::serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, EncodeMode};
#[cfg(feature = "dnssec")]
use crate::rr::dnssec::rdata::DNSSECRecordType;
#[derive(Clone, Debug, PartialEq, Default)]
pub struct Message {
header: Header,
queries: Vec<Query>,
answers: Vec<Record>,
name_servers: Vec<Record>,
additionals: Vec<Record>,
sig0: Vec<Record>,
edns: Option<Edns>,
}
pub fn update_header_counts(
current_header: &Header,
is_truncated: bool,
counts: HeaderCounts,
) -> Header {
assert!(counts.query_count <= u16::max_value() as usize);
assert!(counts.answer_count <= u16::max_value() as usize);
assert!(counts.nameserver_count <= u16::max_value() as usize);
assert!(counts.additional_count <= u16::max_value() as usize);
let mut header = current_header.clone();
header.set_query_count(counts.query_count as u16);
header.set_answer_count(counts.answer_count as u16);
header.set_name_server_count(counts.nameserver_count as u16);
header.set_additional_count(counts.additional_count as u16);
header.set_truncated(is_truncated);
header
}
pub struct HeaderCounts {
pub query_count: usize,
pub answer_count: usize,
pub nameserver_count: usize,
pub additional_count: usize,
}
impl Message {
pub fn new() -> Self {
Message {
header: Header::new(),
queries: Vec::new(),
answers: Vec::new(),
name_servers: Vec::new(),
additionals: Vec::new(),
sig0: Vec::new(),
edns: None,
}
}
pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Message {
let mut message: Message = Message::new();
message.set_message_type(MessageType::Response);
message.set_id(id);
message.set_response_code(response_code);
message.set_op_code(op_code);
message
}
pub fn truncate(&self) -> Self {
let mut truncated: Message = Message::new();
truncated.set_id(self.id());
truncated.set_message_type(self.message_type());
truncated.set_op_code(self.op_code());
truncated.set_authoritative(self.authoritative());
truncated.set_truncated(true);
truncated.set_recursion_desired(self.recursion_desired());
truncated.set_recursion_available(self.recursion_available());
truncated.set_response_code(self.response_code());
if self.edns().is_some() {
truncated.set_edns(self.edns().unwrap().clone());
}
truncated
}
pub fn set_id(&mut self, id: u16) -> &mut Self {
self.header.set_id(id);
self
}
pub fn set_message_type(&mut self, message_type: MessageType) -> &mut Self {
self.header.set_message_type(message_type);
self
}
pub fn set_op_code(&mut self, op_code: OpCode) -> &mut Self {
self.header.set_op_code(op_code);
self
}
pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
self.header.set_authoritative(authoritative);
self
}
pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
self.header.set_truncated(truncated);
self
}
pub fn set_recursion_desired(&mut self, recursion_desired: bool) -> &mut Self {
self.header.set_recursion_desired(recursion_desired);
self
}
pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
self.header.set_recursion_available(recursion_available);
self
}
pub fn set_authentic_data(&mut self, authentic_data: bool) -> &mut Self {
self.header.set_authentic_data(authentic_data);
self
}
pub fn set_checking_disabled(&mut self, checking_disabled: bool) -> &mut Self {
self.header.set_checking_disabled(checking_disabled);
self
}
pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
self.header.set_response_code(response_code);
self
}
pub fn add_query(&mut self, query: Query) -> &mut Self {
self.queries.push(query);
self
}
pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
where
Q: IntoIterator<Item = Query, IntoIter = I>,
I: Iterator<Item = Query>,
{
for query in queries {
self.add_query(query);
}
self
}
pub fn add_answer(&mut self, record: Record) -> &mut Self {
self.answers.push(record);
self
}
pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
where
R: IntoIterator<Item = Record, IntoIter = I>,
I: Iterator<Item = Record>,
{
for record in records {
self.add_answer(record);
}
self
}
pub fn insert_answers(&mut self, records: Vec<Record>) {
assert!(self.answers.is_empty());
self.answers = records;
}
pub fn add_name_server(&mut self, record: Record) -> &mut Self {
self.name_servers.push(record);
self
}
pub fn add_name_servers<R, I>(&mut self, records: R) -> &mut Self
where
R: IntoIterator<Item = Record, IntoIter = I>,
I: Iterator<Item = Record>,
{
for record in records {
self.add_name_server(record);
}
self
}
pub fn insert_name_servers(&mut self, records: Vec<Record>) {
assert!(self.name_servers.is_empty());
self.name_servers = records;
}
pub fn add_additional(&mut self, record: Record) -> &mut Self {
self.additionals.push(record);
self
}
pub fn insert_additionals(&mut self, records: Vec<Record>) {
assert!(self.additionals.is_empty());
self.additionals = records;
}
pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
self.edns = Some(edns);
self
}
#[cfg(feature = "dnssec")]
pub fn add_sig0(&mut self, record: Record) -> &mut Self {
assert_eq!(RecordType::DNSSEC(DNSSECRecordType::SIG), record.rr_type());
self.sig0.push(record);
self
}
pub fn header(&self) -> &Header {
&self.header
}
pub fn id(&self) -> u16 {
self.header.id()
}
pub fn message_type(&self) -> MessageType {
self.header.message_type()
}
pub fn op_code(&self) -> OpCode {
self.header.op_code()
}
pub fn authoritative(&self) -> bool {
self.header.authoritative()
}
pub fn truncated(&self) -> bool {
self.header.truncated()
}
pub fn recursion_desired(&self) -> bool {
self.header.recursion_desired()
}
pub fn recursion_available(&self) -> bool {
self.header.recursion_available()
}
pub fn authentic_data(&self) -> bool {
self.header.authentic_data()
}
pub fn checking_disabled(&self) -> bool {
self.header.checking_disabled()
}
pub fn response_code(&self) -> ResponseCode {
ResponseCode::from(
self.edns.as_ref().map_or(0, Edns::rcode_high),
self.header.response_code(),
)
}
pub fn queries(&self) -> &[Query] {
&self.queries
}
pub fn answers(&self) -> &[Record] {
&self.answers
}
pub fn take_answers(&mut self) -> Vec<Record> {
mem::replace(&mut self.answers, vec![])
}
pub fn name_servers(&self) -> &[Record] {
&self.name_servers
}
pub fn take_name_servers(&mut self) -> Vec<Record> {
mem::replace(&mut self.name_servers, vec![])
}
pub fn additionals(&self) -> &[Record] {
&self.additionals
}
pub fn take_additionals(&mut self) -> Vec<Record> {
mem::replace(&mut self.additionals, vec![])
}
pub fn edns(&self) -> Option<&Edns> {
self.edns.as_ref()
}
pub fn edns_mut(&mut self) -> &mut Edns {
if self.edns.is_none() {
self.edns = Some(Edns::new());
}
self.edns.as_mut().unwrap()
}
pub fn max_payload(&self) -> u16 {
let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
if max_size < 512 {
512
} else {
max_size
}
}
pub fn version(&self) -> u8 {
self.edns.as_ref().map_or(0, Edns::version)
}
pub fn sig0(&self) -> &[Record] {
&self.sig0
}
#[cfg(test)]
pub fn update_counts(&mut self) -> &mut Self {
self.header = update_header_counts(
&self.header,
false,
HeaderCounts {
query_count: self.queries.len(),
answer_count: self.answers.len(),
nameserver_count: self.name_servers.len(),
additional_count: self.additionals.len(),
},
);
self
}
pub fn read_queries(decoder: &mut BinDecoder, count: usize) -> ProtoResult<Vec<Query>> {
let mut queries = Vec::with_capacity(count);
for _ in 0..count {
queries.push(Query::read(decoder)?);
}
Ok(queries)
}
#[cfg_attr(not(feature = "dnssec"), allow(unused_mut))]
pub fn read_records(
decoder: &mut BinDecoder,
count: usize,
is_additional: bool,
) -> ProtoResult<(Vec<Record>, Option<Edns>, Vec<Record>)> {
let mut records: Vec<Record> = Vec::with_capacity(count);
let mut edns: Option<Edns> = None;
let mut sig0s: Vec<Record> = Vec::with_capacity(if is_additional { 1 } else { 0 });
let mut saw_sig0 = false;
for _ in 0..count {
let record = Record::read(decoder)?;
if !is_additional {
if saw_sig0 {
return Err("sig0 must be final resource record".into());
}
records.push(record)
} else {
match record.rr_type() {
#[cfg(feature = "dnssec")]
RecordType::DNSSEC(DNSSECRecordType::SIG) => {
saw_sig0 = true;
sig0s.push(record);
}
RecordType::OPT => {
if saw_sig0 {
return Err("sig0 must be final resource record".into());
}
if edns.is_some() {
return Err("more than one edns record present".into());
}
edns = Some((&record).into());
}
_ => {
if saw_sig0 {
return Err("sig0 must be final resource record".into());
}
records.push(record);
}
}
}
}
Ok((records, edns, sig0s))
}
pub fn from_vec(buffer: &[u8]) -> ProtoResult<Message> {
let mut decoder = BinDecoder::new(buffer);
Message::read(&mut decoder)
}
pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
let mut buffer = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut buffer);
self.emit(&mut encoder)?;
}
Ok(buffer)
}
#[allow(clippy::match_single_binding)]
pub fn finalize<MF: MessageFinalizer>(
&mut self,
finalizer: &MF,
inception_time: u32,
) -> ProtoResult<()> {
debug!("finalizing message: {:?}", self);
let finals: Vec<Record> = finalizer.finalize_message(self, inception_time)?;
for fin in finals {
match fin.rr_type() {
#[cfg(feature = "dnssec")]
RecordType::DNSSEC(DNSSECRecordType::SIG) => self.add_sig0(fin),
_ => self.add_additional(fin),
};
}
Ok(())
}
}
impl Deref for Message {
type Target = Header;
fn deref(&self) -> &Self::Target {
&self.header
}
}
pub trait MessageFinalizer: Send + Sync + 'static {
fn finalize_message(&self, message: &Message, current_time: u32) -> ProtoResult<Vec<Record>>;
}
pub struct NoopMessageFinalizer;
impl NoopMessageFinalizer {
pub fn new() -> Option<Arc<Self>> {
None
}
}
impl MessageFinalizer for NoopMessageFinalizer {
fn finalize_message(&self, _: &Message, _: u32) -> ProtoResult<Vec<Record>> {
panic!("Misused NoopMessageFinalizer, None should be used instead")
}
}
pub fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(usize, bool)> {
result.map(|count| (count, false)).or_else(|e| {
if let ProtoErrorKind::NotAllRecordsWritten { count } = e.kind() {
return Ok((*count, true));
}
Err(e)
})
}
pub trait EmitAndCount {
fn emit(&mut self, encoder: &mut BinEncoder) -> ProtoResult<usize>;
}
impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
fn emit(&mut self, encoder: &mut BinEncoder) -> ProtoResult<usize> {
encoder.emit_all(self)
}
}
#[allow(clippy::too_many_arguments)]
pub fn emit_message_parts<Q, A, N, D>(
header: &Header,
queries: &mut Q,
answers: &mut A,
name_servers: &mut N,
additionals: &mut D,
edns: Option<&Edns>,
sig0: &[Record],
encoder: &mut BinEncoder,
) -> ProtoResult<()>
where
Q: EmitAndCount,
A: EmitAndCount,
N: EmitAndCount,
D: EmitAndCount,
{
let include_sig0: bool = encoder.mode() != EncodeMode::Signing;
let place = encoder.place::<Header>()?;
let query_count = queries.emit(encoder)?;
let answer_count = count_was_truncated(answers.emit(encoder))?;
let nameserver_count = count_was_truncated(name_servers.emit(encoder))?;
let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
if let Some(edns) = edns {
let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(edns))))?;
additional_count.0 += count.0;
additional_count.1 |= count.1;
}
if include_sig0 {
let count = count_was_truncated(encoder.emit_all(sig0.iter()))?;
additional_count.0 += count.0;
additional_count.1 |= count.1;
}
let counts = HeaderCounts {
query_count,
answer_count: answer_count.0,
nameserver_count: nameserver_count.0,
additional_count: additional_count.0,
};
let was_truncated = answer_count.1 || nameserver_count.1 || additional_count.1;
place.replace(encoder, update_header_counts(header, was_truncated, counts))?;
Ok(())
}
impl BinEncodable for Message {
fn emit(&self, encoder: &mut BinEncoder) -> ProtoResult<()> {
emit_message_parts(
&self.header,
&mut self.queries.iter(),
&mut self.answers.iter(),
&mut self.name_servers.iter(),
&mut self.additionals.iter(),
self.edns.as_ref(),
&self.sig0,
encoder,
)
}
}
impl<'r> BinDecodable<'r> for Message {
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
let header = Header::read(decoder)?;
let count = header.query_count() as usize;
let mut queries = Vec::with_capacity(count);
for _ in 0..count {
queries.push(Query::read(decoder)?);
}
let answer_count = header.answer_count() as usize;
let name_server_count = header.name_server_count() as usize;
let additional_count = header.additional_count() as usize;
let (answers, _, _) = Self::read_records(decoder, answer_count, false)?;
let (name_servers, _, _) = Self::read_records(decoder, name_server_count, false)?;
let (additionals, edns, sig0) = Self::read_records(decoder, additional_count, true)?;
Ok(Message {
header,
queries,
answers,
name_servers,
additionals,
sig0,
edns,
})
}
}
#[test]
fn test_emit_and_read_header() {
let mut message = Message::new();
message
.set_id(10)
.set_message_type(MessageType::Response)
.set_op_code(OpCode::Update)
.set_authoritative(true)
.set_truncated(false)
.set_recursion_desired(true)
.set_recursion_available(true)
.set_response_code(ResponseCode::ServFail);
test_emit_and_read(message);
}
#[test]
fn test_emit_and_read_query() {
let mut message = Message::new();
message
.set_id(10)
.set_message_type(MessageType::Response)
.set_op_code(OpCode::Update)
.set_authoritative(true)
.set_truncated(true)
.set_recursion_desired(true)
.set_recursion_available(true)
.set_response_code(ResponseCode::ServFail)
.add_query(Query::new())
.update_counts();
test_emit_and_read(message);
}
#[test]
fn test_emit_and_read_records() {
let mut message = Message::new();
message
.set_id(10)
.set_message_type(MessageType::Response)
.set_op_code(OpCode::Update)
.set_authoritative(true)
.set_truncated(true)
.set_recursion_desired(true)
.set_recursion_available(true)
.set_authentic_data(true)
.set_checking_disabled(true)
.set_response_code(ResponseCode::ServFail);
message.add_answer(Record::new());
message.add_name_server(Record::new());
message.add_additional(Record::new());
message.update_counts();
test_emit_and_read(message);
}
#[cfg(test)]
fn test_emit_and_read(message: Message) {
let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut byte_vec);
message.emit(&mut encoder).unwrap();
}
let mut decoder = BinDecoder::new(&byte_vec);
let got = Message::read(&mut decoder).unwrap();
assert_eq!(got, message);
}
#[test]
#[rustfmt::skip]
fn test_legit_message() {
let buf: Vec<u8> = vec![
0x10,0x00,0x81,0x80,
0x00,0x01,0x00,0x01,
0x00,0x00,0x00,0x00,
0x03,b'w',b'w',b'w',
0x07,b'e',b'x',b'a',
b'm',b'p',b'l',b'e',
0x03,b'c',b'o',b'm',
0x00,
0x00,0x01,0x00,0x01,
0xC0,0x0C,
0x00,0x01,0x00,0x01,
0x00,0x00,0x00,0x02,
0x00,0x04,
0x5D,0xB8,0xD8,0x22,
];
let mut decoder = BinDecoder::new(&buf);
let message = Message::read(&mut decoder).unwrap();
assert_eq!(message.id(), 4096);
let mut buf: Vec<u8> = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut buf);
message.emit(&mut encoder).unwrap();
}
let mut decoder = BinDecoder::new(&buf);
let message = Message::read(&mut decoder).unwrap();
assert_eq!(message.id(), 4096);
}