async_tungstenite/
gio.rs

1//! `gio` integration.
2use tungstenite::Error;
3
4use std::io;
5
6use gio::prelude::*;
7
8use futures_io::{AsyncRead, AsyncWrite};
9
10use tungstenite::client::{uri_mode, IntoClientRequest};
11use tungstenite::handshake::client::Request;
12use tungstenite::handshake::server::{Callback, NoCallback};
13use tungstenite::stream::Mode;
14
15use crate::{client_async_with_config, domain, port, Response, WebSocketConfig, WebSocketStream};
16
17/// Type alias for the stream type of the `connect_async()` functions.
18pub type ConnectStream = IOStreamAsyncReadWrite<gio::SocketConnection>;
19
20/// Connect to a given URL.
21///
22/// Accepts any request that implements [`IntoClientRequest`], which is often just `&str`, but can
23/// be a variety of types such as `httparse::Request` or [`tungstenite::http::Request`] for more
24/// complex uses.
25///
26/// ```no_run
27/// # use tungstenite::client::IntoClientRequest;
28///
29/// # async fn test() {
30/// use tungstenite::http::{Method, Request};
31/// use async_tungstenite::gio::connect_async;
32///
33/// let mut request = "wss://api.example.com".into_client_request().unwrap();
34/// request.headers_mut().insert("api-key", "42".parse().unwrap());
35///
36/// let (stream, response) = connect_async(request).await.unwrap();
37/// # }
38/// ```
39pub async fn connect_async<R>(
40    request: R,
41) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
42where
43    R: IntoClientRequest + Unpin,
44{
45    connect_async_with_config(request, None).await
46}
47
48/// Connect to a given URL with a given WebSocket configuration.
49pub async fn connect_async_with_config<R>(
50    request: R,
51    config: Option<WebSocketConfig>,
52) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
53where
54    R: IntoClientRequest + Unpin,
55{
56    let request: Request = request.into_client_request()?;
57
58    let domain = domain(&request)?;
59    let port = port(&request)?;
60
61    let client = gio::SocketClient::new();
62
63    // Make sure we check domain and mode first. URL must be valid.
64    let mode = uri_mode(request.uri())?;
65    if let Mode::Tls = mode {
66        client.set_tls(true);
67    } else {
68        client.set_tls(false);
69    }
70
71    let connectable = gio::NetworkAddress::new(domain.as_str(), port);
72
73    let socket = client
74        .connect_future(&connectable)
75        .await
76        .map_err(to_std_io_error)?;
77    let socket = IOStreamAsyncReadWrite::new(socket)
78        .map_err(|_| io::Error::new(io::ErrorKind::Other, "Unsupported gio::IOStream"))?;
79
80    client_async_with_config(request, socket, config).await
81}
82
83/// Accepts a new WebSocket connection with the provided stream.
84///
85/// This function will internally call `server::accept` to create a
86/// handshake representation and returns a future representing the
87/// resolution of the WebSocket handshake. The returned future will resolve
88/// to either `WebSocketStream<S>` or `Error` depending if it's successful
89/// or not.
90///
91/// This is typically used after a socket has been accepted from a
92/// `TcpListener`. That socket is then passed to this function to perform
93/// the server half of the accepting a client's websocket connection.
94pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
95where
96    S: IsA<gio::IOStream> + Unpin,
97{
98    accept_hdr_async(stream, NoCallback).await
99}
100
101/// The same as `accept_async()` but the one can specify a websocket configuration.
102/// Please refer to `accept_async()` for more details.
103pub async fn accept_async_with_config<S>(
104    stream: S,
105    config: Option<WebSocketConfig>,
106) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
107where
108    S: IsA<gio::IOStream> + Unpin,
109{
110    accept_hdr_async_with_config(stream, NoCallback, config).await
111}
112
113/// Accepts a new WebSocket connection with the provided stream.
114///
115/// This function does the same as `accept_async()` but accepts an extra callback
116/// for header processing. The callback receives headers of the incoming
117/// requests and is able to add extra headers to the reply.
118pub async fn accept_hdr_async<S, C>(
119    stream: S,
120    callback: C,
121) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
122where
123    S: IsA<gio::IOStream> + Unpin,
124    C: Callback + Unpin,
125{
126    accept_hdr_async_with_config(stream, callback, None).await
127}
128
129/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
130/// Please refer to `accept_hdr_async()` for more details.
131pub async fn accept_hdr_async_with_config<S, C>(
132    stream: S,
133    callback: C,
134    config: Option<WebSocketConfig>,
135) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
136where
137    S: IsA<gio::IOStream> + Unpin,
138    C: Callback + Unpin,
139{
140    let stream = IOStreamAsyncReadWrite::new(stream)
141        .map_err(|_| io::Error::new(io::ErrorKind::Other, "Unsupported gio::IOStream"))?;
142
143    crate::accept_hdr_async_with_config(stream, callback, config).await
144}
145
146/// Adapter for `gio::IOStream` to provide `AsyncRead` and `AsyncWrite`.
147#[derive(Debug)]
148pub struct IOStreamAsyncReadWrite<T: IsA<gio::IOStream>> {
149    #[allow(dead_code)]
150    io_stream: T,
151    read: gio::InputStreamAsyncRead<gio::PollableInputStream>,
152    write: gio::OutputStreamAsyncWrite<gio::PollableOutputStream>,
153}
154
155unsafe impl<T: IsA<gio::IOStream>> Send for IOStreamAsyncReadWrite<T> {}
156
157impl<T: IsA<gio::IOStream>> IOStreamAsyncReadWrite<T> {
158    /// Create a new `gio::IOStream` adapter
159    fn new(stream: T) -> Result<IOStreamAsyncReadWrite<T>, T> {
160        let write = stream
161            .output_stream()
162            .dynamic_cast::<gio::PollableOutputStream>()
163            .ok()
164            .and_then(|s| s.into_async_write().ok());
165
166        let read = stream
167            .input_stream()
168            .dynamic_cast::<gio::PollableInputStream>()
169            .ok()
170            .and_then(|s| s.into_async_read().ok());
171
172        let (read, write) = match (read, write) {
173            (Some(read), Some(write)) => (read, write),
174            _ => return Err(stream),
175        };
176
177        Ok(IOStreamAsyncReadWrite {
178            io_stream: stream,
179            read,
180            write,
181        })
182    }
183}
184
185use std::pin::Pin;
186use std::task::{Context, Poll};
187
188impl<T: IsA<gio::IOStream> + Unpin> AsyncRead for IOStreamAsyncReadWrite<T> {
189    fn poll_read(
190        self: Pin<&mut Self>,
191        cx: &mut Context<'_>,
192        buf: &mut [u8],
193    ) -> Poll<Result<usize, io::Error>> {
194        Pin::new(&mut Pin::get_mut(self).read).poll_read(cx, buf)
195    }
196}
197
198impl<T: IsA<gio::IOStream> + Unpin> AsyncWrite for IOStreamAsyncReadWrite<T> {
199    fn poll_write(
200        self: Pin<&mut Self>,
201        cx: &mut Context<'_>,
202        buf: &[u8],
203    ) -> Poll<Result<usize, io::Error>> {
204        Pin::new(&mut Pin::get_mut(self).write).poll_write(cx, buf)
205    }
206
207    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
208        Pin::new(&mut Pin::get_mut(self).write).poll_close(cx)
209    }
210
211    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
212        Pin::new(&mut Pin::get_mut(self).write).poll_flush(cx)
213    }
214}
215
216fn to_std_io_error(error: glib::Error) -> io::Error {
217    match error.kind::<gio::IOErrorEnum>() {
218        Some(io_error_enum) => io::Error::new(io_error_enum.into(), error),
219        None => io::Error::new(io::ErrorKind::Other, error),
220    }
221}