tempo_precompiles_macros/
utils.rs

1//! Utility functions for the contract macro implementation.
2
3use alloy::primitives::{U256, keccak256};
4use syn::{Attribute, Lit, Type};
5
6/// Return type for [`extract_attributes`]: (slot, base_slot)
7type ExtractedAttributes = (Option<U256>, Option<U256>);
8
9/// Parses a slot value from a literal.
10///
11/// Supports:
12/// - Integer literals: decimal (`42`) or hexadecimal (`0x2a`)
13/// - String literals: computes keccak256 hash of the string
14fn parse_slot_value(value: &Lit) -> syn::Result<U256> {
15    match value {
16        Lit::Int(int) => {
17            let lit_str = int.to_string();
18            let slot = if let Some(hex) = lit_str.strip_prefix("0x") {
19                U256::from_str_radix(hex, 16)
20            } else {
21                U256::from_str_radix(&lit_str, 10)
22            }
23            .map_err(|_| syn::Error::new_spanned(int, "Invalid slot number"))?;
24            Ok(slot)
25        }
26        Lit::Str(lit) => Ok(keccak256(lit.value().as_bytes()).into()),
27        _ => Err(syn::Error::new_spanned(
28            value,
29            "slot attribute must be an integer or a string literal",
30        )),
31    }
32}
33
34/// Converts a string from CamelCase or snake_case to snake_case.
35/// Preserves SCREAMING_SNAKE_CASE, as those are assumed to be constant/immutable names.
36pub(crate) fn to_snake_case(s: &str) -> String {
37    let constant = s.to_uppercase();
38    if s == constant {
39        return constant;
40    }
41
42    let mut result = String::with_capacity(s.len() + 4);
43    let mut chars = s.chars().peekable();
44    let mut prev_upper = false;
45
46    while let Some(c) = chars.next() {
47        if c.is_uppercase() {
48            if !result.is_empty()
49                && (!prev_upper || chars.peek().is_some_and(|&next| next.is_lowercase()))
50            {
51                result.push('_');
52            }
53            result.push(c.to_ascii_lowercase());
54            prev_upper = true;
55        } else {
56            result.push(c);
57            prev_upper = false;
58        }
59    }
60
61    result
62}
63
64/// Converts a string from snake_case to camelCase.
65pub(crate) fn to_camel_case(s: &str) -> String {
66    let mut result = String::new();
67    let mut first_word = true;
68
69    for word in s.split('_') {
70        if word.is_empty() {
71            continue;
72        }
73
74        if first_word {
75            result.push_str(word);
76            first_word = false;
77        } else {
78            let mut chars = word.chars();
79            if let Some(first) = chars.next() {
80                result.push_str(&first.to_uppercase().collect::<String>());
81                result.push_str(chars.as_str());
82            }
83        }
84    }
85    result
86}
87
88/// Extracts `#[slot(N)]`, `#[base_slot(N)]` attributes from a field's attributes.
89///
90/// This function iterates through the attributes a single time to find all
91/// relevant values. It returns a tuple containing:
92/// - The slot number (if present)
93/// - The base_slot number (if present)
94///
95/// # Errors
96///
97/// Returns an error if:
98/// - Both `#[slot]` and `#[base_slot]` are present on the same field
99/// - Duplicate attributes of the same type are found
100pub(crate) fn extract_attributes(attrs: &[Attribute]) -> syn::Result<ExtractedAttributes> {
101    let mut slot_attr: Option<U256> = None;
102    let mut base_slot_attr: Option<U256> = None;
103
104    for attr in attrs {
105        // Extract `#[slot(N)]` attribute
106        if attr.path().is_ident("slot") {
107            if slot_attr.is_some() {
108                return Err(syn::Error::new_spanned(attr, "duplicate `slot` attribute"));
109            }
110            if base_slot_attr.is_some() {
111                return Err(syn::Error::new_spanned(
112                    attr,
113                    "cannot use both `slot` and `base_slot` attributes on the same field",
114                ));
115            }
116
117            let value: Lit = attr.parse_args()?;
118            slot_attr = Some(parse_slot_value(&value)?);
119        }
120        // Extract `#[base_slot(N)]` attribute
121        else if attr.path().is_ident("base_slot") {
122            if base_slot_attr.is_some() {
123                return Err(syn::Error::new_spanned(
124                    attr,
125                    "duplicate `base_slot` attribute",
126                ));
127            }
128            if slot_attr.is_some() {
129                return Err(syn::Error::new_spanned(
130                    attr,
131                    "cannot use both `slot` and `base_slot` attributes on the same field",
132                ));
133            }
134
135            let value: Lit = attr.parse_args()?;
136            base_slot_attr = Some(parse_slot_value(&value)?);
137        }
138    }
139
140    Ok((slot_attr, base_slot_attr))
141}
142
143/// Extracts array sizes from the `#[storable_arrays(...)]` attribute.
144///
145/// Parses attributes like `#[storable_arrays(1, 2, 4, 8)]` and returns a vector
146/// of the specified sizes. Returns `None` if the attribute is not present.
147///
148/// # Format
149///
150/// The attribute should be a comma-separated list of positive integer literals:
151/// ```ignore
152/// #[storable_arrays(1, 2, 4, 8, 16, 32)]
153/// ```
154///
155/// # Errors
156///
157/// Returns an error if:
158/// - The attribute is present but has invalid syntax
159/// - Any size is 0 or exceeds 256
160/// - Duplicate array sizes are specified
161pub(crate) fn extract_storable_array_sizes(attrs: &[Attribute]) -> syn::Result<Option<Vec<usize>>> {
162    for attr in attrs {
163        if attr.path().is_ident("storable_arrays") {
164            // Parse the attribute arguments as a comma-separated list
165            let parsed = attr.parse_args_with(
166                syn::punctuated::Punctuated::<Lit, syn::Token![,]>::parse_terminated,
167            )?;
168
169            let mut sizes = Vec::new();
170            for lit in parsed {
171                if let Lit::Int(int) = lit {
172                    let size = int.base10_parse::<usize>().map_err(|_| {
173                        syn::Error::new_spanned(
174                            &int,
175                            "Invalid array size: must be a positive integer",
176                        )
177                    })?;
178
179                    if size == 0 {
180                        return Err(syn::Error::new_spanned(
181                            &int,
182                            "Array size must be greater than 0",
183                        ));
184                    }
185
186                    if size > 256 {
187                        return Err(syn::Error::new_spanned(
188                            &int,
189                            "Array size must not exceed 256",
190                        ));
191                    }
192
193                    if sizes.contains(&size) {
194                        return Err(syn::Error::new_spanned(
195                            &int,
196                            format!("Duplicate array size: {size}"),
197                        ));
198                    }
199
200                    sizes.push(size);
201                } else {
202                    return Err(syn::Error::new_spanned(
203                        lit,
204                        "Array sizes must be integer literals",
205                    ));
206                }
207            }
208
209            if sizes.is_empty() {
210                return Err(syn::Error::new_spanned(
211                    attr,
212                    "storable_arrays attribute requires at least one size",
213                ));
214            }
215
216            return Ok(Some(sizes));
217        }
218    }
219
220    Ok(None)
221}
222
223/// Extracts the type parameters from Mapping<K, V>.
224///
225/// Returns Some((key_type, value_type)) if the type is a Mapping, None otherwise.
226pub(crate) fn extract_mapping_types(ty: &Type) -> Option<(&Type, &Type)> {
227    if let Type::Path(type_path) = ty {
228        let last_segment = type_path.path.segments.last()?;
229
230        // Check if the type is named "Mapping"
231        if last_segment.ident != "Mapping" {
232            return None;
233        }
234
235        // Extract generic arguments
236        if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
237            let mut iter = args.args.iter();
238
239            // First argument: key type
240            let key_type = if let Some(syn::GenericArgument::Type(ty)) = iter.next() {
241                ty
242            } else {
243                return None;
244            };
245
246            // Second argument: value type
247            let value_type = if let Some(syn::GenericArgument::Type(ty)) = iter.next() {
248                ty
249            } else {
250                return None;
251            };
252
253            return Some((key_type, value_type));
254        }
255    }
256    None
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use syn::parse_quote;
263
264    #[test]
265    fn test_to_snake_case() {
266        assert_eq!(to_snake_case("balanceOf"), "balance_of");
267        assert_eq!(to_snake_case("transferFrom"), "transfer_from");
268        assert_eq!(to_snake_case("name"), "name");
269        assert_eq!(to_snake_case("already_snake"), "already_snake");
270        assert_eq!(to_snake_case("updateQuoteToken"), "update_quote_token");
271        assert_eq!(to_snake_case("DOMAIN_SEPARATOR"), "DOMAIN_SEPARATOR");
272        assert_eq!(to_snake_case("ERC20Token"), "erc20_token");
273    }
274
275    #[test]
276    fn test_to_camel_case() {
277        assert_eq!(to_camel_case("balance_of"), "balanceOf");
278        assert_eq!(to_camel_case("transfer_from"), "transferFrom");
279        assert_eq!(to_camel_case("update_quote_token"), "updateQuoteToken");
280        assert_eq!(to_camel_case("name"), "name");
281        assert_eq!(to_camel_case("token"), "token");
282        assert_eq!(to_camel_case("alreadycamelCase"), "alreadycamelCase");
283        assert_eq!(to_camel_case("DOMAIN_SEPARATOR"), "DOMAINSEPARATOR");
284    }
285
286    #[test]
287    fn test_extract_mapping_types() {
288        // Test simple mapping
289        let ty: Type = parse_quote!(Mapping<Address, U256>);
290        let result = extract_mapping_types(&ty);
291        assert!(result.is_some());
292
293        // Test nested mapping
294        let ty: Type = parse_quote!(Mapping<Address, Mapping<Address, U256>>);
295        let result = extract_mapping_types(&ty);
296        assert!(result.is_some());
297
298        // Test non-mapping type
299        let ty: Type = parse_quote!(String);
300        let result = extract_mapping_types(&ty);
301        assert!(result.is_none());
302
303        // Test non-mapping generic type
304        let ty: Type = parse_quote!(Vec<u8>);
305        let result = extract_mapping_types(&ty);
306        assert!(result.is_none());
307    }
308}