ethers_providers/ext/
dev_rpc.rs1use crate::{Middleware, MiddlewareError, ProviderError};
37use async_trait::async_trait;
38use ethers_core::types::U256;
39use thiserror::Error;
40
41use std::fmt::Debug;
42
43#[derive(Clone, Debug)]
45pub struct DevRpcMiddleware<M>(M);
46
47#[derive(Error, Debug)]
49pub enum DevRpcMiddlewareError<M: Middleware> {
50 #[error("{0}")]
52 MiddlewareError(M::Error),
53
54 #[error("{0}")]
56 ProviderError(ProviderError),
57
58 #[error("Could not revert to snapshot")]
60 NoSnapshot,
61}
62
63#[async_trait]
64impl<M: Middleware> Middleware for DevRpcMiddleware<M> {
65 type Error = DevRpcMiddlewareError<M>;
66 type Provider = M::Provider;
67 type Inner = M;
68
69 fn inner(&self) -> &M {
70 &self.0
71 }
72}
73
74impl<M: Middleware> MiddlewareError for DevRpcMiddlewareError<M> {
75 type Inner = M::Error;
76
77 fn from_err(src: M::Error) -> DevRpcMiddlewareError<M> {
78 DevRpcMiddlewareError::MiddlewareError(src)
79 }
80
81 fn as_inner(&self) -> Option<&Self::Inner> {
82 match self {
83 DevRpcMiddlewareError::MiddlewareError(e) => Some(e),
84 _ => None,
85 }
86 }
87}
88
89impl<M> From<ProviderError> for DevRpcMiddlewareError<M>
90where
91 M: Middleware,
92{
93 fn from(src: ProviderError) -> Self {
94 Self::ProviderError(src)
95 }
96}
97
98impl<M: Middleware> DevRpcMiddleware<M> {
99 pub fn new(inner: M) -> Self {
101 Self(inner)
102 }
103
104 pub async fn snapshot(&self) -> Result<U256, DevRpcMiddlewareError<M>> {
110 self.provider().request::<(), U256>("evm_snapshot", ()).await.map_err(From::from)
111 }
112
113 pub async fn revert_to_snapshot(&self, id: U256) -> Result<(), DevRpcMiddlewareError<M>> {
115 let ok = self
116 .provider()
117 .request::<[U256; 1], bool>("evm_revert", [id])
118 .await
119 .map_err(DevRpcMiddlewareError::ProviderError)?;
120 if ok {
121 Ok(())
122 } else {
123 Err(DevRpcMiddlewareError::NoSnapshot)
124 }
125 }
126}
127
128#[cfg(test)]
129#[cfg(not(feature = "celo"))]
131mod tests {
132 use super::*;
133 use crate::{Http, Provider};
134 use ethers_core::utils::Anvil;
135 use std::convert::TryFrom;
136
137 #[tokio::test]
138 async fn test_snapshot() {
139 let anvil = Anvil::new().spawn();
140 let provider = Provider::<Http>::try_from(anvil.endpoint()).unwrap();
141 let client = DevRpcMiddleware::new(provider);
142
143 let block0 = client.get_block_number().await.unwrap();
145 let time0 = client.get_block(block0).await.unwrap().unwrap().timestamp;
146 let snap_id0 = client.snapshot().await.unwrap();
147
148 client.provider().mine(1).await.unwrap();
150
151 let block1 = client.get_block_number().await.unwrap();
153 let time1 = client.get_block(block1).await.unwrap().unwrap().timestamp;
154 let snap_id1 = client.snapshot().await.unwrap();
155
156 client.provider().mine(5).await.unwrap();
158
159 let block2 = client.get_block_number().await.unwrap();
161 let time2 = client.get_block(block2).await.unwrap().unwrap().timestamp;
162 let snap_id2 = client.snapshot().await.unwrap();
163
164 client.provider().mine(5).await.unwrap();
166
167 client.revert_to_snapshot(snap_id2).await.unwrap();
169 let block = client.get_block_number().await.unwrap();
170 let time = client.get_block(block).await.unwrap().unwrap().timestamp;
171 assert_eq!(block, block2);
172 assert_eq!(time, time2);
173
174 client.revert_to_snapshot(snap_id1).await.unwrap();
175 let block = client.get_block_number().await.unwrap();
176 let time = client.get_block(block).await.unwrap().unwrap().timestamp;
177 assert_eq!(block, block1);
178 assert_eq!(time, time1);
179
180 let result = client.revert_to_snapshot(snap_id1).await;
183 assert!(result.is_err());
184
185 client.revert_to_snapshot(snap_id0).await.unwrap();
186 let block = client.get_block_number().await.unwrap();
187 let time = client.get_block(block).await.unwrap().unwrap().timestamp;
188 assert_eq!(block, block0);
189 assert_eq!(time, time0);
190 }
191}