1use core::fmt;
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15#[cfg(feature = "__dnssec")]
16use crate::dnssec::{Algorithm, SupportedAlgorithms};
17use crate::{
18 error::*,
19 rr::{
20 DNSClass, Name, RData, Record, RecordType,
21 rdata::{
22 OPT,
23 opt::{EdnsCode, EdnsOption},
24 },
25 },
26 serialize::binary::{BinEncodable, BinEncoder},
27};
28
29#[derive(Debug, PartialEq, Eq, Clone)]
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33pub struct Edns {
34 rcode_high: u8,
37 version: u8,
39 flags: EdnsFlags,
40 max_payload: u16,
42
43 options: OPT,
44}
45
46impl Default for Edns {
47 fn default() -> Self {
48 Self {
49 rcode_high: 0,
50 version: 0,
51 flags: EdnsFlags::default(),
52 max_payload: 512,
53 options: OPT::default(),
54 }
55 }
56}
57
58impl Edns {
59 pub fn new() -> Self {
61 Self::default()
62 }
63
64 pub fn rcode_high(&self) -> u8 {
66 self.rcode_high
67 }
68
69 pub fn version(&self) -> u8 {
71 self.version
72 }
73
74 pub fn flags(&self) -> &EdnsFlags {
76 &self.flags
77 }
78
79 pub fn flags_mut(&mut self) -> &mut EdnsFlags {
81 &mut self.flags
82 }
83
84 pub fn max_payload(&self) -> u16 {
86 self.max_payload
87 }
88
89 pub fn option(&self, code: EdnsCode) -> Option<&EdnsOption> {
91 self.options.get(code)
92 }
93
94 pub fn options(&self) -> &OPT {
96 &self.options
97 }
98
99 pub fn options_mut(&mut self) -> &mut OPT {
101 &mut self.options
102 }
103
104 pub fn set_rcode_high(&mut self, rcode_high: u8) -> &mut Self {
106 self.rcode_high = rcode_high;
107 self
108 }
109
110 pub fn set_version(&mut self, version: u8) -> &mut Self {
112 self.version = version;
113 self
114 }
115
116 #[cfg(feature = "__dnssec")]
118 pub fn enable_dnssec(&mut self) {
119 self.set_dnssec_ok(true);
120 self.set_default_algorithms();
121 }
122
123 #[cfg(feature = "__dnssec")]
125 pub fn set_default_algorithms(&mut self) -> &mut Self {
126 let mut algorithms = SupportedAlgorithms::new();
127
128 for algorithm in [
129 Algorithm::RSASHA256,
130 Algorithm::RSASHA512,
131 Algorithm::ECDSAP256SHA256,
132 Algorithm::ECDSAP384SHA384,
133 Algorithm::ED25519,
134 ] {
135 if algorithm.is_supported() {
136 algorithms.set(algorithm);
137 }
138 }
139
140 let dau = EdnsOption::DAU(algorithms);
141
142 self.options_mut().insert(dau);
143 self
144 }
145
146 pub fn set_dnssec_ok(&mut self, dnssec_ok: bool) -> &mut Self {
148 self.flags.dnssec_ok = dnssec_ok;
149 self
150 }
151
152 pub fn set_max_payload(&mut self, max_payload: u16) -> &mut Self {
155 self.max_payload = max_payload.max(512);
156 self
157 }
158
159 #[deprecated(note = "Please use options_mut().insert() to modify")]
161 pub fn set_option(&mut self, option: EdnsOption) {
162 self.options.insert(option);
163 }
164}
165
166impl<'a> From<&'a Record> for Edns {
168 fn from(value: &'a Record) -> Self {
169 assert!(value.record_type() == RecordType::OPT);
170
171 let rcode_high = ((value.ttl() & 0xFF00_0000u32) >> 24) as u8;
172 let version = ((value.ttl() & 0x00FF_0000u32) >> 16) as u8;
173 let flags = EdnsFlags::from((value.ttl() & 0x0000_FFFFu32) as u16);
174 let max_payload = u16::from(value.dns_class());
175
176 let options = match value.data() {
177 RData::Update0(..) | RData::NULL(..) => {
178 OPT::default()
180 }
181 RData::OPT(option_data) => {
182 option_data.clone() }
184 _ => {
185 panic!("rr_type doesn't match the RData: {:?}", value.data()) }
188 };
189
190 Self {
191 rcode_high,
192 version,
193 flags,
194 max_payload,
195 options,
196 }
197 }
198}
199
200impl<'a> From<&'a Edns> for Record {
201 fn from(value: &'a Edns) -> Self {
204 let mut ttl: u32 = u32::from(value.rcode_high()) << 24;
206 ttl |= u32::from(value.version()) << 16;
207 ttl |= u32::from(u16::from(value.flags));
208
209 let mut record = Self::from_rdata(Name::root(), ttl, RData::OPT(value.options().clone()));
214
215 record.set_dns_class(DNSClass::for_opt(value.max_payload()));
216
217 record
218 }
219}
220
221impl BinEncodable for Edns {
222 fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
223 encoder.emit(0)?; RecordType::OPT.emit(encoder)?; DNSClass::for_opt(self.max_payload()).emit(encoder)?; let mut ttl = u32::from(self.rcode_high()) << 24;
229 ttl |= u32::from(self.version()) << 16;
230 ttl |= u32::from(u16::from(self.flags));
231
232 encoder.emit_u32(ttl)?;
233
234 let place = encoder.place::<u16>()?;
236 self.options.emit(encoder)?;
237 let len = encoder.len_since_place(&place);
238 assert!(len <= u16::MAX as usize);
239
240 place.replace(encoder, len as u16)?;
241 Ok(())
242 }
243}
244
245impl fmt::Display for Edns {
246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
247 let version = self.version;
248 let dnssec_ok = self.flags.dnssec_ok;
249 let z_flags = self.flags.z;
250 let max_payload = self.max_payload;
251
252 write!(
253 f,
254 "version: {version} dnssec_ok: {dnssec_ok} z_flags: {z_flags} max_payload: {max_payload} opts: {opts_len}",
255 opts_len = self.options().as_ref().len()
256 )
257 }
258}
259
260#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
264#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
265pub struct EdnsFlags {
266 pub dnssec_ok: bool,
268 pub z: u16,
275}
276
277impl From<u16> for EdnsFlags {
278 fn from(flags: u16) -> Self {
279 Self {
280 dnssec_ok: flags & 0x8000 == 0x8000,
281 z: flags & 0x7FFF,
282 }
283 }
284}
285
286impl From<EdnsFlags> for u16 {
287 fn from(flags: EdnsFlags) -> Self {
288 match flags.dnssec_ok {
289 true => 0x8000 | flags.z,
290 false => 0x7FFF & flags.z,
291 }
292 }
293}
294
295#[cfg(all(test, feature = "__dnssec"))]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_encode_decode() {
301 let mut edns = Edns::new();
302
303 let flags = edns.flags_mut();
304 flags.dnssec_ok = true;
305 flags.z = 1;
306 edns.set_max_payload(0x8008);
307 edns.set_version(0x40);
308 edns.set_rcode_high(0x01);
309 edns.options_mut()
310 .insert(EdnsOption::DAU(SupportedAlgorithms::all()));
311
312 let record = Record::from(&edns);
313 let edns_decode = Edns::from(&record);
314
315 assert_eq!(edns.flags().dnssec_ok, edns_decode.flags().dnssec_ok);
316 assert_eq!(edns.flags().z, edns_decode.flags().z);
317 assert_eq!(edns.max_payload(), edns_decode.max_payload());
318 assert_eq!(edns.version(), edns_decode.version());
319 assert_eq!(edns.rcode_high(), edns_decode.rcode_high());
320 assert_eq!(edns.options(), edns_decode.options());
321
322 edns.options_mut()
324 .insert(EdnsOption::DAU(SupportedAlgorithms::all()));
325 edns.options_mut().remove(EdnsCode::DAU);
326 assert!(edns.option(EdnsCode::DAU).is_none());
327 }
328}