alloy_provider/ext/trace/
with_block.rs

1use crate::ProviderCall;
2use alloy_eips::BlockId;
3use alloy_json_rpc::{RpcRecv, RpcSend};
4use alloy_primitives::{map::HashSet, B256};
5use alloy_rpc_client::RpcCall;
6use alloy_rpc_types_trace::parity::TraceType;
7use alloy_transport::TransportResult;
8use std::future::IntoFuture;
9
10/// A builder for trace_* api calls.
11#[derive(Debug)]
12pub struct TraceBuilder<Params, Resp, Output = Resp, Map = fn(Resp) -> Output>
13where
14    Params: RpcSend,
15    Resp: RpcRecv,
16    Map: Fn(Resp) -> Output,
17{
18    inner: WithBlockInner<Params, Resp, Output, Map>,
19    block_id: Option<BlockId>,
20    trace_types: Option<HashSet<TraceType>>,
21}
22
23impl<Params, Resp, Output, Map> TraceBuilder<Params, Resp, Output, Map>
24where
25    Params: RpcSend,
26    Resp: RpcRecv,
27    Map: Fn(Resp) -> Output + Clone,
28{
29    /// Create a new [`TraceBuilder`] from a [`RpcCall`].
30    pub fn new_rpc(inner: RpcCall<Params, Resp, Output, Map>) -> Self {
31        Self { inner: WithBlockInner::RpcCall(inner), block_id: None, trace_types: None }
32    }
33
34    /// Create a new [`TraceBuilder`] from a closure producing a [`ProviderCall`].
35    pub fn new_provider<F>(get_call: F) -> Self
36    where
37        F: Fn(Option<BlockId>) -> ProviderCall<TraceParams<Params>, Resp, Output, Map>
38            + Send
39            + 'static,
40    {
41        let get_call = Box::new(get_call);
42
43        Self { inner: WithBlockInner::ProviderCall(get_call), block_id: None, trace_types: None }
44    }
45}
46
47impl<Params, Resp, Output, Map> From<RpcCall<Params, Resp, Output, Map>>
48    for TraceBuilder<Params, Resp, Output, Map>
49where
50    Params: RpcSend,
51    Resp: RpcRecv,
52    Map: Fn(Resp) -> Output + Clone,
53{
54    fn from(inner: RpcCall<Params, Resp, Output, Map>) -> Self {
55        Self::new_rpc(inner)
56    }
57}
58
59impl<F, Params, Resp, Output, Map> From<F> for TraceBuilder<Params, Resp, Output, Map>
60where
61    Params: RpcSend,
62    Resp: RpcRecv,
63    Map: Fn(Resp) -> Output + Clone,
64    F: Fn(Option<BlockId>) -> ProviderCall<TraceParams<Params>, Resp, Output, Map> + Send + 'static,
65{
66    fn from(inner: F) -> Self {
67        Self::new_provider(inner)
68    }
69}
70
71impl<Params, Resp, Output, Map> TraceBuilder<Params, Resp, Output, Map>
72where
73    Params: RpcSend,
74    Resp: RpcRecv,
75    Map: Fn(Resp) -> Output + 'static,
76{
77    /// Set the trace type.
78    pub fn trace_type(mut self, trace_type: TraceType) -> Self {
79        self.trace_types.get_or_insert_with(HashSet::default).insert(trace_type);
80        self
81    }
82
83    /// Set the trace types.
84    pub fn trace_types<I: IntoIterator<Item = TraceType>>(mut self, trace_types: I) -> Self {
85        self.trace_types.get_or_insert_with(HashSet::default).extend(trace_types);
86        self
87    }
88
89    /// Set the trace type to "trace".
90    pub fn trace(self) -> Self {
91        self.trace_type(TraceType::Trace)
92    }
93
94    /// Set the trace type to "vmTrace".
95    pub fn vm_trace(self) -> Self {
96        self.trace_type(TraceType::VmTrace)
97    }
98
99    /// Set the trace type to "stateDiff".
100    pub fn state_diff(self) -> Self {
101        self.trace_type(TraceType::StateDiff)
102    }
103
104    /// Get the trace types.
105    pub const fn get_trace_types(&self) -> Option<&HashSet<TraceType>> {
106        self.trace_types.as_ref()
107    }
108
109    /// Set the block id.
110    pub const fn block_id(mut self, block_id: BlockId) -> Self {
111        self.block_id = Some(block_id);
112        self
113    }
114
115    /// Set the block id to "pending".
116    pub const fn pending(self) -> Self {
117        self.block_id(BlockId::pending())
118    }
119
120    /// Set the block id to "latest".
121    pub const fn latest(self) -> Self {
122        self.block_id(BlockId::latest())
123    }
124
125    /// Set the block id to "earliest".
126    pub const fn earliest(self) -> Self {
127        self.block_id(BlockId::earliest())
128    }
129
130    /// Set the block id to "finalized".
131    pub const fn finalized(self) -> Self {
132        self.block_id(BlockId::finalized())
133    }
134
135    /// Set the block id to "safe".
136    pub const fn safe(self) -> Self {
137        self.block_id(BlockId::safe())
138    }
139
140    /// Set the block id to a specific height.
141    pub const fn number(self, number: u64) -> Self {
142        self.block_id(BlockId::number(number))
143    }
144
145    /// Set the block id to a specific hash, without requiring the hash be part
146    /// of the canonical chain.
147    pub const fn hash(self, hash: B256) -> Self {
148        self.block_id(BlockId::hash(hash))
149    }
150
151    /// Set the block id to a specific hash and require the hash be part of the
152    /// canonical chain.
153    pub const fn hash_canonical(self, hash: B256) -> Self {
154        self.block_id(BlockId::hash_canonical(hash))
155    }
156}
157
158impl<Params, Resp, Output, Map> IntoFuture for TraceBuilder<Params, Resp, Output, Map>
159where
160    Params: RpcSend,
161    Resp: RpcRecv,
162    Output: 'static,
163    Map: Fn(Resp) -> Output + 'static,
164{
165    type Output = TransportResult<Output>;
166
167    type IntoFuture = ProviderCall<TraceParams<Params>, Resp, Output, Map>;
168
169    fn into_future(self) -> Self::IntoFuture {
170        match self.inner {
171            WithBlockInner::RpcCall(inner) => {
172                let block_id = self.block_id;
173                let trace_types = self.trace_types;
174                let method = inner.method().to_string();
175                let inner = inner.map_params(|params| {
176                    TraceParams::new(&method, params, block_id, trace_types.clone())
177                });
178                ProviderCall::RpcCall(inner)
179            }
180            WithBlockInner::ProviderCall(get_call) => get_call(self.block_id),
181        }
182    }
183}
184
185/// Parameters for a trace call.
186///
187/// Contains optional block id and trace types to accomodate `trace_*` api calls that don't require
188/// them.
189#[derive(Debug, Clone)]
190pub struct TraceParams<Params: RpcSend> {
191    params: Params,
192    block_id: Option<BlockId>,
193    trace_types: Option<HashSet<TraceType>>,
194}
195
196impl<Params: RpcSend> TraceParams<Params> {}
197
198impl<Params: RpcSend> serde::Serialize for TraceParams<Params> {
199    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
200    where
201        S: serde::Serializer,
202    {
203        use serde::ser::SerializeTuple;
204        // Calculate tuple length based on optional fields
205        let len = 1 + self.trace_types.is_some() as usize + self.block_id.is_some() as usize;
206
207        let mut tup = serializer.serialize_tuple(len)?;
208
209        // Always serialize params first
210        tup.serialize_element(&self.params)?;
211
212        // Add trace_types if present
213        if let Some(trace_types) = &self.trace_types {
214            tup.serialize_element(trace_types)?;
215        }
216
217        // Add block_id last if present
218        if let Some(block_id) = &self.block_id {
219            tup.serialize_element(block_id)?;
220        }
221
222        tup.end()
223    }
224}
225
226impl<Params: RpcSend> TraceParams<Params> {
227    /// Create a new `TraceParams` with the given parameters.
228    ///
229    /// The `method` is used to determine which parameters to ignore according to the `trace_*` api
230    /// spec. See <https://reth.rs/jsonrpc/trace.html>.
231    pub fn new(
232        method: &String,
233        params: Params,
234        block_id: Option<BlockId>,
235        trace_types: Option<HashSet<TraceType>>,
236    ) -> Self {
237        let block_id = block_id.unwrap_or(BlockId::pending());
238        let trace_types = trace_types.unwrap_or_else(|| {
239            let mut set = HashSet::default();
240            set.insert(TraceType::Trace);
241            set
242        });
243        match method.as_str() {
244            "trace_call" => {
245                Self { params, block_id: Some(block_id), trace_types: Some(trace_types) }
246            }
247            "trace_callMany" => {
248                // Trace types are ignored as they are set per-tx-request in `params`.
249                Self { params, block_id: Some(block_id), trace_types: None }
250            }
251            "trace_replayTransaction"
252            | "trace_rawTransaction"
253            | "trace_replayBlockTransactions" => {
254                // BlockId is ignored
255                Self { params, block_id: None, trace_types: Some(trace_types) }
256            }
257            _ => {
258                unreachable!("{method} is not supported by TraceBuilder due to custom serialization requirements");
259            }
260        }
261    }
262}
263
264/// Provider producers that create a [`ProviderCall`] with [`TraceParams`].
265type ProviderCallProducer<Params, Resp, Output, Map> =
266    Box<dyn Fn(Option<BlockId>) -> ProviderCall<TraceParams<Params>, Resp, Output, Map> + Send>;
267
268enum WithBlockInner<Params, Resp, Output, Map>
269where
270    Params: RpcSend,
271    Resp: RpcRecv,
272    Map: Fn(Resp) -> Output,
273{
274    RpcCall(RpcCall<Params, Resp, Output, Map>),
275    ProviderCall(ProviderCallProducer<Params, Resp, Output, Map>),
276}
277
278impl<Params, Resp, Output, Map> core::fmt::Debug for WithBlockInner<Params, Resp, Output, Map>
279where
280    Params: RpcSend,
281    Resp: RpcRecv,
282    Map: Fn(Resp) -> Output,
283{
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        match self {
286            Self::RpcCall(call) => f.debug_tuple("RpcCall").field(call).finish(),
287            Self::ProviderCall(_) => f.debug_struct("ProviderCall").finish(),
288        }
289    }
290}