alloy_provider/layers/
mock.rs

1//! Mock Provider Layer
2
3use std::{collections::VecDeque, sync::Arc};
4
5use crate::{utils, EthCallMany, EthGetBlock};
6use alloy_eips::{BlockId, BlockNumberOrTag};
7use alloy_json_rpc::{ErrorPayload, RpcRecv, RpcSend};
8use alloy_network::Network;
9use alloy_primitives::{
10    Address, BlockHash, Bytes, StorageKey, StorageValue, TxHash, U128, U256, U64,
11};
12use alloy_rpc_client::NoParams;
13use alloy_rpc_types_eth::{
14    AccessListResult, Bundle, EIP1186AccountProofResponse, EthCallResponse, Filter, Log,
15};
16use alloy_transport::{TransportError, TransportErrorKind, TransportResult};
17use parking_lot::RwLock;
18use serde::Serialize;
19
20use crate::{Caller, EthCall, Provider, ProviderCall, ProviderLayer, RpcWithBlock};
21
22/// A mock provider layer that returns responses that have been pushed to the [`Asserter`].
23#[derive(Debug, Clone)]
24pub struct MockLayer {
25    asserter: Asserter,
26}
27
28impl MockLayer {
29    /// Instantiate a new mock layer with the given [`Asserter`].
30    pub fn new(asserter: Asserter) -> Self {
31        Self { asserter }
32    }
33}
34
35impl<P, N> ProviderLayer<P, N> for MockLayer
36where
37    P: Provider<N>,
38    N: Network,
39{
40    type Provider = MockProvider<P, N>;
41
42    fn layer(&self, inner: P) -> Self::Provider {
43        MockProvider::new(inner, self.asserter.clone())
44    }
45}
46
47/// Container for pushing responses into the [`MockProvider`].
48#[derive(Debug, Clone, Default)]
49pub struct Asserter {
50    responses: Arc<RwLock<VecDeque<MockResponse>>>,
51}
52
53impl Asserter {
54    /// Instantiate a new asserter.
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Insert a successful response into the queue.
60    pub fn push_success<R: Serialize>(&self, response: R) {
61        self.responses
62            .write()
63            .push_back(MockResponse::Success(serde_json::to_value(response).unwrap()));
64    }
65
66    /// Push a server error payload into the queue.
67    pub fn push_error(&self, error: ErrorPayload) {
68        self.push_err(TransportError::err_resp(error));
69    }
70
71    /// Insert an error response into the queue.
72    pub fn push_err(&self, err: TransportError) {
73        self.responses.write().push_back(MockResponse::Err(err));
74    }
75
76    /// Pop front to get the next response from the queue.
77    pub fn pop_response(&self) -> Option<MockResponse> {
78        self.responses.write().pop_front()
79    }
80
81    /// Helper function to get and deserialize the next response from the asserter
82    pub fn pop_deser_response<T>(&self) -> Result<T, TransportError>
83    where
84        T: for<'de> serde::Deserialize<'de>,
85    {
86        let value = self.pop_response().ok_or(TransportErrorKind::custom(MockError::EmptyQueue));
87
88        match value {
89            Ok(MockResponse::Success(value)) => serde_json::from_value(value)
90                .map_err(|e| TransportErrorKind::custom(MockError::DeserError(e.to_string()))),
91            Ok(MockResponse::Err(err)) | Err(err) => Err(err),
92        }
93    }
94}
95
96/// A mock response that can be pushed into the asserter.
97#[derive(Debug)]
98pub enum MockResponse {
99    /// A successful response that will be deserialized into the expected type.
100    Success(serde_json::Value),
101    /// An error response.
102    Err(TransportError),
103}
104
105/// A [`MockProvider`] error.
106#[derive(Debug, thiserror::Error)]
107pub enum MockError {
108    /// An error occurred while deserializing the response from asserter into the expected type.
109    #[error("could not deserialize response {0}")]
110    DeserError(String),
111    /// The response queue is empty.
112    #[error("empty response queue")]
113    EmptyQueue,
114}
115
116/// A mock provider implementation that returns responses from the [`Asserter`].
117#[derive(Debug, Clone)]
118pub struct MockProvider<P: Provider<N>, N: Network> {
119    /// Inner dummy provider.
120    inner: P,
121    /// The [`Asserter`] to which response are pushed using [`Asserter::push_success`].
122    ///
123    /// Responses are popped from the asserter in the order they were pushed.
124    asserter: Asserter,
125    _network: std::marker::PhantomData<N>,
126}
127
128impl<P, N> MockProvider<P, N>
129where
130    P: Provider<N>,
131    N: Network,
132{
133    /// Instantiate a new mock provider.
134    pub fn new(inner: P, asserter: Asserter) -> Self {
135        Self { inner, asserter, _network: std::marker::PhantomData }
136    }
137
138    /// Return a reference to the asserter.
139    pub fn asserter(&self) -> &Asserter {
140        &self.asserter
141    }
142
143    /// Insert a successful response into the queue.
144    pub fn push_success<R: Serialize>(&self, response: R) {
145        self.asserter.push_success(response);
146    }
147
148    /// Push a JSON-RPC 2.0 [`ErrorPayload`] into the queue.
149    pub fn push_error(&self, error: ErrorPayload) {
150        self.asserter.push_error(error);
151    }
152
153    /// Push a [`TransportError`] into the queue.
154    pub fn push_err(&self, err: TransportError) {
155        self.asserter.push_err(err);
156    }
157
158    /// Helper function to get and deserialize the next response from the asserter
159    fn next_response<T>(&self) -> Result<T, TransportError>
160    where
161        T: for<'de> serde::Deserialize<'de>,
162    {
163        self.asserter.pop_deser_response()
164    }
165}
166
167#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
168impl<P, N> Provider<N> for MockProvider<P, N>
169where
170    P: Provider<N>,
171    N: Network,
172{
173    fn root(&self) -> &crate::RootProvider<N> {
174        self.inner.root()
175    }
176
177    fn get_accounts(&self) -> ProviderCall<NoParams, Vec<Address>> {
178        ProviderCall::Ready(Some(self.next_response()))
179    }
180
181    fn get_block_number(&self) -> ProviderCall<NoParams, U64, u64> {
182        ProviderCall::Ready(Some(self.next_response()))
183    }
184
185    fn get_blob_base_fee(&self) -> ProviderCall<NoParams, U128, u128> {
186        ProviderCall::Ready(Some(self.next_response()))
187    }
188
189    fn get_chain_id(&self) -> ProviderCall<NoParams, U64, u64> {
190        ProviderCall::Ready(Some(self.next_response()))
191    }
192
193    fn call<'req>(&self, tx: N::TransactionRequest) -> EthCall<N, Bytes> {
194        EthCall::call(self.asserter.clone(), tx)
195    }
196
197    fn call_many<'req>(
198        &self,
199        bundles: &'req Vec<Bundle>,
200    ) -> EthCallMany<'req, N, Vec<Vec<EthCallResponse>>> {
201        EthCallMany::new(self.asserter.clone(), bundles)
202    }
203
204    fn estimate_gas(&self, tx: N::TransactionRequest) -> EthCall<N, U64, u64> {
205        EthCall::gas_estimate(self.asserter.clone(), tx).map_resp(utils::convert_u64)
206    }
207
208    fn create_access_list<'a>(
209        &self,
210        _request: &'a N::TransactionRequest,
211    ) -> RpcWithBlock<&'a N::TransactionRequest, AccessListResult> {
212        let asserter = self.asserter.clone();
213        RpcWithBlock::new_provider(move |_block_id| {
214            let res = asserter.pop_deser_response();
215            ProviderCall::Ready(Some(res))
216        })
217    }
218
219    fn get_balance(&self, _address: Address) -> RpcWithBlock<Address, U256, U256> {
220        let asserter = self.asserter.clone();
221        RpcWithBlock::new_provider(move |_block_id| {
222            let res = asserter.pop_deser_response();
223            ProviderCall::Ready(Some(res))
224        })
225    }
226
227    fn get_gas_price(&self) -> ProviderCall<NoParams, U128, u128> {
228        ProviderCall::Ready(Some(self.next_response()))
229    }
230
231    fn get_account(&self, _address: Address) -> RpcWithBlock<Address, alloy_consensus::Account> {
232        let asserter = self.asserter.clone();
233        RpcWithBlock::new_provider(move |_block_id| {
234            let res = asserter.pop_deser_response();
235            ProviderCall::Ready(Some(res))
236        })
237    }
238
239    fn get_block(&self, block: BlockId) -> EthGetBlock<N::BlockResponse> {
240        let asserter = self.asserter.clone();
241        EthGetBlock::new_provider(
242            block,
243            Box::new(move |_kind| {
244                let res = asserter.pop_deser_response();
245                ProviderCall::Ready(Some(res))
246            }),
247        )
248    }
249
250    fn get_block_by_number(&self, number: BlockNumberOrTag) -> EthGetBlock<N::BlockResponse> {
251        let asserter = self.asserter.clone();
252        EthGetBlock::new_provider(
253            number.into(),
254            Box::new(move |_kind| {
255                let res = asserter.pop_deser_response();
256                ProviderCall::Ready(Some(res))
257            }),
258        )
259    }
260
261    fn get_block_by_hash(&self, hash: BlockHash) -> EthGetBlock<N::BlockResponse> {
262        let asserter = self.asserter.clone();
263        EthGetBlock::new_provider(
264            hash.into(),
265            Box::new(move |_kind| {
266                let res = asserter.pop_deser_response();
267                ProviderCall::Ready(Some(res))
268            }),
269        )
270    }
271
272    async fn get_block_transaction_count_by_hash(
273        &self,
274        _hash: BlockHash,
275    ) -> TransportResult<Option<u64>> {
276        let res = self.next_response::<Option<U64>>()?;
277        Ok(res.map(utils::convert_u64))
278    }
279
280    async fn get_block_transaction_count_by_number(
281        &self,
282        _block_number: BlockNumberOrTag,
283    ) -> TransportResult<Option<u64>> {
284        let res = self.next_response::<Option<U64>>()?;
285        Ok(res.map(utils::convert_u64))
286    }
287
288    fn get_block_receipts(
289        &self,
290        _block: BlockId,
291    ) -> ProviderCall<(BlockId,), Option<Vec<N::ReceiptResponse>>> {
292        ProviderCall::Ready(Some(self.next_response()))
293    }
294
295    fn get_code_at(&self, _address: Address) -> RpcWithBlock<Address, Bytes> {
296        let asserter = self.asserter.clone();
297        RpcWithBlock::new_provider(move |_block_id| {
298            let res = asserter.pop_deser_response();
299            ProviderCall::Ready(Some(res))
300        })
301    }
302
303    async fn get_logs(&self, _filter: &Filter) -> TransportResult<Vec<Log>> {
304        self.next_response()
305    }
306
307    fn get_proof(
308        &self,
309        _address: Address,
310        _keys: Vec<StorageKey>,
311    ) -> RpcWithBlock<(Address, Vec<StorageKey>), EIP1186AccountProofResponse> {
312        let asserter = self.asserter.clone();
313        RpcWithBlock::new_provider(move |_block_id| {
314            let res = asserter.pop_deser_response();
315            ProviderCall::Ready(Some(res))
316        })
317    }
318
319    fn get_storage_at(
320        &self,
321        _address: Address,
322        _key: U256,
323    ) -> RpcWithBlock<(Address, U256), StorageValue> {
324        let asserter = self.asserter.clone();
325        RpcWithBlock::new_provider(move |_block_id| {
326            let res = asserter.pop_deser_response();
327            ProviderCall::Ready(Some(res))
328        })
329    }
330
331    fn get_transaction_by_hash(
332        &self,
333        _hash: TxHash,
334    ) -> ProviderCall<(TxHash,), Option<N::TransactionResponse>> {
335        ProviderCall::Ready(Some(self.next_response()))
336    }
337
338    fn get_raw_transaction_by_hash(&self, _hash: TxHash) -> ProviderCall<(TxHash,), Option<Bytes>> {
339        ProviderCall::Ready(Some(self.next_response()))
340    }
341
342    fn get_transaction_count(
343        &self,
344        _address: Address,
345    ) -> RpcWithBlock<Address, U64, u64, fn(U64) -> u64> {
346        let asserter = self.asserter.clone();
347        RpcWithBlock::new_provider(move |_block_id| {
348            let res = asserter.pop_deser_response::<U64>();
349            let res = res.map(utils::convert_u64);
350            ProviderCall::Ready(Some(res))
351        })
352    }
353
354    fn get_transaction_receipt(
355        &self,
356        _hash: TxHash,
357    ) -> ProviderCall<(TxHash,), Option<N::ReceiptResponse>> {
358        ProviderCall::Ready(Some(self.next_response()))
359    }
360
361    async fn get_uncle(
362        &self,
363        tag: BlockId,
364        _idx: u64,
365    ) -> TransportResult<Option<N::BlockResponse>> {
366        match tag {
367            BlockId::Hash(_) | BlockId::Number(_) => self.next_response(),
368        }
369    }
370
371    /// Gets the number of uncles for the block specified by the tag [BlockId].
372    async fn get_uncle_count(&self, tag: BlockId) -> TransportResult<u64> {
373        match tag {
374            BlockId::Hash(_) | BlockId::Number(_) => {
375                self.next_response::<U64>().map(utils::convert_u64)
376            }
377        }
378    }
379}
380
381/// [`Caller`] implementation for the [`Asserter`] to `eth_call` ops in the [`MockProvider`].
382impl<N: Network, Resp: RpcRecv> Caller<N, Resp> for Asserter {
383    fn call(
384        &self,
385        _params: crate::EthCallParams<N>,
386    ) -> TransportResult<ProviderCall<crate::EthCallParams<N>, Resp>> {
387        provider_eth_call(self)
388    }
389
390    fn call_many(
391        &self,
392        _params: crate::EthCallManyParams<'_>,
393    ) -> TransportResult<ProviderCall<crate::EthCallManyParams<'static>, Resp>> {
394        provider_eth_call(self)
395    }
396
397    fn estimate_gas(
398        &self,
399        _params: crate::EthCallParams<N>,
400    ) -> TransportResult<ProviderCall<crate::EthCallParams<N>, Resp>> {
401        provider_eth_call(self)
402    }
403}
404
405fn provider_eth_call<Params: RpcSend, Resp: RpcRecv>(
406    asserter: &Asserter,
407) -> TransportResult<ProviderCall<Params, Resp>> {
408    let value = asserter.pop_response().ok_or(TransportErrorKind::custom(MockError::EmptyQueue));
409
410    let res = match value {
411        Ok(MockResponse::Success(value)) => serde_json::from_value(value)
412            .map_err(|e| TransportErrorKind::custom(MockError::DeserError(e.to_string()))),
413        Ok(MockResponse::Err(err)) | Err(err) => Err(err),
414    };
415
416    Ok(ProviderCall::Ready(Some(res)))
417}
418
419#[cfg(test)]
420mod tests {
421    use alloy_primitives::bytes;
422    use alloy_rpc_types_eth::TransactionRequest;
423
424    use super::*;
425    use crate::ProviderBuilder;
426
427    #[tokio::test]
428    async fn test_mock() {
429        let provider = ProviderBuilder::mocked();
430
431        let asserter = provider.asserter();
432        asserter.push_success(21965802);
433        asserter.push_success(21965803);
434        asserter.push_err(TransportError::NullResp);
435
436        let response = provider.get_block_number().await.unwrap();
437        assert_eq!(response, 21965802);
438
439        let response = provider.get_block_number().await.unwrap();
440        assert_eq!(response, 21965803);
441
442        let err_res = provider.get_block_number().await.unwrap_err();
443        assert!(matches!(err_res, TransportError::NullResp));
444
445        let response = provider.get_block_number().await.unwrap_err();
446        assert!(response.to_string().contains("empty response queue"));
447
448        asserter.push_success(vec![Address::with_last_byte(1), Address::with_last_byte(2)]);
449        let response = provider.get_accounts().await.unwrap();
450        assert_eq!(response, vec![Address::with_last_byte(1), Address::with_last_byte(2)]);
451
452        let call_resp = bytes!("12345678");
453
454        asserter.push_success(call_resp.clone());
455        let tx = TransactionRequest::default();
456        let response = provider.call(tx).await.unwrap();
457
458        assert_eq!(response, call_resp);
459
460        let assert_bal = U256::from(123456780);
461        asserter.push_success(assert_bal);
462
463        let response = provider.get_balance(Address::default()).await.unwrap();
464        assert_eq!(response, assert_bal);
465    }
466}