use std::collections::VecDeque;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::ready;
use futures::stream::Stream;
use futures::task::{Context, Poll};
use futures::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
use crate::error::CdpError;
use crate::error::Result;
type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
#[must_use = "streams do nothing unless polled"]
#[derive(Debug)]
pub struct Connection<T: EventMessage> {
pending_commands: VecDeque<MethodCall>,
ws: WebSocketStream<ConnectStream>,
next_id: usize,
needs_flush: bool,
pending_flush: Option<MethodCall>,
_marker: PhantomData<T>,
}
impl<T: EventMessage + Unpin> Connection<T> {
pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
let config = WebSocketConfig {
max_message_size: None,
max_frame_size: None,
..Default::default()
};
let (ws, _) =
tokio_tungstenite::connect_async_with_config(debug_ws_url.as_ref(), Some(config), true)
.await?;
Ok(Self {
pending_commands: Default::default(),
ws,
next_id: 0,
needs_flush: false,
pending_flush: None,
_marker: Default::default(),
})
}
}
impl<T: EventMessage> Connection<T> {
fn next_call_id(&mut self) -> CallId {
let id = CallId::new(self.next_id);
self.next_id = self.next_id.wrapping_add(1);
id
}
pub fn submit_command(
&mut self,
method: MethodId,
session_id: Option<SessionId>,
params: serde_json::Value,
) -> serde_json::Result<CallId> {
let id = self.next_call_id();
let call = MethodCall {
id,
method,
session_id: session_id.map(Into::into),
params,
};
self.pending_commands.push_back(call);
Ok(id)
}
fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
if self.needs_flush {
if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
self.needs_flush = false;
}
}
if self.pending_flush.is_none() && !self.needs_flush {
if let Some(cmd) = self.pending_commands.pop_front() {
tracing::trace!("Sending {:?}", cmd);
let msg = serde_json::to_string(&cmd)?;
self.ws.start_send_unpin(msg.into())?;
self.pending_flush = Some(cmd);
}
}
Ok(())
}
}
impl<T: EventMessage + Unpin> Stream for Connection<T> {
type Item = Result<Message<T>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pin = self.get_mut();
loop {
if let Err(err) = pin.start_send_next(cx) {
return Poll::Ready(Some(Err(err)));
}
if let Some(call) = pin.pending_flush.take() {
if pin.ws.poll_ready_unpin(cx).is_ready() {
pin.needs_flush = true;
continue;
} else {
pin.pending_flush = Some(call);
}
}
break;
}
match ready!(pin.ws.poll_next_unpin(cx)) {
Some(Ok(WsMessage::Text(text))) => {
let ready = match serde_json::from_str::<Message<T>>(&text) {
Ok(msg) => {
tracing::trace!("Received {:?}", msg);
Ok(msg)
}
Err(err) => {
tracing::error!(target: "chromiumoxide::conn::raw_ws::parse_errors", msg = text, "Failed to parse raw WS message {err}");
Err(err.into())
}
};
Poll::Ready(Some(ready))
}
Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Some(Ok(msg)) => Poll::Ready(Some(Err(CdpError::UnexpectedWsMessage(msg)))),
Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
None => {
Poll::Ready(None)
}
}
}
}