1use crate::{ParamsWithBlock, Provider, ProviderCall, ProviderLayer, RootProvider, RpcWithBlock};
2use alloy_eips::BlockId;
3use alloy_json_rpc::{RpcError, RpcSend};
4use alloy_network::Network;
5use alloy_primitives::{keccak256, Address, Bytes, StorageKey, StorageValue, TxHash, B256, U256};
6use alloy_rpc_types_eth::{BlockNumberOrTag, EIP1186AccountProofResponse, Filter, Log};
7use alloy_transport::{TransportErrorKind, TransportResult};
8use lru::LruCache;
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::{io::BufReader, marker::PhantomData, num::NonZero, path::PathBuf, sync::Arc};
12#[derive(Debug, Clone)]
20pub struct CacheLayer {
21 cache: SharedCache,
23}
24
25impl CacheLayer {
26 pub fn new(max_items: u32) -> Self {
29 Self { cache: SharedCache::new(max_items) }
30 }
31
32 pub const fn max_items(&self) -> u32 {
34 self.cache.max_items()
35 }
36
37 pub fn cache(&self) -> SharedCache {
39 self.cache.clone()
40 }
41}
42
43impl<P, N> ProviderLayer<P, N> for CacheLayer
44where
45 P: Provider<N>,
46 N: Network,
47{
48 type Provider = CacheProvider<P, N>;
49
50 fn layer(&self, inner: P) -> Self::Provider {
51 CacheProvider::new(inner, self.cache())
52 }
53}
54
55#[derive(Debug, Clone)]
63pub struct CacheProvider<P, N> {
64 inner: P,
66 cache: SharedCache,
68 _pd: PhantomData<N>,
70}
71
72impl<P, N> CacheProvider<P, N>
73where
74 P: Provider<N>,
75 N: Network,
76{
77 pub const fn new(inner: P, cache: SharedCache) -> Self {
79 Self { inner, cache, _pd: PhantomData }
80 }
81}
82
83macro_rules! rpc_call_with_block {
91 ($cache:expr, $client:expr, $req:expr) => {{
92 let client =
93 $client.upgrade().ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"));
94 let cache = $cache.clone();
95 ProviderCall::BoxedFuture(Box::pin(async move {
96 let client = client?;
97
98 let result = client.request($req.method(), $req.params()).map_params(|params| {
99 ParamsWithBlock::new(params, $req.block_id.unwrap_or(BlockId::latest()))
100 });
101
102 let res = result.await?;
103 let json_str = serde_json::to_string(&res).map_err(TransportErrorKind::custom)?;
105 let hash = $req.params_hash()?;
106 let _ = cache.put(hash, json_str);
107
108 Ok(res)
109 }))
110 }};
111}
112
113macro_rules! cache_rpc_call_with_block {
119 ($cache:expr, $client:expr, $req:expr) => {{
120 if $req.has_block_tag() {
121 return rpc_call_with_block!($cache, $client, $req);
122 }
123
124 let hash = $req.params_hash().ok();
125
126 if let Some(hash) = hash {
127 if let Ok(Some(cached)) = $cache.get_deserialized(&hash) {
128 return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
129 }
130 }
131
132 rpc_call_with_block!($cache, $client, $req)
133 }};
134}
135
136#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
137#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
138impl<P, N> Provider<N> for CacheProvider<P, N>
139where
140 P: Provider<N>,
141 N: Network,
142{
143 #[inline(always)]
144 fn root(&self) -> &RootProvider<N> {
145 self.inner.root()
146 }
147
148 fn get_block_receipts(
149 &self,
150 block: BlockId,
151 ) -> ProviderCall<(BlockId,), Option<Vec<N::ReceiptResponse>>> {
152 let req = RequestType::new("eth_getBlockReceipts", (block,));
153
154 let redirect = req.has_block_tag();
155
156 if !redirect {
157 let params_hash = req.params_hash().ok();
158
159 if let Some(hash) = params_hash {
160 if let Ok(Some(cached)) = self.cache.get_deserialized(&hash) {
161 return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
162 }
163 }
164 }
165
166 let client = self.inner.weak_client();
167 let cache = self.cache.clone();
168
169 ProviderCall::BoxedFuture(Box::pin(async move {
170 let client = client
171 .upgrade()
172 .ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"))?;
173
174 let result = client.request(req.method(), req.params()).await?;
175
176 let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
177
178 if !redirect {
179 let hash = req.params_hash()?;
180 let _ = cache.put(hash, json_str);
181 }
182
183 Ok(result)
184 }))
185 }
186
187 fn get_code_at(&self, address: Address) -> RpcWithBlock<Address, Bytes> {
188 let client = self.inner.weak_client();
189 let cache = self.cache.clone();
190 RpcWithBlock::new_provider(move |block_id| {
191 let req = RequestType::new("eth_getCode", address).with_block_id(block_id);
192 cache_rpc_call_with_block!(cache, client, req)
193 })
194 }
195
196 async fn get_logs(&self, filter: &Filter) -> TransportResult<Vec<Log>> {
197 let req = RequestType::new("eth_getLogs", filter.clone());
198
199 let params_hash = req.params_hash().ok();
200
201 if let Some(hash) = params_hash {
202 if let Some(cached) = self.cache.get_deserialized(&hash)? {
203 return Ok(cached);
204 }
205 }
206
207 let result = self.inner.get_logs(filter).await?;
208
209 let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
210
211 let hash = req.params_hash()?;
212 let _ = self.cache.put(hash, json_str);
213
214 Ok(result)
215 }
216
217 fn get_proof(
218 &self,
219 address: Address,
220 keys: Vec<StorageKey>,
221 ) -> RpcWithBlock<(Address, Vec<StorageKey>), EIP1186AccountProofResponse> {
222 let client = self.inner.weak_client();
223 let cache = self.cache.clone();
224 RpcWithBlock::new_provider(move |block_id| {
225 let req =
226 RequestType::new("eth_getProof", (address, keys.clone())).with_block_id(block_id);
227 cache_rpc_call_with_block!(cache, client, req)
228 })
229 }
230
231 fn get_storage_at(
232 &self,
233 address: Address,
234 key: U256,
235 ) -> RpcWithBlock<(Address, U256), StorageValue> {
236 let client = self.inner.weak_client();
237 let cache = self.cache.clone();
238 RpcWithBlock::new_provider(move |block_id| {
239 let req = RequestType::new("eth_getStorageAt", (address, key)).with_block_id(block_id);
240 cache_rpc_call_with_block!(cache, client, req)
241 })
242 }
243
244 fn get_transaction_by_hash(
245 &self,
246 hash: TxHash,
247 ) -> ProviderCall<(TxHash,), Option<N::TransactionResponse>> {
248 let req = RequestType::new("eth_getTransactionByHash", (hash,));
249
250 let params_hash = req.params_hash().ok();
251
252 if let Some(hash) = params_hash {
253 if let Ok(Some(cached)) = self.cache.get_deserialized(&hash) {
254 return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
255 }
256 }
257 let client = self.inner.weak_client();
258 let cache = self.cache.clone();
259 ProviderCall::BoxedFuture(Box::pin(async move {
260 let client = client
261 .upgrade()
262 .ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"))?;
263 let result = client.request(req.method(), req.params()).await?;
264
265 let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
266 let hash = req.params_hash()?;
267 let _ = cache.put(hash, json_str);
268
269 Ok(result)
270 }))
271 }
272
273 fn get_raw_transaction_by_hash(&self, hash: TxHash) -> ProviderCall<(TxHash,), Option<Bytes>> {
274 let req = RequestType::new("eth_getRawTransactionByHash", (hash,));
275
276 let params_hash = req.params_hash().ok();
277
278 if let Some(hash) = params_hash {
279 if let Ok(Some(cached)) = self.cache.get_deserialized(&hash) {
280 return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
281 }
282 }
283
284 let client = self.inner.weak_client();
285 let cache = self.cache.clone();
286 ProviderCall::BoxedFuture(Box::pin(async move {
287 let client = client
288 .upgrade()
289 .ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"))?;
290
291 let result = client.request(req.method(), req.params()).await?;
292
293 let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
294 let hash = req.params_hash()?;
295 let _ = cache.put(hash, json_str);
296
297 Ok(result)
298 }))
299 }
300
301 fn get_transaction_receipt(
302 &self,
303 hash: TxHash,
304 ) -> ProviderCall<(TxHash,), Option<N::ReceiptResponse>> {
305 let req = RequestType::new("eth_getTransactionReceipt", (hash,));
306
307 let params_hash = req.params_hash().ok();
308
309 if let Some(hash) = params_hash {
310 if let Ok(Some(cached)) = self.cache.get_deserialized(&hash) {
311 return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
312 }
313 }
314
315 let client = self.inner.weak_client();
316 let cache = self.cache.clone();
317 ProviderCall::BoxedFuture(Box::pin(async move {
318 let client = client
319 .upgrade()
320 .ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"))?;
321
322 let result = client.request(req.method(), req.params()).await?;
323
324 let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
325 let hash = req.params_hash()?;
326 let _ = cache.put(hash, json_str);
327
328 Ok(result)
329 }))
330 }
331}
332
333struct RequestType<Params: RpcSend> {
335 method: &'static str,
336 params: Params,
337 block_id: Option<BlockId>,
338}
339
340impl<Params: RpcSend> RequestType<Params> {
341 const fn new(method: &'static str, params: Params) -> Self {
342 Self { method, params, block_id: None }
343 }
344
345 const fn with_block_id(mut self, block_id: BlockId) -> Self {
346 self.block_id = Some(block_id);
347 self
348 }
349
350 fn params_hash(&self) -> TransportResult<B256> {
351 let hash = serde_json::to_string(&self.params())
353 .map(|p| keccak256(format!("{}{}", self.method(), p).as_bytes()))
354 .map_err(RpcError::ser_err)?;
355
356 Ok(hash)
357 }
358
359 const fn method(&self) -> &'static str {
360 self.method
361 }
362
363 fn params(&self) -> Params {
364 self.params.clone()
365 }
366
367 const fn has_block_tag(&self) -> bool {
370 if let Some(block_id) = self.block_id {
371 return !matches!(
372 block_id,
373 BlockId::Hash(_) | BlockId::Number(BlockNumberOrTag::Number(_))
374 );
375 }
376 false
377 }
378}
379
380#[derive(Debug, Serialize, Deserialize)]
381struct FsCacheEntry {
382 key: B256,
384 value: String,
386}
387
388#[derive(Debug, Clone)]
390pub struct SharedCache {
391 inner: Arc<RwLock<LruCache<B256, String, alloy_primitives::map::FbBuildHasher<32>>>>,
392 max_items: NonZero<usize>,
393}
394
395impl SharedCache {
396 pub fn new(max_items: u32) -> Self {
398 let max_items = NonZero::new(max_items as usize).unwrap_or(NonZero::<usize>::MIN);
399 let inner = Arc::new(RwLock::new(LruCache::with_hasher(max_items, Default::default())));
400 Self { inner, max_items }
401 }
402
403 pub const fn max_items(&self) -> u32 {
405 self.max_items.get() as u32
406 }
407
408 pub fn put(&self, key: B256, value: String) -> TransportResult<bool> {
410 Ok(self.inner.write().put(key, value).is_some())
411 }
412
413 pub fn get(&self, key: &B256) -> Option<String> {
415 self.inner.write().get(key).cloned()
417 }
418
419 pub fn get_deserialized<T>(&self, key: &B256) -> TransportResult<Option<T>>
421 where
422 T: for<'de> Deserialize<'de>,
423 {
424 let Some(cached) = self.get(key) else { return Ok(None) };
425 let result = serde_json::from_str(&cached).map_err(TransportErrorKind::custom)?;
426 Ok(Some(result))
427 }
428
429 pub fn save_cache(&self, path: PathBuf) -> TransportResult<()> {
433 let entries: Vec<FsCacheEntry> = {
434 self.inner
435 .read()
436 .iter()
437 .map(|(key, value)| FsCacheEntry { key: *key, value: value.clone() })
438 .collect()
439 };
440 let file = std::fs::File::create(path).map_err(TransportErrorKind::custom)?;
441 serde_json::to_writer(file, &entries).map_err(TransportErrorKind::custom)?;
442 Ok(())
443 }
444
445 pub fn load_cache(&self, path: PathBuf) -> TransportResult<()> {
448 if !path.exists() {
449 return Ok(());
450 };
451 let file = std::fs::File::open(path).map_err(TransportErrorKind::custom)?;
452 let file = BufReader::new(file);
453 let entries: Vec<FsCacheEntry> =
454 serde_json::from_reader(file).map_err(TransportErrorKind::custom)?;
455 let mut cache = self.inner.write();
456 for entry in entries {
457 cache.put(entry.key, entry.value);
458 }
459
460 Ok(())
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::ProviderBuilder;
468 use alloy_network::TransactionBuilder;
469 use alloy_node_bindings::{utils::run_with_tempdir, Anvil};
470 use alloy_primitives::{bytes, hex, Bytes, FixedBytes};
471 use alloy_rpc_types_eth::{BlockId, TransactionRequest};
472
473 #[tokio::test]
474 async fn test_get_proof() {
475 run_with_tempdir("get-proof", |dir| async move {
476 let cache_layer = CacheLayer::new(100);
477 let shared_cache = cache_layer.cache();
478 let anvil = Anvil::new().block_time_f64(0.3).spawn();
479 let provider = ProviderBuilder::new().layer(cache_layer).on_http(anvil.endpoint_url());
480
481 let from = anvil.addresses()[0];
482 let path = dir.join("rpc-cache-proof.txt");
483
484 shared_cache.load_cache(path.clone()).unwrap();
485
486 let calldata: Bytes = "0x6080604052348015600f57600080fd5b506101f28061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c80633fb5c1cb146100465780638381f58a14610062578063d09de08a14610080575b600080fd5b610060600480360381019061005b91906100ee565b61008a565b005b61006a610094565b604051610077919061012a565b60405180910390f35b61008861009a565b005b8060008190555050565b60005481565b6000808154809291906100ac90610174565b9190505550565b600080fd5b6000819050919050565b6100cb816100b8565b81146100d657600080fd5b50565b6000813590506100e8816100c2565b92915050565b600060208284031215610104576101036100b3565b5b6000610112848285016100d9565b91505092915050565b610124816100b8565b82525050565b600060208201905061013f600083018461011b565b92915050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052601160045260246000fd5b600061017f826100b8565b91507fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff82036101b1576101b0610145565b5b60018201905091905056fea264697066735822122067ac0f21f648b0cacd1b7260772852ad4a0f63e2cc174168c51a6887fd5197a964736f6c634300081a0033".parse().unwrap();
487
488 let tx = TransactionRequest::default()
489 .with_from(from)
490 .with_input(calldata)
491 .with_max_fee_per_gas(1_000_000_000)
492 .with_max_priority_fee_per_gas(1_000_000)
493 .with_gas_limit(1_000_000)
494 .with_nonce(0);
495
496 let tx_receipt = provider.send_transaction(tx).await.unwrap().get_receipt().await.unwrap();
497
498 let counter_addr = tx_receipt.contract_address.unwrap();
499
500 let keys = vec![
501 FixedBytes::with_last_byte(0),
502 FixedBytes::with_last_byte(0x1),
503 FixedBytes::with_last_byte(0x2),
504 FixedBytes::with_last_byte(0x3),
505 FixedBytes::with_last_byte(0x4),
506 ];
507
508 let proof =
509 provider.get_proof(counter_addr, keys.clone()).block_id(1.into()).await.unwrap();
510 let proof2 = provider.get_proof(counter_addr, keys).block_id(1.into()).await.unwrap();
511
512 assert_eq!(proof, proof2);
513
514 shared_cache.save_cache(path).unwrap();
515 }).await;
516 }
517
518 #[tokio::test]
519 async fn test_get_tx_by_hash_and_receipt() {
520 run_with_tempdir("get-tx-by-hash", |dir| async move {
521 let cache_layer = CacheLayer::new(100);
522 let shared_cache = cache_layer.cache();
523 let anvil = Anvil::new().block_time_f64(0.3).spawn();
524 let provider = ProviderBuilder::new()
525 .disable_recommended_fillers()
526 .layer(cache_layer)
527 .on_http(anvil.endpoint_url());
528
529 let path = dir.join("rpc-cache-tx.txt");
530 shared_cache.load_cache(path.clone()).unwrap();
531
532 let req = TransactionRequest::default()
533 .from(anvil.addresses()[0])
534 .to(Address::repeat_byte(5))
535 .value(U256::ZERO)
536 .input(bytes!("deadbeef").into());
537
538 let tx_hash =
539 *provider.send_transaction(req).await.expect("failed to send tx").tx_hash();
540
541 let tx = provider.get_transaction_by_hash(tx_hash).await.unwrap(); let tx2 = provider.get_transaction_by_hash(tx_hash).await.unwrap(); assert_eq!(tx, tx2);
544
545 let receipt = provider.get_transaction_receipt(tx_hash).await.unwrap(); let receipt2 = provider.get_transaction_receipt(tx_hash).await.unwrap(); assert_eq!(receipt, receipt2);
549
550 shared_cache.save_cache(path).unwrap();
551 })
552 .await;
553 }
554
555 #[tokio::test]
556 async fn test_block_receipts() {
557 run_with_tempdir("get-block-receipts", |dir| async move {
558 let cache_layer = CacheLayer::new(100);
559 let shared_cache = cache_layer.cache();
560 let anvil = Anvil::new().spawn();
561 let provider = ProviderBuilder::new().layer(cache_layer).on_http(anvil.endpoint_url());
562
563 let path = dir.join("rpc-cache-block-receipts.txt");
564 shared_cache.load_cache(path.clone()).unwrap();
565
566 let receipt = provider
569 .send_raw_transaction(
570 bytes!("f865808477359400825208940000000000000000000000000000000000000000018082f4f5a00505e227c1c636c76fac55795db1a40a4d24840d81b40d2fe0cc85767f6bd202a01e91b437099a8a90234ac5af3cb7ca4fb1432e133f75f9a91678eaf5f487c74b").as_ref()
572 )
573 .await.unwrap().get_receipt().await.unwrap();
574
575 let block_number = receipt.block_number.unwrap();
576
577 let receipts =
578 provider.get_block_receipts(block_number.into()).await.unwrap(); let receipts2 =
580 provider.get_block_receipts(block_number.into()).await.unwrap(); assert_eq!(receipts, receipts2);
582
583 assert!(receipts.is_some_and(|r| r[0] == receipt));
584
585 shared_cache.save_cache(path).unwrap();
586 })
587 .await
588 }
589
590 #[tokio::test]
591 async fn test_get_code() {
592 run_with_tempdir("get-code", |dir| async move {
593 let cache_layer = CacheLayer::new(100);
594 let shared_cache = cache_layer.cache();
595 let provider = ProviderBuilder::default().with_gas_estimation().layer(cache_layer).on_anvil_with_wallet();
596
597 let path = dir.join("rpc-cache-code.txt");
598 shared_cache.load_cache(path.clone()).unwrap();
599
600 let bytecode = hex::decode(
601 "6080806040523460135760df908160198239f35b600080fdfe6080806040526004361015601257600080fd5b60003560e01c9081633fb5c1cb1460925781638381f58a146079575063d09de08a14603c57600080fd5b3460745760003660031901126074576000546000198114605e57600101600055005b634e487b7160e01b600052601160045260246000fd5b600080fd5b3460745760003660031901126074576020906000548152f35b34607457602036600319011260745760043560005500fea2646970667358221220e978270883b7baed10810c4079c941512e93a7ba1cd1108c781d4bc738d9090564736f6c634300081a0033"
603 ).unwrap();
604 let tx = TransactionRequest::default().with_nonce(0).with_deploy_code(bytecode).with_chain_id(31337);
605
606 let receipt = provider.send_transaction(tx).await.unwrap().get_receipt().await.unwrap();
607
608 let counter_addr = receipt.contract_address.unwrap();
609
610 let block_id = BlockId::number(receipt.block_number.unwrap());
611
612 let code = provider.get_code_at(counter_addr).block_id(block_id).await.unwrap(); let code2 = provider.get_code_at(counter_addr).block_id(block_id).await.unwrap(); assert_eq!(code, code2);
615
616 shared_cache.save_cache(path).unwrap();
617 })
618 .await;
619 }
620}