trust_dns_resolver/
hosts.rs1use std::collections::HashMap;
4use std::io;
5use std::path::Path;
6use std::str::FromStr;
7use std::sync::Arc;
8
9use proto::op::Query;
10use proto::rr::{Name, RecordType};
11use proto::rr::{RData, Record};
12use tracing::warn;
13
14use crate::dns_lru;
15use crate::lookup::Lookup;
16
17#[derive(Debug, Default)]
18struct LookupType {
19 a: Option<Lookup>,
21 aaaa: Option<Lookup>,
23}
24
25#[derive(Debug, Default)]
27pub struct Hosts {
28 by_name: HashMap<Name, LookupType>,
30}
31
32impl Hosts {
33 #[cfg(any(unix, windows))]
37 pub fn new() -> Self {
38 read_hosts_conf(hosts_path()).unwrap_or_default()
39 }
40
41 #[cfg(not(any(unix, windows)))]
43 pub fn new() -> Self {
44 Hosts::default()
45 }
46
47 pub fn lookup_static_host(&self, query: &Query) -> Option<Lookup> {
49 if !self.by_name.is_empty() {
50 if let Some(val) = self.by_name.get(query.name()) {
51 let result = match query.query_type() {
52 RecordType::A => val.a.clone(),
53 RecordType::AAAA => val.aaaa.clone(),
54 _ => None,
55 };
56
57 return result;
58 }
59 }
60 None
61 }
62
63 pub fn insert(&mut self, name: Name, record_type: RecordType, lookup: Lookup) {
65 assert!(record_type == RecordType::A || record_type == RecordType::AAAA);
66
67 let lookup_type = self
68 .by_name
69 .entry(name.clone())
70 .or_insert_with(LookupType::default);
71
72 let new_lookup = {
73 let old_lookup = match record_type {
74 RecordType::A => lookup_type.a.get_or_insert_with(|| {
75 let query = Query::query(name.clone(), record_type);
76 Lookup::new_with_max_ttl(query, Arc::from([]))
77 }),
78 RecordType::AAAA => lookup_type.aaaa.get_or_insert_with(|| {
79 let query = Query::query(name.clone(), record_type);
80 Lookup::new_with_max_ttl(query, Arc::from([]))
81 }),
82 _ => {
83 tracing::warn!("unsupported IP type from Hosts file: {:#?}", record_type);
84 return;
85 }
86 };
87
88 old_lookup.append(lookup)
89 };
90
91 match record_type {
93 RecordType::A => lookup_type.a = Some(new_lookup),
94 RecordType::AAAA => lookup_type.aaaa = Some(new_lookup),
95 _ => tracing::warn!("unsupported IP type from Hosts file"),
96 }
97 }
98
99 pub fn read_hosts_conf(mut self, src: impl io::Read) -> io::Result<Self> {
101 use std::io::{BufRead, BufReader};
102
103 use proto::rr::domain::TryParseIp;
104
105 for line in BufReader::new(src).lines() {
112 let line = line?;
114 let line = line.split('#').next().unwrap().trim();
115 if line.is_empty() {
116 continue;
117 }
118
119 let fields: Vec<_> = line.split_whitespace().collect();
120 if fields.len() < 2 {
121 continue;
122 }
123 let addr = if let Some(a) = fields[0].try_parse_ip() {
124 a
125 } else {
126 warn!("could not parse an IP from hosts file");
127 continue;
128 };
129
130 for domain in fields.iter().skip(1).map(|domain| domain.to_lowercase()) {
131 if let Ok(name) = Name::from_str(&domain) {
132 let record = Record::from_rdata(name.clone(), dns_lru::MAX_TTL, addr.clone());
133
134 match addr {
135 RData::A(..) => {
136 let query = Query::query(name.clone(), RecordType::A);
137 let lookup = Lookup::new_with_max_ttl(query, Arc::from([record]));
138 self.insert(name.clone(), RecordType::A, lookup);
139 }
140 RData::AAAA(..) => {
141 let query = Query::query(name.clone(), RecordType::AAAA);
142 let lookup = Lookup::new_with_max_ttl(query, Arc::from([record]));
143 self.insert(name.clone(), RecordType::AAAA, lookup);
144 }
145 _ => {
146 warn!("unsupported IP type from Hosts file: {:#?}", addr);
147 continue;
148 }
149 };
150
151 };
153 }
154 }
155
156 Ok(self)
157 }
158}
159
160#[cfg(unix)]
161fn hosts_path() -> &'static str {
162 "/etc/hosts"
163}
164
165#[cfg(windows)]
166fn hosts_path() -> std::path::PathBuf {
167 let system_root =
168 std::env::var_os("SystemRoot").expect("Environtment variable SystemRoot not found");
169 let system_root = Path::new(&system_root);
170 system_root.join("System32\\drivers\\etc\\hosts")
171}
172
173#[cfg(any(unix, windows))]
175#[cfg_attr(docsrs, doc(cfg(any(unix, windows))))]
176pub(crate) fn read_hosts_conf<P: AsRef<Path>>(path: P) -> io::Result<Hosts> {
177 use std::fs::File;
178
179 let file = File::open(path)?;
180 Hosts::default().read_hosts_conf(file)
181}
182
183#[cfg(any(unix, windows))]
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use std::env;
188 use std::net::{Ipv4Addr, Ipv6Addr};
189
190 fn tests_dir() -> String {
191 let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
192 format! {"{server_path}/crates/resolver/tests"}
193 }
194
195 #[test]
196 fn test_read_hosts_conf() {
197 let path = format!("{}/hosts", tests_dir());
198 let hosts = read_hosts_conf(path).unwrap();
199
200 let name = Name::from_str("localhost").unwrap();
201 let rdatas = hosts
202 .lookup_static_host(&Query::query(name.clone(), RecordType::A))
203 .unwrap()
204 .iter()
205 .map(ToOwned::to_owned)
206 .collect::<Vec<RData>>();
207
208 assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(127, 0, 0, 1).into())]);
209
210 let rdatas = hosts
211 .lookup_static_host(&Query::query(name, RecordType::AAAA))
212 .unwrap()
213 .iter()
214 .map(ToOwned::to_owned)
215 .collect::<Vec<RData>>();
216
217 assert_eq!(
218 rdatas,
219 vec![RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into())]
220 );
221
222 let name = Name::from_str("broadcasthost").unwrap();
223 let rdatas = hosts
224 .lookup_static_host(&Query::query(name, RecordType::A))
225 .unwrap()
226 .iter()
227 .map(ToOwned::to_owned)
228 .collect::<Vec<RData>>();
229 assert_eq!(
230 rdatas,
231 vec![RData::A(Ipv4Addr::new(255, 255, 255, 255).into())]
232 );
233
234 let name = Name::from_str("example.com").unwrap();
235 let rdatas = hosts
236 .lookup_static_host(&Query::query(name, RecordType::A))
237 .unwrap()
238 .iter()
239 .map(ToOwned::to_owned)
240 .collect::<Vec<RData>>();
241 assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 102).into())]);
242
243 let name = Name::from_str("a.example.com").unwrap();
244 let rdatas = hosts
245 .lookup_static_host(&Query::query(name, RecordType::A))
246 .unwrap()
247 .iter()
248 .map(ToOwned::to_owned)
249 .collect::<Vec<RData>>();
250 assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 111).into())]);
251
252 let name = Name::from_str("b.example.com").unwrap();
253 let rdatas = hosts
254 .lookup_static_host(&Query::query(name, RecordType::A))
255 .unwrap()
256 .iter()
257 .map(ToOwned::to_owned)
258 .collect::<Vec<RData>>();
259 assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 111).into())]);
260 }
261}