tempo_transaction_pool/
state_cache.rs1use alloy_primitives::{Address, B256, U256, map::DefaultHashBuilder};
4use dashmap::DashMap;
5use revm::{Database, DatabaseRef, bytecode::Bytecode, state::AccountInfo};
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8#[derive(Debug, Default)]
18pub(crate) struct StateCache {
19 accounts: DashMap<Address, Option<AccountInfo>, DefaultHashBuilder>,
21 storage: DashMap<(Address, U256), U256, DefaultHashBuilder>,
23 contracts: DashMap<B256, Bytecode, DefaultHashBuilder>,
25 account_count: AtomicUsize,
28 storage_count: AtomicUsize,
29 contract_count: AtomicUsize,
30}
31
32impl StateCache {
33 const MAX_ACCOUNTS: usize = 1 << 17;
38 const MAX_STORAGE_SLOTS: usize = 1 << 18;
40 const MAX_CONTRACTS: usize = 1 << 12;
42}
43
44#[derive(Debug)]
47pub(crate) struct StateCacheDb<'a, DB> {
48 cache: &'a StateCache,
50 db: DB,
52}
53
54impl<'a, DB> StateCacheDb<'a, DB> {
55 pub(crate) const fn new(cache: &'a StateCache, db: DB) -> Self {
56 Self { cache, db }
57 }
58}
59
60impl<DB: DatabaseRef> DatabaseRef for StateCacheDb<'_, DB> {
61 type Error = DB::Error;
62
63 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
64 if let Some(account) = self.cache.accounts.get(&address) {
65 return Ok(account.clone());
66 }
67 let account = self.db.basic_ref(address)?;
68 if self.cache.account_count.load(Ordering::Relaxed) < StateCache::MAX_ACCOUNTS
69 && self
70 .cache
71 .accounts
72 .insert(address, account.clone())
73 .is_none()
74 {
75 self.cache.account_count.fetch_add(1, Ordering::Relaxed);
76 }
77 Ok(account)
78 }
79
80 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
81 if let Some(code) = self.cache.contracts.get(&code_hash) {
82 return Ok(code.clone());
83 }
84 let code = self.db.code_by_hash_ref(code_hash)?;
85 if self.cache.contract_count.load(Ordering::Relaxed) < StateCache::MAX_CONTRACTS
86 && self
87 .cache
88 .contracts
89 .insert(code_hash, code.clone())
90 .is_none()
91 {
92 self.cache.contract_count.fetch_add(1, Ordering::Relaxed);
93 }
94 Ok(code)
95 }
96
97 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
98 if let Some(value) = self.cache.storage.get(&(address, index)) {
99 return Ok(*value);
100 }
101 let value = self.db.storage_ref(address, index)?;
102 if self.cache.storage_count.load(Ordering::Relaxed) < StateCache::MAX_STORAGE_SLOTS
103 && self.cache.storage.insert((address, index), value).is_none()
104 {
105 self.cache.storage_count.fetch_add(1, Ordering::Relaxed);
106 }
107 Ok(value)
108 }
109
110 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
111 self.db.block_hash_ref(number)
112 }
113}
114
115impl<DB: DatabaseRef> Database for StateCacheDb<'_, DB> {
116 type Error = DB::Error;
117
118 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
119 self.basic_ref(address)
120 }
121
122 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
123 self.code_by_hash_ref(code_hash)
124 }
125
126 fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
127 self.storage_ref(address, index)
128 }
129
130 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
131 self.block_hash_ref(number)
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use std::sync::atomic::{AtomicUsize, Ordering};
139
140 #[derive(Default)]
141 struct CountingDb {
142 reads: AtomicUsize,
143 }
144
145 impl DatabaseRef for CountingDb {
146 type Error = core::convert::Infallible;
147
148 fn basic_ref(&self, _address: Address) -> Result<Option<AccountInfo>, Self::Error> {
149 self.reads.fetch_add(1, Ordering::Relaxed);
150 Ok(Some(AccountInfo {
151 balance: U256::from(1),
152 ..Default::default()
153 }))
154 }
155
156 fn code_by_hash_ref(&self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
157 self.reads.fetch_add(1, Ordering::Relaxed);
158 Ok(Bytecode::default())
159 }
160
161 fn storage_ref(&self, _address: Address, _index: U256) -> Result<U256, Self::Error> {
162 self.reads.fetch_add(1, Ordering::Relaxed);
163 Ok(U256::from(42))
164 }
165
166 fn block_hash_ref(&self, _number: u64) -> Result<B256, Self::Error> {
167 Ok(B256::ZERO)
168 }
169 }
170
171 #[test]
172 fn caches_reads_across_instances() {
173 let cache = StateCache::default();
174 let inner = CountingDb::default();
175 let address = Address::with_last_byte(1);
176 let slot = U256::from(7);
177
178 {
179 let db = StateCacheDb::new(&cache, &inner);
180 assert_eq!(db.storage_ref(address, slot).unwrap(), U256::from(42));
181 assert!(db.basic_ref(address).unwrap().is_some());
182 }
183 assert_eq!(inner.reads.load(Ordering::Relaxed), 2);
184
185 let db = StateCacheDb::new(&cache, &inner);
187 assert_eq!(db.storage_ref(address, slot).unwrap(), U256::from(42));
188 assert!(db.basic_ref(address).unwrap().is_some());
189 assert_eq!(inner.reads.load(Ordering::Relaxed), 2);
190 }
191}