1use alloy_contract::Result as ContractResult;
2use alloy_primitives::{Address, U256};
3use alloy_provider::{
4 Identity, Provider, ProviderBuilder,
5 fillers::{JoinFill, RecommendedFillers},
6};
7use tempo_chainspec::hardfork::TempoHardfork;
8use tempo_contracts::precompiles::{
9 ACCOUNT_KEYCHAIN_ADDRESS,
10 IAccountKeychain::{IAccountKeychainInstance, KeyInfo},
11 INonce::INonceInstance,
12 NONCE_PRECOMPILE_ADDRESS, getAllowedCallsReturn, getRemainingLimitReturn,
13};
14use tempo_primitives::transaction::{CallScope, TEMPO_EXPIRING_NONCE_KEY};
15
16use crate::{
17 TempoFillers, TempoNetwork,
18 fillers::{ExpiringNonceFiller, NonceKeyFiller, Random2DNonceFiller},
19};
20
21#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
23#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
24pub trait TempoProviderExt: Provider<TempoNetwork> {
25 fn account_keychain(&self) -> IAccountKeychainInstance<&Self, TempoNetwork>
27 where
28 Self: Sized,
29 {
30 IAccountKeychainInstance::new(ACCOUNT_KEYCHAIN_ADDRESS, self)
31 }
32
33 fn nonce_manager(&self) -> INonceInstance<&Self, TempoNetwork>
35 where
36 Self: Sized,
37 {
38 INonceInstance::new(NONCE_PRECOMPILE_ADDRESS, self)
39 }
40
41 async fn get_transaction_count_with_nonce_key(
46 &self,
47 account: Address,
48 nonce_key: U256,
49 ) -> ContractResult<u64>
50 where
51 Self: Sized,
52 {
53 if nonce_key.is_zero() {
54 return self
55 .get_transaction_count(account)
56 .await
57 .map_err(Into::into);
58 }
59
60 if nonce_key == TEMPO_EXPIRING_NONCE_KEY {
61 return Ok(0);
62 }
63
64 self.nonce_manager()
65 .getNonce(account, nonce_key)
66 .call()
67 .await
68 }
69
70 async fn get_keychain_key(&self, account: Address, key_id: Address) -> ContractResult<KeyInfo>
72 where
73 Self: Sized,
74 {
75 self.account_keychain().getKey(account, key_id).call().await
76 }
77
78 async fn get_keychain_remaining_limit(
80 &self,
81 account: Address,
82 key_id: Address,
83 token: Address,
84 ) -> ContractResult<U256>
85 where
86 Self: Sized,
87 {
88 self.get_keychain_remaining_limit_with_period(account, key_id, token)
89 .await
90 .map(|getRemainingLimitReturn { remaining, .. }| remaining)
91 }
92
93 async fn get_keychain_remaining_limit_with_period(
95 &self,
96 account: Address,
97 key_id: Address,
98 token: Address,
99 ) -> ContractResult<getRemainingLimitReturn>
100 where
101 Self: Sized,
102 {
103 self.account_keychain()
104 .getRemainingLimitWithPeriod(account, key_id, token)
105 .call()
106 .await
107 }
108
109 async fn get_keychain_allowed_calls(
113 &self,
114 account: Address,
115 key_id: Address,
116 ) -> ContractResult<Option<Vec<CallScope>>>
117 where
118 Self: Sized,
119 {
120 self.account_keychain()
121 .getAllowedCalls(account, key_id)
122 .call()
123 .await
124 .map(|getAllowedCallsReturn { isScoped, scopes }| {
125 isScoped.then(|| scopes.into_iter().map(Into::into).collect())
126 })
127 }
128
129 async fn get_keychain_transaction_key(&self) -> ContractResult<Address>
131 where
132 Self: Sized,
133 {
134 self.account_keychain().getTransactionKey().call().await
135 }
136
137 async fn is_hardfork_active(
141 &self,
142 hardfork: TempoHardfork,
143 ) -> Result<bool, alloy_transport::TransportError>
144 where
145 Self: Sized,
146 {
147 #[derive(Debug, serde::Deserialize)]
148 struct Response {
149 active: String,
150 }
151
152 let resp: Response = self.raw_request("tempo_forkSchedule".into(), ()).await?;
153
154 Ok(resp
155 .active
156 .parse::<TempoHardfork>()
157 .is_ok_and(|h| h >= hardfork))
158 }
159}
160
161#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
162#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
163impl<P> TempoProviderExt for P where P: Provider<TempoNetwork> {}
164
165pub trait TempoProviderBuilderExt {
167 fn with_random_2d_nonces(
171 self,
172 ) -> ProviderBuilder<
173 Identity,
174 JoinFill<Identity, TempoFillers<Random2DNonceFiller>>,
175 TempoNetwork,
176 >;
177
178 fn with_expiring_nonces(
184 self,
185 ) -> ProviderBuilder<
186 Identity,
187 JoinFill<Identity, TempoFillers<ExpiringNonceFiller>>,
188 TempoNetwork,
189 >;
190
191 fn with_nonce_key_filler(
198 self,
199 ) -> ProviderBuilder<Identity, JoinFill<Identity, TempoFillers<NonceKeyFiller>>, TempoNetwork>;
200}
201
202impl TempoProviderBuilderExt
203 for ProviderBuilder<
204 Identity,
205 JoinFill<Identity, <TempoNetwork as RecommendedFillers>::RecommendedFillers>,
206 TempoNetwork,
207 >
208{
209 fn with_random_2d_nonces(
210 self,
211 ) -> ProviderBuilder<
212 Identity,
213 JoinFill<Identity, TempoFillers<Random2DNonceFiller>>,
214 TempoNetwork,
215 > {
216 ProviderBuilder::default().filler(TempoFillers::default())
217 }
218
219 fn with_expiring_nonces(
220 self,
221 ) -> ProviderBuilder<
222 Identity,
223 JoinFill<Identity, TempoFillers<ExpiringNonceFiller>>,
224 TempoNetwork,
225 > {
226 ProviderBuilder::default().filler(TempoFillers::default())
227 }
228
229 fn with_nonce_key_filler(
230 self,
231 ) -> ProviderBuilder<Identity, JoinFill<Identity, TempoFillers<NonceKeyFiller>>, TempoNetwork>
232 {
233 ProviderBuilder::default().filler(TempoFillers::default())
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use alloy::sol_types::SolCall;
240 use alloy_primitives::{Address, Bytes, U64, U256};
241 use alloy_provider::{Identity, ProviderBuilder, fillers::JoinFill, mock::Asserter};
242 use tempo_contracts::precompiles::{
243 IAccountKeychain::{
244 CallScope as AbiCallScope, KeyInfo, SelectorRule as AbiSelectorRule, SignatureType,
245 getAllowedCallsCall, getKeyCall, getRemainingLimitWithPeriodCall,
246 getTransactionKeyCall,
247 },
248 INonce::getNonceCall,
249 getAllowedCallsReturn, getRemainingLimitReturn,
250 };
251 use tempo_primitives::transaction::{CallScope, SelectorRule, TEMPO_EXPIRING_NONCE_KEY};
252
253 use crate::{
254 TempoFillers, TempoNetwork,
255 fillers::{ExpiringNonceFiller, NonceKeyFiller, Random2DNonceFiller},
256 provider::ext::{TempoProviderBuilderExt, TempoProviderExt},
257 };
258
259 fn mock_provider(asserter: Asserter) -> impl alloy_provider::Provider<TempoNetwork> {
260 ProviderBuilder::<_, _, TempoNetwork>::default().connect_mocked_client(asserter)
261 }
262
263 #[test]
264 fn test_with_random_nonces() {
265 let _: ProviderBuilder<_, JoinFill<Identity, TempoFillers<Random2DNonceFiller>>, _> =
266 ProviderBuilder::new_with_network::<TempoNetwork>().with_random_2d_nonces();
267 }
268
269 #[test]
270 fn test_with_expiring_nonces() {
271 let _: ProviderBuilder<_, JoinFill<Identity, TempoFillers<ExpiringNonceFiller>>, _> =
272 ProviderBuilder::new_with_network::<TempoNetwork>().with_expiring_nonces();
273 }
274
275 #[test]
276 fn test_with_nonce_key_filler() {
277 let _: ProviderBuilder<_, JoinFill<Identity, TempoFillers<NonceKeyFiller>>, _> =
278 ProviderBuilder::new_with_network::<TempoNetwork>().with_nonce_key_filler();
279 }
280
281 #[tokio::test]
282 async fn test_get_keychain_key() {
283 let asserter = Asserter::new();
284 let provider = mock_provider(asserter.clone());
285 let account = Address::repeat_byte(0x11);
286 let key_id = Address::repeat_byte(0x22);
287 let expected = KeyInfo {
288 signatureType: SignatureType::P256,
289 keyId: key_id,
290 expiry: 1_234_567_890,
291 enforceLimits: true,
292 isRevoked: false,
293 };
294
295 asserter.push_success(&Bytes::from(getKeyCall::abi_encode_returns(&expected)));
296
297 let actual = provider
298 .get_keychain_key(account, key_id)
299 .await
300 .expect("key info call succeeds");
301
302 assert_eq!(actual, expected);
303 }
304
305 #[tokio::test]
306 async fn test_get_transaction_count_with_protocol_nonce_key() {
307 let asserter = Asserter::new();
308 let provider = mock_provider(asserter.clone());
309 let account = Address::repeat_byte(0x11);
310 let expected = 42_u64;
311
312 asserter.push_success(&U64::from(expected));
313
314 let actual = provider
315 .get_transaction_count_with_nonce_key(account, U256::ZERO)
316 .await
317 .expect("protocol nonce query succeeds");
318
319 assert_eq!(actual, expected);
320 }
321
322 #[tokio::test]
323 async fn test_get_transaction_count_with_expiring_nonce_key() {
324 let provider = mock_provider(Asserter::new());
325
326 let actual = provider
327 .get_transaction_count_with_nonce_key(
328 Address::repeat_byte(0x11),
329 TEMPO_EXPIRING_NONCE_KEY,
330 )
331 .await
332 .expect("expiring nonce query succeeds");
333
334 assert_eq!(actual, 0);
335 }
336
337 #[tokio::test]
338 async fn test_get_transaction_count_with_2d_nonce_key() {
339 let asserter = Asserter::new();
340 let provider = mock_provider(asserter.clone());
341 let account = Address::repeat_byte(0x11);
342 let nonce_key = U256::from(7_u64);
343 let expected = 42_u64;
344
345 asserter.push_success(&Bytes::from(getNonceCall::abi_encode_returns(&expected)));
346
347 let actual = provider
348 .get_transaction_count_with_nonce_key(account, nonce_key)
349 .await
350 .expect("2D nonce query succeeds");
351
352 assert_eq!(actual, expected);
353 }
354
355 #[tokio::test]
356 async fn test_nonce_manager_accessor() {
357 let asserter = Asserter::new();
358 let provider = mock_provider(asserter.clone());
359 let account = Address::repeat_byte(0x11);
360 let nonce_key = U256::from(7_u64);
361 let expected = 42_u64;
362
363 asserter.push_success(&Bytes::from(getNonceCall::abi_encode_returns(&expected)));
364
365 let actual = provider
366 .nonce_manager()
367 .getNonce(account, nonce_key)
368 .call()
369 .await
370 .expect("typed nonce manager call succeeds");
371
372 assert_eq!(actual, expected);
373 }
374
375 #[tokio::test]
376 async fn test_get_keychain_remaining_limit() {
377 let asserter = Asserter::new();
378 let provider = mock_provider(asserter.clone());
379 let account = Address::repeat_byte(0x11);
380 let key_id = Address::repeat_byte(0x22);
381 let token = Address::repeat_byte(0x33);
382 let expected = U256::from(42_u64);
383
384 asserter.push_success(&Bytes::from(
385 getRemainingLimitWithPeriodCall::abi_encode_returns(&getRemainingLimitReturn {
386 remaining: expected,
387 periodEnd: 0,
388 }),
389 ));
390
391 let actual = provider
392 .get_keychain_remaining_limit(account, key_id, token)
393 .await
394 .expect("remaining limit call succeeds");
395
396 assert_eq!(actual, expected);
397 }
398
399 #[tokio::test]
400 async fn test_get_keychain_remaining_limit_with_period() {
401 let asserter = Asserter::new();
402 let provider = mock_provider(asserter.clone());
403 let account = Address::repeat_byte(0x11);
404 let key_id = Address::repeat_byte(0x22);
405 let token = Address::repeat_byte(0x33);
406 let expected = getRemainingLimitReturn {
407 remaining: U256::from(42_u64),
408 periodEnd: 123,
409 };
410
411 asserter.push_success(&Bytes::from(
412 getRemainingLimitWithPeriodCall::abi_encode_returns(&expected),
413 ));
414
415 let actual = provider
416 .get_keychain_remaining_limit_with_period(account, key_id, token)
417 .await
418 .expect("remaining limit with period call succeeds");
419
420 assert_eq!(actual, expected);
421 }
422
423 #[tokio::test]
424 async fn test_get_keychain_allowed_calls_maps_unrestricted_to_none() {
425 let asserter = Asserter::new();
426 let provider = mock_provider(asserter.clone());
427 let account = Address::repeat_byte(0x11);
428 let key_id = Address::repeat_byte(0x22);
429
430 asserter.push_success(&Bytes::from(getAllowedCallsCall::abi_encode_returns(
431 &getAllowedCallsReturn {
432 isScoped: false,
433 scopes: vec![],
434 },
435 )));
436
437 let actual = provider
438 .get_keychain_allowed_calls(account, key_id)
439 .await
440 .expect("allowed calls query succeeds");
441
442 assert_eq!(actual, None);
443 }
444
445 #[tokio::test]
446 async fn test_get_keychain_allowed_calls_maps_scopes() {
447 let asserter = Asserter::new();
448 let provider = mock_provider(asserter.clone());
449 let account = Address::repeat_byte(0x11);
450 let key_id = Address::repeat_byte(0x22);
451 let expected = vec![CallScope {
452 target: Address::repeat_byte(0x33),
453 selector_rules: vec![SelectorRule {
454 selector: [0xaa, 0xbb, 0xcc, 0xdd],
455 recipients: vec![Address::repeat_byte(0x44)],
456 }],
457 }];
458
459 asserter.push_success(&Bytes::from(getAllowedCallsCall::abi_encode_returns(
460 &getAllowedCallsReturn {
461 isScoped: true,
462 scopes: vec![AbiCallScope {
463 target: Address::repeat_byte(0x33),
464 selectorRules: vec![AbiSelectorRule {
465 selector: [0xaa, 0xbb, 0xcc, 0xdd].into(),
466 recipients: vec![Address::repeat_byte(0x44)],
467 }],
468 }],
469 },
470 )));
471
472 let actual = provider
473 .get_keychain_allowed_calls(account, key_id)
474 .await
475 .expect("allowed calls query succeeds");
476
477 assert_eq!(actual, Some(expected));
478 }
479
480 #[tokio::test]
481 async fn test_get_keychain_transaction_key() {
482 let asserter = Asserter::new();
483 let provider = mock_provider(asserter.clone());
484 let expected = Address::repeat_byte(0x44);
485
486 asserter.push_success(&Bytes::from(getTransactionKeyCall::abi_encode_returns(
487 &expected,
488 )));
489
490 let actual = provider
491 .get_keychain_transaction_key()
492 .await
493 .expect("transaction key call succeeds");
494
495 assert_eq!(actual, expected);
496 }
497
498 #[tokio::test]
499 async fn test_account_keychain_accessor() {
500 let asserter = Asserter::new();
501 let provider = mock_provider(asserter.clone());
502 let account = Address::repeat_byte(0x11);
503 let key_id = Address::repeat_byte(0x22);
504 let expected = KeyInfo {
505 signatureType: SignatureType::Secp256k1,
506 keyId: key_id,
507 expiry: u64::MAX,
508 enforceLimits: false,
509 isRevoked: true,
510 };
511
512 asserter.push_success(&Bytes::from(getKeyCall::abi_encode_returns(&expected)));
513
514 let actual = provider
515 .account_keychain()
516 .getKey(account, key_id)
517 .call()
518 .await
519 .expect("typed instance call succeeds");
520
521 assert_eq!(actual, expected);
522 }
523
524 #[tokio::test]
525 async fn test_get_keychain_key_propagates_errors() {
526 let asserter = Asserter::new();
527 let provider = mock_provider(asserter.clone());
528
529 asserter.push_failure_msg("boom");
530
531 let err = provider
532 .get_keychain_key(Address::repeat_byte(0x11), Address::repeat_byte(0x22))
533 .await
534 .expect_err("errors should propagate");
535
536 assert!(matches!(err, alloy_contract::Error::TransportError(_)));
537 assert!(err.to_string().contains("boom"));
538 }
539}