1use std::{collections::VecDeque, sync::Arc};
4
5use crate::{utils, EthCallMany, EthGetBlock};
6use alloy_eips::{BlockId, BlockNumberOrTag};
7use alloy_json_rpc::{ErrorPayload, RpcRecv, RpcSend};
8use alloy_network::Network;
9use alloy_primitives::{
10 Address, BlockHash, Bytes, StorageKey, StorageValue, TxHash, U128, U256, U64,
11};
12use alloy_rpc_client::NoParams;
13use alloy_rpc_types_eth::{
14 AccessListResult, Bundle, EIP1186AccountProofResponse, EthCallResponse, Filter, Log,
15};
16use alloy_transport::{TransportError, TransportErrorKind, TransportResult};
17use parking_lot::RwLock;
18use serde::Serialize;
19
20use crate::{Caller, EthCall, Provider, ProviderCall, ProviderLayer, RpcWithBlock};
21
22#[derive(Debug, Clone)]
24pub struct MockLayer {
25 asserter: Asserter,
26}
27
28impl MockLayer {
29 pub fn new(asserter: Asserter) -> Self {
31 Self { asserter }
32 }
33}
34
35impl<P, N> ProviderLayer<P, N> for MockLayer
36where
37 P: Provider<N>,
38 N: Network,
39{
40 type Provider = MockProvider<P, N>;
41
42 fn layer(&self, inner: P) -> Self::Provider {
43 MockProvider::new(inner, self.asserter.clone())
44 }
45}
46
47#[derive(Debug, Clone, Default)]
49pub struct Asserter {
50 responses: Arc<RwLock<VecDeque<MockResponse>>>,
51}
52
53impl Asserter {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn push_success<R: Serialize>(&self, response: R) {
61 self.responses
62 .write()
63 .push_back(MockResponse::Success(serde_json::to_value(response).unwrap()));
64 }
65
66 pub fn push_error(&self, error: ErrorPayload) {
68 self.push_err(TransportError::err_resp(error));
69 }
70
71 pub fn push_err(&self, err: TransportError) {
73 self.responses.write().push_back(MockResponse::Err(err));
74 }
75
76 pub fn pop_response(&self) -> Option<MockResponse> {
78 self.responses.write().pop_front()
79 }
80
81 pub fn pop_deser_response<T>(&self) -> Result<T, TransportError>
83 where
84 T: for<'de> serde::Deserialize<'de>,
85 {
86 let value = self.pop_response().ok_or(TransportErrorKind::custom(MockError::EmptyQueue));
87
88 match value {
89 Ok(MockResponse::Success(value)) => serde_json::from_value(value)
90 .map_err(|e| TransportErrorKind::custom(MockError::DeserError(e.to_string()))),
91 Ok(MockResponse::Err(err)) | Err(err) => Err(err),
92 }
93 }
94}
95
96#[derive(Debug)]
98pub enum MockResponse {
99 Success(serde_json::Value),
101 Err(TransportError),
103}
104
105#[derive(Debug, thiserror::Error)]
107pub enum MockError {
108 #[error("could not deserialize response {0}")]
110 DeserError(String),
111 #[error("empty response queue")]
113 EmptyQueue,
114}
115
116#[derive(Debug, Clone)]
118pub struct MockProvider<P: Provider<N>, N: Network> {
119 inner: P,
121 asserter: Asserter,
125 _network: std::marker::PhantomData<N>,
126}
127
128impl<P, N> MockProvider<P, N>
129where
130 P: Provider<N>,
131 N: Network,
132{
133 pub fn new(inner: P, asserter: Asserter) -> Self {
135 Self { inner, asserter, _network: std::marker::PhantomData }
136 }
137
138 pub fn asserter(&self) -> &Asserter {
140 &self.asserter
141 }
142
143 pub fn push_success<R: Serialize>(&self, response: R) {
145 self.asserter.push_success(response);
146 }
147
148 pub fn push_error(&self, error: ErrorPayload) {
150 self.asserter.push_error(error);
151 }
152
153 pub fn push_err(&self, err: TransportError) {
155 self.asserter.push_err(err);
156 }
157
158 fn next_response<T>(&self) -> Result<T, TransportError>
160 where
161 T: for<'de> serde::Deserialize<'de>,
162 {
163 self.asserter.pop_deser_response()
164 }
165}
166
167#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
168impl<P, N> Provider<N> for MockProvider<P, N>
169where
170 P: Provider<N>,
171 N: Network,
172{
173 fn root(&self) -> &crate::RootProvider<N> {
174 self.inner.root()
175 }
176
177 fn get_accounts(&self) -> ProviderCall<NoParams, Vec<Address>> {
178 ProviderCall::Ready(Some(self.next_response()))
179 }
180
181 fn get_block_number(&self) -> ProviderCall<NoParams, U64, u64> {
182 ProviderCall::Ready(Some(self.next_response()))
183 }
184
185 fn get_blob_base_fee(&self) -> ProviderCall<NoParams, U128, u128> {
186 ProviderCall::Ready(Some(self.next_response()))
187 }
188
189 fn get_chain_id(&self) -> ProviderCall<NoParams, U64, u64> {
190 ProviderCall::Ready(Some(self.next_response()))
191 }
192
193 fn call<'req>(&self, tx: N::TransactionRequest) -> EthCall<N, Bytes> {
194 EthCall::call(self.asserter.clone(), tx)
195 }
196
197 fn call_many<'req>(
198 &self,
199 bundles: &'req Vec<Bundle>,
200 ) -> EthCallMany<'req, N, Vec<Vec<EthCallResponse>>> {
201 EthCallMany::new(self.asserter.clone(), bundles)
202 }
203
204 fn estimate_gas(&self, tx: N::TransactionRequest) -> EthCall<N, U64, u64> {
205 EthCall::gas_estimate(self.asserter.clone(), tx).map_resp(utils::convert_u64)
206 }
207
208 fn create_access_list<'a>(
209 &self,
210 _request: &'a N::TransactionRequest,
211 ) -> RpcWithBlock<&'a N::TransactionRequest, AccessListResult> {
212 let asserter = self.asserter.clone();
213 RpcWithBlock::new_provider(move |_block_id| {
214 let res = asserter.pop_deser_response();
215 ProviderCall::Ready(Some(res))
216 })
217 }
218
219 fn get_balance(&self, _address: Address) -> RpcWithBlock<Address, U256, U256> {
220 let asserter = self.asserter.clone();
221 RpcWithBlock::new_provider(move |_block_id| {
222 let res = asserter.pop_deser_response();
223 ProviderCall::Ready(Some(res))
224 })
225 }
226
227 fn get_gas_price(&self) -> ProviderCall<NoParams, U128, u128> {
228 ProviderCall::Ready(Some(self.next_response()))
229 }
230
231 fn get_account(&self, _address: Address) -> RpcWithBlock<Address, alloy_consensus::Account> {
232 let asserter = self.asserter.clone();
233 RpcWithBlock::new_provider(move |_block_id| {
234 let res = asserter.pop_deser_response();
235 ProviderCall::Ready(Some(res))
236 })
237 }
238
239 fn get_block(&self, block: BlockId) -> EthGetBlock<N::BlockResponse> {
240 let asserter = self.asserter.clone();
241 EthGetBlock::new_provider(
242 block,
243 Box::new(move |_kind| {
244 let res = asserter.pop_deser_response();
245 ProviderCall::Ready(Some(res))
246 }),
247 )
248 }
249
250 fn get_block_by_number(&self, number: BlockNumberOrTag) -> EthGetBlock<N::BlockResponse> {
251 let asserter = self.asserter.clone();
252 EthGetBlock::new_provider(
253 number.into(),
254 Box::new(move |_kind| {
255 let res = asserter.pop_deser_response();
256 ProviderCall::Ready(Some(res))
257 }),
258 )
259 }
260
261 fn get_block_by_hash(&self, hash: BlockHash) -> EthGetBlock<N::BlockResponse> {
262 let asserter = self.asserter.clone();
263 EthGetBlock::new_provider(
264 hash.into(),
265 Box::new(move |_kind| {
266 let res = asserter.pop_deser_response();
267 ProviderCall::Ready(Some(res))
268 }),
269 )
270 }
271
272 async fn get_block_transaction_count_by_hash(
273 &self,
274 _hash: BlockHash,
275 ) -> TransportResult<Option<u64>> {
276 let res = self.next_response::<Option<U64>>()?;
277 Ok(res.map(utils::convert_u64))
278 }
279
280 async fn get_block_transaction_count_by_number(
281 &self,
282 _block_number: BlockNumberOrTag,
283 ) -> TransportResult<Option<u64>> {
284 let res = self.next_response::<Option<U64>>()?;
285 Ok(res.map(utils::convert_u64))
286 }
287
288 fn get_block_receipts(
289 &self,
290 _block: BlockId,
291 ) -> ProviderCall<(BlockId,), Option<Vec<N::ReceiptResponse>>> {
292 ProviderCall::Ready(Some(self.next_response()))
293 }
294
295 fn get_code_at(&self, _address: Address) -> RpcWithBlock<Address, Bytes> {
296 let asserter = self.asserter.clone();
297 RpcWithBlock::new_provider(move |_block_id| {
298 let res = asserter.pop_deser_response();
299 ProviderCall::Ready(Some(res))
300 })
301 }
302
303 async fn get_logs(&self, _filter: &Filter) -> TransportResult<Vec<Log>> {
304 self.next_response()
305 }
306
307 fn get_proof(
308 &self,
309 _address: Address,
310 _keys: Vec<StorageKey>,
311 ) -> RpcWithBlock<(Address, Vec<StorageKey>), EIP1186AccountProofResponse> {
312 let asserter = self.asserter.clone();
313 RpcWithBlock::new_provider(move |_block_id| {
314 let res = asserter.pop_deser_response();
315 ProviderCall::Ready(Some(res))
316 })
317 }
318
319 fn get_storage_at(
320 &self,
321 _address: Address,
322 _key: U256,
323 ) -> RpcWithBlock<(Address, U256), StorageValue> {
324 let asserter = self.asserter.clone();
325 RpcWithBlock::new_provider(move |_block_id| {
326 let res = asserter.pop_deser_response();
327 ProviderCall::Ready(Some(res))
328 })
329 }
330
331 fn get_transaction_by_hash(
332 &self,
333 _hash: TxHash,
334 ) -> ProviderCall<(TxHash,), Option<N::TransactionResponse>> {
335 ProviderCall::Ready(Some(self.next_response()))
336 }
337
338 fn get_raw_transaction_by_hash(&self, _hash: TxHash) -> ProviderCall<(TxHash,), Option<Bytes>> {
339 ProviderCall::Ready(Some(self.next_response()))
340 }
341
342 fn get_transaction_count(
343 &self,
344 _address: Address,
345 ) -> RpcWithBlock<Address, U64, u64, fn(U64) -> u64> {
346 let asserter = self.asserter.clone();
347 RpcWithBlock::new_provider(move |_block_id| {
348 let res = asserter.pop_deser_response::<U64>();
349 let res = res.map(utils::convert_u64);
350 ProviderCall::Ready(Some(res))
351 })
352 }
353
354 fn get_transaction_receipt(
355 &self,
356 _hash: TxHash,
357 ) -> ProviderCall<(TxHash,), Option<N::ReceiptResponse>> {
358 ProviderCall::Ready(Some(self.next_response()))
359 }
360
361 async fn get_uncle(
362 &self,
363 tag: BlockId,
364 _idx: u64,
365 ) -> TransportResult<Option<N::BlockResponse>> {
366 match tag {
367 BlockId::Hash(_) | BlockId::Number(_) => self.next_response(),
368 }
369 }
370
371 async fn get_uncle_count(&self, tag: BlockId) -> TransportResult<u64> {
373 match tag {
374 BlockId::Hash(_) | BlockId::Number(_) => {
375 self.next_response::<U64>().map(utils::convert_u64)
376 }
377 }
378 }
379}
380
381impl<N: Network, Resp: RpcRecv> Caller<N, Resp> for Asserter {
383 fn call(
384 &self,
385 _params: crate::EthCallParams<N>,
386 ) -> TransportResult<ProviderCall<crate::EthCallParams<N>, Resp>> {
387 provider_eth_call(self)
388 }
389
390 fn call_many(
391 &self,
392 _params: crate::EthCallManyParams<'_>,
393 ) -> TransportResult<ProviderCall<crate::EthCallManyParams<'static>, Resp>> {
394 provider_eth_call(self)
395 }
396
397 fn estimate_gas(
398 &self,
399 _params: crate::EthCallParams<N>,
400 ) -> TransportResult<ProviderCall<crate::EthCallParams<N>, Resp>> {
401 provider_eth_call(self)
402 }
403}
404
405fn provider_eth_call<Params: RpcSend, Resp: RpcRecv>(
406 asserter: &Asserter,
407) -> TransportResult<ProviderCall<Params, Resp>> {
408 let value = asserter.pop_response().ok_or(TransportErrorKind::custom(MockError::EmptyQueue));
409
410 let res = match value {
411 Ok(MockResponse::Success(value)) => serde_json::from_value(value)
412 .map_err(|e| TransportErrorKind::custom(MockError::DeserError(e.to_string()))),
413 Ok(MockResponse::Err(err)) | Err(err) => Err(err),
414 };
415
416 Ok(ProviderCall::Ready(Some(res)))
417}
418
419#[cfg(test)]
420mod tests {
421 use alloy_primitives::bytes;
422 use alloy_rpc_types_eth::TransactionRequest;
423
424 use super::*;
425 use crate::ProviderBuilder;
426
427 #[tokio::test]
428 async fn test_mock() {
429 let provider = ProviderBuilder::mocked();
430
431 let asserter = provider.asserter();
432 asserter.push_success(21965802);
433 asserter.push_success(21965803);
434 asserter.push_err(TransportError::NullResp);
435
436 let response = provider.get_block_number().await.unwrap();
437 assert_eq!(response, 21965802);
438
439 let response = provider.get_block_number().await.unwrap();
440 assert_eq!(response, 21965803);
441
442 let err_res = provider.get_block_number().await.unwrap_err();
443 assert!(matches!(err_res, TransportError::NullResp));
444
445 let response = provider.get_block_number().await.unwrap_err();
446 assert!(response.to_string().contains("empty response queue"));
447
448 asserter.push_success(vec![Address::with_last_byte(1), Address::with_last_byte(2)]);
449 let response = provider.get_accounts().await.unwrap();
450 assert_eq!(response, vec![Address::with_last_byte(1), Address::with_last_byte(2)]);
451
452 let call_resp = bytes!("12345678");
453
454 asserter.push_success(call_resp.clone());
455 let tx = TransactionRequest::default();
456 let response = provider.call(tx).await.unwrap();
457
458 assert_eq!(response, call_resp);
459
460 let assert_bal = U256::from(123456780);
461 asserter.push_success(assert_bal);
462
463 let response = provider.get_balance(Address::default()).await.unwrap();
464 assert_eq!(response, assert_bal);
465 }
466}