tokio_socks/
lib.rs

1use std::{
2    borrow::Cow,
3    io::Result as IoResult,
4    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
5    pin::Pin,
6    task::{Context, Poll},
7    vec,
8};
9
10use either::Either;
11pub use error::Error;
12use futures_util::{
13    future,
14    stream::{self, Once, Stream},
15};
16
17pub type Result<T> = std::result::Result<T, Error>;
18
19/// A trait for objects which can be converted or resolved to one or more
20/// `SocketAddr` values, which are going to be connected as the the proxy
21/// server.
22///
23/// This trait is similar to `std::net::ToSocketAddrs` but allows asynchronous
24/// name resolution.
25pub trait ToProxyAddrs {
26    type Output: Stream<Item = Result<SocketAddr>> + Unpin;
27
28    fn to_proxy_addrs(&self) -> Self::Output;
29}
30
31macro_rules! trivial_impl_to_proxy_addrs {
32    ($t: ty) => {
33        impl ToProxyAddrs for $t {
34            type Output = Once<future::Ready<Result<SocketAddr>>>;
35
36            fn to_proxy_addrs(&self) -> Self::Output {
37                stream::once(future::ready(Ok(SocketAddr::from(*self))))
38            }
39        }
40    };
41}
42
43trivial_impl_to_proxy_addrs!(SocketAddr);
44trivial_impl_to_proxy_addrs!((IpAddr, u16));
45trivial_impl_to_proxy_addrs!((Ipv4Addr, u16));
46trivial_impl_to_proxy_addrs!((Ipv6Addr, u16));
47trivial_impl_to_proxy_addrs!(SocketAddrV4);
48trivial_impl_to_proxy_addrs!(SocketAddrV6);
49
50impl<'a> ToProxyAddrs for &'a [SocketAddr] {
51    type Output = ProxyAddrsStream;
52
53    fn to_proxy_addrs(&self) -> Self::Output {
54        let addrs = self.to_vec();
55        ProxyAddrsStream(Some(IoResult::Ok(addrs.into_iter())))
56    }
57}
58
59impl ToProxyAddrs for str {
60    type Output = ProxyAddrsStream;
61
62    fn to_proxy_addrs(&self) -> Self::Output {
63        ProxyAddrsStream(Some(self.to_socket_addrs()))
64    }
65}
66
67impl<'a> ToProxyAddrs for (&'a str, u16) {
68    type Output = ProxyAddrsStream;
69
70    fn to_proxy_addrs(&self) -> Self::Output {
71        ProxyAddrsStream(Some(self.to_socket_addrs()))
72    }
73}
74
75impl<'a, T: ToProxyAddrs + ?Sized> ToProxyAddrs for &'a T {
76    type Output = T::Output;
77
78    fn to_proxy_addrs(&self) -> Self::Output {
79        (**self).to_proxy_addrs()
80    }
81}
82
83pub struct ProxyAddrsStream(Option<IoResult<vec::IntoIter<SocketAddr>>>);
84
85impl Stream for ProxyAddrsStream {
86    type Item = Result<SocketAddr>;
87
88    fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89        match self.0.as_mut() {
90            Some(Ok(iter)) => Poll::Ready(iter.next().map(Result::Ok)),
91            Some(Err(_)) => {
92                let err = self.0.take().unwrap().unwrap_err();
93                Poll::Ready(Some(Err(err.into())))
94            },
95            None => unreachable!(),
96        }
97    }
98}
99
100/// A SOCKS connection target.
101#[derive(Debug, PartialEq, Eq)]
102pub enum TargetAddr<'a> {
103    /// Connect to an IP address.
104    Ip(SocketAddr),
105
106    /// Connect to a fully-qualified domain name.
107    ///
108    /// The domain name will be passed along to the proxy server and DNS lookup
109    /// will happen there.
110    Domain(Cow<'a, str>, u16),
111}
112
113impl<'a> TargetAddr<'a> {
114    /// Creates owned `TargetAddr` by cloning. It is usually used to eliminate
115    /// the lifetime bound.
116    pub fn to_owned(&self) -> TargetAddr<'static> {
117        match self {
118            TargetAddr::Ip(addr) => TargetAddr::Ip(*addr),
119            TargetAddr::Domain(domain, port) => TargetAddr::Domain(String::from(domain.clone()).into(), *port),
120        }
121    }
122}
123
124impl<'a> ToSocketAddrs for TargetAddr<'a> {
125    type Iter = Either<std::option::IntoIter<SocketAddr>, std::vec::IntoIter<SocketAddr>>;
126
127    fn to_socket_addrs(&self) -> IoResult<Self::Iter> {
128        Ok(match self {
129            TargetAddr::Ip(addr) => Either::Left(addr.to_socket_addrs()?),
130            TargetAddr::Domain(domain, port) => Either::Right((&**domain, *port).to_socket_addrs()?),
131        })
132    }
133}
134
135/// A trait for objects that can be converted to `TargetAddr`.
136pub trait IntoTargetAddr<'a> {
137    /// Converts the value of self to a `TargetAddr`.
138    fn into_target_addr(self) -> Result<TargetAddr<'a>>;
139}
140
141macro_rules! trivial_impl_into_target_addr {
142    ($t: ty) => {
143        impl<'a> IntoTargetAddr<'a> for $t {
144            fn into_target_addr(self) -> Result<TargetAddr<'a>> {
145                Ok(TargetAddr::Ip(SocketAddr::from(self)))
146            }
147        }
148    };
149}
150
151trivial_impl_into_target_addr!(SocketAddr);
152trivial_impl_into_target_addr!((IpAddr, u16));
153trivial_impl_into_target_addr!((Ipv4Addr, u16));
154trivial_impl_into_target_addr!((Ipv6Addr, u16));
155trivial_impl_into_target_addr!(SocketAddrV4);
156trivial_impl_into_target_addr!(SocketAddrV6);
157
158impl<'a> IntoTargetAddr<'a> for TargetAddr<'a> {
159    fn into_target_addr(self) -> Result<TargetAddr<'a>> {
160        Ok(self)
161    }
162}
163
164impl<'a> IntoTargetAddr<'a> for (&'a str, u16) {
165    fn into_target_addr(self) -> Result<TargetAddr<'a>> {
166        // Try IP address first
167        if let Ok(addr) = self.0.parse::<IpAddr>() {
168            return (addr, self.1).into_target_addr();
169        }
170
171        // Treat as domain name
172        if self.0.len() > 255 {
173            return Err(Error::InvalidTargetAddress("overlong domain"));
174        }
175        // TODO: Should we validate the domain format here?
176
177        Ok(TargetAddr::Domain(self.0.into(), self.1))
178    }
179}
180
181impl<'a> IntoTargetAddr<'a> for &'a str {
182    fn into_target_addr(self) -> Result<TargetAddr<'a>> {
183        // Try IP address first
184        if let Ok(addr) = self.parse::<SocketAddr>() {
185            return addr.into_target_addr();
186        }
187
188        let mut parts_iter = self.rsplitn(2, ':');
189        let port: u16 = parts_iter
190            .next()
191            .and_then(|port_str| port_str.parse().ok())
192            .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
193        let domain = parts_iter
194            .next()
195            .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
196        if domain.len() > 255 {
197            return Err(Error::InvalidTargetAddress("overlong domain"));
198        }
199        Ok(TargetAddr::Domain(domain.into(), port))
200    }
201}
202
203impl IntoTargetAddr<'static> for String {
204    fn into_target_addr(mut self) -> Result<TargetAddr<'static>> {
205        // Try IP address first
206        if let Ok(addr) = self.parse::<SocketAddr>() {
207            return addr.into_target_addr();
208        }
209
210        let mut parts_iter = self.rsplitn(2, ':');
211        let port: u16 = parts_iter
212            .next()
213            .and_then(|port_str| port_str.parse().ok())
214            .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
215        let domain_len = parts_iter
216            .next()
217            .ok_or(Error::InvalidTargetAddress("invalid address format"))?
218            .len();
219        if domain_len > 255 {
220            return Err(Error::InvalidTargetAddress("overlong domain"));
221        }
222        self.truncate(domain_len);
223        Ok(TargetAddr::Domain(self.into(), port))
224    }
225}
226
227impl IntoTargetAddr<'static> for (String, u16) {
228    fn into_target_addr(self) -> Result<TargetAddr<'static>> {
229        let addr = (self.0.as_str(), self.1).into_target_addr()?;
230        if let TargetAddr::Ip(addr) = addr {
231            Ok(TargetAddr::Ip(addr))
232        } else {
233            Ok(TargetAddr::Domain(self.0.into(), self.1))
234        }
235    }
236}
237
238impl<'a, T> IntoTargetAddr<'a> for &'a T
239where T: IntoTargetAddr<'a> + Copy
240{
241    fn into_target_addr(self) -> Result<TargetAddr<'a>> {
242        (*self).into_target_addr()
243    }
244}
245
246/// Authentication methods
247#[derive(Debug)]
248enum Authentication<'a> {
249    Password { username: &'a str, password: &'a str },
250    None,
251}
252
253impl<'a> Authentication<'a> {
254    fn id(&self) -> u8 {
255        match self {
256            Authentication::Password { .. } => 0x02,
257            Authentication::None => 0x00,
258        }
259    }
260}
261
262mod error;
263pub mod io;
264pub mod tcp;
265
266#[cfg(test)]
267mod tests {
268    use futures_executor::block_on;
269    use futures_util::StreamExt;
270
271    use super::*;
272
273    fn to_proxy_addrs<T: ToProxyAddrs>(t: T) -> Result<Vec<SocketAddr>> {
274        Ok(block_on(t.to_proxy_addrs().map(Result::unwrap).collect()))
275    }
276
277    #[test]
278    fn converts_socket_addr_to_proxy_addrs() -> Result<()> {
279        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
280        let res = to_proxy_addrs(addr)?;
281        assert_eq!(&res[..], &[addr]);
282        Ok(())
283    }
284
285    #[test]
286    fn converts_socket_addr_ref_to_proxy_addrs() -> Result<()> {
287        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
288        let res = to_proxy_addrs(addr)?;
289        assert_eq!(&res[..], &[addr]);
290        Ok(())
291    }
292
293    #[test]
294    fn converts_socket_addrs_to_proxy_addrs() -> Result<()> {
295        let addrs = [
296            SocketAddr::from(([1, 1, 1, 1], 443)),
297            SocketAddr::from(([8, 8, 8, 8], 53)),
298        ];
299        let res = to_proxy_addrs(&addrs[..])?;
300        assert_eq!(&res[..], &addrs);
301        Ok(())
302    }
303
304    fn into_target_addr<'a, T>(t: T) -> Result<TargetAddr<'a>>
305    where T: IntoTargetAddr<'a> {
306        t.into_target_addr()
307    }
308
309    #[test]
310    fn converts_socket_addr_to_target_addr() -> Result<()> {
311        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
312        let res = into_target_addr(addr)?;
313        assert_eq!(TargetAddr::Ip(addr), res);
314        Ok(())
315    }
316
317    #[test]
318    fn converts_socket_addr_ref_to_target_addr() -> Result<()> {
319        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
320        let res = into_target_addr(addr)?;
321        assert_eq!(TargetAddr::Ip(addr), res);
322        Ok(())
323    }
324
325    #[test]
326    fn converts_socket_addr_str_to_target_addr() -> Result<()> {
327        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
328        let ip_str = format!("{}", addr);
329        let res = into_target_addr(ip_str.as_str())?;
330        assert_eq!(TargetAddr::Ip(addr), res);
331        Ok(())
332    }
333
334    #[test]
335    fn converts_ip_str_and_port_target_addr() -> Result<()> {
336        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
337        let ip_str = format!("{}", addr.ip());
338        let res = into_target_addr((ip_str.as_str(), addr.port()))?;
339        assert_eq!(TargetAddr::Ip(addr), res);
340        Ok(())
341    }
342
343    #[test]
344    fn converts_domain_to_target_addr() -> Result<()> {
345        let domain = "www.example.com:80";
346        let res = into_target_addr(domain)?;
347        assert_eq!(TargetAddr::Domain(Cow::Borrowed("www.example.com"), 80), res);
348
349        let res = into_target_addr(domain.to_owned())?;
350        assert_eq!(TargetAddr::Domain(Cow::Owned("www.example.com".to_owned()), 80), res);
351        Ok(())
352    }
353
354    #[test]
355    fn converts_domain_and_port_to_target_addr() -> Result<()> {
356        let domain = "www.example.com";
357        let res = into_target_addr((domain, 80))?;
358        assert_eq!(TargetAddr::Domain(Cow::Borrowed("www.example.com"), 80), res);
359        Ok(())
360    }
361
362    #[test]
363    fn overlong_domain_to_target_addr_should_fail() {
364        let domain = format!("www.{:a<1$}.com:80", 'a', 300);
365        assert!(into_target_addr(domain.as_str()).is_err());
366        let domain = format!("www.{:a<1$}.com", 'a', 300);
367        assert!(into_target_addr((domain.as_str(), 80)).is_err());
368    }
369
370    #[test]
371    fn addr_with_invalid_port_to_target_addr_should_fail() {
372        let addr = "[ffff::1]:65536";
373        assert!(into_target_addr(addr).is_err());
374        let addr = "www.example.com:65536";
375        assert!(into_target_addr(addr).is_err());
376    }
377}