hickory_proto/dnssec/
supported_algorithm.rs

1/*
2 * Copyright (C) 2015 Benjamin Fry <benjaminfry@me.com>
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! bitmap for expressing the set of supported algorithms in edns.
18
19use alloc::vec::Vec;
20use core::fmt::{self, Display, Formatter};
21
22#[cfg(feature = "serde")]
23use serde::{Deserialize, Serialize};
24
25use tracing::warn;
26
27use super::Algorithm;
28use crate::error::ProtoResult;
29use crate::serialize::binary::{BinEncodable, BinEncoder};
30
31/// Used to specify the set of SupportedAlgorithms between a client and server
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33#[derive(Debug, PartialOrd, PartialEq, Eq, Clone, Copy, Hash)]
34pub struct SupportedAlgorithms {
35    // right now the number of Algorithms supported are fewer than 8.
36    bit_map: u8,
37}
38
39impl SupportedAlgorithms {
40    /// Return a new set of Supported algorithms
41    pub fn new() -> Self {
42        Self { bit_map: 0 }
43    }
44
45    /// Specify the entire set is supported
46    pub fn all() -> Self {
47        Self {
48            bit_map: 0b0111_1111,
49        }
50    }
51
52    /// Based on the set of Algorithms, return the supported set
53    pub fn from_vec(algorithms: &[Algorithm]) -> Self {
54        let mut supported = Self::new();
55
56        for a in algorithms {
57            supported.set(*a);
58        }
59
60        supported
61    }
62
63    fn pos(algorithm: Algorithm) -> Option<u8> {
64        // not using the values from the RFC's to keep the bit_map space condensed
65        #[allow(deprecated)]
66        let bit_pos: Option<u8> = match algorithm {
67            Algorithm::RSASHA1 => Some(0),
68            Algorithm::RSASHA256 => Some(1),
69            Algorithm::RSASHA1NSEC3SHA1 => Some(2),
70            Algorithm::RSASHA512 => Some(3),
71            Algorithm::ECDSAP256SHA256 => Some(4),
72            Algorithm::ECDSAP384SHA384 => Some(5),
73            Algorithm::ED25519 => Some(6),
74            Algorithm::RSAMD5 | Algorithm::DSA | Algorithm::Unknown(_) => None,
75        };
76
77        bit_pos.map(|b| 1u8 << b)
78    }
79
80    fn from_pos(pos: u8) -> Option<Algorithm> {
81        // TODO: should build a code generator or possibly a macro for deriving these inversions
82        #[allow(deprecated)]
83        match pos {
84            0 => Some(Algorithm::RSASHA1),
85            1 => Some(Algorithm::RSASHA256),
86            2 => Some(Algorithm::RSASHA1NSEC3SHA1),
87            3 => Some(Algorithm::RSASHA512),
88            4 => Some(Algorithm::ECDSAP256SHA256),
89            5 => Some(Algorithm::ECDSAP384SHA384),
90            6 => Some(Algorithm::ED25519),
91            _ => None,
92        }
93    }
94
95    /// Set the specified algorithm as supported
96    pub fn set(&mut self, algorithm: Algorithm) {
97        if let Some(bit_pos) = Self::pos(algorithm) {
98            self.bit_map |= bit_pos;
99        }
100    }
101
102    /// Returns true if the algorithm is supported
103    pub fn has(self, algorithm: Algorithm) -> bool {
104        if let Some(bit_pos) = Self::pos(algorithm) {
105            (bit_pos & self.bit_map) == bit_pos
106        } else {
107            false
108        }
109    }
110
111    /// Return an Iterator over the supported set.
112    pub fn iter(&self) -> SupportedAlgorithmsIter<'_> {
113        SupportedAlgorithmsIter::new(self)
114    }
115
116    /// Return the count of supported algorithms
117    pub fn len(self) -> u16 {
118        // this is pretty much guaranteed to be less that u16::MAX
119        self.iter().count() as u16
120    }
121
122    /// Return true if no SupportedAlgorithms are set, this implies the option is not supported
123    pub fn is_empty(self) -> bool {
124        self.bit_map == 0
125    }
126}
127
128impl Default for SupportedAlgorithms {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl Display for SupportedAlgorithms {
135    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
136        for a in self.iter() {
137            a.fmt(f)?;
138            f.write_str(", ")?;
139        }
140
141        Ok(())
142    }
143}
144
145impl<'a> From<&'a [u8]> for SupportedAlgorithms {
146    fn from(values: &'a [u8]) -> Self {
147        let mut supported = Self::new();
148
149        for a in values.iter().map(|i| Algorithm::from_u8(*i)) {
150            match a {
151                Algorithm::Unknown(v) => warn!("unrecognized algorithm: {}", v),
152                a => supported.set(a),
153            }
154        }
155
156        supported
157    }
158}
159
160impl<'a> From<&'a SupportedAlgorithms> for Vec<u8> {
161    fn from(value: &'a SupportedAlgorithms) -> Self {
162        let mut bytes = Self::with_capacity(8); // today this is less than 8
163
164        for a in value.iter() {
165            bytes.push(a.into());
166        }
167
168        bytes.shrink_to_fit();
169        bytes
170    }
171}
172
173impl From<Algorithm> for SupportedAlgorithms {
174    fn from(algorithm: Algorithm) -> Self {
175        Self::from_vec(&[algorithm])
176    }
177}
178
179pub struct SupportedAlgorithmsIter<'a> {
180    algorithms: &'a SupportedAlgorithms,
181    current: usize,
182}
183
184impl<'a> SupportedAlgorithmsIter<'a> {
185    pub fn new(algorithms: &'a SupportedAlgorithms) -> Self {
186        SupportedAlgorithmsIter {
187            algorithms,
188            current: 0,
189        }
190    }
191}
192
193impl Iterator for SupportedAlgorithmsIter<'_> {
194    type Item = Algorithm;
195    fn next(&mut self) -> Option<Self::Item> {
196        // some quick bounds checking
197        if self.current > u8::MAX as usize {
198            return None;
199        }
200
201        while let Some(algorithm) = SupportedAlgorithms::from_pos(self.current as u8) {
202            self.current += 1;
203            if self.algorithms.has(algorithm) {
204                return Some(algorithm);
205            }
206        }
207
208        None
209    }
210}
211
212impl BinEncodable for SupportedAlgorithms {
213    fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
214        for a in self.iter() {
215            encoder.emit_u8(a.into())?;
216        }
217        Ok(())
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    #[allow(deprecated)]
227    fn test_has() {
228        let mut supported = SupportedAlgorithms::new();
229
230        supported.set(Algorithm::RSASHA1);
231
232        assert!(supported.has(Algorithm::RSASHA1));
233        assert!(!supported.has(Algorithm::RSASHA1NSEC3SHA1));
234
235        let mut supported = SupportedAlgorithms::new();
236
237        supported.set(Algorithm::RSASHA256);
238        assert!(!supported.has(Algorithm::RSASHA1));
239        assert!(!supported.has(Algorithm::RSASHA1NSEC3SHA1));
240        assert!(supported.has(Algorithm::RSASHA256));
241    }
242
243    #[allow(deprecated)]
244    #[test]
245    fn test_iterator() {
246        let supported = SupportedAlgorithms::all();
247        assert_eq!(supported.iter().count(), 7);
248
249        // it just so happens that the iterator has a fixed order...
250        let supported = SupportedAlgorithms::all();
251        let mut iter = supported.iter();
252        assert_eq!(iter.next(), Some(Algorithm::RSASHA1));
253        assert_eq!(iter.next(), Some(Algorithm::RSASHA256));
254        assert_eq!(iter.next(), Some(Algorithm::RSASHA1NSEC3SHA1));
255        assert_eq!(iter.next(), Some(Algorithm::RSASHA512));
256        assert_eq!(iter.next(), Some(Algorithm::ECDSAP256SHA256));
257        assert_eq!(iter.next(), Some(Algorithm::ECDSAP384SHA384));
258        assert_eq!(iter.next(), Some(Algorithm::ED25519));
259
260        let mut supported = SupportedAlgorithms::new();
261        supported.set(Algorithm::RSASHA256);
262        supported.set(Algorithm::RSASHA512);
263
264        let mut iter = supported.iter();
265        assert_eq!(iter.next(), Some(Algorithm::RSASHA256));
266        assert_eq!(iter.next(), Some(Algorithm::RSASHA512));
267    }
268
269    #[test]
270    #[allow(deprecated)]
271    fn test_vec() {
272        let supported = SupportedAlgorithms::all();
273        let array: Vec<u8> = (&supported).into();
274        let decoded: SupportedAlgorithms = (&array as &[_]).into();
275
276        assert_eq!(supported, decoded);
277
278        let mut supported = SupportedAlgorithms::new();
279        supported.set(Algorithm::RSASHA256);
280        supported.set(Algorithm::ECDSAP256SHA256);
281        supported.set(Algorithm::ECDSAP384SHA384);
282        supported.set(Algorithm::ED25519);
283        let array: Vec<u8> = (&supported).into();
284        let decoded: SupportedAlgorithms = (&array as &[_]).into();
285
286        assert_eq!(supported, decoded);
287        assert!(!supported.has(Algorithm::RSASHA1));
288        assert!(!supported.has(Algorithm::RSASHA1NSEC3SHA1));
289        assert!(supported.has(Algorithm::RSASHA256));
290        assert!(supported.has(Algorithm::ECDSAP256SHA256));
291        assert!(supported.has(Algorithm::ECDSAP384SHA384));
292        assert!(supported.has(Algorithm::ED25519));
293    }
294}