use std::{
collections::HashMap,
io::{Cursor, Seek, Write},
};
use crate::{header_buffer, rdata::OPT, RCODE};
use super::{Header, PacketFlag, Question, ResourceRecord, WireFormat, OPCODE};
#[derive(Debug, Clone)]
pub struct Packet<'a> {
header: Header<'a>,
pub questions: Vec<Question<'a>>,
pub answers: Vec<ResourceRecord<'a>>,
pub name_servers: Vec<ResourceRecord<'a>>,
pub additional_records: Vec<ResourceRecord<'a>>,
}
impl<'a> Packet<'a> {
pub fn new_query(id: u16) -> Self {
Self {
header: Header::new_query(id),
questions: Vec::new(),
answers: Vec::new(),
name_servers: Vec::new(),
additional_records: Vec::new(),
}
}
pub fn new_reply(id: u16) -> Self {
Self {
header: Header::new_reply(id, OPCODE::StandardQuery),
questions: Vec::new(),
answers: Vec::new(),
name_servers: Vec::new(),
additional_records: Vec::new(),
}
}
pub fn id(&self) -> u16 {
self.header.id
}
pub fn set_id(&mut self, id: u16) {
self.header.id = id;
}
pub fn set_flags(&mut self, flags: PacketFlag) {
self.header.set_flags(flags);
}
pub fn remove_flags(&mut self, flags: PacketFlag) {
self.header.remove_flags(flags)
}
pub fn has_flags(&self, flags: PacketFlag) -> bool {
self.header.has_flags(flags)
}
pub fn rcode(&self) -> RCODE {
self.header.response_code
}
pub fn rcode_mut(&mut self) -> &mut RCODE {
&mut self.header.response_code
}
pub fn opcode(&self) -> OPCODE {
self.header.opcode
}
pub fn opcode_mut(&mut self) -> &mut OPCODE {
&mut self.header.opcode
}
pub fn opt(&self) -> Option<&OPT<'a>> {
self.header.opt.as_ref()
}
pub fn opt_mut(&mut self) -> &mut Option<OPT<'a>> {
&mut self.header.opt
}
pub fn into_reply(mut self) -> Self {
self.header = Header::new_reply(self.header.id, self.header.opcode);
self
}
pub fn parse(data: &'a [u8]) -> crate::Result<Self> {
let mut header = Header::parse(data)?;
let mut offset = 12;
let questions = Self::parse_section(data, &mut offset, header_buffer::questions(data)?)?;
let answers = Self::parse_section(data, &mut offset, header_buffer::answers(data)?)?;
let name_servers =
Self::parse_section(data, &mut offset, header_buffer::name_servers(data)?)?;
let mut additional_records: Vec<ResourceRecord> =
Self::parse_section(data, &mut offset, header_buffer::additional_records(data)?)?;
header.extract_info_from_opt_rr(
additional_records
.iter()
.position(|rr| rr.rdata.type_code() == crate::TYPE::OPT)
.map(|i| additional_records.remove(i)),
);
Ok(Self {
header,
questions,
answers,
name_servers,
additional_records,
})
}
fn parse_section<T: WireFormat<'a>>(
data: &'a [u8],
offset: &mut usize,
items_count: u16,
) -> crate::Result<Vec<T>> {
let mut section_items = Vec::with_capacity(items_count as usize);
for _ in 0..items_count {
section_items.push(T::parse(data, offset)?);
}
Ok(section_items)
}
pub fn build_bytes_vec(&self) -> crate::Result<Vec<u8>> {
let mut out = Cursor::new(Vec::with_capacity(900));
self.write_to(&mut out)?;
Ok(out.into_inner())
}
pub fn build_bytes_vec_compressed(&self) -> crate::Result<Vec<u8>> {
let mut out = Cursor::new(Vec::with_capacity(900));
self.write_compressed_to(&mut out)?;
Ok(out.into_inner())
}
pub fn write_to<T: Write>(&self, out: &mut T) -> crate::Result<()> {
self.write_header(out)?;
for e in &self.questions {
e.write_to(out)?;
}
for e in &self.answers {
e.write_to(out)?;
}
for e in &self.name_servers {
e.write_to(out)?;
}
if let Some(rr) = self.header.opt_rr() {
rr.write_to(out)?;
}
for e in &self.additional_records {
e.write_to(out)?;
}
out.flush()?;
Ok(())
}
pub fn write_compressed_to<T: Write + Seek>(&self, out: &mut T) -> crate::Result<()> {
self.write_header(out)?;
let mut name_refs = HashMap::new();
for e in &self.questions {
e.write_compressed_to(out, &mut name_refs)?;
}
for e in &self.answers {
e.write_compressed_to(out, &mut name_refs)?;
}
for e in &self.name_servers {
e.write_compressed_to(out, &mut name_refs)?;
}
if let Some(rr) = self.header.opt_rr() {
rr.write_to(out)?;
}
for e in &self.additional_records {
e.write_compressed_to(out, &mut name_refs)?;
}
out.flush()?;
Ok(())
}
fn write_header<T: Write>(&self, out: &mut T) -> crate::Result<()> {
self.header.write_to(
out,
self.questions.len() as u16,
self.answers.len() as u16,
self.name_servers.len() as u16,
self.additional_records.len() as u16 + u16::from(self.header.opt.is_some()),
)
}
}
#[cfg(test)]
mod tests {
use crate::{dns::CLASS, dns::TYPE, SimpleDnsError};
use super::*;
use std::convert::TryInto;
#[test]
fn parse_without_data_should_not_panic() {
assert!(matches!(
Packet::parse(&[]),
Err(SimpleDnsError::InsufficientData)
));
}
#[test]
fn build_query_correct() {
let mut query = Packet::new_query(1);
query.questions.push(Question::new(
"_srv._udp.local".try_into().unwrap(),
TYPE::TXT.into(),
CLASS::IN.into(),
false,
));
query.questions.push(Question::new(
"_srv2._udp.local".try_into().unwrap(),
TYPE::TXT.into(),
CLASS::IN.into(),
false,
));
let query = query.build_bytes_vec().unwrap();
let parsed = Packet::parse(&query);
assert!(parsed.is_ok());
let parsed = parsed.unwrap();
assert_eq!(2, parsed.questions.len());
assert_eq!("_srv._udp.local", parsed.questions[0].qname.to_string());
assert_eq!("_srv2._udp.local", parsed.questions[1].qname.to_string());
}
}