1use alloy_contract::Result as ContractResult;
2use alloy_network::Network;
3use alloy_primitives::{Address, U256};
4use alloy_provider::{
5 Identity, Provider, ProviderBuilder, ProviderLayer, RootProvider,
6 fillers::{JoinFill, TxFiller},
7};
8use alloy_rpc_client::{BuiltInConnectionString, ConnectionConfig};
9use alloy_transport::{
10 Authorization, BoxTransport, TransportConnect, TransportError, TransportErrorKind,
11};
12use std::str::FromStr;
13use tempo_chainspec::hardfork::TempoHardfork;
14use tempo_contracts::precompiles::{
15 ACCOUNT_KEYCHAIN_ADDRESS,
16 IAccountKeychain::{IAccountKeychainInstance, KeyInfo},
17 INonce::INonceInstance,
18 NONCE_PRECOMPILE_ADDRESS, getAllowedCallsReturn, getRemainingLimitReturn,
19};
20use tempo_primitives::transaction::{CallScope, TEMPO_EXPIRING_NONCE_KEY};
21
22use crate::{
23 TempoFillers, TempoNetwork,
24 fillers::{ExpiringNonceFiller, NonceKeyFiller, Random2DNonceFiller, SponsorFiller},
25 transport::{AuthHeaderTransport, RelayConnector, SponsorshipMode},
26};
27
28#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
30#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
31pub trait TempoProviderExt: Provider<TempoNetwork> {
32 fn account_keychain(&self) -> IAccountKeychainInstance<&Self, TempoNetwork>
34 where
35 Self: Sized,
36 {
37 IAccountKeychainInstance::new(ACCOUNT_KEYCHAIN_ADDRESS, self)
38 }
39
40 fn nonce_manager(&self) -> INonceInstance<&Self, TempoNetwork>
42 where
43 Self: Sized,
44 {
45 INonceInstance::new(NONCE_PRECOMPILE_ADDRESS, self)
46 }
47
48 async fn get_transaction_count_with_nonce_key(
53 &self,
54 account: Address,
55 nonce_key: U256,
56 ) -> ContractResult<u64>
57 where
58 Self: Sized,
59 {
60 if nonce_key.is_zero() {
61 return self
62 .get_transaction_count(account)
63 .await
64 .map_err(Into::into);
65 }
66
67 if nonce_key == TEMPO_EXPIRING_NONCE_KEY {
68 return Ok(0);
69 }
70
71 self.nonce_manager()
72 .getNonce(account, nonce_key)
73 .call()
74 .await
75 }
76
77 async fn get_keychain_key(&self, account: Address, key_id: Address) -> ContractResult<KeyInfo>
79 where
80 Self: Sized,
81 {
82 self.account_keychain().getKey(account, key_id).call().await
83 }
84
85 async fn get_keychain_remaining_limit(
87 &self,
88 account: Address,
89 key_id: Address,
90 token: Address,
91 ) -> ContractResult<U256>
92 where
93 Self: Sized,
94 {
95 self.get_keychain_remaining_limit_with_period(account, key_id, token)
96 .await
97 .map(|getRemainingLimitReturn { remaining, .. }| remaining)
98 }
99
100 async fn get_keychain_remaining_limit_with_period(
102 &self,
103 account: Address,
104 key_id: Address,
105 token: Address,
106 ) -> ContractResult<getRemainingLimitReturn>
107 where
108 Self: Sized,
109 {
110 self.account_keychain()
111 .getRemainingLimitWithPeriod(account, key_id, token)
112 .call()
113 .await
114 }
115
116 async fn get_keychain_allowed_calls(
120 &self,
121 account: Address,
122 key_id: Address,
123 ) -> ContractResult<Option<Vec<CallScope>>>
124 where
125 Self: Sized,
126 {
127 self.account_keychain()
128 .getAllowedCalls(account, key_id)
129 .call()
130 .await
131 .map(|getAllowedCallsReturn { isScoped, scopes }| {
132 isScoped.then(|| scopes.into_iter().map(Into::into).collect())
133 })
134 }
135
136 async fn get_keychain_transaction_key(&self) -> ContractResult<Address>
138 where
139 Self: Sized,
140 {
141 self.account_keychain().getTransactionKey().call().await
142 }
143
144 async fn is_hardfork_active(
148 &self,
149 hardfork: TempoHardfork,
150 ) -> Result<bool, alloy_transport::TransportError>
151 where
152 Self: Sized,
153 {
154 #[derive(Debug, serde::Deserialize)]
155 struct Response {
156 active: String,
157 }
158
159 let resp: Response = self.raw_request("tempo_forkSchedule".into(), ()).await?;
160
161 Ok(resp
162 .active
163 .parse::<TempoHardfork>()
164 .is_ok_and(|h| h >= hardfork))
165 }
166}
167
168#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
169#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
170impl<P> TempoProviderExt for P where P: Provider<TempoNetwork> {}
171
172#[derive(Clone, Debug)]
174pub struct SponsorConfig {
175 mode: SponsorshipMode,
176 config: ConnectionConfig,
177 forward_headers: bool,
178}
179
180impl SponsorConfig {
181 pub const fn new(mode: SponsorshipMode, forward_headers: bool) -> Self {
183 Self {
184 mode,
185 config: ConnectionConfig::new(),
186 forward_headers,
187 }
188 }
189
190 pub const fn sign_and_relay() -> Self {
192 Self::new(SponsorshipMode::SignAndRelay, true)
193 }
194
195 pub const fn sign_only() -> Self {
197 Self::new(SponsorshipMode::SignOnly, false)
198 }
199
200 pub fn with_connection_config(mut self, mut connection_config: ConnectionConfig) -> Self {
205 if connection_config.auth.is_none() {
206 connection_config.auth = self.config.auth;
207 }
208 self.config = connection_config;
209 self
210 }
211
212 pub fn with_auth(mut self, auth: Authorization) -> Self {
214 self.config = self.config.with_auth(auth);
215 self
216 }
217}
218
219impl Default for SponsorConfig {
220 fn default() -> Self {
221 Self::sign_and_relay()
222 }
223}
224
225#[derive(Debug)]
230pub struct SponsoredProviderBuilder<L, F, N = TempoNetwork> {
231 inner: ProviderBuilder<L, F, N>,
232 sponsor_rpc: String,
233 sponsor_config: SponsorConfig,
234}
235
236#[derive(Clone, Debug)]
237struct ConfiguredBuiltInConnection {
238 connection: BuiltInConnectionString,
239 config: ConnectionConfig,
240}
241
242impl TransportConnect for ConfiguredBuiltInConnection {
243 fn is_local(&self) -> bool {
244 self.connection.is_local()
245 }
246
247 async fn get_transport(&self) -> Result<BoxTransport, TransportError> {
248 let transport = self
249 .connection
250 .connect_boxed_with(self.config.clone())
251 .await?;
252 Ok(match self.config.auth.clone() {
253 Some(auth) => BoxTransport::new(AuthHeaderTransport::new(transport, auth)?),
254 None => transport,
255 })
256 }
257}
258
259impl<L, F, N> SponsoredProviderBuilder<L, F, N> {
260 pub async fn connect(self, default_rpc: &str) -> Result<F::Provider, TransportError>
262 where
263 L: ProviderLayer<RootProvider<N>, N>,
264 F: TxFiller<N> + ProviderLayer<L::Provider, N>,
265 N: Network,
266 {
267 let default =
268 BuiltInConnectionString::from_str(default_rpc).map_err(TransportErrorKind::custom)?;
269 let SponsorConfig {
270 mode,
271 config,
272 forward_headers,
273 } = self.sponsor_config;
274 let sponsor = ConfiguredBuiltInConnection {
275 connection: BuiltInConnectionString::from_str(&self.sponsor_rpc)
276 .map_err(TransportErrorKind::custom)?,
277 config,
278 };
279 let connect = RelayConnector::with_config(default, sponsor, mode, forward_headers);
280 self.inner.connect_with(&connect).await
281 }
282}
283
284pub trait TempoProviderBuilderExt<L, F>: Sized {
286 fn sponsor(
292 self,
293 sponsor_rpc: impl Into<String>,
294 ) -> SponsoredProviderBuilder<L, JoinFill<SponsorFiller, F>, TempoNetwork>;
295
296 fn sponsor_with_config(
298 self,
299 sponsor_rpc: impl Into<String>,
300 sponsor_config: SponsorConfig,
301 ) -> SponsoredProviderBuilder<L, JoinFill<SponsorFiller, F>, TempoNetwork>;
302
303 fn with_random_2d_nonces(
307 self,
308 ) -> ProviderBuilder<
309 Identity,
310 JoinFill<Identity, TempoFillers<Random2DNonceFiller>>,
311 TempoNetwork,
312 >;
313
314 fn with_expiring_nonces(
320 self,
321 ) -> ProviderBuilder<
322 Identity,
323 JoinFill<Identity, TempoFillers<ExpiringNonceFiller>>,
324 TempoNetwork,
325 >;
326
327 fn with_nonce_key_filler(
334 self,
335 ) -> ProviderBuilder<Identity, JoinFill<Identity, TempoFillers<NonceKeyFiller>>, TempoNetwork>;
336}
337
338impl<L, F> TempoProviderBuilderExt<L, F> for ProviderBuilder<L, F, TempoNetwork>
339where
340 F: TxFiller<TempoNetwork>,
341{
342 fn sponsor(
343 self,
344 sponsor_rpc: impl Into<String>,
345 ) -> SponsoredProviderBuilder<L, JoinFill<SponsorFiller, F>, TempoNetwork> {
346 self.sponsor_with_config(sponsor_rpc, SponsorConfig::default())
347 }
348
349 fn sponsor_with_config(
350 self,
351 sponsor_rpc: impl Into<String>,
352 sponsor_config: SponsorConfig,
353 ) -> SponsoredProviderBuilder<L, JoinFill<SponsorFiller, F>, TempoNetwork> {
354 SponsoredProviderBuilder {
355 inner: self.map_filler(|fillers| JoinFill::new(SponsorFiller, fillers)),
356 sponsor_rpc: sponsor_rpc.into(),
357 sponsor_config,
358 }
359 }
360
361 fn with_random_2d_nonces(
362 self,
363 ) -> ProviderBuilder<
364 Identity,
365 JoinFill<Identity, TempoFillers<Random2DNonceFiller>>,
366 TempoNetwork,
367 > {
368 ProviderBuilder::default().filler(TempoFillers::default())
369 }
370
371 fn with_expiring_nonces(
372 self,
373 ) -> ProviderBuilder<
374 Identity,
375 JoinFill<Identity, TempoFillers<ExpiringNonceFiller>>,
376 TempoNetwork,
377 > {
378 ProviderBuilder::default().filler(TempoFillers::default())
379 }
380
381 fn with_nonce_key_filler(
382 self,
383 ) -> ProviderBuilder<Identity, JoinFill<Identity, TempoFillers<NonceKeyFiller>>, TempoNetwork>
384 {
385 ProviderBuilder::default().filler(TempoFillers::default())
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use alloy::sol_types::SolCall;
392 use alloy_primitives::{Address, Bytes, U64, U256};
393 use alloy_provider::{Identity, ProviderBuilder, fillers::JoinFill, mock::Asserter};
394 use tempo_contracts::precompiles::{
395 IAccountKeychain::{
396 CallScope as AbiCallScope, KeyInfo, SelectorRule as AbiSelectorRule, SignatureType,
397 getAllowedCallsCall, getKeyCall, getRemainingLimitWithPeriodCall,
398 getTransactionKeyCall,
399 },
400 INonce::getNonceCall,
401 getAllowedCallsReturn, getRemainingLimitReturn,
402 };
403 use tempo_primitives::transaction::{CallScope, SelectorRule, TEMPO_EXPIRING_NONCE_KEY};
404
405 use crate::{
406 TempoFillers, TempoNetwork,
407 fillers::{ExpiringNonceFiller, NonceKeyFiller, Random2DNonceFiller},
408 provider::ext::{SponsorConfig, TempoProviderBuilderExt, TempoProviderExt},
409 };
410
411 fn mock_provider(asserter: Asserter) -> impl alloy_provider::Provider<TempoNetwork> {
412 ProviderBuilder::<_, _, TempoNetwork>::default().connect_mocked_client(asserter)
413 }
414
415 #[test]
416 fn test_sponsor_builder_extension() {
417 let _ = ProviderBuilder::<_, _, TempoNetwork>::default()
418 .sponsor("https://sponsor.testnet.tempo.xyz");
419 let _ = ProviderBuilder::<_, _, TempoNetwork>::default().sponsor_with_config(
420 "https://sponsor.testnet.tempo.xyz",
421 SponsorConfig::sign_only(),
422 );
423 }
424
425 #[test]
426 fn test_with_random_nonces() {
427 let _: ProviderBuilder<_, JoinFill<Identity, TempoFillers<Random2DNonceFiller>>, _> =
428 ProviderBuilder::new_with_network::<TempoNetwork>().with_random_2d_nonces();
429 }
430
431 #[test]
432 fn test_with_expiring_nonces() {
433 let _: ProviderBuilder<_, JoinFill<Identity, TempoFillers<ExpiringNonceFiller>>, _> =
434 ProviderBuilder::new_with_network::<TempoNetwork>().with_expiring_nonces();
435 }
436
437 #[test]
438 fn test_with_nonce_key_filler() {
439 let _: ProviderBuilder<_, JoinFill<Identity, TempoFillers<NonceKeyFiller>>, _> =
440 ProviderBuilder::new_with_network::<TempoNetwork>().with_nonce_key_filler();
441 }
442
443 #[tokio::test]
444 async fn test_get_keychain_key() {
445 let asserter = Asserter::new();
446 let provider = mock_provider(asserter.clone());
447 let account = Address::repeat_byte(0x11);
448 let key_id = Address::repeat_byte(0x22);
449 let expected = KeyInfo {
450 signatureType: SignatureType::P256,
451 keyId: key_id,
452 expiry: 1_234_567_890,
453 enforceLimits: true,
454 isRevoked: false,
455 };
456
457 asserter.push_success(&Bytes::from(getKeyCall::abi_encode_returns(&expected)));
458
459 let actual = provider
460 .get_keychain_key(account, key_id)
461 .await
462 .expect("key info call succeeds");
463
464 assert_eq!(actual, expected);
465 }
466
467 #[tokio::test]
468 async fn test_get_transaction_count_with_protocol_nonce_key() {
469 let asserter = Asserter::new();
470 let provider = mock_provider(asserter.clone());
471 let account = Address::repeat_byte(0x11);
472 let expected = 42_u64;
473
474 asserter.push_success(&U64::from(expected));
475
476 let actual = provider
477 .get_transaction_count_with_nonce_key(account, U256::ZERO)
478 .await
479 .expect("protocol nonce query succeeds");
480
481 assert_eq!(actual, expected);
482 }
483
484 #[tokio::test]
485 async fn test_get_transaction_count_with_expiring_nonce_key() {
486 let provider = mock_provider(Asserter::new());
487
488 let actual = provider
489 .get_transaction_count_with_nonce_key(
490 Address::repeat_byte(0x11),
491 TEMPO_EXPIRING_NONCE_KEY,
492 )
493 .await
494 .expect("expiring nonce query succeeds");
495
496 assert_eq!(actual, 0);
497 }
498
499 #[tokio::test]
500 async fn test_get_transaction_count_with_2d_nonce_key() {
501 let asserter = Asserter::new();
502 let provider = mock_provider(asserter.clone());
503 let account = Address::repeat_byte(0x11);
504 let nonce_key = U256::from(7_u64);
505 let expected = 42_u64;
506
507 asserter.push_success(&Bytes::from(getNonceCall::abi_encode_returns(&expected)));
508
509 let actual = provider
510 .get_transaction_count_with_nonce_key(account, nonce_key)
511 .await
512 .expect("2D nonce query succeeds");
513
514 assert_eq!(actual, expected);
515 }
516
517 #[tokio::test]
518 async fn test_nonce_manager_accessor() {
519 let asserter = Asserter::new();
520 let provider = mock_provider(asserter.clone());
521 let account = Address::repeat_byte(0x11);
522 let nonce_key = U256::from(7_u64);
523 let expected = 42_u64;
524
525 asserter.push_success(&Bytes::from(getNonceCall::abi_encode_returns(&expected)));
526
527 let actual = provider
528 .nonce_manager()
529 .getNonce(account, nonce_key)
530 .call()
531 .await
532 .expect("typed nonce manager call succeeds");
533
534 assert_eq!(actual, expected);
535 }
536
537 #[tokio::test]
538 async fn test_get_keychain_remaining_limit() {
539 let asserter = Asserter::new();
540 let provider = mock_provider(asserter.clone());
541 let account = Address::repeat_byte(0x11);
542 let key_id = Address::repeat_byte(0x22);
543 let token = Address::repeat_byte(0x33);
544 let expected = U256::from(42_u64);
545
546 asserter.push_success(&Bytes::from(
547 getRemainingLimitWithPeriodCall::abi_encode_returns(&getRemainingLimitReturn {
548 remaining: expected,
549 periodEnd: 0,
550 }),
551 ));
552
553 let actual = provider
554 .get_keychain_remaining_limit(account, key_id, token)
555 .await
556 .expect("remaining limit call succeeds");
557
558 assert_eq!(actual, expected);
559 }
560
561 #[tokio::test]
562 async fn test_get_keychain_remaining_limit_with_period() {
563 let asserter = Asserter::new();
564 let provider = mock_provider(asserter.clone());
565 let account = Address::repeat_byte(0x11);
566 let key_id = Address::repeat_byte(0x22);
567 let token = Address::repeat_byte(0x33);
568 let expected = getRemainingLimitReturn {
569 remaining: U256::from(42_u64),
570 periodEnd: 123,
571 };
572
573 asserter.push_success(&Bytes::from(
574 getRemainingLimitWithPeriodCall::abi_encode_returns(&expected),
575 ));
576
577 let actual = provider
578 .get_keychain_remaining_limit_with_period(account, key_id, token)
579 .await
580 .expect("remaining limit with period call succeeds");
581
582 assert_eq!(actual, expected);
583 }
584
585 #[tokio::test]
586 async fn test_get_keychain_allowed_calls_maps_unrestricted_to_none() {
587 let asserter = Asserter::new();
588 let provider = mock_provider(asserter.clone());
589 let account = Address::repeat_byte(0x11);
590 let key_id = Address::repeat_byte(0x22);
591
592 asserter.push_success(&Bytes::from(getAllowedCallsCall::abi_encode_returns(
593 &getAllowedCallsReturn {
594 isScoped: false,
595 scopes: vec![],
596 },
597 )));
598
599 let actual = provider
600 .get_keychain_allowed_calls(account, key_id)
601 .await
602 .expect("allowed calls query succeeds");
603
604 assert_eq!(actual, None);
605 }
606
607 #[tokio::test]
608 async fn test_get_keychain_allowed_calls_maps_scopes() {
609 let asserter = Asserter::new();
610 let provider = mock_provider(asserter.clone());
611 let account = Address::repeat_byte(0x11);
612 let key_id = Address::repeat_byte(0x22);
613 let expected = vec![CallScope {
614 target: Address::repeat_byte(0x33),
615 selector_rules: vec![SelectorRule {
616 selector: [0xaa, 0xbb, 0xcc, 0xdd],
617 recipients: vec![Address::repeat_byte(0x44)],
618 }],
619 }];
620
621 asserter.push_success(&Bytes::from(getAllowedCallsCall::abi_encode_returns(
622 &getAllowedCallsReturn {
623 isScoped: true,
624 scopes: vec![AbiCallScope {
625 target: Address::repeat_byte(0x33),
626 selectorRules: vec![AbiSelectorRule {
627 selector: [0xaa, 0xbb, 0xcc, 0xdd].into(),
628 recipients: vec![Address::repeat_byte(0x44)],
629 }],
630 }],
631 },
632 )));
633
634 let actual = provider
635 .get_keychain_allowed_calls(account, key_id)
636 .await
637 .expect("allowed calls query succeeds");
638
639 assert_eq!(actual, Some(expected));
640 }
641
642 #[tokio::test]
643 async fn test_get_keychain_transaction_key() {
644 let asserter = Asserter::new();
645 let provider = mock_provider(asserter.clone());
646 let expected = Address::repeat_byte(0x44);
647
648 asserter.push_success(&Bytes::from(getTransactionKeyCall::abi_encode_returns(
649 &expected,
650 )));
651
652 let actual = provider
653 .get_keychain_transaction_key()
654 .await
655 .expect("transaction key call succeeds");
656
657 assert_eq!(actual, expected);
658 }
659
660 #[tokio::test]
661 async fn test_account_keychain_accessor() {
662 let asserter = Asserter::new();
663 let provider = mock_provider(asserter.clone());
664 let account = Address::repeat_byte(0x11);
665 let key_id = Address::repeat_byte(0x22);
666 let expected = KeyInfo {
667 signatureType: SignatureType::Secp256k1,
668 keyId: key_id,
669 expiry: u64::MAX,
670 enforceLimits: false,
671 isRevoked: true,
672 };
673
674 asserter.push_success(&Bytes::from(getKeyCall::abi_encode_returns(&expected)));
675
676 let actual = provider
677 .account_keychain()
678 .getKey(account, key_id)
679 .call()
680 .await
681 .expect("typed instance call succeeds");
682
683 assert_eq!(actual, expected);
684 }
685
686 #[tokio::test]
687 async fn test_get_keychain_key_propagates_errors() {
688 let asserter = Asserter::new();
689 let provider = mock_provider(asserter.clone());
690
691 asserter.push_failure_msg("boom");
692
693 let err = provider
694 .get_keychain_key(Address::repeat_byte(0x11), Address::repeat_byte(0x22))
695 .await
696 .expect_err("errors should propagate");
697
698 assert!(matches!(err, alloy_contract::Error::TransportError(_)));
699 assert!(err.to_string().contains("boom"));
700 }
701}