1use std::{
2 borrow::Borrow,
3 io,
4 net::{Ipv4Addr, SocketAddr},
5 ops::{Deref, DerefMut},
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use futures_util::stream::{self, Fuse, Stream, StreamExt};
11#[cfg(feature = "tokio")]
12use tokio::net::TcpStream;
13
14#[cfg(feature = "tokio")]
15use crate::ToProxyAddrs;
16use crate::{
17 io::{AsyncSocket, AsyncSocketExt},
18 Error,
19 IntoTargetAddr,
20 Result,
21 TargetAddr,
22};
23
24#[repr(u8)]
25#[derive(Clone, Copy)]
26enum CommandV4 {
27 Connect = 0x01,
28 Bind = 0x02,
29}
30
31#[derive(Debug)]
35pub struct Socks4Stream<S> {
36 socket: S,
37 target: TargetAddr<'static>,
38}
39
40impl<S> Deref for Socks4Stream<S> {
41 type Target = S;
42
43 fn deref(&self) -> &Self::Target {
44 &self.socket
45 }
46}
47
48impl<S> DerefMut for Socks4Stream<S> {
49 fn deref_mut(&mut self) -> &mut Self::Target {
50 &mut self.socket
51 }
52}
53
54#[cfg(feature = "tokio")]
55impl Socks4Stream<TcpStream> {
56 pub async fn connect<'t, P, T>(proxy: P, target: T) -> Result<Socks4Stream<TcpStream>>
64 where
65 P: ToProxyAddrs,
66 T: IntoTargetAddr<'t>,
67 {
68 Self::execute_command(proxy, target, None, CommandV4::Connect).await
69 }
70
71 pub async fn connect_with_userid<'a, 't, P, T>(
79 proxy: P,
80 target: T,
81 user_id: &'a str,
82 ) -> Result<Socks4Stream<TcpStream>>
83 where
84 P: ToProxyAddrs,
85 T: IntoTargetAddr<'t>,
86 {
87 Self::execute_command(proxy, target, Some(user_id), CommandV4::Connect).await
88 }
89
90 async fn execute_command<'a, 't, P, T>(
91 proxy: P,
92 target: T,
93 user_id: Option<&'a str>,
94 command: CommandV4,
95 ) -> Result<Socks4Stream<TcpStream>>
96 where
97 P: ToProxyAddrs,
98 T: IntoTargetAddr<'t>,
99 {
100 Self::validate_userid(user_id)?;
101
102 let sock = Socks4Connector::new(
103 user_id,
104 command,
105 proxy.to_proxy_addrs().fuse(),
106 target.into_target_addr()?,
107 )
108 .execute()
109 .await?;
110
111 Ok(sock)
112 }
113}
114
115impl<S> Socks4Stream<S>
116where S: AsyncSocket + Unpin
117{
118 pub async fn connect_with_socket<'t, T>(socket: S, target: T) -> Result<Socks4Stream<S>>
125 where T: IntoTargetAddr<'t> {
126 Self::execute_command_with_socket(socket, target, None, CommandV4::Connect).await
127 }
128
129 pub async fn connect_with_userid_and_socket<'a, 't, T>(
137 socket: S,
138 target: T,
139 user_id: &'a str,
140 ) -> Result<Socks4Stream<S>>
141 where
142 T: IntoTargetAddr<'t>,
143 {
144 Self::execute_command_with_socket(socket, target, Some(user_id), CommandV4::Connect).await
145 }
146
147 fn validate_userid(user_id: Option<&str>) -> Result<()> {
148 if let Some(user_id) = user_id {
151 let user_id_len = user_id.len();
152 if !(1..=255).contains(&user_id_len) {
153 Err(Error::InvalidAuthValues("userid length should between 1 to 255"))?
154 }
155 }
156
157 Ok(())
158 }
159
160 async fn execute_command_with_socket<'a, 't, T>(
161 socket: S,
162 target: T,
163 user_id: Option<&'a str>,
164 command: CommandV4,
165 ) -> Result<Socks4Stream<S>>
166 where
167 T: IntoTargetAddr<'t>,
168 {
169 Self::validate_userid(user_id)?;
170
171 let sock = Socks4Connector::new(user_id, command, stream::empty().fuse(), target.into_target_addr()?)
172 .execute_with_socket(socket)
173 .await?;
174
175 Ok(sock)
176 }
177
178 pub fn into_inner(self) -> S {
180 self.socket
181 }
182
183 pub fn target_addr(&self) -> TargetAddr<'_> {
185 match &self.target {
186 TargetAddr::Ip(addr) => TargetAddr::Ip(*addr),
187 TargetAddr::Domain(domain, port) => {
188 let domain: &str = domain.borrow();
189 TargetAddr::Domain(domain.into(), *port)
190 },
191 }
192 }
193}
194
195pub struct Socks4Connector<'a, 't, S> {
197 user_id: Option<&'a str>,
198 command: CommandV4,
199 #[allow(dead_code)]
200 proxy: Fuse<S>,
201 target: TargetAddr<'t>,
202 buf: [u8; 513],
203 ptr: usize,
204 len: usize,
205}
206
207impl<'a, 't, S> Socks4Connector<'a, 't, S>
208where S: Stream<Item = Result<SocketAddr>> + Unpin
209{
210 fn new(user_id: Option<&'a str>, command: CommandV4, proxy: Fuse<S>, target: TargetAddr<'t>) -> Self {
211 Socks4Connector {
212 user_id,
213 command,
214 proxy,
215 target,
216 buf: [0; 513],
217 ptr: 0,
218 len: 0,
219 }
220 }
221
222 #[cfg(feature = "tokio")]
223 pub async fn execute(&mut self) -> Result<Socks4Stream<TcpStream>> {
225 let next_addr = self.proxy.select_next_some().await?;
226 let tcp = TcpStream::connect(next_addr)
227 .await
228 .map_err(|_| Error::ProxyServerUnreachable)?;
229
230 self.execute_with_socket(tcp).await
231 }
232
233 pub async fn execute_with_socket<T: AsyncSocket + Unpin>(&mut self, mut socket: T) -> Result<Socks4Stream<T>> {
234 self.prepare_send_request()?;
236 socket.write_all(&self.buf[self.ptr..self.len]).await?;
237
238 let target = self.receive_reply(&mut socket).await?;
239
240 Ok(Socks4Stream { socket, target })
241 }
242
243 fn prepare_send_request(&mut self) -> Result<()> {
244 self.ptr = 0;
245 self.buf[..2].copy_from_slice(&[0x04, self.command as u8]);
246 match &self.target {
247 TargetAddr::Ip(SocketAddr::V4(addr)) => {
248 self.buf[2..4].copy_from_slice(&addr.port().to_be_bytes());
249 self.buf[4..8].copy_from_slice(&addr.ip().octets());
250 self.len = 8;
251 if let Some(user_id) = self.user_id {
252 let usr_byts = user_id.as_bytes();
253 let user_id_len = usr_byts.len();
254 self.len += user_id_len;
255 self.buf[8..self.len].copy_from_slice(usr_byts);
256 }
257 self.buf[self.len] = 0; self.len += 1;
259 },
260 TargetAddr::Ip(SocketAddr::V6(_)) => {
261 return Err(Error::AddressTypeNotSupported);
262 },
263 TargetAddr::Domain(domain, port) => {
264 self.buf[2..4].copy_from_slice(&port.to_be_bytes());
265 self.buf[4..8].copy_from_slice(&[0, 0, 0, 1]);
266 self.len = 8;
267 if let Some(user_id) = self.user_id {
268 let usr_byts = user_id.as_bytes();
269 let user_id_len = usr_byts.len();
270 self.len += user_id_len;
271 self.buf[8..self.len].copy_from_slice(usr_byts);
272 }
273 self.buf[self.len] = 0; self.len += 1;
275 let domain = domain.as_bytes();
276 let domain_len = domain.len();
277 self.buf[self.len..self.len + domain_len].copy_from_slice(domain);
278 self.len += domain_len;
279 self.buf[self.len] = 0;
280 self.len += 1;
281 },
282 };
283 Ok(())
284 }
285
286 fn prepare_recv_reply(&mut self) {
287 self.ptr = 0;
288 self.len = 8;
289 }
290
291 async fn receive_reply<T: AsyncSocket + Unpin>(&mut self, tcp: &mut T) -> Result<TargetAddr<'static>> {
292 self.prepare_recv_reply();
308 self.ptr += tcp.read_exact(&mut self.buf[self.ptr..self.len]).await?;
309 if self.buf[0] != 0 {
310 return Err(Error::InvalidResponseVersion);
311 }
312
313 match self.buf[1] {
314 0x5A => {}, 0x5B => return Err(Error::GeneralSocksServerFailure), 0x5C => return Err(Error::IdentdAuthFailure), 0x5D => return Err(Error::InvalidUserIdAuthFailure), _ => return Err(Error::UnknownError),
319 }
320
321 let port = u16::from_be_bytes([self.buf[2], self.buf[3]]);
322
323 let target = Ipv4Addr::from([self.buf[4], self.buf[5], self.buf[6], self.buf[7]]);
324
325 Ok(TargetAddr::Ip(SocketAddr::new(target.into(), port)))
326 }
327}
328
329pub struct Socks4Listener<S> {
336 inner: Socks4Stream<S>,
337}
338
339#[cfg(feature = "tokio")]
340impl Socks4Listener<TcpStream> {
341 pub async fn bind<'t, P, T>(proxy: P, target: T) -> Result<Socks4Listener<TcpStream>>
351 where
352 P: ToProxyAddrs,
353 T: IntoTargetAddr<'t>,
354 {
355 Self::bind_to_target(None, proxy, target).await
356 }
357
358 pub async fn bind_with_userid<'a, 't, P, T>(
369 proxy: P,
370 target: T,
371 user_id: &'a str,
372 ) -> Result<Socks4Listener<TcpStream>>
373 where
374 P: ToProxyAddrs,
375 T: IntoTargetAddr<'t>,
376 {
377 Self::bind_to_target(Some(user_id), proxy, target).await
378 }
379
380 async fn bind_to_target<'a, 't, P, T>(
381 user_id: Option<&'a str>,
382 proxy: P,
383 target: T,
384 ) -> Result<Socks4Listener<TcpStream>>
385 where
386 P: ToProxyAddrs,
387 T: IntoTargetAddr<'t>,
388 {
389 let socket = Socks4Connector::new(
390 user_id,
391 CommandV4::Bind,
392 proxy.to_proxy_addrs().fuse(),
393 target.into_target_addr()?,
394 )
395 .execute()
396 .await?;
397
398 Ok(Socks4Listener { inner: socket })
399 }
400}
401
402impl<S> Socks4Listener<S>
403where S: AsyncSocket + Unpin
404{
405 pub async fn bind_with_socket<'t, T>(socket: S, target: T) -> Result<Socks4Listener<S>>
416 where T: IntoTargetAddr<'t> {
417 Self::bind_to_target_with_socket(None, socket, target).await
418 }
419
420 pub async fn bind_with_user_and_socket<'a, 't, T>(
431 socket: S,
432 target: T,
433 user_id: &'a str,
434 ) -> Result<Socks4Listener<S>>
435 where
436 T: IntoTargetAddr<'t>,
437 {
438 Self::bind_to_target_with_socket(Some(user_id), socket, target).await
439 }
440
441 async fn bind_to_target_with_socket<'a, 't, T>(
442 auth: Option<&'a str>,
443 socket: S,
444 target: T,
445 ) -> Result<Socks4Listener<S>>
446 where
447 T: IntoTargetAddr<'t>,
448 {
449 let socket = Socks4Connector::new(
450 auth,
451 CommandV4::Bind,
452 stream::empty().fuse(),
453 target.into_target_addr()?,
454 )
455 .execute_with_socket(socket)
456 .await?;
457
458 Ok(Socks4Listener { inner: socket })
459 }
460
461 pub fn bind_addr(&self) -> TargetAddr {
466 self.inner.target_addr()
467 }
468
469 pub async fn accept(mut self) -> Result<Socks4Stream<S>> {
475 let mut connector = Socks4Connector {
476 user_id: None,
477 command: CommandV4::Bind,
478 proxy: stream::empty().fuse(),
479 target: self.inner.target,
480 buf: [0; 513],
481 ptr: 0,
482 len: 0,
483 };
484
485 let target = connector.receive_reply(&mut self.inner.socket).await?;
486
487 Ok(Socks4Stream {
488 socket: self.inner.socket,
489 target,
490 })
491 }
492}
493
494#[cfg(feature = "tokio")]
495impl<T> tokio::io::AsyncRead for Socks4Stream<T>
496where T: tokio::io::AsyncRead + Unpin
497{
498 fn poll_read(
499 mut self: Pin<&mut Self>,
500 cx: &mut Context<'_>,
501 buf: &mut tokio::io::ReadBuf<'_>,
502 ) -> Poll<io::Result<()>> {
503 tokio::io::AsyncRead::poll_read(Pin::new(&mut self.socket), cx, buf)
504 }
505}
506
507#[cfg(feature = "tokio")]
508impl<T> tokio::io::AsyncWrite for Socks4Stream<T>
509where T: tokio::io::AsyncWrite + Unpin
510{
511 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
512 tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.socket), cx, buf)
513 }
514
515 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
516 tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.socket), cx)
517 }
518
519 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
520 tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.socket), cx)
521 }
522}
523
524#[cfg(feature = "futures-io")]
525impl<T> futures_io::AsyncRead for Socks4Stream<T>
526where T: futures_io::AsyncRead + Unpin
527{
528 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
529 futures_io::AsyncRead::poll_read(Pin::new(&mut self.socket), cx, buf)
530 }
531}
532
533#[cfg(feature = "futures-io")]
534impl<T> futures_io::AsyncWrite for Socks4Stream<T>
535where T: futures_io::AsyncWrite + Unpin
536{
537 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
538 futures_io::AsyncWrite::poll_write(Pin::new(&mut self.socket), cx, buf)
539 }
540
541 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
542 futures_io::AsyncWrite::poll_flush(Pin::new(&mut self.socket), cx)
543 }
544
545 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
546 futures_io::AsyncWrite::poll_close(Pin::new(&mut self.socket), cx)
547 }
548}