use core::str::FromStr;
use alloc::collections::BTreeMap;
use alloy_primitives::{ruint::ParseError, Bytes, B256, U256};
use core::fmt::{Display, Formatter};
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum JsonStorageKey {
Hash(B256),
Number(U256),
}
impl JsonStorageKey {
pub fn as_b256(&self) -> B256 {
match self {
Self::Hash(hash) => *hash,
Self::Number(num) => B256::from(*num),
}
}
}
impl Default for JsonStorageKey {
fn default() -> Self {
Self::Hash(Default::default())
}
}
impl From<B256> for JsonStorageKey {
fn from(value: B256) -> Self {
Self::Hash(value)
}
}
impl From<[u8; 32]> for JsonStorageKey {
fn from(value: [u8; 32]) -> Self {
B256::from(value).into()
}
}
impl From<U256> for JsonStorageKey {
fn from(value: U256) -> Self {
Self::Number(value)
}
}
impl FromStr for JsonStorageKey {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(hash) = B256::from_str(s) {
return Ok(Self::Hash(hash));
}
s.parse().map(Self::Number)
}
}
impl Display for JsonStorageKey {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
match self {
Self::Hash(hash) => hash.fmt(f),
Self::Number(num) => alloc::format!("{num:#x}").fmt(f),
}
}
}
pub fn from_bytes_to_b256<'de, D>(bytes: Bytes) -> Result<B256, D::Error>
where
D: Deserializer<'de>,
{
if bytes.0.len() > 32 {
return Err(serde::de::Error::custom("input too long to be a B256"));
}
let mut padded = [0u8; 32];
padded[32 - bytes.0.len()..].copy_from_slice(&bytes.0);
Ok(B256::from_slice(&padded))
}
pub fn deserialize_storage_map<'de, D>(
deserializer: D,
) -> Result<Option<BTreeMap<B256, B256>>, D::Error>
where
D: Deserializer<'de>,
{
let map = Option::<BTreeMap<Bytes, Bytes>>::deserialize(deserializer)?;
match map {
Some(map) => {
let mut res_map = BTreeMap::new();
for (k, v) in map {
let k_deserialized = from_bytes_to_b256::<'de, D>(k)?;
let v_deserialized = from_bytes_to_b256::<'de, D>(v)?;
res_map.insert(k_deserialized, v_deserialized);
}
Ok(Some(res_map))
}
None => Ok(None),
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::{String, ToString};
use serde_json::json;
#[test]
fn default_number_storage_key() {
let key = JsonStorageKey::Number(Default::default());
assert_eq!(key.to_string(), String::from("0x0"));
}
#[test]
fn default_hash_storage_key() {
let key = JsonStorageKey::default();
assert_eq!(
key.to_string(),
String::from("0x0000000000000000000000000000000000000000000000000000000000000000")
);
}
#[test]
fn test_storage_key() {
let cases = [
"0x0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000001", ];
let key: JsonStorageKey = serde_json::from_str(&json!(cases[0]).to_string()).unwrap();
let key2: JsonStorageKey = serde_json::from_str(&json!(cases[1]).to_string()).unwrap();
assert_eq!(key.as_b256(), key2.as_b256());
}
#[test]
fn test_storage_key_serde_roundtrips() {
let test_cases = [
"0x0000000000000000000000000000000000000000000000000000000000000001", "0x0000000000000000000000000000000000000000000000000000000000000abc", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xabc", "0xabcd", ];
for input in test_cases {
let key: JsonStorageKey = serde_json::from_str(&json!(input).to_string()).unwrap();
let output = key.to_string();
assert_eq!(
input, output,
"Storage key roundtrip failed to preserve the exact hex representation for {}",
input
);
}
}
#[test]
fn test_as_b256() {
let cases = [
"0x0abc", "0x0000000000000000000000000000000000000000000000000000000000000abc", ];
let num_key: JsonStorageKey = serde_json::from_str(&json!(cases[0]).to_string()).unwrap();
let hash_key: JsonStorageKey = serde_json::from_str(&json!(cases[1]).to_string()).unwrap();
assert_eq!(num_key, JsonStorageKey::Number(U256::from_str(cases[0]).unwrap()));
assert_eq!(hash_key, JsonStorageKey::Hash(B256::from_str(cases[1]).unwrap()));
assert_eq!(num_key.as_b256(), hash_key.as_b256());
}
}