1use std::{
2 borrow::Cow,
3 io::Result as IoResult,
4 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
5 pin::Pin,
6 task::{Context, Poll},
7 vec,
8};
9
10use either::Either;
11pub use error::Error;
12use futures_util::{
13 future,
14 stream::{self, Once, Stream},
15};
16
17pub type Result<T> = std::result::Result<T, Error>;
18
19pub trait ToProxyAddrs {
26 type Output: Stream<Item = Result<SocketAddr>> + Unpin;
27
28 fn to_proxy_addrs(&self) -> Self::Output;
29}
30
31macro_rules! trivial_impl_to_proxy_addrs {
32 ($t: ty) => {
33 impl ToProxyAddrs for $t {
34 type Output = Once<future::Ready<Result<SocketAddr>>>;
35
36 fn to_proxy_addrs(&self) -> Self::Output {
37 stream::once(future::ready(Ok(SocketAddr::from(*self))))
38 }
39 }
40 };
41}
42
43trivial_impl_to_proxy_addrs!(SocketAddr);
44trivial_impl_to_proxy_addrs!((IpAddr, u16));
45trivial_impl_to_proxy_addrs!((Ipv4Addr, u16));
46trivial_impl_to_proxy_addrs!((Ipv6Addr, u16));
47trivial_impl_to_proxy_addrs!(SocketAddrV4);
48trivial_impl_to_proxy_addrs!(SocketAddrV6);
49
50impl<'a> ToProxyAddrs for &'a [SocketAddr] {
51 type Output = ProxyAddrsStream;
52
53 fn to_proxy_addrs(&self) -> Self::Output {
54 let addrs = self.to_vec();
55 ProxyAddrsStream(Some(IoResult::Ok(addrs.into_iter())))
56 }
57}
58
59impl ToProxyAddrs for str {
60 type Output = ProxyAddrsStream;
61
62 fn to_proxy_addrs(&self) -> Self::Output {
63 ProxyAddrsStream(Some(self.to_socket_addrs()))
64 }
65}
66
67impl<'a> ToProxyAddrs for (&'a str, u16) {
68 type Output = ProxyAddrsStream;
69
70 fn to_proxy_addrs(&self) -> Self::Output {
71 ProxyAddrsStream(Some(self.to_socket_addrs()))
72 }
73}
74
75impl<'a, T: ToProxyAddrs + ?Sized> ToProxyAddrs for &'a T {
76 type Output = T::Output;
77
78 fn to_proxy_addrs(&self) -> Self::Output {
79 (**self).to_proxy_addrs()
80 }
81}
82
83pub struct ProxyAddrsStream(Option<IoResult<vec::IntoIter<SocketAddr>>>);
84
85impl Stream for ProxyAddrsStream {
86 type Item = Result<SocketAddr>;
87
88 fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89 match self.0.as_mut() {
90 Some(Ok(iter)) => Poll::Ready(iter.next().map(Result::Ok)),
91 Some(Err(_)) => {
92 let err = self.0.take().unwrap().unwrap_err();
93 Poll::Ready(Some(Err(err.into())))
94 },
95 None => unreachable!(),
96 }
97 }
98}
99
100#[derive(Debug, PartialEq, Eq)]
102pub enum TargetAddr<'a> {
103 Ip(SocketAddr),
105
106 Domain(Cow<'a, str>, u16),
111}
112
113impl<'a> TargetAddr<'a> {
114 pub fn to_owned(&self) -> TargetAddr<'static> {
117 match self {
118 TargetAddr::Ip(addr) => TargetAddr::Ip(*addr),
119 TargetAddr::Domain(domain, port) => TargetAddr::Domain(String::from(domain.clone()).into(), *port),
120 }
121 }
122}
123
124impl<'a> ToSocketAddrs for TargetAddr<'a> {
125 type Iter = Either<std::option::IntoIter<SocketAddr>, std::vec::IntoIter<SocketAddr>>;
126
127 fn to_socket_addrs(&self) -> IoResult<Self::Iter> {
128 Ok(match self {
129 TargetAddr::Ip(addr) => Either::Left(addr.to_socket_addrs()?),
130 TargetAddr::Domain(domain, port) => Either::Right((&**domain, *port).to_socket_addrs()?),
131 })
132 }
133}
134
135pub trait IntoTargetAddr<'a> {
137 fn into_target_addr(self) -> Result<TargetAddr<'a>>;
139}
140
141macro_rules! trivial_impl_into_target_addr {
142 ($t: ty) => {
143 impl<'a> IntoTargetAddr<'a> for $t {
144 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
145 Ok(TargetAddr::Ip(SocketAddr::from(self)))
146 }
147 }
148 };
149}
150
151trivial_impl_into_target_addr!(SocketAddr);
152trivial_impl_into_target_addr!((IpAddr, u16));
153trivial_impl_into_target_addr!((Ipv4Addr, u16));
154trivial_impl_into_target_addr!((Ipv6Addr, u16));
155trivial_impl_into_target_addr!(SocketAddrV4);
156trivial_impl_into_target_addr!(SocketAddrV6);
157
158impl<'a> IntoTargetAddr<'a> for TargetAddr<'a> {
159 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
160 Ok(self)
161 }
162}
163
164impl<'a> IntoTargetAddr<'a> for (&'a str, u16) {
165 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
166 if let Ok(addr) = self.0.parse::<IpAddr>() {
168 return (addr, self.1).into_target_addr();
169 }
170
171 if self.0.len() > 255 {
173 return Err(Error::InvalidTargetAddress("overlong domain"));
174 }
175 Ok(TargetAddr::Domain(self.0.into(), self.1))
178 }
179}
180
181impl<'a> IntoTargetAddr<'a> for &'a str {
182 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
183 if let Ok(addr) = self.parse::<SocketAddr>() {
185 return addr.into_target_addr();
186 }
187
188 let mut parts_iter = self.rsplitn(2, ':');
189 let port: u16 = parts_iter
190 .next()
191 .and_then(|port_str| port_str.parse().ok())
192 .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
193 let domain = parts_iter
194 .next()
195 .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
196 if domain.len() > 255 {
197 return Err(Error::InvalidTargetAddress("overlong domain"));
198 }
199 Ok(TargetAddr::Domain(domain.into(), port))
200 }
201}
202
203impl IntoTargetAddr<'static> for String {
204 fn into_target_addr(mut self) -> Result<TargetAddr<'static>> {
205 if let Ok(addr) = self.parse::<SocketAddr>() {
207 return addr.into_target_addr();
208 }
209
210 let mut parts_iter = self.rsplitn(2, ':');
211 let port: u16 = parts_iter
212 .next()
213 .and_then(|port_str| port_str.parse().ok())
214 .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
215 let domain_len = parts_iter
216 .next()
217 .ok_or(Error::InvalidTargetAddress("invalid address format"))?
218 .len();
219 if domain_len > 255 {
220 return Err(Error::InvalidTargetAddress("overlong domain"));
221 }
222 self.truncate(domain_len);
223 Ok(TargetAddr::Domain(self.into(), port))
224 }
225}
226
227impl IntoTargetAddr<'static> for (String, u16) {
228 fn into_target_addr(self) -> Result<TargetAddr<'static>> {
229 let addr = (self.0.as_str(), self.1).into_target_addr()?;
230 if let TargetAddr::Ip(addr) = addr {
231 Ok(TargetAddr::Ip(addr))
232 } else {
233 Ok(TargetAddr::Domain(self.0.into(), self.1))
234 }
235 }
236}
237
238impl<'a, T> IntoTargetAddr<'a> for &'a T
239where T: IntoTargetAddr<'a> + Copy
240{
241 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
242 (*self).into_target_addr()
243 }
244}
245
246#[derive(Debug)]
248enum Authentication<'a> {
249 Password { username: &'a str, password: &'a str },
250 None,
251}
252
253impl<'a> Authentication<'a> {
254 fn id(&self) -> u8 {
255 match self {
256 Authentication::Password { .. } => 0x02,
257 Authentication::None => 0x00,
258 }
259 }
260}
261
262mod error;
263pub mod io;
264pub mod tcp;
265
266#[cfg(test)]
267mod tests {
268 use futures_executor::block_on;
269 use futures_util::StreamExt;
270
271 use super::*;
272
273 fn to_proxy_addrs<T: ToProxyAddrs>(t: T) -> Result<Vec<SocketAddr>> {
274 Ok(block_on(t.to_proxy_addrs().map(Result::unwrap).collect()))
275 }
276
277 #[test]
278 fn converts_socket_addr_to_proxy_addrs() -> Result<()> {
279 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
280 let res = to_proxy_addrs(addr)?;
281 assert_eq!(&res[..], &[addr]);
282 Ok(())
283 }
284
285 #[test]
286 fn converts_socket_addr_ref_to_proxy_addrs() -> Result<()> {
287 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
288 let res = to_proxy_addrs(addr)?;
289 assert_eq!(&res[..], &[addr]);
290 Ok(())
291 }
292
293 #[test]
294 fn converts_socket_addrs_to_proxy_addrs() -> Result<()> {
295 let addrs = [
296 SocketAddr::from(([1, 1, 1, 1], 443)),
297 SocketAddr::from(([8, 8, 8, 8], 53)),
298 ];
299 let res = to_proxy_addrs(&addrs[..])?;
300 assert_eq!(&res[..], &addrs);
301 Ok(())
302 }
303
304 fn into_target_addr<'a, T>(t: T) -> Result<TargetAddr<'a>>
305 where T: IntoTargetAddr<'a> {
306 t.into_target_addr()
307 }
308
309 #[test]
310 fn converts_socket_addr_to_target_addr() -> Result<()> {
311 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
312 let res = into_target_addr(addr)?;
313 assert_eq!(TargetAddr::Ip(addr), res);
314 Ok(())
315 }
316
317 #[test]
318 fn converts_socket_addr_ref_to_target_addr() -> Result<()> {
319 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
320 let res = into_target_addr(addr)?;
321 assert_eq!(TargetAddr::Ip(addr), res);
322 Ok(())
323 }
324
325 #[test]
326 fn converts_socket_addr_str_to_target_addr() -> Result<()> {
327 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
328 let ip_str = format!("{}", addr);
329 let res = into_target_addr(ip_str.as_str())?;
330 assert_eq!(TargetAddr::Ip(addr), res);
331 Ok(())
332 }
333
334 #[test]
335 fn converts_ip_str_and_port_target_addr() -> Result<()> {
336 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
337 let ip_str = format!("{}", addr.ip());
338 let res = into_target_addr((ip_str.as_str(), addr.port()))?;
339 assert_eq!(TargetAddr::Ip(addr), res);
340 Ok(())
341 }
342
343 #[test]
344 fn converts_domain_to_target_addr() -> Result<()> {
345 let domain = "www.example.com:80";
346 let res = into_target_addr(domain)?;
347 assert_eq!(TargetAddr::Domain(Cow::Borrowed("www.example.com"), 80), res);
348
349 let res = into_target_addr(domain.to_owned())?;
350 assert_eq!(TargetAddr::Domain(Cow::Owned("www.example.com".to_owned()), 80), res);
351 Ok(())
352 }
353
354 #[test]
355 fn converts_domain_and_port_to_target_addr() -> Result<()> {
356 let domain = "www.example.com";
357 let res = into_target_addr((domain, 80))?;
358 assert_eq!(TargetAddr::Domain(Cow::Borrowed("www.example.com"), 80), res);
359 Ok(())
360 }
361
362 #[test]
363 fn overlong_domain_to_target_addr_should_fail() {
364 let domain = format!("www.{:a<1$}.com:80", 'a', 300);
365 assert!(into_target_addr(domain.as_str()).is_err());
366 let domain = format!("www.{:a<1$}.com", 'a', 300);
367 assert!(into_target_addr((domain.as_str(), 80)).is_err());
368 }
369
370 #[test]
371 fn addr_with_invalid_port_to_target_addr_should_fail() {
372 let addr = "[ffff::1]:65536";
373 assert!(into_target_addr(addr).is_err());
374 let addr = "www.example.com:65536";
375 assert!(into_target_addr(addr).is_err());
376 }
377}