1use crate::{
6 amm::AmmLiquidityCache, best::MergeBestTransactions, transaction::TempoPooledTransaction,
7 tt_2d_pool::AA2dPool, validator::TempoTransactionValidator,
8};
9use alloy_consensus::Transaction;
10use alloy_primitives::{
11 Address, B256, TxHash,
12 map::{AddressMap, AddressSet, HashMap},
13};
14use parking_lot::RwLock;
15use reth_chainspec::ChainSpecProvider;
16use reth_eth_wire_types::HandleMempoolData;
17use reth_provider::{ChangedAccount, StateProviderFactory};
18use reth_storage_api::StateProvider;
19use reth_transaction_pool::{
20 AddedTransactionOutcome, AllPoolTransactions, BestTransactions, BestTransactionsAttributes,
21 BlockInfo, CanonicalStateUpdate, CoinbaseTipOrdering, GetPooledTransactionLimit,
22 NewBlobSidecar, Pool, PoolResult, PoolSize, PoolTransaction, PropagatedTransactions,
23 TransactionEvents, TransactionOrigin, TransactionPool, TransactionPoolExt,
24 TransactionValidationOutcome, TransactionValidationTaskExecutor, TransactionValidator,
25 ValidPoolTransaction,
26 blobstore::InMemoryBlobStore,
27 error::{PoolError, PoolErrorKind},
28 identifier::TransactionId,
29};
30use revm::database::BundleAccount;
31use std::{sync::Arc, time::Instant};
32use tempo_chainspec::{
33 TempoChainSpec,
34 hardfork::{TempoHardfork, TempoHardforks},
35};
36use tempo_precompiles::{
37 TIP_FEE_MANAGER_ADDRESS,
38 account_keychain::AccountKeychain,
39 error::Result as TempoPrecompileResult,
40 nonce::NonceManager,
41 storage::Handler,
42 tip20::TIP20Token,
43 tip403_registry::{REJECT_ALL_POLICY_ID, TIP403Registry},
44};
45use tempo_primitives::Block;
46use tempo_revm::TempoStateAccess;
47
48pub struct TempoTransactionPool<Client> {
50 protocol_pool: Pool<
52 TransactionValidationTaskExecutor<TempoTransactionValidator<Client>>,
53 CoinbaseTipOrdering<TempoPooledTransaction>,
54 InMemoryBlobStore,
55 >,
56 aa_2d_pool: Arc<RwLock<AA2dPool>>,
58}
59
60impl<Client> TempoTransactionPool<Client> {
61 pub fn new(
62 protocol_pool: Pool<
63 TransactionValidationTaskExecutor<TempoTransactionValidator<Client>>,
64 CoinbaseTipOrdering<TempoPooledTransaction>,
65 InMemoryBlobStore,
66 >,
67 aa_2d_pool: AA2dPool,
68 ) -> Self {
69 Self {
70 protocol_pool,
71 aa_2d_pool: Arc::new(RwLock::new(aa_2d_pool)),
72 }
73 }
74}
75impl<Client> TempoTransactionPool<Client>
76where
77 Client: StateProviderFactory + ChainSpecProvider<ChainSpec = TempoChainSpec> + 'static,
78{
79 pub fn amm_liquidity_cache(&self) -> AmmLiquidityCache {
81 self.protocol_pool
82 .validator()
83 .validator()
84 .amm_liquidity_cache()
85 }
86
87 pub fn client(&self) -> &Client {
89 self.protocol_pool.validator().validator().client()
90 }
91
92 pub(crate) fn notify_aa_pool_on_state_updates(&self, state: &AddressMap<BundleAccount>) {
94 let (promoted, _mined) = self.aa_2d_pool.write().on_state_updates(state);
95 self.protocol_pool
97 .inner()
98 .notify_on_transaction_updates(promoted, Vec::new());
99 }
100
101 pub(crate) fn reset_2d_nonces_from_state(
105 &self,
106 seq_ids: Vec<crate::tt_2d_pool::AASequenceId>,
107 block_hash: B256,
108 ) -> Result<(), reth_provider::ProviderError> {
109 if seq_ids.is_empty() {
110 return Ok(());
111 }
112
113 let spec = TempoHardfork::default();
115 let mut state_provider = self.client().state_by_block_hash(block_hash)?;
116
117 let nonce_changes = state_provider
118 .with_read_only_storage_ctx(spec, || -> TempoPrecompileResult<_> {
119 let mut changes = HashMap::default();
120 for id in &seq_ids {
122 let current_nonce =
123 NonceManager::new().nonces[id.address][id.nonce_key].read()?;
124 changes.insert(*id, current_nonce);
125 }
126 Ok(changes)
127 })
128 .map_err(reth_provider::ProviderError::other)?;
129
130 let (promoted, _mined) = self.aa_2d_pool.write().on_nonce_changes(nonce_changes);
132 if !promoted.is_empty() {
133 self.protocol_pool
134 .inner()
135 .notify_on_transaction_updates(promoted, Vec::new());
136 }
137
138 Ok(())
139 }
140
141 pub(crate) fn remove_included_expiring_nonce_txs<'a>(
146 &self,
147 tx_hashes: impl Iterator<Item = &'a TxHash>,
148 ) {
149 self.aa_2d_pool
150 .write()
151 .remove_included_expiring_nonce_txs(tx_hashes);
152 }
153
154 pub fn evict_invalidated_transactions(
168 &self,
169 updates: &crate::maintain::TempoPoolUpdates,
170 ) -> Vec<TxHash> {
171 if !updates.has_invalidation_events() {
172 return Vec::new();
173 }
174
175 let mut state_provider = if !updates.validator_token_changes.is_empty()
180 || !updates.blacklist_additions.is_empty()
181 || !updates.whitelist_removals.is_empty()
182 || !updates.spending_limit_spends.is_empty()
183 {
184 self.client().latest().ok()
185 } else {
186 None
187 };
188
189 let tip_timestamp = self
191 .protocol_pool
192 .validator()
193 .validator()
194 .inner
195 .fork_tracker()
196 .tip_timestamp();
197 let spec = self.client().chain_spec().tempo_hardfork_at(tip_timestamp);
198
199 let mut policy_cache: AddressMap<Vec<u64>> = AddressMap::default();
203
204 let fee_manager_blacklisted: Vec<u64> = updates
208 .blacklist_additions
209 .iter()
210 .filter(|(_, account)| *account == TIP_FEE_MANAGER_ADDRESS)
211 .map(|(policy_id, _)| *policy_id)
212 .collect();
213 let fee_manager_unwhitelisted: Vec<u64> = updates
214 .whitelist_removals
215 .iter()
216 .filter(|(_, account)| *account == TIP_FEE_MANAGER_ADDRESS)
217 .map(|(policy_id, _)| *policy_id)
218 .collect();
219
220 let amm_cache = self.amm_liquidity_cache();
224 let has_active_validator_token_changes = !updates.validator_token_changes.is_empty() && {
225 let active_new_tokens: Vec<_> = updates
226 .validator_token_changes
227 .iter()
228 .filter(|(validator, _)| amm_cache.is_active_validator(validator))
229 .filter(|(_, new_token)| !amm_cache.is_active_validator_token(new_token))
230 .map(|(_, new_token)| *new_token)
231 .collect();
232 amm_cache.track_tokens(&active_new_tokens)
233 };
234
235 let mut to_remove = Vec::new();
236 let mut revoked_count = 0;
237 let mut spending_limit_count = 0;
238 let mut spending_limit_spend_count = 0;
239 let mut liquidity_count = 0;
240 let mut user_token_count = 0;
241 let mut blacklisted_count = 0;
242 let mut unwhitelisted_count = 0;
243
244 let all_txs = self.all_transactions();
245 for tx in all_txs.pending.iter().chain(all_txs.queued.iter()) {
246 let keychain_subject = tx.transaction.keychain_subject();
248
249 if !updates.revoked_keys.is_empty()
251 && let Some(ref subject) = keychain_subject
252 && subject.matches_revoked(&updates.revoked_keys)
253 {
254 to_remove.push(*tx.hash());
255 revoked_count += 1;
256 continue;
257 }
258
259 if !updates.spending_limit_changes.is_empty()
262 && let Some(ref subject) = keychain_subject
263 && subject.matches_spending_limit_update(&updates.spending_limit_changes)
264 {
265 to_remove.push(*tx.hash());
266 spending_limit_count += 1;
267 continue;
268 }
269
270 if !updates.spending_limit_spends.is_empty()
276 && let Some(ref subject) = keychain_subject
277 && subject.matches_spending_limit_update(&updates.spending_limit_spends)
278 && let Some(ref mut provider) = state_provider
279 && exceeds_spending_limit(provider, subject, tx.transaction.fee_token_cost())
280 {
281 to_remove.push(*tx.hash());
282 spending_limit_spend_count += 1;
283 continue;
284 }
285
286 if has_active_validator_token_changes && let Some(ref provider) = state_provider {
291 let user_token = tx
292 .transaction
293 .inner()
294 .fee_token()
295 .unwrap_or(tempo_precompiles::DEFAULT_FEE_TOKEN);
296 let cost = tx.transaction.fee_token_cost();
297
298 match amm_cache.has_enough_liquidity(user_token, cost, &**provider) {
299 Ok(true) => {}
300 Ok(false) => {
301 to_remove.push(*tx.hash());
302 liquidity_count += 1;
303 continue;
304 }
305 Err(_) => continue,
306 }
307 }
308
309 if !updates.blacklist_additions.is_empty()
313 && let Some(ref mut provider) = state_provider
314 && let Some(fee_token) = tx.transaction.inner().fee_token()
315 {
316 let fee_payer = tx
317 .transaction
318 .inner()
319 .fee_payer(tx.transaction.sender())
320 .unwrap_or(tx.transaction.sender());
321
322 let mut sender_evicted = false;
324 for &(blacklist_policy_id, blacklisted_account) in &updates.blacklist_additions {
325 if fee_payer != blacklisted_account {
326 continue;
327 }
328
329 let token_policies =
330 get_sender_policy_ids(provider, fee_token, spec, &mut policy_cache);
331
332 if token_policies
333 .as_ref()
334 .is_some_and(|ids| ids.contains(&blacklist_policy_id))
335 {
336 sender_evicted = true;
337 break;
338 }
339 }
340
341 let recipient_evicted = !sender_evicted
345 && !fee_manager_blacklisted.is_empty()
346 && get_recipient_policy_ids(provider, fee_token, spec)
347 .is_some_and(|ids| fee_manager_blacklisted.iter().any(|p| ids.contains(p)));
348
349 if sender_evicted || recipient_evicted {
350 to_remove.push(*tx.hash());
351 blacklisted_count += 1;
352 }
353 }
354
355 if !updates.whitelist_removals.is_empty()
359 && let Some(ref mut provider) = state_provider
360 && let Some(fee_token) = tx.transaction.inner().fee_token()
361 {
362 let fee_payer = tx
363 .transaction
364 .inner()
365 .fee_payer(tx.transaction.sender())
366 .unwrap_or(tx.transaction.sender());
367
368 let mut sender_evicted = false;
369 for &(whitelist_policy_id, unwhitelisted_account) in &updates.whitelist_removals {
370 if fee_payer != unwhitelisted_account {
371 continue;
372 }
373
374 let token_policies =
375 get_sender_policy_ids(provider, fee_token, spec, &mut policy_cache);
376
377 if token_policies
378 .as_ref()
379 .is_some_and(|ids| ids.contains(&whitelist_policy_id))
380 {
381 sender_evicted = true;
382 break;
383 }
384 }
385
386 let recipient_evicted = !sender_evicted
389 && !fee_manager_unwhitelisted.is_empty()
390 && get_recipient_policy_ids(provider, fee_token, spec).is_some_and(|ids| {
391 fee_manager_unwhitelisted.iter().any(|p| ids.contains(p))
392 });
393
394 if sender_evicted || recipient_evicted {
395 to_remove.push(*tx.hash());
396 unwhitelisted_count += 1;
397 }
398 }
399
400 if !updates.user_token_changes.is_empty()
406 && tx.transaction.inner().fee_token().is_none()
407 && updates
408 .user_token_changes
409 .contains(&tx.transaction.sender())
410 {
411 to_remove.push(*tx.hash());
412 user_token_count += 1;
413 }
414 }
415
416 if !to_remove.is_empty() {
417 tracing::debug!(
418 target: "txpool",
419 total = to_remove.len(),
420 revoked_count,
421 spending_limit_count,
422 spending_limit_spend_count,
423 liquidity_count,
424 user_token_count,
425 blacklisted_count,
426 unwhitelisted_count,
427 "Evicting invalidated transactions"
428 );
429 self.remove_transactions(to_remove.clone());
430 }
431 to_remove
432 }
433
434 fn add_validated_transaction(
435 &self,
436 origin: TransactionOrigin,
437 transaction: TransactionValidationOutcome<TempoPooledTransaction>,
438 ) -> PoolResult<AddedTransactionOutcome> {
439 match transaction {
440 TransactionValidationOutcome::Valid {
441 balance,
442 state_nonce,
443 bytecode_hash,
444 transaction,
445 propagate,
446 authorities,
447 } => {
448 if transaction.transaction().is_aa_2d() {
449 let transaction = transaction.into_transaction();
450 let sender_id = self
451 .protocol_pool
452 .inner()
453 .get_sender_id(transaction.sender());
454 let transaction_id = TransactionId::new(sender_id, transaction.nonce());
455 let tx = ValidPoolTransaction {
456 transaction,
457 transaction_id,
458 propagate,
459 timestamp: Instant::now(),
460 origin,
461 authority_ids: authorities
462 .map(|auths| self.protocol_pool.inner().get_sender_ids(auths)),
463 };
464
465 let tip_timestamp = self
467 .protocol_pool
468 .validator()
469 .validator()
470 .inner
471 .fork_tracker()
472 .tip_timestamp();
473 let hardfork = self.client().chain_spec().tempo_hardfork_at(tip_timestamp);
474
475 let added = self.aa_2d_pool.write().add_transaction(
476 Arc::new(tx),
477 state_nonce,
478 hardfork,
479 )?;
480 let hash = *added.hash();
481 if let Some(pending) = added.as_pending() {
482 if pending.discarded.iter().any(|tx| *tx.hash() == hash) {
483 return Err(PoolError::new(hash, PoolErrorKind::DiscardedOnInsert));
484 }
485 self.protocol_pool
486 .inner()
487 .on_new_pending_transaction(pending);
488 }
489
490 let state = added.transaction_state();
491 self.protocol_pool.inner().notify_event_listeners(&added);
493 self.protocol_pool
494 .inner()
495 .on_new_transaction(added.into_new_transaction_event());
496
497 Ok(AddedTransactionOutcome { hash, state })
498 } else {
499 self.protocol_pool
500 .inner()
501 .add_transactions(
502 origin,
503 std::iter::once(TransactionValidationOutcome::Valid {
504 balance,
505 state_nonce,
506 bytecode_hash,
507 transaction,
508 propagate,
509 authorities,
510 }),
511 )
512 .pop()
513 .unwrap()
514 }
515 }
516 invalid => {
517 self.protocol_pool
519 .inner()
520 .add_transactions(origin, Some(invalid))
521 .pop()
522 .unwrap()
523 }
524 }
525 }
526}
527
528impl<Client> Clone for TempoTransactionPool<Client> {
530 fn clone(&self) -> Self {
531 Self {
532 protocol_pool: self.protocol_pool.clone(),
533 aa_2d_pool: Arc::clone(&self.aa_2d_pool),
534 }
535 }
536}
537
538impl<Client> std::fmt::Debug for TempoTransactionPool<Client> {
540 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541 f.debug_struct("TempoTransactionPool")
542 .field("protocol_pool", &"Pool<...>")
543 .field("aa_2d_nonce_pool", &"AA2dPool<...>")
544 .field("paused_fee_token_pool", &"PausedFeeTokenPool<...>")
545 .finish_non_exhaustive()
546 }
547}
548
549impl<Client> TransactionPool for TempoTransactionPool<Client>
551where
552 Client: StateProviderFactory
553 + ChainSpecProvider<ChainSpec = TempoChainSpec>
554 + Send
555 + Sync
556 + 'static,
557 TempoPooledTransaction: reth_transaction_pool::EthPoolTransaction,
558{
559 type Transaction = TempoPooledTransaction;
560
561 fn pool_size(&self) -> PoolSize {
562 let mut size = self.protocol_pool.pool_size();
563 let (pending, queued) = self.aa_2d_pool.read().pending_and_queued_txn_count();
564 size.pending += pending;
565 size.queued += queued;
566 size
567 }
568
569 fn block_info(&self) -> BlockInfo {
570 self.protocol_pool.block_info()
571 }
572
573 async fn add_transaction_and_subscribe(
574 &self,
575 origin: TransactionOrigin,
576 transaction: Self::Transaction,
577 ) -> PoolResult<TransactionEvents> {
578 let tx = self
579 .protocol_pool
580 .validator()
581 .validate_transaction(origin, transaction)
582 .await;
583 let res = self.add_validated_transaction(origin, tx)?;
584 self.transaction_event_listener(res.hash)
585 .ok_or_else(|| PoolError::new(res.hash, PoolErrorKind::DiscardedOnInsert))
586 }
587
588 async fn add_transaction(
589 &self,
590 origin: TransactionOrigin,
591 transaction: Self::Transaction,
592 ) -> PoolResult<AddedTransactionOutcome> {
593 let tx = self
594 .protocol_pool
595 .validator()
596 .validate_transaction(origin, transaction)
597 .await;
598 self.add_validated_transaction(origin, tx)
599 }
600
601 async fn add_transactions(
602 &self,
603 origin: TransactionOrigin,
604 transactions: Vec<Self::Transaction>,
605 ) -> Vec<PoolResult<AddedTransactionOutcome>> {
606 if transactions.is_empty() {
607 return Vec::new();
608 }
609
610 if !transactions.iter().any(|tx| tx.is_aa_2d()) {
612 return self
613 .protocol_pool
614 .add_transactions(origin, transactions)
615 .await;
616 }
617
618 self.protocol_pool
619 .validator()
620 .validate_transactions_with_origin(origin, transactions)
621 .await
622 .into_iter()
623 .map(|outcome| self.add_validated_transaction(origin, outcome))
624 .collect()
625 }
626
627 async fn add_transactions_with_origins(
628 &self,
629 transactions: Vec<(TransactionOrigin, Self::Transaction)>,
630 ) -> Vec<PoolResult<AddedTransactionOutcome>> {
631 if transactions.is_empty() {
632 return Vec::new();
633 }
634
635 if !transactions.iter().any(|(_, tx)| tx.is_aa_2d()) {
637 return self
638 .protocol_pool
639 .add_transactions_with_origins(transactions)
640 .await;
641 }
642
643 let origins = transactions
644 .iter()
645 .map(|(origin, _)| *origin)
646 .collect::<Vec<_>>();
647
648 self.protocol_pool
649 .validator()
650 .validate_transactions(transactions)
651 .await
652 .into_iter()
653 .zip(origins)
654 .map(|(outcome, origin)| self.add_validated_transaction(origin, outcome))
655 .collect()
656 }
657
658 fn transaction_event_listener(&self, tx_hash: B256) -> Option<TransactionEvents> {
659 self.protocol_pool.transaction_event_listener(tx_hash)
660 }
661
662 fn all_transactions_event_listener(
663 &self,
664 ) -> reth_transaction_pool::AllTransactionsEvents<Self::Transaction> {
665 self.protocol_pool.all_transactions_event_listener()
666 }
667
668 fn pending_transactions_listener_for(
669 &self,
670 kind: reth_transaction_pool::TransactionListenerKind,
671 ) -> tokio::sync::mpsc::Receiver<B256> {
672 self.protocol_pool.pending_transactions_listener_for(kind)
673 }
674
675 fn blob_transaction_sidecars_listener(&self) -> tokio::sync::mpsc::Receiver<NewBlobSidecar> {
676 self.protocol_pool.blob_transaction_sidecars_listener()
677 }
678
679 fn new_transactions_listener_for(
680 &self,
681 kind: reth_transaction_pool::TransactionListenerKind,
682 ) -> tokio::sync::mpsc::Receiver<reth_transaction_pool::NewTransactionEvent<Self::Transaction>>
683 {
684 self.protocol_pool.new_transactions_listener_for(kind)
685 }
686
687 fn pooled_transaction_hashes(&self) -> Vec<B256> {
688 let mut hashes = self.protocol_pool.pooled_transaction_hashes();
689 hashes.extend(self.aa_2d_pool.read().pooled_transactions_hashes_iter());
690 hashes
691 }
692
693 fn pooled_transaction_hashes_max(&self, max: usize) -> Vec<B256> {
694 let protocol_hashes = self.protocol_pool.pooled_transaction_hashes_max(max);
695 if protocol_hashes.len() >= max {
696 return protocol_hashes;
697 }
698 let remaining = max - protocol_hashes.len();
699 let mut hashes = protocol_hashes;
700 hashes.extend(
701 self.aa_2d_pool
702 .read()
703 .pooled_transactions_hashes_iter()
704 .take(remaining),
705 );
706 hashes
707 }
708
709 fn pooled_transactions(&self) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
710 let mut txs = self.protocol_pool.pooled_transactions();
711 txs.extend(self.aa_2d_pool.read().pooled_transactions_iter());
712 txs
713 }
714
715 fn pooled_transactions_max(
716 &self,
717 max: usize,
718 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
719 let mut txs = self.protocol_pool.pooled_transactions_max(max);
720 if txs.len() >= max {
721 return txs;
722 }
723
724 let remaining = max - txs.len();
725 txs.extend(
726 self.aa_2d_pool
727 .read()
728 .pooled_transactions_iter()
729 .take(remaining),
730 );
731 txs
732 }
733
734 fn get_pooled_transaction_elements(
735 &self,
736 tx_hashes: Vec<B256>,
737 limit: GetPooledTransactionLimit,
738 ) -> Vec<<Self::Transaction as PoolTransaction>::Pooled> {
739 let mut out = Vec::new();
740 self.append_pooled_transaction_elements(&tx_hashes, limit, &mut out);
741 out
742 }
743
744 fn append_pooled_transaction_elements(
745 &self,
746 tx_hashes: &[B256],
747 limit: GetPooledTransactionLimit,
748 out: &mut Vec<<Self::Transaction as PoolTransaction>::Pooled>,
749 ) {
750 let mut accumulated_size = 0;
751 self.aa_2d_pool.read().append_pooled_transaction_elements(
752 tx_hashes,
753 limit,
754 &mut accumulated_size,
755 out,
756 );
757
758 if limit.exceeds(accumulated_size) {
760 return;
761 }
762
763 let remaining_limit = match limit {
765 GetPooledTransactionLimit::None => GetPooledTransactionLimit::None,
766 GetPooledTransactionLimit::ResponseSizeSoftLimit(max) => {
767 GetPooledTransactionLimit::ResponseSizeSoftLimit(
768 max.saturating_sub(accumulated_size),
769 )
770 }
771 };
772
773 self.protocol_pool
774 .append_pooled_transaction_elements(tx_hashes, remaining_limit, out);
775 }
776
777 fn get_pooled_transaction_element(
778 &self,
779 tx_hash: B256,
780 ) -> Option<reth_primitives_traits::Recovered<<Self::Transaction as PoolTransaction>::Pooled>>
781 {
782 self.protocol_pool
783 .get_pooled_transaction_element(tx_hash)
784 .or_else(|| {
785 self.aa_2d_pool
786 .read()
787 .get(&tx_hash)
788 .and_then(|tx| tx.transaction.clone_into_pooled().ok())
789 })
790 }
791
792 fn best_transactions(
793 &self,
794 ) -> Box<dyn BestTransactions<Item = Arc<ValidPoolTransaction<Self::Transaction>>>> {
795 let left = self.protocol_pool.inner().best_transactions();
796 let right = self.aa_2d_pool.read().best_transactions();
797 Box::new(MergeBestTransactions::new(left, right))
798 }
799
800 fn best_transactions_with_attributes(
801 &self,
802 _attributes: BestTransactionsAttributes,
803 ) -> Box<dyn BestTransactions<Item = Arc<ValidPoolTransaction<Self::Transaction>>>> {
804 self.best_transactions()
805 }
806
807 fn pending_transactions(&self) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
808 let mut pending = self.protocol_pool.pending_transactions();
809 pending.extend(self.aa_2d_pool.read().pending_transactions());
810 pending
811 }
812
813 fn pending_transactions_max(
814 &self,
815 max: usize,
816 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
817 let protocol_txs = self.protocol_pool.pending_transactions_max(max);
818 if protocol_txs.len() >= max {
819 return protocol_txs;
820 }
821 let remaining = max - protocol_txs.len();
822 let mut txs = protocol_txs;
823 txs.extend(
824 self.aa_2d_pool
825 .read()
826 .pending_transactions()
827 .take(remaining),
828 );
829 txs
830 }
831
832 fn queued_transactions(&self) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
833 let mut queued = self.protocol_pool.queued_transactions();
834 queued.extend(self.aa_2d_pool.read().queued_transactions());
835 queued
836 }
837
838 fn pending_and_queued_txn_count(&self) -> (usize, usize) {
839 let (protocol_pending, protocol_queued) = self.protocol_pool.pending_and_queued_txn_count();
840 let (aa_pending, aa_queued) = self.aa_2d_pool.read().pending_and_queued_txn_count();
841 (protocol_pending + aa_pending, protocol_queued + aa_queued)
842 }
843
844 fn all_transactions(&self) -> AllPoolTransactions<Self::Transaction> {
845 let mut transactions = self.protocol_pool.all_transactions();
846 {
847 let aa_2d_pool = self.aa_2d_pool.read();
848 transactions
849 .pending
850 .extend(aa_2d_pool.pending_transactions());
851 transactions.queued.extend(aa_2d_pool.queued_transactions());
852 }
853 transactions
854 }
855
856 fn all_transaction_hashes(&self) -> Vec<B256> {
857 let mut hashes = self.protocol_pool.all_transaction_hashes();
858 hashes.extend(self.aa_2d_pool.read().all_transaction_hashes_iter());
859 hashes
860 }
861
862 fn remove_transactions(
863 &self,
864 hashes: Vec<B256>,
865 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
866 let mut txs = self.aa_2d_pool.write().remove_transactions(hashes.iter());
867 txs.extend(self.protocol_pool.remove_transactions(hashes));
868 txs
869 }
870
871 fn remove_transactions_and_descendants(
872 &self,
873 hashes: Vec<B256>,
874 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
875 let mut txs = self
876 .aa_2d_pool
877 .write()
878 .remove_transactions_and_descendants(hashes.iter());
879 txs.extend(
880 self.protocol_pool
881 .remove_transactions_and_descendants(hashes),
882 );
883 txs
884 }
885
886 fn remove_transactions_by_sender(
887 &self,
888 sender: Address,
889 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
890 let mut txs = self
891 .aa_2d_pool
892 .write()
893 .remove_transactions_by_sender(sender);
894 txs.extend(self.protocol_pool.remove_transactions_by_sender(sender));
895 txs
896 }
897
898 fn prune_transactions(
899 &self,
900 hashes: Vec<TxHash>,
901 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
902 let mut txs = self.aa_2d_pool.write().remove_transactions(hashes.iter());
903 txs.extend(self.protocol_pool.prune_transactions(hashes));
904 txs
905 }
906
907 fn retain_unknown<A: HandleMempoolData>(&self, announcement: &mut A) {
908 self.protocol_pool.retain_unknown(announcement);
909 if announcement.is_empty() {
910 return;
911 }
912 let aa_pool = self.aa_2d_pool.read();
913 announcement.retain_by_hash(|tx| !aa_pool.contains(tx))
914 }
915
916 fn contains(&self, tx_hash: &B256) -> bool {
917 self.protocol_pool.contains(tx_hash) || self.aa_2d_pool.read().contains(tx_hash)
918 }
919
920 fn get(&self, tx_hash: &B256) -> Option<Arc<ValidPoolTransaction<Self::Transaction>>> {
921 self.protocol_pool
922 .get(tx_hash)
923 .or_else(|| self.aa_2d_pool.read().get(tx_hash))
924 }
925
926 fn get_all(&self, txs: Vec<B256>) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
927 let mut result = self.aa_2d_pool.read().get_all(txs.iter());
928 result.extend(self.protocol_pool.get_all(txs));
929 result
930 }
931
932 fn on_propagated(&self, txs: PropagatedTransactions) {
933 self.protocol_pool.on_propagated(txs);
934 }
935
936 fn get_transactions_by_sender(
937 &self,
938 sender: Address,
939 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
940 let mut txs = self.protocol_pool.get_transactions_by_sender(sender);
941 txs.extend(
942 self.aa_2d_pool
943 .read()
944 .get_transactions_by_sender_iter(sender),
945 );
946 txs
947 }
948
949 fn get_pending_transactions_with_predicate(
950 &self,
951 mut predicate: impl FnMut(&ValidPoolTransaction<Self::Transaction>) -> bool,
952 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
953 let mut txs = self
954 .protocol_pool
955 .get_pending_transactions_with_predicate(&mut predicate);
956 txs.extend(
957 self.aa_2d_pool
958 .read()
959 .pending_transactions()
960 .filter(|tx| predicate(tx)),
961 );
962 txs
963 }
964
965 fn get_pending_transactions_by_sender(
966 &self,
967 sender: Address,
968 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
969 let mut txs = self
970 .protocol_pool
971 .get_pending_transactions_by_sender(sender);
972 txs.extend(
973 self.aa_2d_pool
974 .read()
975 .pending_transactions()
976 .filter(|tx| tx.sender() == sender),
977 );
978
979 txs
980 }
981
982 fn get_queued_transactions_by_sender(
983 &self,
984 sender: Address,
985 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
986 self.protocol_pool.get_queued_transactions_by_sender(sender)
987 }
988
989 fn get_highest_transaction_by_sender(
990 &self,
991 sender: Address,
992 ) -> Option<Arc<ValidPoolTransaction<Self::Transaction>>> {
993 self.protocol_pool.get_highest_transaction_by_sender(sender)
996 }
997
998 fn get_highest_consecutive_transaction_by_sender(
999 &self,
1000 sender: Address,
1001 on_chain_nonce: u64,
1002 ) -> Option<Arc<ValidPoolTransaction<Self::Transaction>>> {
1003 self.protocol_pool
1005 .get_highest_consecutive_transaction_by_sender(sender, on_chain_nonce)
1006 }
1007
1008 fn get_transaction_by_sender_and_nonce(
1009 &self,
1010 sender: Address,
1011 nonce: u64,
1012 ) -> Option<Arc<ValidPoolTransaction<Self::Transaction>>> {
1013 self.protocol_pool
1015 .get_transaction_by_sender_and_nonce(sender, nonce)
1016 }
1017
1018 fn get_transactions_by_origin(
1019 &self,
1020 origin: TransactionOrigin,
1021 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
1022 let mut txs = self.protocol_pool.get_transactions_by_origin(origin);
1023 txs.extend(
1024 self.aa_2d_pool
1025 .read()
1026 .get_transactions_by_origin_iter(origin),
1027 );
1028 txs
1029 }
1030
1031 fn get_pending_transactions_by_origin(
1032 &self,
1033 origin: TransactionOrigin,
1034 ) -> Vec<Arc<ValidPoolTransaction<Self::Transaction>>> {
1035 let mut txs = self
1036 .protocol_pool
1037 .get_pending_transactions_by_origin(origin);
1038 txs.extend(
1039 self.aa_2d_pool
1040 .read()
1041 .get_pending_transactions_by_origin_iter(origin),
1042 );
1043 txs
1044 }
1045
1046 fn unique_senders(&self) -> AddressSet {
1047 let mut senders = self.protocol_pool.unique_senders();
1048 senders.extend(self.aa_2d_pool.read().senders_iter().copied());
1049 senders
1050 }
1051
1052 fn get_blob(
1053 &self,
1054 tx_hash: B256,
1055 ) -> Result<
1056 Option<Arc<alloy_eips::eip7594::BlobTransactionSidecarVariant>>,
1057 reth_transaction_pool::blobstore::BlobStoreError,
1058 > {
1059 self.protocol_pool.get_blob(tx_hash)
1060 }
1061
1062 fn get_all_blobs(
1063 &self,
1064 tx_hashes: Vec<B256>,
1065 ) -> Result<
1066 Vec<(
1067 B256,
1068 Arc<alloy_eips::eip7594::BlobTransactionSidecarVariant>,
1069 )>,
1070 reth_transaction_pool::blobstore::BlobStoreError,
1071 > {
1072 self.protocol_pool.get_all_blobs(tx_hashes)
1073 }
1074
1075 fn get_all_blobs_exact(
1076 &self,
1077 tx_hashes: Vec<B256>,
1078 ) -> Result<
1079 Vec<Arc<alloy_eips::eip7594::BlobTransactionSidecarVariant>>,
1080 reth_transaction_pool::blobstore::BlobStoreError,
1081 > {
1082 self.protocol_pool.get_all_blobs_exact(tx_hashes)
1083 }
1084
1085 fn get_blobs_for_versioned_hashes_v1(
1086 &self,
1087 versioned_hashes: &[B256],
1088 ) -> Result<
1089 Vec<Option<alloy_eips::eip4844::BlobAndProofV1>>,
1090 reth_transaction_pool::blobstore::BlobStoreError,
1091 > {
1092 self.protocol_pool
1093 .get_blobs_for_versioned_hashes_v1(versioned_hashes)
1094 }
1095
1096 fn get_blobs_for_versioned_hashes_v2(
1097 &self,
1098 versioned_hashes: &[B256],
1099 ) -> Result<
1100 Option<Vec<alloy_eips::eip4844::BlobAndProofV2>>,
1101 reth_transaction_pool::blobstore::BlobStoreError,
1102 > {
1103 self.protocol_pool
1104 .get_blobs_for_versioned_hashes_v2(versioned_hashes)
1105 }
1106
1107 fn get_blobs_for_versioned_hashes_v3(
1108 &self,
1109 versioned_hashes: &[B256],
1110 ) -> Result<
1111 Vec<Option<alloy_eips::eip4844::BlobAndProofV2>>,
1112 reth_transaction_pool::blobstore::BlobStoreError,
1113 > {
1114 self.protocol_pool
1115 .get_blobs_for_versioned_hashes_v3(versioned_hashes)
1116 }
1117}
1118
1119impl<Client> TransactionPoolExt for TempoTransactionPool<Client>
1120where
1121 Client: StateProviderFactory + ChainSpecProvider<ChainSpec = TempoChainSpec> + 'static,
1122{
1123 type Block = Block;
1124
1125 fn set_block_info(&self, info: BlockInfo) {
1126 self.protocol_pool.set_block_info(info)
1127 }
1128
1129 fn on_canonical_state_change(&self, update: CanonicalStateUpdate<'_, Self::Block>) {
1130 self.protocol_pool.on_canonical_state_change(update)
1131 }
1132
1133 fn update_accounts(&self, accounts: Vec<ChangedAccount>) {
1134 self.protocol_pool.update_accounts(accounts)
1135 }
1136
1137 fn delete_blob(&self, tx: B256) {
1138 self.protocol_pool.delete_blob(tx)
1139 }
1140
1141 fn delete_blobs(&self, txs: Vec<B256>) {
1142 self.protocol_pool.delete_blobs(txs)
1143 }
1144
1145 fn cleanup_blobs(&self) {
1146 self.protocol_pool.cleanup_blobs()
1147 }
1148}
1149
1150pub(crate) fn exceeds_spending_limit(
1156 provider: &mut impl StateProvider,
1157 subject: &crate::transaction::KeychainSubject,
1158 fee_token_cost: alloy_primitives::U256,
1159) -> bool {
1160 let spec = TempoHardfork::default();
1162 let limit_key = AccountKeychain::spending_limit_key(subject.account, subject.key_id);
1163
1164 provider
1165 .with_read_only_storage_ctx(spec, || -> TempoPrecompileResult<bool> {
1166 let keychain = AccountKeychain::new();
1167 if !keychain.keys[subject.account][subject.key_id]
1168 .read()?
1169 .enforce_limits
1170 {
1171 return Ok(false);
1172 }
1173
1174 let remaining = keychain.spending_limits[limit_key][subject.fee_token].read()?;
1175 Ok(fee_token_cost > remaining)
1176 })
1177 .unwrap_or_default()
1178}
1179
1180fn get_sender_policy_ids(
1188 provider: &mut impl StateProvider,
1189 fee_token: Address,
1190 spec: TempoHardfork,
1191 cache: &mut AddressMap<Vec<u64>>,
1192) -> Option<Vec<u64>> {
1193 if let Some(cached) = cache.get(&fee_token) {
1194 return Some(cached.clone());
1195 }
1196
1197 provider.with_read_only_storage_ctx(spec, || {
1198 let policy_id = TIP20Token::from_address(fee_token)
1199 .and_then(|t| t.transfer_policy_id())
1200 .ok()
1201 .filter(|&id| id != REJECT_ALL_POLICY_ID)?;
1202
1203 let mut ids = vec![policy_id];
1204
1205 let registry = TIP403Registry::new();
1207 if let Ok(data) = registry.policy_records[policy_id].base.read()
1208 && data.is_compound()
1209 && let Ok(compound) = registry.policy_records[policy_id].compound.read()
1210 && compound.sender_policy_id != REJECT_ALL_POLICY_ID
1211 {
1212 ids.push(compound.sender_policy_id);
1213 }
1214
1215 cache.insert(fee_token, ids.clone());
1218 Some(ids)
1219 })
1220}
1221
1222fn get_recipient_policy_ids(
1232 provider: &mut impl StateProvider,
1233 fee_token: Address,
1234 spec: TempoHardfork,
1235) -> Option<Vec<u64>> {
1236 provider.with_read_only_storage_ctx(spec, || {
1237 let policy_id = TIP20Token::from_address(fee_token)
1238 .and_then(|t| t.transfer_policy_id())
1239 .ok()
1240 .filter(|&id| id != REJECT_ALL_POLICY_ID)?;
1241
1242 let mut ids = vec![policy_id];
1243
1244 let registry = TIP403Registry::new();
1245 if let Ok(data) = registry.policy_records[policy_id].base.read()
1246 && data.is_compound()
1247 && let Ok(compound) = registry.policy_records[policy_id].compound.read()
1248 && compound.recipient_policy_id != REJECT_ALL_POLICY_ID
1249 {
1250 ids.push(compound.recipient_policy_id);
1251 }
1252
1253 Some(ids)
1254 })
1255}
1256
1257#[cfg(test)]
1258mod tests {
1259 use super::*;
1260 use crate::transaction::KeychainSubject;
1261 use alloy_primitives::{U256, address};
1262 use reth_provider::test_utils::{ExtendedAccount, MockEthProvider};
1263 use reth_storage_api::StateProviderFactory;
1264 use tempo_contracts::precompiles::ITIP403Registry;
1265 use tempo_precompiles::{
1266 ACCOUNT_KEYCHAIN_ADDRESS, TIP403_REGISTRY_ADDRESS,
1267 account_keychain::{AccountKeychain, AuthorizedKey},
1268 tip20::slots as tip20_slots,
1269 tip403_registry::PolicyData,
1270 };
1271
1272 fn provider_with_spending_limit(
1273 account: Address,
1274 key_id: Address,
1275 fee_token: Address,
1276 remaining_limit: alloy_primitives::U256,
1277 ) -> Box<dyn reth_storage_api::StateProvider> {
1278 let provider = MockEthProvider::default().with_chain_spec(std::sync::Arc::unwrap_or_clone(
1279 tempo_chainspec::spec::MODERATO.clone(),
1280 ));
1281
1282 let keychain = AccountKeychain::new();
1283
1284 let key_slot = keychain.keys[account][key_id].base_slot();
1286 let authorized_key = AuthorizedKey {
1287 signature_type: 0,
1288 expiry: u64::MAX,
1289 enforce_limits: true,
1290 is_revoked: false,
1291 }
1292 .encode_to_slot();
1293
1294 let limit_key = AccountKeychain::spending_limit_key(account, key_id);
1295 let limit_slot = keychain.spending_limits[limit_key][fee_token].slot();
1296
1297 provider.add_account(
1298 ACCOUNT_KEYCHAIN_ADDRESS,
1299 ExtendedAccount::new(0, alloy_primitives::U256::ZERO).extend_storage([
1300 (key_slot.into(), authorized_key),
1301 (limit_slot.into(), remaining_limit),
1302 ]),
1303 );
1304
1305 provider.latest().unwrap()
1306 }
1307
1308 #[test]
1312 fn compound_policy_sub_policy_matches_eviction_check() {
1313 let fee_token = address!("20C0000000000000000000000000000000000001");
1314 let compound_policy_id: u64 = 5;
1315 let sender_sub_policy: u64 = 3;
1316 let recipient_sub_policy: u64 = 4;
1317
1318 let provider = MockEthProvider::default().with_chain_spec(std::sync::Arc::unwrap_or_clone(
1319 tempo_chainspec::spec::MODERATO.clone(),
1320 ));
1321
1322 let transfer_policy_id_packed =
1324 U256::from(compound_policy_id) << (tip20_slots::TRANSFER_POLICY_ID_OFFSET * 8);
1325 provider.add_account(
1326 fee_token,
1327 ExtendedAccount::new(0, U256::ZERO).extend_storage([(
1328 tip20_slots::TRANSFER_POLICY_ID.into(),
1329 transfer_policy_id_packed,
1330 )]),
1331 );
1332
1333 let registry = TIP403Registry::new();
1335 let policy_data = PolicyData {
1336 policy_type: ITIP403Registry::PolicyType::COMPOUND as u8,
1337 admin: Address::ZERO,
1338 };
1339 let base_slot = registry.policy_records[compound_policy_id].base.base_slot();
1340 let compound_slot = registry.policy_records[compound_policy_id]
1341 .compound
1342 .base_slot();
1343 let compound_encoded =
1345 U256::from(sender_sub_policy) | (U256::from(recipient_sub_policy) << 64);
1346
1347 provider.add_account(
1348 TIP403_REGISTRY_ADDRESS,
1349 ExtendedAccount::new(0, U256::ZERO).extend_storage([
1350 (base_slot.into(), policy_data.encode_to_slot()),
1351 (compound_slot.into(), compound_encoded),
1352 ]),
1353 );
1354
1355 let mut state = provider.latest().unwrap();
1356 let mut cache: AddressMap<Vec<u64>> = AddressMap::default();
1357
1358 let ids =
1359 get_sender_policy_ids(&mut state, fee_token, TempoHardfork::default(), &mut cache)
1360 .expect("should resolve policy IDs");
1361
1362 assert!(
1363 ids.contains(&compound_policy_id),
1364 "should contain compound policy ID"
1365 );
1366 assert!(
1367 ids.contains(&sender_sub_policy),
1368 "should contain sender sub-policy"
1369 );
1370 }
1371
1372 #[test]
1375 fn compound_policy_sender_ids_exclude_recipient_sub_policy() {
1376 let fee_token = address!("20C0000000000000000000000000000000000001");
1377 let compound_policy_id: u64 = 5;
1378 let sender_sub_policy: u64 = 3;
1379 let recipient_sub_policy: u64 = 4;
1380
1381 let provider = MockEthProvider::default().with_chain_spec(std::sync::Arc::unwrap_or_clone(
1382 tempo_chainspec::spec::MODERATO.clone(),
1383 ));
1384
1385 let transfer_policy_id_packed =
1386 U256::from(compound_policy_id) << (tip20_slots::TRANSFER_POLICY_ID_OFFSET * 8);
1387 provider.add_account(
1388 fee_token,
1389 ExtendedAccount::new(0, U256::ZERO).extend_storage([(
1390 tip20_slots::TRANSFER_POLICY_ID.into(),
1391 transfer_policy_id_packed,
1392 )]),
1393 );
1394
1395 let registry = TIP403Registry::new();
1396 let policy_data = PolicyData {
1397 policy_type: ITIP403Registry::PolicyType::COMPOUND as u8,
1398 admin: Address::ZERO,
1399 };
1400 let base_slot = registry.policy_records[compound_policy_id].base.base_slot();
1401 let compound_slot = registry.policy_records[compound_policy_id]
1402 .compound
1403 .base_slot();
1404 let compound_encoded =
1405 U256::from(sender_sub_policy) | (U256::from(recipient_sub_policy) << 64);
1406
1407 provider.add_account(
1408 TIP403_REGISTRY_ADDRESS,
1409 ExtendedAccount::new(0, U256::ZERO).extend_storage([
1410 (base_slot.into(), policy_data.encode_to_slot()),
1411 (compound_slot.into(), compound_encoded),
1412 ]),
1413 );
1414
1415 let mut state = provider.latest().unwrap();
1416 let mut cache: AddressMap<Vec<u64>> = AddressMap::default();
1417
1418 let ids =
1419 get_sender_policy_ids(&mut state, fee_token, TempoHardfork::default(), &mut cache)
1420 .expect("should resolve policy IDs");
1421
1422 assert!(ids.contains(&compound_policy_id));
1423 assert!(ids.contains(&sender_sub_policy));
1424 assert!(
1425 !ids.contains(&recipient_sub_policy),
1426 "sender policy IDs should not contain recipient_sub_policy"
1427 );
1428 }
1429
1430 #[test]
1433 fn compound_policy_excludes_mint_recipient() {
1434 let fee_token = address!("20C0000000000000000000000000000000000001");
1435 let compound_policy_id: u64 = 5;
1436 let sender_sub: u64 = 3;
1437 let recipient_sub: u64 = 4;
1438 let mint_recipient_sub: u64 = 6;
1439
1440 let provider = MockEthProvider::default().with_chain_spec(std::sync::Arc::unwrap_or_clone(
1441 tempo_chainspec::spec::MODERATO.clone(),
1442 ));
1443
1444 let transfer_policy_id_packed =
1445 U256::from(compound_policy_id) << (tip20_slots::TRANSFER_POLICY_ID_OFFSET * 8);
1446 provider.add_account(
1447 fee_token,
1448 ExtendedAccount::new(0, U256::ZERO).extend_storage([(
1449 tip20_slots::TRANSFER_POLICY_ID.into(),
1450 transfer_policy_id_packed,
1451 )]),
1452 );
1453
1454 let registry = TIP403Registry::new();
1455 let policy_data = PolicyData {
1456 policy_type: ITIP403Registry::PolicyType::COMPOUND as u8,
1457 admin: Address::ZERO,
1458 };
1459 let base_slot = registry.policy_records[compound_policy_id].base.base_slot();
1460 let compound_slot = registry.policy_records[compound_policy_id]
1461 .compound
1462 .base_slot();
1463 let compound_encoded = U256::from(sender_sub)
1464 | (U256::from(recipient_sub) << 64)
1465 | (U256::from(mint_recipient_sub) << 128);
1466
1467 provider.add_account(
1468 TIP403_REGISTRY_ADDRESS,
1469 ExtendedAccount::new(0, U256::ZERO).extend_storage([
1470 (base_slot.into(), policy_data.encode_to_slot()),
1471 (compound_slot.into(), compound_encoded),
1472 ]),
1473 );
1474
1475 let mut state = provider.latest().unwrap();
1476 let mut cache: AddressMap<Vec<u64>> = AddressMap::default();
1477
1478 let ids =
1479 get_sender_policy_ids(&mut state, fee_token, TempoHardfork::default(), &mut cache)
1480 .expect("should resolve policy IDs");
1481
1482 assert!(
1483 !ids.contains(&mint_recipient_sub),
1484 "mint_recipient must be excluded from sender policy IDs"
1485 );
1486 }
1487
1488 #[test]
1490 fn recipient_policy_ids_includes_recipient_sub_policy() {
1491 let fee_token = address!("20C0000000000000000000000000000000000001");
1492 let compound_policy_id: u64 = 5;
1493 let sender_sub: u64 = 3;
1494 let recipient_sub: u64 = 4;
1495
1496 let provider = MockEthProvider::default().with_chain_spec(std::sync::Arc::unwrap_or_clone(
1497 tempo_chainspec::spec::MODERATO.clone(),
1498 ));
1499
1500 let transfer_policy_id_packed =
1501 U256::from(compound_policy_id) << (tip20_slots::TRANSFER_POLICY_ID_OFFSET * 8);
1502 provider.add_account(
1503 fee_token,
1504 ExtendedAccount::new(0, U256::ZERO).extend_storage([(
1505 tip20_slots::TRANSFER_POLICY_ID.into(),
1506 transfer_policy_id_packed,
1507 )]),
1508 );
1509
1510 let registry = TIP403Registry::new();
1511 let policy_data = PolicyData {
1512 policy_type: ITIP403Registry::PolicyType::COMPOUND as u8,
1513 admin: Address::ZERO,
1514 };
1515 let base_slot = registry.policy_records[compound_policy_id].base.base_slot();
1516 let compound_slot = registry.policy_records[compound_policy_id]
1517 .compound
1518 .base_slot();
1519 let compound_encoded = U256::from(sender_sub) | (U256::from(recipient_sub) << 64);
1520
1521 provider.add_account(
1522 TIP403_REGISTRY_ADDRESS,
1523 ExtendedAccount::new(0, U256::ZERO).extend_storage([
1524 (base_slot.into(), policy_data.encode_to_slot()),
1525 (compound_slot.into(), compound_encoded),
1526 ]),
1527 );
1528
1529 let mut state = provider.latest().unwrap();
1530 let ids = get_recipient_policy_ids(&mut state, fee_token, TempoHardfork::default())
1531 .expect("should resolve policy IDs");
1532
1533 assert!(
1534 ids.contains(&compound_policy_id),
1535 "should contain compound policy ID"
1536 );
1537 assert!(
1538 ids.contains(&recipient_sub),
1539 "should contain recipient sub-policy"
1540 );
1541 assert!(
1542 !ids.contains(&sender_sub),
1543 "recipient policy IDs should not contain sender sub-policy"
1544 );
1545 }
1546
1547 #[test]
1549 fn recipient_policy_ids_simple_policy() {
1550 let fee_token = address!("20C0000000000000000000000000000000000001");
1551 let simple_policy_id: u64 = 7;
1552
1553 let provider = MockEthProvider::default().with_chain_spec(std::sync::Arc::unwrap_or_clone(
1554 tempo_chainspec::spec::MODERATO.clone(),
1555 ));
1556
1557 let transfer_policy_id_packed =
1558 U256::from(simple_policy_id) << (tip20_slots::TRANSFER_POLICY_ID_OFFSET * 8);
1559 provider.add_account(
1560 fee_token,
1561 ExtendedAccount::new(0, U256::ZERO).extend_storage([(
1562 tip20_slots::TRANSFER_POLICY_ID.into(),
1563 transfer_policy_id_packed,
1564 )]),
1565 );
1566
1567 let registry = TIP403Registry::new();
1568 let policy_data = PolicyData {
1569 policy_type: ITIP403Registry::PolicyType::BLACKLIST as u8,
1570 admin: Address::ZERO,
1571 };
1572 let base_slot = registry.policy_records[simple_policy_id].base.base_slot();
1573
1574 provider.add_account(
1575 TIP403_REGISTRY_ADDRESS,
1576 ExtendedAccount::new(0, U256::ZERO)
1577 .extend_storage([(base_slot.into(), policy_data.encode_to_slot())]),
1578 );
1579
1580 let mut state = provider.latest().unwrap();
1581 let ids = get_recipient_policy_ids(&mut state, fee_token, TempoHardfork::default())
1582 .expect("should resolve policy IDs");
1583
1584 assert_eq!(ids, vec![simple_policy_id]);
1585 }
1586
1587 #[test]
1588 fn exceeds_spending_limit_returns_true_when_cost_exceeds_remaining() {
1589 let account = Address::random();
1590 let key_id = Address::random();
1591 let fee_token = Address::random();
1592 let subject = KeychainSubject {
1593 account,
1594 key_id,
1595 fee_token,
1596 };
1597
1598 let mut state = provider_with_spending_limit(
1599 account,
1600 key_id,
1601 fee_token,
1602 alloy_primitives::U256::from(100),
1603 );
1604
1605 assert!(exceeds_spending_limit(
1606 &mut state,
1607 &subject,
1608 alloy_primitives::U256::from(200)
1609 ));
1610 }
1611
1612 #[test]
1613 fn exceeds_spending_limit_returns_false_when_cost_within_limit() {
1614 let account = Address::random();
1615 let key_id = Address::random();
1616 let fee_token = Address::random();
1617 let subject = KeychainSubject {
1618 account,
1619 key_id,
1620 fee_token,
1621 };
1622
1623 let mut state = provider_with_spending_limit(
1624 account,
1625 key_id,
1626 fee_token,
1627 alloy_primitives::U256::from(500),
1628 );
1629
1630 assert!(!exceeds_spending_limit(
1631 &mut state,
1632 &subject,
1633 alloy_primitives::U256::from(200)
1634 ));
1635 }
1636
1637 #[test]
1638 fn exceeds_spending_limit_returns_true_when_no_limit_set() {
1639 let account = Address::random();
1640 let key_id = Address::random();
1641 let fee_token = Address::random();
1642 let subject = KeychainSubject {
1643 account,
1644 key_id,
1645 fee_token,
1646 };
1647
1648 let provider = MockEthProvider::default().with_chain_spec(std::sync::Arc::unwrap_or_clone(
1650 tempo_chainspec::spec::MODERATO.clone(),
1651 ));
1652 let key_slot = AccountKeychain::new().keys[account][key_id].base_slot();
1653 let authorized_key = AuthorizedKey {
1654 signature_type: 0,
1655 expiry: u64::MAX,
1656 enforce_limits: true,
1657 is_revoked: false,
1658 }
1659 .encode_to_slot();
1660 provider.add_account(
1661 ACCOUNT_KEYCHAIN_ADDRESS,
1662 ExtendedAccount::new(0, alloy_primitives::U256::ZERO)
1663 .extend_storage([(key_slot.into(), authorized_key)]),
1664 );
1665 let mut state = provider.latest().unwrap();
1666
1667 assert!(exceeds_spending_limit(
1668 &mut state,
1669 &subject,
1670 alloy_primitives::U256::from(1)
1671 ));
1672 }
1673
1674 #[test]
1675 fn exceeds_spending_limit_returns_false_when_limits_not_enforced() {
1676 let account = Address::random();
1677 let key_id = Address::random();
1678 let fee_token = Address::random();
1679 let subject = KeychainSubject {
1680 account,
1681 key_id,
1682 fee_token,
1683 };
1684
1685 let provider = MockEthProvider::default().with_chain_spec(std::sync::Arc::unwrap_or_clone(
1687 tempo_chainspec::spec::MODERATO.clone(),
1688 ));
1689 let key_slot = AccountKeychain::new().keys[account][key_id].base_slot();
1690 let authorized_key = AuthorizedKey {
1691 signature_type: 0,
1692 expiry: u64::MAX,
1693 enforce_limits: false,
1694 is_revoked: false,
1695 }
1696 .encode_to_slot();
1697 provider.add_account(
1698 ACCOUNT_KEYCHAIN_ADDRESS,
1699 ExtendedAccount::new(0, alloy_primitives::U256::ZERO)
1700 .extend_storage([(key_slot.into(), authorized_key)]),
1701 );
1702 let mut state = provider.latest().unwrap();
1703
1704 assert!(!exceeds_spending_limit(
1705 &mut state,
1706 &subject,
1707 alloy_primitives::U256::from(1)
1708 ));
1709 }
1710}