use crate::{client::RpcClientInner, ClientRef};
use alloy_json_rpc::{
transform_response, try_deserialize_ok, Id, Request, RequestPacket, ResponsePacket, RpcParam,
RpcReturn, SerializedRequest,
};
use alloy_primitives::map::HashMap;
use alloy_transport::{Transport, TransportError, TransportErrorKind, TransportResult};
use futures::FutureExt;
use pin_project::pin_project;
use serde_json::value::RawValue;
use std::{
borrow::Cow,
future::{Future, IntoFuture},
marker::PhantomData,
pin::Pin,
task::{
self, ready,
Poll::{self, Ready},
},
};
use tokio::sync::oneshot;
pub(crate) type Channel = oneshot::Sender<TransportResult<Box<RawValue>>>;
pub(crate) type ChannelMap = HashMap<Id, Channel>;
#[derive(Debug)]
#[must_use = "A BatchRequest does nothing unless sent via `send_batch` and `.await`"]
pub struct BatchRequest<'a, T> {
transport: ClientRef<'a, T>,
requests: RequestPacket,
channels: ChannelMap,
}
#[must_use = "A Waiter does nothing unless the corresponding BatchRequest is sent via `send_batch` and `.await`, AND the Waiter is awaited."]
#[pin_project]
#[derive(Debug)]
pub struct Waiter<Resp, Output = Resp, Map = fn(Resp) -> Output> {
#[pin]
rx: oneshot::Receiver<TransportResult<Box<RawValue>>>,
map: Option<Map>,
_resp: PhantomData<fn() -> (Output, Resp)>,
}
impl<Resp, Output, Map> Waiter<Resp, Output, Map> {
pub fn map_resp<NewOutput, NewMap>(self, map: NewMap) -> Waiter<Resp, NewOutput, NewMap>
where
NewMap: FnOnce(Resp) -> NewOutput,
{
Waiter { rx: self.rx, map: Some(map), _resp: PhantomData }
}
}
impl<Resp> From<oneshot::Receiver<TransportResult<Box<RawValue>>>> for Waiter<Resp> {
fn from(rx: oneshot::Receiver<TransportResult<Box<RawValue>>>) -> Self {
Self { rx, map: Some(std::convert::identity), _resp: PhantomData }
}
}
impl<Resp, Output, Map> std::future::Future for Waiter<Resp, Output, Map>
where
Resp: RpcReturn,
Map: FnOnce(Resp) -> Output,
{
type Output = TransportResult<Output>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match ready!(this.rx.poll_unpin(cx)) {
Ok(resp) => {
let resp: Result<Resp, _> = try_deserialize_ok(resp);
Ready(resp.map(this.map.take().expect("polled after completion")))
}
Err(e) => Poll::Ready(Err(TransportErrorKind::custom(e))),
}
}
}
#[pin_project::pin_project(project = CallStateProj)]
#[derive(Debug)]
#[allow(unnameable_types)]
pub enum BatchFuture<Conn: Transport> {
Prepared {
transport: Conn,
requests: RequestPacket,
channels: ChannelMap,
},
SerError(Option<TransportError>),
AwaitingResponse {
channels: ChannelMap,
#[pin]
fut: Conn::Future,
},
Complete,
}
impl<'a, T> BatchRequest<'a, T> {
pub fn new(transport: &'a RpcClientInner<T>) -> Self {
Self {
transport,
requests: RequestPacket::Batch(Vec::with_capacity(10)),
channels: HashMap::with_capacity_and_hasher(10, Default::default()),
}
}
fn push_raw(
&mut self,
request: SerializedRequest,
) -> oneshot::Receiver<TransportResult<Box<RawValue>>> {
let (tx, rx) = oneshot::channel();
self.channels.insert(request.id().clone(), tx);
self.requests.push(request);
rx
}
fn push<Params: RpcParam, Resp: RpcReturn>(
&mut self,
request: Request<Params>,
) -> TransportResult<Waiter<Resp>> {
let ser = request.serialize().map_err(TransportError::ser_err)?;
Ok(self.push_raw(ser).into())
}
}
impl<Conn> BatchRequest<'_, Conn>
where
Conn: Transport + Clone,
{
pub fn add_call<Params: RpcParam, Resp: RpcReturn>(
&mut self,
method: impl Into<Cow<'static, str>>,
params: &Params,
) -> TransportResult<Waiter<Resp>> {
let request = self.transport.make_request(method, Cow::Borrowed(params));
self.push(request)
}
pub fn send(self) -> BatchFuture<Conn> {
BatchFuture::Prepared {
transport: self.transport.transport.clone(),
requests: self.requests,
channels: self.channels,
}
}
}
impl<T> IntoFuture for BatchRequest<'_, T>
where
T: Transport + Clone,
{
type Output = <BatchFuture<T> as Future>::Output;
type IntoFuture = BatchFuture<T>;
fn into_future(self) -> Self::IntoFuture {
self.send()
}
}
impl<T> BatchFuture<T>
where
T: Transport + Clone,
{
fn poll_prepared(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<<Self as Future>::Output> {
let CallStateProj::Prepared { transport, requests, channels } = self.as_mut().project()
else {
unreachable!("Called poll_prepared in incorrect state")
};
if let Err(e) = task::ready!(transport.poll_ready(cx)) {
self.set(Self::Complete);
return Poll::Ready(Err(e));
}
let channels = std::mem::take(channels);
let req = std::mem::replace(requests, RequestPacket::Batch(Vec::new()));
let fut = transport.call(req);
self.set(Self::AwaitingResponse { channels, fut });
cx.waker().wake_by_ref();
Poll::Pending
}
fn poll_awaiting_response(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<<Self as Future>::Output> {
let CallStateProj::AwaitingResponse { channels, fut } = self.as_mut().project() else {
unreachable!("Called poll_awaiting_response in incorrect state")
};
let responses = match ready!(fut.poll(cx)) {
Ok(responses) => responses,
Err(e) => {
self.set(Self::Complete);
return Poll::Ready(Err(e));
}
};
match responses {
ResponsePacket::Single(single) => {
if let Some(tx) = channels.remove(&single.id) {
let _ = tx.send(transform_response(single));
}
}
ResponsePacket::Batch(responses) => {
for response in responses {
if let Some(tx) = channels.remove(&response.id) {
let _ = tx.send(transform_response(response));
}
}
}
}
for (id, tx) in channels.drain() {
let _ = tx.send(Err(TransportErrorKind::missing_batch_response(id)));
}
self.set(Self::Complete);
Poll::Ready(Ok(()))
}
fn poll_ser_error(
mut self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
) -> Poll<<Self as Future>::Output> {
let e = if let CallStateProj::SerError(e) = self.as_mut().project() {
e.take().expect("no error")
} else {
unreachable!("Called poll_ser_error in incorrect state")
};
self.set(Self::Complete);
Poll::Ready(Err(e))
}
}
impl<T> Future for BatchFuture<T>
where
T: Transport + Clone,
{
type Output = TransportResult<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
if matches!(*self.as_mut(), Self::Prepared { .. }) {
return self.poll_prepared(cx);
}
if matches!(*self.as_mut(), Self::AwaitingResponse { .. }) {
return self.poll_awaiting_response(cx);
}
if matches!(*self.as_mut(), Self::SerError(_)) {
return self.poll_ser_error(cx);
}
panic!("Called poll on CallState in invalid state")
}
}