socks5_impl/protocol/
address.rs

1#[cfg(feature = "tokio")]
2use crate::protocol::AsyncStreamOperation;
3use crate::protocol::StreamOperation;
4#[cfg(feature = "tokio")]
5use async_trait::async_trait;
6use bytes::BufMut;
7use std::{
8    io::Cursor,
9    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
10};
11#[cfg(feature = "tokio")]
12use tokio::io::{AsyncRead, AsyncReadExt};
13
14#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Default)]
15#[repr(u8)]
16pub enum AddressType {
17    #[default]
18    IPv4 = 0x01,
19    Domain = 0x03,
20    IPv6 = 0x04,
21}
22
23impl TryFrom<u8> for AddressType {
24    type Error = std::io::Error;
25    fn try_from(code: u8) -> core::result::Result<Self, Self::Error> {
26        let err = format!("Unsupported address type code {0:#x}", code);
27        match code {
28            0x01 => Ok(AddressType::IPv4),
29            0x03 => Ok(AddressType::Domain),
30            0x04 => Ok(AddressType::IPv6),
31            _ => Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, err)),
32        }
33    }
34}
35
36impl From<AddressType> for u8 {
37    fn from(addr_type: AddressType) -> Self {
38        match addr_type {
39            AddressType::IPv4 => 0x01,
40            AddressType::Domain => 0x03,
41            AddressType::IPv6 => 0x04,
42        }
43    }
44}
45
46/// SOCKS5 Adderss Format
47///
48/// ```plain
49/// +------+----------+----------+
50/// | ATYP | DST.ADDR | DST.PORT |
51/// +------+----------+----------+
52/// |  1   | Variable |    2     |
53/// +------+----------+----------+
54/// ```
55#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
56pub enum Address {
57    SocketAddress(SocketAddr),
58    DomainAddress(String, u16),
59}
60
61impl Address {
62    pub fn unspecified() -> Self {
63        Address::SocketAddress(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)))
64    }
65
66    pub fn get_type(&self) -> AddressType {
67        match self {
68            Self::SocketAddress(SocketAddr::V4(_)) => AddressType::IPv4,
69            Self::SocketAddress(SocketAddr::V6(_)) => AddressType::IPv6,
70            Self::DomainAddress(_, _) => AddressType::Domain,
71        }
72    }
73
74    pub fn port(&self) -> u16 {
75        match self {
76            Self::SocketAddress(addr) => addr.port(),
77            Self::DomainAddress(_, port) => *port,
78        }
79    }
80
81    pub fn domain(&self) -> String {
82        match self {
83            Self::SocketAddress(addr) => addr.ip().to_string(),
84            Self::DomainAddress(addr, _) => addr.clone(),
85        }
86    }
87
88    pub const fn max_serialized_len() -> usize {
89        1 + 1 + u8::MAX as usize + 2
90    }
91}
92
93impl StreamOperation for Address {
94    fn retrieve_from_stream<R: std::io::Read>(stream: &mut R) -> std::io::Result<Self> {
95        let mut atyp = [0; 1];
96        stream.read_exact(&mut atyp)?;
97        match AddressType::try_from(atyp[0])? {
98            AddressType::IPv4 => {
99                let mut buf = [0; 6];
100                stream.read_exact(&mut buf)?;
101                let addr = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
102                let port = u16::from_be_bytes([buf[4], buf[5]]);
103                Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
104            }
105            AddressType::Domain => {
106                let mut len = [0; 1];
107                stream.read_exact(&mut len)?;
108                let len = len[0] as usize;
109                let mut buf = vec![0; len + 2];
110                stream.read_exact(&mut buf)?;
111
112                let port = u16::from_be_bytes([buf[len], buf[len + 1]]);
113                buf.truncate(len);
114
115                let addr = match String::from_utf8(buf) {
116                    Ok(addr) => addr,
117                    Err(err) => {
118                        let err = format!("Invalid address encoding: {err}");
119                        return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, err));
120                    }
121                };
122                Ok(Self::DomainAddress(addr, port))
123            }
124            AddressType::IPv6 => {
125                let mut buf = [0; 18];
126                stream.read_exact(&mut buf)?;
127                let port = u16::from_be_bytes([buf[16], buf[17]]);
128                let mut addr_bytes = [0; 16];
129                addr_bytes.copy_from_slice(&buf[..16]);
130                Ok(Self::SocketAddress(SocketAddr::from((Ipv6Addr::from(addr_bytes), port))))
131            }
132        }
133    }
134
135    fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
136        match self {
137            Self::SocketAddress(SocketAddr::V4(addr)) => {
138                buf.put_u8(AddressType::IPv4.into());
139                buf.put_slice(&addr.ip().octets());
140                buf.put_u16(addr.port());
141            }
142            Self::SocketAddress(SocketAddr::V6(addr)) => {
143                buf.put_u8(AddressType::IPv6.into());
144                buf.put_slice(&addr.ip().octets());
145                buf.put_u16(addr.port());
146            }
147            Self::DomainAddress(addr, port) => {
148                let addr = addr.as_bytes();
149                buf.put_u8(AddressType::Domain.into());
150                buf.put_u8(addr.len() as u8);
151                buf.put_slice(addr);
152                buf.put_u16(*port);
153            }
154        }
155    }
156
157    fn len(&self) -> usize {
158        match self {
159            Address::SocketAddress(SocketAddr::V4(_)) => 1 + 4 + 2,
160            Address::SocketAddress(SocketAddr::V6(_)) => 1 + 16 + 2,
161            Address::DomainAddress(addr, _) => 1 + 1 + addr.len() + 2,
162        }
163    }
164}
165
166#[cfg(feature = "tokio")]
167#[async_trait]
168impl AsyncStreamOperation for Address {
169    async fn retrieve_from_async_stream<R>(stream: &mut R) -> std::io::Result<Self>
170    where
171        R: AsyncRead + Unpin + Send + ?Sized,
172    {
173        let atyp = stream.read_u8().await?;
174        match AddressType::try_from(atyp)? {
175            AddressType::IPv4 => {
176                let mut addr_bytes = [0; 4];
177                stream.read_exact(&mut addr_bytes).await?;
178                let mut buf = [0; 2];
179                stream.read_exact(&mut buf).await?;
180                let addr = Ipv4Addr::from(addr_bytes);
181                let port = u16::from_be_bytes(buf);
182                Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
183            }
184            AddressType::Domain => {
185                let len = stream.read_u8().await? as usize;
186                let mut buf = vec![0; len + 2];
187                stream.read_exact(&mut buf).await?;
188
189                let port = u16::from_be_bytes([buf[len], buf[len + 1]]);
190                buf.truncate(len);
191
192                let addr = match String::from_utf8(buf) {
193                    Ok(addr) => addr,
194                    Err(err) => {
195                        let err = format!("Invalid address encoding: {err}");
196                        return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, err));
197                    }
198                };
199                Ok(Self::DomainAddress(addr, port))
200            }
201            AddressType::IPv6 => {
202                let mut addr_bytes = [0; 16];
203                stream.read_exact(&mut addr_bytes).await?;
204                let mut buf = [0; 2];
205                stream.read_exact(&mut buf).await?;
206                let port = u16::from_be_bytes(buf);
207                Ok(Self::SocketAddress(SocketAddr::from((Ipv6Addr::from(addr_bytes), port))))
208            }
209        }
210    }
211}
212
213impl ToSocketAddrs for Address {
214    type Iter = std::vec::IntoIter<SocketAddr>;
215
216    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
217        match self {
218            Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()),
219            Address::DomainAddress(addr, port) => Ok((addr.as_str(), *port).to_socket_addrs()?),
220        }
221    }
222}
223
224impl std::fmt::Display for Address {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        match self {
227            Address::DomainAddress(hostname, port) => write!(f, "{hostname}:{port}"),
228            Address::SocketAddress(socket_addr) => write!(f, "{socket_addr}"),
229        }
230    }
231}
232
233impl TryFrom<Address> for SocketAddr {
234    type Error = std::io::Error;
235
236    fn try_from(address: Address) -> std::result::Result<Self, Self::Error> {
237        match address {
238            Address::SocketAddress(addr) => Ok(addr),
239            Address::DomainAddress(addr, port) => {
240                if let Ok(addr) = addr.parse::<Ipv4Addr>() {
241                    Ok(SocketAddr::from((addr, port)))
242                } else if let Ok(addr) = addr.parse::<Ipv6Addr>() {
243                    Ok(SocketAddr::from((addr, port)))
244                } else {
245                    let err = format!("domain address {addr} is not supported");
246                    Err(Self::Error::new(std::io::ErrorKind::Unsupported, err))
247                }
248            }
249        }
250    }
251}
252
253impl TryFrom<&Address> for SocketAddr {
254    type Error = std::io::Error;
255
256    fn try_from(address: &Address) -> std::result::Result<Self, Self::Error> {
257        TryFrom::<Address>::try_from(address.clone())
258    }
259}
260
261impl From<Address> for Vec<u8> {
262    fn from(addr: Address) -> Self {
263        let mut buf = Vec::with_capacity(addr.len());
264        addr.write_to_buf(&mut buf);
265        buf
266    }
267}
268
269impl TryFrom<Vec<u8>> for Address {
270    type Error = std::io::Error;
271
272    fn try_from(data: Vec<u8>) -> std::result::Result<Self, Self::Error> {
273        let mut rdr = Cursor::new(data);
274        Self::retrieve_from_stream(&mut rdr)
275    }
276}
277
278impl TryFrom<&[u8]> for Address {
279    type Error = std::io::Error;
280
281    fn try_from(data: &[u8]) -> std::result::Result<Self, Self::Error> {
282        let mut rdr = Cursor::new(data);
283        Self::retrieve_from_stream(&mut rdr)
284    }
285}
286
287impl From<SocketAddr> for Address {
288    fn from(addr: SocketAddr) -> Self {
289        Address::SocketAddress(addr)
290    }
291}
292
293impl From<&SocketAddr> for Address {
294    fn from(addr: &SocketAddr) -> Self {
295        Address::SocketAddress(*addr)
296    }
297}
298
299impl From<(Ipv4Addr, u16)> for Address {
300    fn from((addr, port): (Ipv4Addr, u16)) -> Self {
301        Address::SocketAddress(SocketAddr::from((addr, port)))
302    }
303}
304
305impl From<(Ipv6Addr, u16)> for Address {
306    fn from((addr, port): (Ipv6Addr, u16)) -> Self {
307        Address::SocketAddress(SocketAddr::from((addr, port)))
308    }
309}
310
311impl From<(IpAddr, u16)> for Address {
312    fn from((addr, port): (IpAddr, u16)) -> Self {
313        Address::SocketAddress(SocketAddr::from((addr, port)))
314    }
315}
316
317impl From<(String, u16)> for Address {
318    fn from((addr, port): (String, u16)) -> Self {
319        Address::DomainAddress(addr, port)
320    }
321}
322
323impl From<(&str, u16)> for Address {
324    fn from((addr, port): (&str, u16)) -> Self {
325        Address::DomainAddress(addr.to_owned(), port)
326    }
327}
328
329impl From<&Address> for Address {
330    fn from(addr: &Address) -> Self {
331        addr.clone()
332    }
333}
334
335impl TryFrom<&str> for Address {
336    type Error = crate::Error;
337
338    fn try_from(addr: &str) -> std::result::Result<Self, Self::Error> {
339        if let Ok(addr) = addr.parse::<SocketAddr>() {
340            Ok(Address::SocketAddress(addr))
341        } else {
342            let (addr, port) = if let Some(pos) = addr.rfind(':') {
343                (&addr[..pos], &addr[pos + 1..])
344            } else {
345                (addr, "0")
346            };
347            let port = port.parse::<u16>()?;
348            Ok(Address::DomainAddress(addr.to_owned(), port))
349        }
350    }
351}
352
353#[test]
354fn test_address() {
355    let addr = Address::from((Ipv4Addr::new(127, 0, 0, 1), 8080));
356    let mut buf = Vec::new();
357    addr.write_to_buf(&mut buf);
358    assert_eq!(buf, vec![0x01, 127, 0, 0, 1, 0x1f, 0x90]);
359    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
360    assert_eq!(addr, addr2);
361
362    let addr = Address::from((Ipv6Addr::new(0x45, 0xff89, 0, 0, 0, 0, 0, 1), 8080));
363    let mut buf = Vec::new();
364    addr.write_to_buf(&mut buf);
365    assert_eq!(buf, vec![0x04, 0, 0x45, 0xff, 0x89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90]);
366    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
367    assert_eq!(addr, addr2);
368
369    let addr = Address::from(("sex.com".to_owned(), 8080));
370    let mut buf = Vec::new();
371    addr.write_to_buf(&mut buf);
372    assert_eq!(buf, vec![0x03, 0x07, b's', b'e', b'x', b'.', b'c', b'o', b'm', 0x1f, 0x90]);
373    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
374    assert_eq!(addr, addr2);
375}
376
377#[cfg(feature = "tokio")]
378#[tokio::test]
379async fn test_address_async() {
380    let addr = Address::from((Ipv4Addr::new(127, 0, 0, 1), 8080));
381    let mut buf = Vec::new();
382    addr.write_to_async_stream(&mut buf).await.unwrap();
383    assert_eq!(buf, vec![0x01, 127, 0, 0, 1, 0x1f, 0x90]);
384    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
385    assert_eq!(addr, addr2);
386
387    let addr = Address::from((Ipv6Addr::new(0x45, 0xff89, 0, 0, 0, 0, 0, 1), 8080));
388    let mut buf = Vec::new();
389    addr.write_to_async_stream(&mut buf).await.unwrap();
390    assert_eq!(buf, vec![0x04, 0, 0x45, 0xff, 0x89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90]);
391    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
392    assert_eq!(addr, addr2);
393
394    let addr = Address::from(("sex.com".to_owned(), 8080));
395    let mut buf = Vec::new();
396    addr.write_to_async_stream(&mut buf).await.unwrap();
397    assert_eq!(buf, vec![0x03, 0x07, b's', b'e', b'x', b'.', b'c', b'o', b'm', 0x1f, 0x90]);
398    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
399    assert_eq!(addr, addr2);
400}