ethers_middleware/
nonce_manager.rs

1use async_trait::async_trait;
2use ethers_core::types::{transaction::eip2718::TypedTransaction, *};
3use ethers_providers::{Middleware, MiddlewareError, PendingTransaction};
4use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
5use thiserror::Error;
6
7#[derive(Debug)]
8/// Middleware used for calculating nonces locally, useful for signing multiple
9/// consecutive transactions without waiting for them to hit the mempool
10pub struct NonceManagerMiddleware<M> {
11    inner: M,
12    init_guard: futures_locks::Mutex<()>,
13    initialized: AtomicBool,
14    nonce: AtomicU64,
15    address: Address,
16}
17
18impl<M> NonceManagerMiddleware<M>
19where
20    M: Middleware,
21{
22    /// Instantiates the nonce manager with a 0 nonce. The `address` should be the
23    /// address which you'll be sending transactions from
24    pub fn new(inner: M, address: Address) -> Self {
25        Self {
26            inner,
27            init_guard: Default::default(),
28            initialized: Default::default(),
29            nonce: Default::default(),
30            address,
31        }
32    }
33
34    /// Returns the next nonce to be used
35    pub fn next(&self) -> U256 {
36        let nonce = self.nonce.fetch_add(1, Ordering::SeqCst);
37        nonce.into()
38    }
39
40    pub async fn initialize_nonce(
41        &self,
42        block: Option<BlockId>,
43    ) -> Result<U256, NonceManagerError<M>> {
44        if self.initialized.load(Ordering::SeqCst) {
45            // return current nonce
46            return Ok(self.nonce.load(Ordering::SeqCst).into())
47        }
48
49        let _guard = self.init_guard.lock().await;
50
51        // do this again in case multiple tasks enter this codepath
52        if self.initialized.load(Ordering::SeqCst) {
53            // return current nonce
54            return Ok(self.nonce.load(Ordering::SeqCst).into())
55        }
56
57        // initialize the nonce the first time the manager is called
58        let nonce = self
59            .inner
60            .get_transaction_count(self.address, block)
61            .await
62            .map_err(MiddlewareError::from_err)?;
63        self.nonce.store(nonce.as_u64(), Ordering::SeqCst);
64        self.initialized.store(true, Ordering::SeqCst);
65        Ok(nonce)
66    } // guard dropped here
67
68    async fn get_transaction_count_with_manager(
69        &self,
70        block: Option<BlockId>,
71    ) -> Result<U256, NonceManagerError<M>> {
72        // initialize the nonce the first time the manager is called
73        if !self.initialized.load(Ordering::SeqCst) {
74            let nonce = self
75                .inner
76                .get_transaction_count(self.address, block)
77                .await
78                .map_err(MiddlewareError::from_err)?;
79            self.nonce.store(nonce.as_u64(), Ordering::SeqCst);
80            self.initialized.store(true, Ordering::SeqCst);
81        }
82
83        Ok(self.next())
84    }
85}
86
87#[derive(Error, Debug)]
88/// Thrown when an error happens at the Nonce Manager
89pub enum NonceManagerError<M: Middleware> {
90    /// Thrown when the internal middleware errors
91    #[error("{0}")]
92    MiddlewareError(M::Error),
93}
94
95impl<M: Middleware> MiddlewareError for NonceManagerError<M> {
96    type Inner = M::Error;
97
98    fn from_err(src: M::Error) -> Self {
99        NonceManagerError::MiddlewareError(src)
100    }
101
102    fn as_inner(&self) -> Option<&Self::Inner> {
103        match self {
104            NonceManagerError::MiddlewareError(e) => Some(e),
105        }
106    }
107}
108
109#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
110#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
111impl<M> Middleware for NonceManagerMiddleware<M>
112where
113    M: Middleware,
114{
115    type Error = NonceManagerError<M>;
116    type Provider = M::Provider;
117    type Inner = M;
118
119    fn inner(&self) -> &M {
120        &self.inner
121    }
122
123    async fn fill_transaction(
124        &self,
125        tx: &mut TypedTransaction,
126        block: Option<BlockId>,
127    ) -> Result<(), Self::Error> {
128        if tx.nonce().is_none() {
129            tx.set_nonce(self.get_transaction_count_with_manager(block).await?);
130        }
131
132        Ok(self.inner().fill_transaction(tx, block).await.map_err(MiddlewareError::from_err)?)
133    }
134
135    /// Signs and broadcasts the transaction. The optional parameter `block` can be passed so that
136    /// gas cost and nonce calculations take it into account. For simple transactions this can be
137    /// left to `None`.
138    async fn send_transaction<T: Into<TypedTransaction> + Send + Sync>(
139        &self,
140        tx: T,
141        block: Option<BlockId>,
142    ) -> Result<PendingTransaction<'_, Self::Provider>, Self::Error> {
143        let mut tx = tx.into();
144
145        if tx.nonce().is_none() {
146            tx.set_nonce(self.get_transaction_count_with_manager(block).await?);
147        }
148
149        match self.inner.send_transaction(tx.clone(), block).await {
150            Ok(tx_hash) => Ok(tx_hash),
151            Err(err) => {
152                let nonce = self.get_transaction_count(self.address, block).await?;
153                if nonce != self.nonce.load(Ordering::SeqCst).into() {
154                    // try re-submitting the transaction with the correct nonce if there
155                    // was a nonce mismatch
156                    self.nonce.store(nonce.as_u64(), Ordering::SeqCst);
157                    tx.set_nonce(nonce);
158                    self.inner.send_transaction(tx, block).await.map_err(MiddlewareError::from_err)
159                } else {
160                    // propagate the error otherwise
161                    Err(MiddlewareError::from_err(err))
162                }
163            }
164        }
165    }
166}