From aa0daf98e82c71a6eb3c8c65a154cee1afd6ffe2 Mon Sep 17 00:00:00 2001 From: Sam Wilson Date: Sat, 25 Jan 2025 13:47:05 -0500 Subject: [PATCH] update storage trie type (#1070 #1071) --- src/ethereum/prague/state.py | 21 +++++++++++---------- src/ethereum_optimized/state_db.py | 10 +++++----- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/ethereum/prague/state.py b/src/ethereum/prague/state.py index 1cb9cdcdc7..eae92ed528 100644 --- a/src/ethereum/prague/state.py +++ b/src/ethereum/prague/state.py @@ -19,7 +19,7 @@ from dataclasses import dataclass, field from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple -from ethereum_types.bytes import Bytes +from ethereum_types.bytes import Bytes, Bytes32 from ethereum_types.frozen import modify from ethereum_types.numeric import U256, Uint @@ -37,12 +37,13 @@ class State: _main_trie: Trie[Address, Optional[Account]] = field( default_factory=lambda: Trie(secured=True, default=None) ) - _storage_tries: Dict[Address, Trie[Bytes, U256]] = field( + _storage_tries: Dict[Address, Trie[Bytes32, U256]] = field( default_factory=dict ) _snapshots: List[ Tuple[ - Trie[Address, Optional[Account]], Dict[Address, Trie[Bytes, U256]] + Trie[Address, Optional[Account]], + Dict[Address, Trie[Bytes32, U256]], ] ] = field(default_factory=list) created_accounts: Set[Address] = field(default_factory=set) @@ -55,8 +56,8 @@ class TransientStorage: within a transaction. """ - _tries: Dict[Address, Trie[Bytes, U256]] = field(default_factory=dict) - _snapshots: List[Dict[Address, Trie[Bytes, U256]]] = field( + _tries: Dict[Address, Trie[Bytes32, U256]] = field(default_factory=dict) + _snapshots: List[Dict[Address, Trie[Bytes32, U256]]] = field( default_factory=list ) @@ -261,7 +262,7 @@ def mark_account_created(state: State, address: Address) -> None: state.created_accounts.add(address) -def get_storage(state: State, address: Address, key: Bytes) -> U256: +def get_storage(state: State, address: Address, key: Bytes32) -> U256: """ Get a value at a storage key on an account. Returns `U256(0)` if the storage key has not been set previously. @@ -291,7 +292,7 @@ def get_storage(state: State, address: Address, key: Bytes) -> U256: def set_storage( - state: State, address: Address, key: Bytes, value: U256 + state: State, address: Address, key: Bytes32, value: U256 ) -> None: """ Set a value at a storage key on an account. Setting to `U256(0)` deletes @@ -626,7 +627,7 @@ def write_code(sender: Account) -> None: modify_state(state, address, write_code) -def get_storage_original(state: State, address: Address, key: Bytes) -> U256: +def get_storage_original(state: State, address: Address, key: Bytes32) -> U256: """ Get the original value in a storage slot i.e. the value before the current transaction began. This function reads the value from the snapshots taken @@ -660,7 +661,7 @@ def get_storage_original(state: State, address: Address, key: Bytes) -> U256: def get_transient_storage( - transient_storage: TransientStorage, address: Address, key: Bytes + transient_storage: TransientStorage, address: Address, key: Bytes32 ) -> U256: """ Get a value at a storage key on an account from transient storage. @@ -691,7 +692,7 @@ def get_transient_storage( def set_transient_storage( transient_storage: TransientStorage, address: Address, - key: Bytes, + key: Bytes32, value: U256, ) -> None: """ diff --git a/src/ethereum_optimized/state_db.py b/src/ethereum_optimized/state_db.py index edcb165b38..d9c9ced6be 100644 --- a/src/ethereum_optimized/state_db.py +++ b/src/ethereum_optimized/state_db.py @@ -26,7 +26,7 @@ "package" ) -from ethereum_types.bytes import Bytes, Bytes20 +from ethereum_types.bytes import Bytes, Bytes20, Bytes32 from ethereum_types.numeric import U256, Uint from ethereum.crypto.hash import Hash32 @@ -76,7 +76,7 @@ class State: db: Any dirty_accounts: Dict[Address, Optional[Account_]] - dirty_storage: Dict[Address, Dict[Bytes, U256]] + dirty_storage: Dict[Address, Dict[Bytes32, U256]] destroyed_accounts: Set[Address] tx_restore_points: List[int] journal: List[Any] @@ -328,7 +328,7 @@ def rollback_transaction(state: State) -> None: _rollback_transaction(state) @add_item(patches) - def get_storage(state: State, address: Address, key: Bytes) -> U256: + def get_storage(state: State, address: Address, key: Bytes32) -> U256: """ See `state`. """ @@ -345,7 +345,7 @@ def get_storage(state: State, address: Address, key: Bytes) -> U256: @add_item(patches) def get_storage_original( - state: State, address: Address, key: Bytes + state: State, address: Address, key: Bytes32 ) -> U256: """ See `state`. @@ -357,7 +357,7 @@ def get_storage_original( @add_item(patches) def set_storage( - state: State, address: Address, key: Bytes, value: U256 + state: State, address: Address, key: Bytes32, value: U256 ) -> None: """ See `state`.