use tungstenite::Error;
use std::io;
use gio::prelude::*;
use futures_io::{AsyncRead, AsyncWrite};
use tungstenite::client::{uri_mode, IntoClientRequest};
use tungstenite::handshake::client::Request;
use tungstenite::handshake::server::{Callback, NoCallback};
use tungstenite::stream::Mode;
use crate::{client_async_with_config, domain, port, Response, WebSocketConfig, WebSocketStream};
pub type ConnectStream = IOStreamAsyncReadWrite<gio::SocketConnection>;
pub async fn connect_async<R>(
request: R,
) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
where
R: IntoClientRequest + Unpin,
{
connect_async_with_config(request, None).await
}
pub async fn connect_async_with_config<R>(
request: R,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
where
R: IntoClientRequest + Unpin,
{
let request: Request = request.into_client_request()?;
let domain = domain(&request)?;
let port = port(&request)?;
let client = gio::SocketClient::new();
let mode = uri_mode(request.uri())?;
if let Mode::Tls = mode {
client.set_tls(true);
} else {
client.set_tls(false);
}
let connectable = gio::NetworkAddress::new(domain.as_str(), port);
let socket = client
.connect_future(&connectable)
.await
.map_err(to_std_io_error)?;
let socket = IOStreamAsyncReadWrite::new(socket)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Unsupported gio::IOStream"))?;
client_async_with_config(request, socket, config).await
}
pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
where
S: IsA<gio::IOStream> + Unpin,
{
accept_hdr_async(stream, NoCallback).await
}
pub async fn accept_async_with_config<S>(
stream: S,
config: Option<WebSocketConfig>,
) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
where
S: IsA<gio::IOStream> + Unpin,
{
accept_hdr_async_with_config(stream, NoCallback, config).await
}
pub async fn accept_hdr_async<S, C>(
stream: S,
callback: C,
) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
where
S: IsA<gio::IOStream> + Unpin,
C: Callback + Unpin,
{
accept_hdr_async_with_config(stream, callback, None).await
}
pub async fn accept_hdr_async_with_config<S, C>(
stream: S,
callback: C,
config: Option<WebSocketConfig>,
) -> Result<WebSocketStream<IOStreamAsyncReadWrite<S>>, Error>
where
S: IsA<gio::IOStream> + Unpin,
C: Callback + Unpin,
{
let stream = IOStreamAsyncReadWrite::new(stream)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Unsupported gio::IOStream"))?;
crate::accept_hdr_async_with_config(stream, callback, config).await
}
#[derive(Debug)]
pub struct IOStreamAsyncReadWrite<T: IsA<gio::IOStream>> {
#[allow(dead_code)]
io_stream: T,
read: gio::InputStreamAsyncRead<gio::PollableInputStream>,
write: gio::OutputStreamAsyncWrite<gio::PollableOutputStream>,
}
unsafe impl<T: IsA<gio::IOStream>> Send for IOStreamAsyncReadWrite<T> {}
impl<T: IsA<gio::IOStream>> IOStreamAsyncReadWrite<T> {
fn new(stream: T) -> Result<IOStreamAsyncReadWrite<T>, T> {
let write = stream
.output_stream()
.dynamic_cast::<gio::PollableOutputStream>()
.ok()
.and_then(|s| s.into_async_write().ok());
let read = stream
.input_stream()
.dynamic_cast::<gio::PollableInputStream>()
.ok()
.and_then(|s| s.into_async_read().ok());
let (read, write) = match (read, write) {
(Some(read), Some(write)) => (read, write),
_ => return Err(stream),
};
Ok(IOStreamAsyncReadWrite {
io_stream: stream,
read,
write,
})
}
}
use std::pin::Pin;
use std::task::{Context, Poll};
impl<T: IsA<gio::IOStream> + Unpin> AsyncRead for IOStreamAsyncReadWrite<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut Pin::get_mut(self).read).poll_read(cx, buf)
}
}
impl<T: IsA<gio::IOStream> + Unpin> AsyncWrite for IOStreamAsyncReadWrite<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut Pin::get_mut(self).write).poll_write(cx, buf)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut Pin::get_mut(self).write).poll_close(cx)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut Pin::get_mut(self).write).poll_flush(cx)
}
}
fn to_std_io_error(error: glib::Error) -> io::Error {
match error.kind::<gio::IOErrorEnum>() {
Some(io_error_enum) => io::Error::new(io_error_enum.into(), error),
None => io::Error::new(io::ErrorKind::Other, error),
}
}