alloy_provider/layers/
cache.rs

1use crate::{ParamsWithBlock, Provider, ProviderCall, ProviderLayer, RootProvider, RpcWithBlock};
2use alloy_eips::BlockId;
3use alloy_json_rpc::{RpcError, RpcSend};
4use alloy_network::Network;
5use alloy_primitives::{keccak256, Address, Bytes, StorageKey, StorageValue, TxHash, B256, U256};
6use alloy_rpc_types_eth::{BlockNumberOrTag, EIP1186AccountProofResponse, Filter, Log};
7use alloy_transport::{TransportErrorKind, TransportResult};
8use lru::LruCache;
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::{io::BufReader, marker::PhantomData, num::NonZero, path::PathBuf, sync::Arc};
12/// A provider layer that caches RPC responses and serves them on subsequent requests.
13///
14/// In order to initialize the caching layer, the path to the cache file is provided along with the
15/// max number of items that are stored in the in-memory LRU cache.
16///
17/// One can load the cache from the file system by calling `load_cache` and save the cache to the
18/// file system by calling `save_cache`.
19#[derive(Debug, Clone)]
20pub struct CacheLayer {
21    /// In-memory LRU cache, mapping requests to responses.
22    cache: SharedCache,
23}
24
25impl CacheLayer {
26    /// Instantiate a new cache layer with the the maximum number of
27    /// items to store.
28    pub fn new(max_items: u32) -> Self {
29        Self { cache: SharedCache::new(max_items) }
30    }
31
32    /// Returns the maximum number of items that can be stored in the cache, set at initialization.
33    pub const fn max_items(&self) -> u32 {
34        self.cache.max_items()
35    }
36
37    /// Returns the shared cache.
38    pub fn cache(&self) -> SharedCache {
39        self.cache.clone()
40    }
41}
42
43impl<P, N> ProviderLayer<P, N> for CacheLayer
44where
45    P: Provider<N>,
46    N: Network,
47{
48    type Provider = CacheProvider<P, N>;
49
50    fn layer(&self, inner: P) -> Self::Provider {
51        CacheProvider::new(inner, self.cache())
52    }
53}
54
55/// The [`CacheProvider`] holds the underlying in-memory LRU cache and overrides methods
56/// from the [`Provider`] trait. It attempts to fetch from the cache and fallbacks to
57/// the RPC in case of a cache miss.
58///
59/// Most importantly, the [`CacheProvider`] adds `save_cache` and `load_cache` methods
60/// to the provider interface, allowing users to save the cache to disk and load it
61/// from there on demand.
62#[derive(Debug, Clone)]
63pub struct CacheProvider<P, N> {
64    /// Inner provider.
65    inner: P,
66    /// In-memory LRU cache, mapping requests to responses.
67    cache: SharedCache,
68    /// Phantom data
69    _pd: PhantomData<N>,
70}
71
72impl<P, N> CacheProvider<P, N>
73where
74    P: Provider<N>,
75    N: Network,
76{
77    /// Instantiate a new cache provider.
78    pub const fn new(inner: P, cache: SharedCache) -> Self {
79        Self { inner, cache, _pd: PhantomData }
80    }
81}
82
83/// Uses underlying transport client to fetch data from the RPC.
84///
85/// This is specific to RPC requests that require the `block_id` parameter.
86///
87/// Fetches from the RPC and saves the response to the cache.
88///
89/// Returns a ProviderCall::BoxedFuture
90macro_rules! rpc_call_with_block {
91    ($cache:expr, $client:expr, $req:expr) => {{
92        let client =
93            $client.upgrade().ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"));
94        let cache = $cache.clone();
95        ProviderCall::BoxedFuture(Box::pin(async move {
96            let client = client?;
97
98            let result = client.request($req.method(), $req.params()).map_params(|params| {
99                ParamsWithBlock::new(params, $req.block_id.unwrap_or(BlockId::latest()))
100            });
101
102            let res = result.await?;
103            // Insert into cache.
104            let json_str = serde_json::to_string(&res).map_err(TransportErrorKind::custom)?;
105            let hash = $req.params_hash()?;
106            let _ = cache.put(hash, json_str);
107
108            Ok(res)
109        }))
110    }};
111}
112
113/// Attempts to fetch the response from the cache by using the hash of the request params.
114///
115/// Fetches from the RPC in case of a cache miss
116///
117/// This helps overriding [`Provider`] methods that return `RpcWithBlock`.
118macro_rules! cache_rpc_call_with_block {
119    ($cache:expr, $client:expr, $req:expr) => {{
120        if $req.has_block_tag() {
121            return rpc_call_with_block!($cache, $client, $req);
122        }
123
124        let hash = $req.params_hash().ok();
125
126        if let Some(hash) = hash {
127            if let Ok(Some(cached)) = $cache.get_deserialized(&hash) {
128                return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
129            }
130        }
131
132        rpc_call_with_block!($cache, $client, $req)
133    }};
134}
135
136#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
137#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
138impl<P, N> Provider<N> for CacheProvider<P, N>
139where
140    P: Provider<N>,
141    N: Network,
142{
143    #[inline(always)]
144    fn root(&self) -> &RootProvider<N> {
145        self.inner.root()
146    }
147
148    fn get_block_receipts(
149        &self,
150        block: BlockId,
151    ) -> ProviderCall<(BlockId,), Option<Vec<N::ReceiptResponse>>> {
152        let req = RequestType::new("eth_getBlockReceipts", (block,));
153
154        let redirect = req.has_block_tag();
155
156        if !redirect {
157            let params_hash = req.params_hash().ok();
158
159            if let Some(hash) = params_hash {
160                if let Ok(Some(cached)) = self.cache.get_deserialized(&hash) {
161                    return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
162                }
163            }
164        }
165
166        let client = self.inner.weak_client();
167        let cache = self.cache.clone();
168
169        ProviderCall::BoxedFuture(Box::pin(async move {
170            let client = client
171                .upgrade()
172                .ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"))?;
173
174            let result = client.request(req.method(), req.params()).await?;
175
176            let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
177
178            if !redirect {
179                let hash = req.params_hash()?;
180                let _ = cache.put(hash, json_str);
181            }
182
183            Ok(result)
184        }))
185    }
186
187    fn get_code_at(&self, address: Address) -> RpcWithBlock<Address, Bytes> {
188        let client = self.inner.weak_client();
189        let cache = self.cache.clone();
190        RpcWithBlock::new_provider(move |block_id| {
191            let req = RequestType::new("eth_getCode", address).with_block_id(block_id);
192            cache_rpc_call_with_block!(cache, client, req)
193        })
194    }
195
196    async fn get_logs(&self, filter: &Filter) -> TransportResult<Vec<Log>> {
197        let req = RequestType::new("eth_getLogs", filter.clone());
198
199        let params_hash = req.params_hash().ok();
200
201        if let Some(hash) = params_hash {
202            if let Some(cached) = self.cache.get_deserialized(&hash)? {
203                return Ok(cached);
204            }
205        }
206
207        let result = self.inner.get_logs(filter).await?;
208
209        let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
210
211        let hash = req.params_hash()?;
212        let _ = self.cache.put(hash, json_str);
213
214        Ok(result)
215    }
216
217    fn get_proof(
218        &self,
219        address: Address,
220        keys: Vec<StorageKey>,
221    ) -> RpcWithBlock<(Address, Vec<StorageKey>), EIP1186AccountProofResponse> {
222        let client = self.inner.weak_client();
223        let cache = self.cache.clone();
224        RpcWithBlock::new_provider(move |block_id| {
225            let req =
226                RequestType::new("eth_getProof", (address, keys.clone())).with_block_id(block_id);
227            cache_rpc_call_with_block!(cache, client, req)
228        })
229    }
230
231    fn get_storage_at(
232        &self,
233        address: Address,
234        key: U256,
235    ) -> RpcWithBlock<(Address, U256), StorageValue> {
236        let client = self.inner.weak_client();
237        let cache = self.cache.clone();
238        RpcWithBlock::new_provider(move |block_id| {
239            let req = RequestType::new("eth_getStorageAt", (address, key)).with_block_id(block_id);
240            cache_rpc_call_with_block!(cache, client, req)
241        })
242    }
243
244    fn get_transaction_by_hash(
245        &self,
246        hash: TxHash,
247    ) -> ProviderCall<(TxHash,), Option<N::TransactionResponse>> {
248        let req = RequestType::new("eth_getTransactionByHash", (hash,));
249
250        let params_hash = req.params_hash().ok();
251
252        if let Some(hash) = params_hash {
253            if let Ok(Some(cached)) = self.cache.get_deserialized(&hash) {
254                return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
255            }
256        }
257        let client = self.inner.weak_client();
258        let cache = self.cache.clone();
259        ProviderCall::BoxedFuture(Box::pin(async move {
260            let client = client
261                .upgrade()
262                .ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"))?;
263            let result = client.request(req.method(), req.params()).await?;
264
265            let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
266            let hash = req.params_hash()?;
267            let _ = cache.put(hash, json_str);
268
269            Ok(result)
270        }))
271    }
272
273    fn get_raw_transaction_by_hash(&self, hash: TxHash) -> ProviderCall<(TxHash,), Option<Bytes>> {
274        let req = RequestType::new("eth_getRawTransactionByHash", (hash,));
275
276        let params_hash = req.params_hash().ok();
277
278        if let Some(hash) = params_hash {
279            if let Ok(Some(cached)) = self.cache.get_deserialized(&hash) {
280                return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
281            }
282        }
283
284        let client = self.inner.weak_client();
285        let cache = self.cache.clone();
286        ProviderCall::BoxedFuture(Box::pin(async move {
287            let client = client
288                .upgrade()
289                .ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"))?;
290
291            let result = client.request(req.method(), req.params()).await?;
292
293            let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
294            let hash = req.params_hash()?;
295            let _ = cache.put(hash, json_str);
296
297            Ok(result)
298        }))
299    }
300
301    fn get_transaction_receipt(
302        &self,
303        hash: TxHash,
304    ) -> ProviderCall<(TxHash,), Option<N::ReceiptResponse>> {
305        let req = RequestType::new("eth_getTransactionReceipt", (hash,));
306
307        let params_hash = req.params_hash().ok();
308
309        if let Some(hash) = params_hash {
310            if let Ok(Some(cached)) = self.cache.get_deserialized(&hash) {
311                return ProviderCall::BoxedFuture(Box::pin(async move { Ok(cached) }));
312            }
313        }
314
315        let client = self.inner.weak_client();
316        let cache = self.cache.clone();
317        ProviderCall::BoxedFuture(Box::pin(async move {
318            let client = client
319                .upgrade()
320                .ok_or_else(|| TransportErrorKind::custom_str("RPC client dropped"))?;
321
322            let result = client.request(req.method(), req.params()).await?;
323
324            let json_str = serde_json::to_string(&result).map_err(TransportErrorKind::custom)?;
325            let hash = req.params_hash()?;
326            let _ = cache.put(hash, json_str);
327
328            Ok(result)
329        }))
330    }
331}
332
333/// Internal type to handle different types of requests and generating their param hashes.
334struct RequestType<Params: RpcSend> {
335    method: &'static str,
336    params: Params,
337    block_id: Option<BlockId>,
338}
339
340impl<Params: RpcSend> RequestType<Params> {
341    const fn new(method: &'static str, params: Params) -> Self {
342        Self { method, params, block_id: None }
343    }
344
345    const fn with_block_id(mut self, block_id: BlockId) -> Self {
346        self.block_id = Some(block_id);
347        self
348    }
349
350    fn params_hash(&self) -> TransportResult<B256> {
351        // Merge the method + params and hash them.
352        let hash = serde_json::to_string(&self.params())
353            .map(|p| keccak256(format!("{}{}", self.method(), p).as_bytes()))
354            .map_err(RpcError::ser_err)?;
355
356        Ok(hash)
357    }
358
359    const fn method(&self) -> &'static str {
360        self.method
361    }
362
363    fn params(&self) -> Params {
364        self.params.clone()
365    }
366
367    /// Returns true if the BlockId has been set to a tag value such as "latest", "earliest", or
368    /// "pending".
369    const fn has_block_tag(&self) -> bool {
370        if let Some(block_id) = self.block_id {
371            return !matches!(
372                block_id,
373                BlockId::Hash(_) | BlockId::Number(BlockNumberOrTag::Number(_))
374            );
375        }
376        false
377    }
378}
379
380#[derive(Debug, Serialize, Deserialize)]
381struct FsCacheEntry {
382    /// Hash of the request params
383    key: B256,
384    /// Serialized response to the request from which the hash was computed.
385    value: String,
386}
387
388/// Shareable cache.
389#[derive(Debug, Clone)]
390pub struct SharedCache {
391    inner: Arc<RwLock<LruCache<B256, String, alloy_primitives::map::FbBuildHasher<32>>>>,
392    max_items: NonZero<usize>,
393}
394
395impl SharedCache {
396    /// Instantiate a new shared cache.
397    pub fn new(max_items: u32) -> Self {
398        let max_items = NonZero::new(max_items as usize).unwrap_or(NonZero::<usize>::MIN);
399        let inner = Arc::new(RwLock::new(LruCache::with_hasher(max_items, Default::default())));
400        Self { inner, max_items }
401    }
402
403    /// Maximum number of items that can be stored in the cache.
404    pub const fn max_items(&self) -> u32 {
405        self.max_items.get() as u32
406    }
407
408    /// Puts a value into the cache, and returns the old value if it existed.
409    pub fn put(&self, key: B256, value: String) -> TransportResult<bool> {
410        Ok(self.inner.write().put(key, value).is_some())
411    }
412
413    /// Gets a value from the cache, if it exists.
414    pub fn get(&self, key: &B256) -> Option<String> {
415        // Need to acquire a write guard to change the order of keys in LRU cache.
416        self.inner.write().get(key).cloned()
417    }
418
419    /// Get deserialized value from the cache.
420    pub fn get_deserialized<T>(&self, key: &B256) -> TransportResult<Option<T>>
421    where
422        T: for<'de> Deserialize<'de>,
423    {
424        let Some(cached) = self.get(key) else { return Ok(None) };
425        let result = serde_json::from_str(&cached).map_err(TransportErrorKind::custom)?;
426        Ok(Some(result))
427    }
428
429    /// Saves the cache to a file specified by the path.
430    /// If the files does not exist, it creates one.
431    /// If the file exists, it overwrites it.
432    pub fn save_cache(&self, path: PathBuf) -> TransportResult<()> {
433        let entries: Vec<FsCacheEntry> = {
434            self.inner
435                .read()
436                .iter()
437                .map(|(key, value)| FsCacheEntry { key: *key, value: value.clone() })
438                .collect()
439        };
440        let file = std::fs::File::create(path).map_err(TransportErrorKind::custom)?;
441        serde_json::to_writer(file, &entries).map_err(TransportErrorKind::custom)?;
442        Ok(())
443    }
444
445    /// Loads the cache from a file specified by the path.
446    /// If the file does not exist, it returns without error.
447    pub fn load_cache(&self, path: PathBuf) -> TransportResult<()> {
448        if !path.exists() {
449            return Ok(());
450        };
451        let file = std::fs::File::open(path).map_err(TransportErrorKind::custom)?;
452        let file = BufReader::new(file);
453        let entries: Vec<FsCacheEntry> =
454            serde_json::from_reader(file).map_err(TransportErrorKind::custom)?;
455        let mut cache = self.inner.write();
456        for entry in entries {
457            cache.put(entry.key, entry.value);
458        }
459
460        Ok(())
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use crate::ProviderBuilder;
468    use alloy_network::TransactionBuilder;
469    use alloy_node_bindings::{utils::run_with_tempdir, Anvil};
470    use alloy_primitives::{bytes, hex, Bytes, FixedBytes};
471    use alloy_rpc_types_eth::{BlockId, TransactionRequest};
472
473    #[tokio::test]
474    async fn test_get_proof() {
475        run_with_tempdir("get-proof", |dir| async move {
476            let cache_layer = CacheLayer::new(100);
477            let shared_cache = cache_layer.cache();
478            let anvil = Anvil::new().block_time_f64(0.3).spawn();
479            let provider = ProviderBuilder::new().layer(cache_layer).on_http(anvil.endpoint_url());
480
481            let from = anvil.addresses()[0];
482            let path = dir.join("rpc-cache-proof.txt");
483
484            shared_cache.load_cache(path.clone()).unwrap();
485
486            let calldata: Bytes = "0x6080604052348015600f57600080fd5b506101f28061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c80633fb5c1cb146100465780638381f58a14610062578063d09de08a14610080575b600080fd5b610060600480360381019061005b91906100ee565b61008a565b005b61006a610094565b604051610077919061012a565b60405180910390f35b61008861009a565b005b8060008190555050565b60005481565b6000808154809291906100ac90610174565b9190505550565b600080fd5b6000819050919050565b6100cb816100b8565b81146100d657600080fd5b50565b6000813590506100e8816100c2565b92915050565b600060208284031215610104576101036100b3565b5b6000610112848285016100d9565b91505092915050565b610124816100b8565b82525050565b600060208201905061013f600083018461011b565b92915050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052601160045260246000fd5b600061017f826100b8565b91507fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff82036101b1576101b0610145565b5b60018201905091905056fea264697066735822122067ac0f21f648b0cacd1b7260772852ad4a0f63e2cc174168c51a6887fd5197a964736f6c634300081a0033".parse().unwrap();
487
488            let tx = TransactionRequest::default()
489                .with_from(from)
490                .with_input(calldata)
491                .with_max_fee_per_gas(1_000_000_000)
492                .with_max_priority_fee_per_gas(1_000_000)
493                .with_gas_limit(1_000_000)
494                .with_nonce(0);
495
496            let tx_receipt = provider.send_transaction(tx).await.unwrap().get_receipt().await.unwrap();
497
498            let counter_addr = tx_receipt.contract_address.unwrap();
499
500            let keys = vec![
501                FixedBytes::with_last_byte(0),
502                FixedBytes::with_last_byte(0x1),
503                FixedBytes::with_last_byte(0x2),
504                FixedBytes::with_last_byte(0x3),
505                FixedBytes::with_last_byte(0x4),
506            ];
507
508            let proof =
509                provider.get_proof(counter_addr, keys.clone()).block_id(1.into()).await.unwrap();
510            let proof2 = provider.get_proof(counter_addr, keys).block_id(1.into()).await.unwrap();
511
512            assert_eq!(proof, proof2);
513
514            shared_cache.save_cache(path).unwrap();
515        }).await;
516    }
517
518    #[tokio::test]
519    async fn test_get_tx_by_hash_and_receipt() {
520        run_with_tempdir("get-tx-by-hash", |dir| async move {
521            let cache_layer = CacheLayer::new(100);
522            let shared_cache = cache_layer.cache();
523            let anvil = Anvil::new().block_time_f64(0.3).spawn();
524            let provider = ProviderBuilder::new()
525                .disable_recommended_fillers()
526                .layer(cache_layer)
527                .on_http(anvil.endpoint_url());
528
529            let path = dir.join("rpc-cache-tx.txt");
530            shared_cache.load_cache(path.clone()).unwrap();
531
532            let req = TransactionRequest::default()
533                .from(anvil.addresses()[0])
534                .to(Address::repeat_byte(5))
535                .value(U256::ZERO)
536                .input(bytes!("deadbeef").into());
537
538            let tx_hash =
539                *provider.send_transaction(req).await.expect("failed to send tx").tx_hash();
540
541            let tx = provider.get_transaction_by_hash(tx_hash).await.unwrap(); // Received from RPC.
542            let tx2 = provider.get_transaction_by_hash(tx_hash).await.unwrap(); // Received from cache.
543            assert_eq!(tx, tx2);
544
545            let receipt = provider.get_transaction_receipt(tx_hash).await.unwrap(); // Received from RPC.
546            let receipt2 = provider.get_transaction_receipt(tx_hash).await.unwrap(); // Received from cache.
547
548            assert_eq!(receipt, receipt2);
549
550            shared_cache.save_cache(path).unwrap();
551        })
552        .await;
553    }
554
555    #[tokio::test]
556    async fn test_block_receipts() {
557        run_with_tempdir("get-block-receipts", |dir| async move {
558            let cache_layer = CacheLayer::new(100);
559            let shared_cache = cache_layer.cache();
560            let anvil = Anvil::new().spawn();
561            let provider = ProviderBuilder::new().layer(cache_layer).on_http(anvil.endpoint_url());
562
563            let path = dir.join("rpc-cache-block-receipts.txt");
564            shared_cache.load_cache(path.clone()).unwrap();
565
566            // Send txs
567
568            let receipt = provider
569                    .send_raw_transaction(
570                        // Transfer 1 ETH from default EOA address to the Genesis address.
571                        bytes!("f865808477359400825208940000000000000000000000000000000000000000018082f4f5a00505e227c1c636c76fac55795db1a40a4d24840d81b40d2fe0cc85767f6bd202a01e91b437099a8a90234ac5af3cb7ca4fb1432e133f75f9a91678eaf5f487c74b").as_ref()
572                    )
573                    .await.unwrap().get_receipt().await.unwrap();
574
575            let block_number = receipt.block_number.unwrap();
576
577            let receipts =
578                provider.get_block_receipts(block_number.into()).await.unwrap(); // Received from RPC.
579            let receipts2 =
580                provider.get_block_receipts(block_number.into()).await.unwrap(); // Received from cache.
581            assert_eq!(receipts, receipts2);
582
583            assert!(receipts.is_some_and(|r| r[0] == receipt));
584
585            shared_cache.save_cache(path).unwrap();
586        })
587        .await
588    }
589
590    #[tokio::test]
591    async fn test_get_code() {
592        run_with_tempdir("get-code", |dir| async move {
593            let cache_layer = CacheLayer::new(100);
594            let shared_cache = cache_layer.cache();
595            let provider = ProviderBuilder::default().with_gas_estimation().layer(cache_layer).on_anvil_with_wallet();
596
597            let path = dir.join("rpc-cache-code.txt");
598            shared_cache.load_cache(path.clone()).unwrap();
599
600            let bytecode = hex::decode(
601                // solc v0.8.26; solc Counter.sol --via-ir --optimize --bin
602                "6080806040523460135760df908160198239f35b600080fdfe6080806040526004361015601257600080fd5b60003560e01c9081633fb5c1cb1460925781638381f58a146079575063d09de08a14603c57600080fd5b3460745760003660031901126074576000546000198114605e57600101600055005b634e487b7160e01b600052601160045260246000fd5b600080fd5b3460745760003660031901126074576020906000548152f35b34607457602036600319011260745760043560005500fea2646970667358221220e978270883b7baed10810c4079c941512e93a7ba1cd1108c781d4bc738d9090564736f6c634300081a0033"
603            ).unwrap();
604            let tx = TransactionRequest::default().with_nonce(0).with_deploy_code(bytecode).with_chain_id(31337);
605
606            let receipt = provider.send_transaction(tx).await.unwrap().get_receipt().await.unwrap();
607
608            let counter_addr = receipt.contract_address.unwrap();
609
610            let block_id = BlockId::number(receipt.block_number.unwrap());
611
612            let code = provider.get_code_at(counter_addr).block_id(block_id).await.unwrap(); // Received from RPC.
613            let code2 = provider.get_code_at(counter_addr).block_id(block_id).await.unwrap(); // Received from cache.
614            assert_eq!(code, code2);
615
616            shared_cache.save_cache(path).unwrap();
617        })
618        .await;
619    }
620}