use crate::{ErrorPayload, Id, Response, SerializedRequest};
use alloy_primitives::map::HashSet;
use serde::{
de::{self, Deserializer, MapAccess, SeqAccess, Visitor},
Deserialize, Serialize,
};
use serde_json::value::RawValue;
use std::{fmt, marker::PhantomData};
#[derive(Clone, Debug)]
pub enum RequestPacket {
Single(SerializedRequest),
Batch(Vec<SerializedRequest>),
}
impl FromIterator<SerializedRequest> for RequestPacket {
fn from_iter<T: IntoIterator<Item = SerializedRequest>>(iter: T) -> Self {
Self::Batch(iter.into_iter().collect())
}
}
impl From<SerializedRequest> for RequestPacket {
fn from(req: SerializedRequest) -> Self {
Self::Single(req)
}
}
impl Serialize for RequestPacket {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Self::Single(single) => single.serialize(serializer),
Self::Batch(batch) => batch.serialize(serializer),
}
}
}
impl RequestPacket {
pub fn with_capacity(capacity: usize) -> Self {
Self::Batch(Vec::with_capacity(capacity))
}
pub fn serialize(self) -> serde_json::Result<Box<RawValue>> {
match self {
Self::Single(single) => Ok(single.take_request()),
Self::Batch(batch) => serde_json::value::to_raw_value(&batch),
}
}
pub fn subscription_request_ids(&self) -> HashSet<&Id> {
match self {
Self::Single(single) => {
let id = (single.method() == "eth_subscribe").then(|| single.id());
HashSet::from_iter(id)
}
Self::Batch(batch) => batch
.iter()
.filter(|req| req.method() == "eth_subscribe")
.map(|req| req.id())
.collect(),
}
}
pub fn len(&self) -> usize {
match self {
Self::Single(_) => 1,
Self::Batch(batch) => batch.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn push(&mut self, req: SerializedRequest) {
if let Self::Batch(batch) = self {
batch.push(req);
return;
}
if matches!(self, Self::Single(_)) {
let old = std::mem::replace(self, Self::Batch(Vec::with_capacity(10)));
match old {
Self::Single(single) => {
self.push(single);
}
_ => unreachable!(),
}
self.push(req);
}
}
}
#[derive(Clone, Debug)]
pub enum ResponsePacket<Payload = Box<RawValue>, ErrData = Box<RawValue>> {
Single(Response<Payload, ErrData>),
Batch(Vec<Response<Payload, ErrData>>),
}
impl<Payload, ErrData> FromIterator<Response<Payload, ErrData>>
for ResponsePacket<Payload, ErrData>
{
fn from_iter<T: IntoIterator<Item = Response<Payload, ErrData>>>(iter: T) -> Self {
let mut iter = iter.into_iter().peekable();
if let Some(first) = iter.next() {
return if iter.peek().is_none() {
Self::Single(first)
} else {
let mut batch = Vec::new();
batch.push(first);
batch.extend(iter);
Self::Batch(batch)
};
}
Self::Batch(vec![])
}
}
impl<Payload, ErrData> From<Vec<Response<Payload, ErrData>>> for ResponsePacket<Payload, ErrData> {
fn from(value: Vec<Response<Payload, ErrData>>) -> Self {
if value.len() == 1 {
Self::Single(value.into_iter().next().unwrap())
} else {
Self::Batch(value)
}
}
}
impl<'de, Payload, ErrData> Deserialize<'de> for ResponsePacket<Payload, ErrData>
where
Payload: Deserialize<'de>,
ErrData: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ResponsePacketVisitor<Payload, ErrData> {
marker: PhantomData<fn() -> ResponsePacket<Payload, ErrData>>,
}
impl<'de, Payload, ErrData> Visitor<'de> for ResponsePacketVisitor<Payload, ErrData>
where
Payload: Deserialize<'de>,
ErrData: Deserialize<'de>,
{
type Value = ResponsePacket<Payload, ErrData>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a single response or a batch of responses")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut responses = Vec::new();
while let Some(response) = seq.next_element()? {
responses.push(response);
}
Ok(ResponsePacket::Batch(responses))
}
fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let response =
Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(ResponsePacket::Single(response))
}
}
deserializer.deserialize_any(ResponsePacketVisitor { marker: PhantomData })
}
}
pub type BorrowedResponsePacket<'a> = ResponsePacket<&'a RawValue, &'a RawValue>;
impl BorrowedResponsePacket<'_> {
pub fn into_owned(self) -> ResponsePacket {
match self {
Self::Single(single) => ResponsePacket::Single(single.into_owned()),
Self::Batch(batch) => {
ResponsePacket::Batch(batch.into_iter().map(Response::into_owned).collect())
}
}
}
}
impl<Payload, ErrData> ResponsePacket<Payload, ErrData> {
pub fn is_success(&self) -> bool {
match self {
Self::Single(single) => single.is_success(),
Self::Batch(batch) => batch.iter().all(|res| res.is_success()),
}
}
pub fn is_error(&self) -> bool {
match self {
Self::Single(single) => single.is_error(),
Self::Batch(batch) => batch.iter().any(|res| res.is_error()),
}
}
pub fn as_error(&self) -> Option<&ErrorPayload<ErrData>> {
self.iter_errors().next()
}
pub fn iter_errors(&self) -> impl Iterator<Item = &ErrorPayload<ErrData>> + '_ {
match self {
Self::Single(single) => ResponsePacketErrorsIter::Single(Some(single)),
Self::Batch(batch) => ResponsePacketErrorsIter::Batch(batch.iter()),
}
}
pub fn responses_by_ids(&self, ids: &HashSet<Id>) -> Vec<&Response<Payload, ErrData>> {
match self {
Self::Single(single) => {
let mut resps = Vec::new();
if ids.contains(&single.id) {
resps.push(single);
}
resps
}
Self::Batch(batch) => batch.iter().filter(|res| ids.contains(&res.id)).collect(),
}
}
}
#[derive(Clone, Debug)]
enum ResponsePacketErrorsIter<'a, Payload, ErrData> {
Single(Option<&'a Response<Payload, ErrData>>),
Batch(std::slice::Iter<'a, Response<Payload, ErrData>>),
}
impl<'a, Payload, ErrData> Iterator for ResponsePacketErrorsIter<'a, Payload, ErrData> {
type Item = &'a ErrorPayload<ErrData>;
fn next(&mut self) -> Option<Self::Item> {
match self {
ResponsePacketErrorsIter::Single(single) => single.take()?.payload.as_error(),
ResponsePacketErrorsIter::Batch(batch) => loop {
let res = batch.next()?;
if let Some(err) = res.payload.as_error() {
return Some(err);
}
},
}
}
}