1use tungstenite::Error;
3
4use std::io;
5
6use gio::prelude::*;
7
8use futures_io::{AsyncRead, AsyncWrite};
9
10use tungstenite::client::{uri_mode, IntoClientRequest};
11use tungstenite::handshake::client::Request;
12use tungstenite::handshake::server::{Callback, NoCallback};
13use tungstenite::stream::Mode;
14
15use crate::{client_async_with_config, domain, port, Response, WebSocketConfig, WebSocketStream};
16
17pub type ConnectStream = IOStreamAsyncReadWrite<gio::SocketConnection>;
19
20pub async fn connect_async<R>(
40 request: R,
41) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
42where
43 R: IntoClientRequest + Unpin,
44{
45 connect_async_with_config(request, None).await
46}
47
48pub async fn connect_async_with_config<R>(
50 request: R,
51 config: Option<WebSocketConfig>,
52) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
53where
54 R: IntoClientRequest + Unpin,
55{
56 let request: Request = request.into_client_request()?;
57
58 let domain = domain(&request)?;
59 let port = port(&request)?;
60
61 let client = gio::SocketClient::new();
62
63 let mode = uri_mode(request.uri())?;
65 if let Mode::Tls = mode {
66 client.set_tls(true);
67 } else {
68 client.set_tls(false);
69 }
70
71 let connectable = gio::NetworkAddress::new(domain.as_str(), port);
72
73 let socket = client
74 .connect_future(&connectable)
75 .await
76 .map_err(to_std_io_error)?;
77 let socket = IOStreamAsyncReadWrite::new(socket)
78 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Unsupported gio::IOStream"))?;
79
80 client_async_with_config(request, socket, config).await
81}
82
83pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
95where
96 S: IsA<gio::IOStream> + Unpin,
97{
98 accept_hdr_async(stream, NoCallback).await
99}
100
101pub async fn accept_async_with_config<S>(
104 stream: S,
105 config: Option<WebSocketConfig>,
106) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
107where
108 S: IsA<gio::IOStream> + Unpin,
109{
110 accept_hdr_async_with_config(stream, NoCallback, config).await
111}
112
113pub async fn accept_hdr_async<S, C>(
119 stream: S,
120 callback: C,
121) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
122where
123 S: IsA<gio::IOStream> + Unpin,
124 C: Callback + Unpin,
125{
126 accept_hdr_async_with_config(stream, callback, None).await
127}
128
129pub async fn accept_hdr_async_with_config<S, C>(
132 stream: S,
133 callback: C,
134 config: Option<WebSocketConfig>,
135) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
136where
137 S: IsA<gio::IOStream> + Unpin,
138 C: Callback + Unpin,
139{
140 let stream = IOStreamAsyncReadWrite::new(stream)
141 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Unsupported gio::IOStream"))?;
142
143 crate::accept_hdr_async_with_config(stream, callback, config).await
144}
145
146#[derive(Debug)]
148pub struct IOStreamAsyncReadWrite<T: IsA<gio::IOStream>> {
149 #[allow(dead_code)]
150 io_stream: T,
151 read: gio::InputStreamAsyncRead<gio::PollableInputStream>,
152 write: gio::OutputStreamAsyncWrite<gio::PollableOutputStream>,
153}
154
155unsafe impl<T: IsA<gio::IOStream>> Send for IOStreamAsyncReadWrite<T> {}
156
157impl<T: IsA<gio::IOStream>> IOStreamAsyncReadWrite<T> {
158 fn new(stream: T) -> Result<IOStreamAsyncReadWrite<T>, T> {
160 let write = stream
161 .output_stream()
162 .dynamic_cast::<gio::PollableOutputStream>()
163 .ok()
164 .and_then(|s| s.into_async_write().ok());
165
166 let read = stream
167 .input_stream()
168 .dynamic_cast::<gio::PollableInputStream>()
169 .ok()
170 .and_then(|s| s.into_async_read().ok());
171
172 let (read, write) = match (read, write) {
173 (Some(read), Some(write)) => (read, write),
174 _ => return Err(stream),
175 };
176
177 Ok(IOStreamAsyncReadWrite {
178 io_stream: stream,
179 read,
180 write,
181 })
182 }
183}
184
185use std::pin::Pin;
186use std::task::{Context, Poll};
187
188impl<T: IsA<gio::IOStream> + Unpin> AsyncRead for IOStreamAsyncReadWrite<T> {
189 fn poll_read(
190 self: Pin<&mut Self>,
191 cx: &mut Context<'_>,
192 buf: &mut [u8],
193 ) -> Poll<Result<usize, io::Error>> {
194 Pin::new(&mut Pin::get_mut(self).read).poll_read(cx, buf)
195 }
196}
197
198impl<T: IsA<gio::IOStream> + Unpin> AsyncWrite for IOStreamAsyncReadWrite<T> {
199 fn poll_write(
200 self: Pin<&mut Self>,
201 cx: &mut Context<'_>,
202 buf: &[u8],
203 ) -> Poll<Result<usize, io::Error>> {
204 Pin::new(&mut Pin::get_mut(self).write).poll_write(cx, buf)
205 }
206
207 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
208 Pin::new(&mut Pin::get_mut(self).write).poll_close(cx)
209 }
210
211 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
212 Pin::new(&mut Pin::get_mut(self).write).poll_flush(cx)
213 }
214}
215
216fn to_std_io_error(error: glib::Error) -> io::Error {
217 match error.kind::<gio::IOErrorEnum>() {
218 Some(io_error_enum) => io::Error::new(io_error_enum.into(), error),
219 None => io::Error::new(io::ErrorKind::Other, error),
220 }
221}