trust_dns_resolver/
hosts.rs

1//! Hosts result from a configuration of the system hosts file
2
3use std::collections::HashMap;
4use std::io;
5use std::path::Path;
6use std::str::FromStr;
7use std::sync::Arc;
8
9use proto::op::Query;
10use proto::rr::{Name, RecordType};
11use proto::rr::{RData, Record};
12use tracing::warn;
13
14use crate::dns_lru;
15use crate::lookup::Lookup;
16
17#[derive(Debug, Default)]
18struct LookupType {
19    /// represents the A record type
20    a: Option<Lookup>,
21    /// represents the AAAA record type
22    aaaa: Option<Lookup>,
23}
24
25/// Configuration for the local hosts file
26#[derive(Debug, Default)]
27pub struct Hosts {
28    /// Name -> RDatas map
29    by_name: HashMap<Name, LookupType>,
30}
31
32impl Hosts {
33    /// Creates a new configuration from the system hosts file,
34    /// only works for Windows and Unix-like OSes,
35    /// will return empty configuration on others
36    #[cfg(any(unix, windows))]
37    pub fn new() -> Self {
38        read_hosts_conf(hosts_path()).unwrap_or_default()
39    }
40
41    /// Creates a default configuration for non Windows or Unix-like OSes
42    #[cfg(not(any(unix, windows)))]
43    pub fn new() -> Self {
44        Hosts::default()
45    }
46
47    /// Look up the addresses for the given host from the system hosts file.
48    pub fn lookup_static_host(&self, query: &Query) -> Option<Lookup> {
49        if !self.by_name.is_empty() {
50            if let Some(val) = self.by_name.get(query.name()) {
51                let result = match query.query_type() {
52                    RecordType::A => val.a.clone(),
53                    RecordType::AAAA => val.aaaa.clone(),
54                    _ => None,
55                };
56
57                return result;
58            }
59        }
60        None
61    }
62
63    /// Insert a new Lookup for the associated `Name` and `RecordType`
64    pub fn insert(&mut self, name: Name, record_type: RecordType, lookup: Lookup) {
65        assert!(record_type == RecordType::A || record_type == RecordType::AAAA);
66
67        let lookup_type = self
68            .by_name
69            .entry(name.clone())
70            .or_insert_with(LookupType::default);
71
72        let new_lookup = {
73            let old_lookup = match record_type {
74                RecordType::A => lookup_type.a.get_or_insert_with(|| {
75                    let query = Query::query(name.clone(), record_type);
76                    Lookup::new_with_max_ttl(query, Arc::from([]))
77                }),
78                RecordType::AAAA => lookup_type.aaaa.get_or_insert_with(|| {
79                    let query = Query::query(name.clone(), record_type);
80                    Lookup::new_with_max_ttl(query, Arc::from([]))
81                }),
82                _ => {
83                    tracing::warn!("unsupported IP type from Hosts file: {:#?}", record_type);
84                    return;
85                }
86            };
87
88            old_lookup.append(lookup)
89        };
90
91        // replace the appended version
92        match record_type {
93            RecordType::A => lookup_type.a = Some(new_lookup),
94            RecordType::AAAA => lookup_type.aaaa = Some(new_lookup),
95            _ => tracing::warn!("unsupported IP type from Hosts file"),
96        }
97    }
98
99    /// parse configuration from `src`
100    pub fn read_hosts_conf(mut self, src: impl io::Read) -> io::Result<Self> {
101        use std::io::{BufRead, BufReader};
102
103        use proto::rr::domain::TryParseIp;
104
105        // lines in the src should have the form `addr host1 host2 host3 ...`
106        // line starts with `#` will be regarded with comments and ignored,
107        // also empty line also will be ignored,
108        // if line only include `addr` without `host` will be ignored,
109        // the src will be parsed to map in the form `Name -> LookUp`.
110
111        for line in BufReader::new(src).lines() {
112            // Remove comments from the line
113            let line = line?;
114            let line = line.split('#').next().unwrap().trim();
115            if line.is_empty() {
116                continue;
117            }
118
119            let fields: Vec<_> = line.split_whitespace().collect();
120            if fields.len() < 2 {
121                continue;
122            }
123            let addr = if let Some(a) = fields[0].try_parse_ip() {
124                a
125            } else {
126                warn!("could not parse an IP from hosts file");
127                continue;
128            };
129
130            for domain in fields.iter().skip(1).map(|domain| domain.to_lowercase()) {
131                if let Ok(name) = Name::from_str(&domain) {
132                    let record = Record::from_rdata(name.clone(), dns_lru::MAX_TTL, addr.clone());
133
134                    match addr {
135                        RData::A(..) => {
136                            let query = Query::query(name.clone(), RecordType::A);
137                            let lookup = Lookup::new_with_max_ttl(query, Arc::from([record]));
138                            self.insert(name.clone(), RecordType::A, lookup);
139                        }
140                        RData::AAAA(..) => {
141                            let query = Query::query(name.clone(), RecordType::AAAA);
142                            let lookup = Lookup::new_with_max_ttl(query, Arc::from([record]));
143                            self.insert(name.clone(), RecordType::AAAA, lookup);
144                        }
145                        _ => {
146                            warn!("unsupported IP type from Hosts file: {:#?}", addr);
147                            continue;
148                        }
149                    };
150
151                    // TODO: insert reverse lookup as well.
152                };
153            }
154        }
155
156        Ok(self)
157    }
158}
159
160#[cfg(unix)]
161fn hosts_path() -> &'static str {
162    "/etc/hosts"
163}
164
165#[cfg(windows)]
166fn hosts_path() -> std::path::PathBuf {
167    let system_root =
168        std::env::var_os("SystemRoot").expect("Environtment variable SystemRoot not found");
169    let system_root = Path::new(&system_root);
170    system_root.join("System32\\drivers\\etc\\hosts")
171}
172
173/// parse configuration from `path`
174#[cfg(any(unix, windows))]
175#[cfg_attr(docsrs, doc(cfg(any(unix, windows))))]
176pub(crate) fn read_hosts_conf<P: AsRef<Path>>(path: P) -> io::Result<Hosts> {
177    use std::fs::File;
178
179    let file = File::open(path)?;
180    Hosts::default().read_hosts_conf(file)
181}
182
183#[cfg(any(unix, windows))]
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use std::env;
188    use std::net::{Ipv4Addr, Ipv6Addr};
189
190    fn tests_dir() -> String {
191        let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
192        format! {"{server_path}/crates/resolver/tests"}
193    }
194
195    #[test]
196    fn test_read_hosts_conf() {
197        let path = format!("{}/hosts", tests_dir());
198        let hosts = read_hosts_conf(path).unwrap();
199
200        let name = Name::from_str("localhost").unwrap();
201        let rdatas = hosts
202            .lookup_static_host(&Query::query(name.clone(), RecordType::A))
203            .unwrap()
204            .iter()
205            .map(ToOwned::to_owned)
206            .collect::<Vec<RData>>();
207
208        assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(127, 0, 0, 1).into())]);
209
210        let rdatas = hosts
211            .lookup_static_host(&Query::query(name, RecordType::AAAA))
212            .unwrap()
213            .iter()
214            .map(ToOwned::to_owned)
215            .collect::<Vec<RData>>();
216
217        assert_eq!(
218            rdatas,
219            vec![RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into())]
220        );
221
222        let name = Name::from_str("broadcasthost").unwrap();
223        let rdatas = hosts
224            .lookup_static_host(&Query::query(name, RecordType::A))
225            .unwrap()
226            .iter()
227            .map(ToOwned::to_owned)
228            .collect::<Vec<RData>>();
229        assert_eq!(
230            rdatas,
231            vec![RData::A(Ipv4Addr::new(255, 255, 255, 255).into())]
232        );
233
234        let name = Name::from_str("example.com").unwrap();
235        let rdatas = hosts
236            .lookup_static_host(&Query::query(name, RecordType::A))
237            .unwrap()
238            .iter()
239            .map(ToOwned::to_owned)
240            .collect::<Vec<RData>>();
241        assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 102).into())]);
242
243        let name = Name::from_str("a.example.com").unwrap();
244        let rdatas = hosts
245            .lookup_static_host(&Query::query(name, RecordType::A))
246            .unwrap()
247            .iter()
248            .map(ToOwned::to_owned)
249            .collect::<Vec<RData>>();
250        assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 111).into())]);
251
252        let name = Name::from_str("b.example.com").unwrap();
253        let rdatas = hosts
254            .lookup_static_host(&Query::query(name, RecordType::A))
255            .unwrap()
256            .iter()
257            .map(ToOwned::to_owned)
258            .collect::<Vec<RData>>();
259        assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 111).into())]);
260    }
261}