Skip to main content

tempo_transaction_pool/
state_cache.rs

1//! Tip-scoped concurrent cache of state reads shared across transaction validations.
2
3use alloy_primitives::{Address, B256, U256, map::DefaultHashBuilder};
4use dashmap::DashMap;
5use revm::{Database, DatabaseRef, bytecode::Bytecode, state::AccountInfo};
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8/// Concurrent cache of raw state reads anchored to a specific tip.
9///
10/// Transaction validation repeatedly reads the same state (system contract configuration,
11/// fee token slots, sender accounts) for every transaction. This cache shares those reads
12/// across all concurrent validation calls so only the first access hits the underlying
13/// state provider.
14///
15/// The validator replaces the cache whenever a new head block is processed, mirroring the
16/// lifecycle of its cached EVM environment.
17#[derive(Debug, Default)]
18pub(crate) struct StateCache {
19    /// Cached basic account info, including non-existent accounts (`None`).
20    accounts: DashMap<Address, Option<AccountInfo>, DefaultHashBuilder>,
21    /// Cached storage values keyed by account and slot.
22    storage: DashMap<(Address, U256), U256, DefaultHashBuilder>,
23    /// Cached bytecode keyed by code hash.
24    contracts: DashMap<B256, Bytecode, DefaultHashBuilder>,
25    /// Approximate entry counts for cap enforcement; `DashMap::len` locks every shard and is
26    /// too expensive for the insert path. Racing inserts may overshoot the caps slightly.
27    account_count: AtomicUsize,
28    storage_count: AtomicUsize,
29    contract_count: AtomicUsize,
30}
31
32impl StateCache {
33    /// Maximum number of cached accounts.
34    ///
35    /// The caps bound memory if a flood of unique accounts is validated within a single
36    /// block interval; once reached, additional reads fall through to the state provider.
37    const MAX_ACCOUNTS: usize = 1 << 17;
38    /// Maximum number of cached storage slots.
39    const MAX_STORAGE_SLOTS: usize = 1 << 18;
40    /// Maximum number of cached contracts.
41    const MAX_CONTRACTS: usize = 1 << 12;
42}
43
44/// A [`DatabaseRef`] adapter that serves reads from a shared [`StateCache`], falling back
45/// to the wrapped database and populating the cache on miss.
46#[derive(Debug)]
47pub(crate) struct StateCacheDb<'a, DB> {
48    /// The shared read cache.
49    cache: &'a StateCache,
50    /// The underlying database.
51    db: DB,
52}
53
54impl<'a, DB> StateCacheDb<'a, DB> {
55    pub(crate) const fn new(cache: &'a StateCache, db: DB) -> Self {
56        Self { cache, db }
57    }
58}
59
60impl<DB: DatabaseRef> DatabaseRef for StateCacheDb<'_, DB> {
61    type Error = DB::Error;
62
63    fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
64        if let Some(account) = self.cache.accounts.get(&address) {
65            return Ok(account.clone());
66        }
67        let account = self.db.basic_ref(address)?;
68        if self.cache.account_count.load(Ordering::Relaxed) < StateCache::MAX_ACCOUNTS
69            && self
70                .cache
71                .accounts
72                .insert(address, account.clone())
73                .is_none()
74        {
75            self.cache.account_count.fetch_add(1, Ordering::Relaxed);
76        }
77        Ok(account)
78    }
79
80    fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
81        if let Some(code) = self.cache.contracts.get(&code_hash) {
82            return Ok(code.clone());
83        }
84        let code = self.db.code_by_hash_ref(code_hash)?;
85        if self.cache.contract_count.load(Ordering::Relaxed) < StateCache::MAX_CONTRACTS
86            && self
87                .cache
88                .contracts
89                .insert(code_hash, code.clone())
90                .is_none()
91        {
92            self.cache.contract_count.fetch_add(1, Ordering::Relaxed);
93        }
94        Ok(code)
95    }
96
97    fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
98        if let Some(value) = self.cache.storage.get(&(address, index)) {
99            return Ok(*value);
100        }
101        let value = self.db.storage_ref(address, index)?;
102        if self.cache.storage_count.load(Ordering::Relaxed) < StateCache::MAX_STORAGE_SLOTS
103            && self.cache.storage.insert((address, index), value).is_none()
104        {
105            self.cache.storage_count.fetch_add(1, Ordering::Relaxed);
106        }
107        Ok(value)
108    }
109
110    fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
111        self.db.block_hash_ref(number)
112    }
113}
114
115impl<DB: DatabaseRef> Database for StateCacheDb<'_, DB> {
116    type Error = DB::Error;
117
118    fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
119        self.basic_ref(address)
120    }
121
122    fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
123        self.code_by_hash_ref(code_hash)
124    }
125
126    fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
127        self.storage_ref(address, index)
128    }
129
130    fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
131        self.block_hash_ref(number)
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use std::sync::atomic::{AtomicUsize, Ordering};
139
140    #[derive(Default)]
141    struct CountingDb {
142        reads: AtomicUsize,
143    }
144
145    impl DatabaseRef for CountingDb {
146        type Error = core::convert::Infallible;
147
148        fn basic_ref(&self, _address: Address) -> Result<Option<AccountInfo>, Self::Error> {
149            self.reads.fetch_add(1, Ordering::Relaxed);
150            Ok(Some(AccountInfo {
151                balance: U256::from(1),
152                ..Default::default()
153            }))
154        }
155
156        fn code_by_hash_ref(&self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
157            self.reads.fetch_add(1, Ordering::Relaxed);
158            Ok(Bytecode::default())
159        }
160
161        fn storage_ref(&self, _address: Address, _index: U256) -> Result<U256, Self::Error> {
162            self.reads.fetch_add(1, Ordering::Relaxed);
163            Ok(U256::from(42))
164        }
165
166        fn block_hash_ref(&self, _number: u64) -> Result<B256, Self::Error> {
167            Ok(B256::ZERO)
168        }
169    }
170
171    #[test]
172    fn caches_reads_across_instances() {
173        let cache = StateCache::default();
174        let inner = CountingDb::default();
175        let address = Address::with_last_byte(1);
176        let slot = U256::from(7);
177
178        {
179            let db = StateCacheDb::new(&cache, &inner);
180            assert_eq!(db.storage_ref(address, slot).unwrap(), U256::from(42));
181            assert!(db.basic_ref(address).unwrap().is_some());
182        }
183        assert_eq!(inner.reads.load(Ordering::Relaxed), 2);
184
185        // A new adapter over the same cache serves all reads from memory.
186        let db = StateCacheDb::new(&cache, &inner);
187        assert_eq!(db.storage_ref(address, slot).unwrap(), U256::from(42));
188        assert!(db.basic_ref(address).unwrap().is_some());
189        assert_eq!(inner.reads.load(Ordering::Relaxed), 2);
190    }
191}