1use std::{future::Future, io, net::SocketAddr};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncWrite};
6use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
7
8use crate::{
9 OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf,
10};
11
12#[derive(Debug, Clone)]
45pub struct TcpListener {
46 inner: Socket,
47}
48
49impl TcpListener {
50 pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
58 super::each_addr(addr, |addr| async move {
59 let sa = SockAddr::from(addr);
60 let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
61 socket.socket.set_reuse_address(true)?;
62 socket.socket.bind(&sa)?;
63 socket.listen(128)?;
64 Ok(Self { inner: socket })
65 })
66 .await
67 }
68
69 pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
71 Ok(Self {
72 inner: Socket::from_socket2(Socket2::from(stream))?,
73 })
74 }
75
76 pub fn close(self) -> impl Future<Output = io::Result<()>> {
79 self.inner.close()
80 }
81
82 pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
88 let (socket, addr) = self.inner.accept().await?;
89 let stream = TcpStream { inner: socket };
90 Ok((stream, addr.as_socket().expect("should be SocketAddr")))
91 }
92
93 pub fn local_addr(&self) -> io::Result<SocketAddr> {
117 self.inner
118 .local_addr()
119 .map(|addr| addr.as_socket().expect("should be SocketAddr"))
120 }
121}
122
123impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
124
125#[derive(Debug, Clone)]
147pub struct TcpStream {
148 inner: Socket,
149}
150
151impl TcpStream {
152 pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
154 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
155
156 super::each_addr(addr, |addr| async move {
157 let addr2 = SockAddr::from(addr);
158 let socket = if cfg!(windows) {
159 let bind_addr = if addr.is_ipv4() {
160 SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
161 } else if addr.is_ipv6() {
162 SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
163 } else {
164 return Err(io::Error::new(
165 io::ErrorKind::AddrNotAvailable,
166 "Unsupported address domain.",
167 ));
168 };
169 Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
170 } else {
171 Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
172 };
173 socket.connect_async(&addr2).await?;
174 Ok(Self { inner: socket })
175 })
176 .await
177 }
178
179 pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
181 Ok(Self {
182 inner: Socket::from_socket2(Socket2::from(stream))?,
183 })
184 }
185
186 pub fn close(self) -> impl Future<Output = io::Result<()>> {
189 self.inner.close()
190 }
191
192 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
194 self.inner
195 .peer_addr()
196 .map(|addr| addr.as_socket().expect("should be SocketAddr"))
197 }
198
199 pub fn local_addr(&self) -> io::Result<SocketAddr> {
201 self.inner
202 .local_addr()
203 .map(|addr| addr.as_socket().expect("should be SocketAddr"))
204 }
205
206 pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
213 crate::split(self)
214 }
215
216 pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
222 crate::into_split(self)
223 }
224
225 pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
227 self.inner.to_poll_fd()
228 }
229
230 pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
232 self.inner.into_poll_fd()
233 }
234}
235
236impl AsyncRead for TcpStream {
237 #[inline]
238 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
239 (&*self).read(buf).await
240 }
241
242 #[inline]
243 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
244 (&*self).read_vectored(buf).await
245 }
246}
247
248impl AsyncRead for &TcpStream {
249 #[inline]
250 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
251 self.inner.recv(buf).await
252 }
253
254 #[inline]
255 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
256 self.inner.recv_vectored(buf).await
257 }
258}
259
260impl AsyncWrite for TcpStream {
261 #[inline]
262 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
263 (&*self).write(buf).await
264 }
265
266 #[inline]
267 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
268 (&*self).write_vectored(buf).await
269 }
270
271 #[inline]
272 async fn flush(&mut self) -> io::Result<()> {
273 (&*self).flush().await
274 }
275
276 #[inline]
277 async fn shutdown(&mut self) -> io::Result<()> {
278 (&*self).shutdown().await
279 }
280}
281
282impl AsyncWrite for &TcpStream {
283 #[inline]
284 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
285 self.inner.send(buf).await
286 }
287
288 #[inline]
289 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
290 self.inner.send_vectored(buf).await
291 }
292
293 #[inline]
294 async fn flush(&mut self) -> io::Result<()> {
295 Ok(())
296 }
297
298 #[inline]
299 async fn shutdown(&mut self) -> io::Result<()> {
300 self.inner.shutdown().await
301 }
302}
303
304impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);