From e3f2b88c731728a4dcc260ddebcb768589a5ffd8 Mon Sep 17 00:00:00 2001 From: Yoav Gross Date: Thu, 5 Dec 2024 13:19:59 +0200 Subject: [PATCH] feat(blockifier): iterate over aliases on the state diff --- .../src/blockifier/transaction_executor.rs | 12 ++-- .../src/state/stateful_compression.rs | 38 +++++++++++++ .../src/state/stateful_compression_test.rs | 55 ++++++++++++++++++- 3 files changed, 98 insertions(+), 7 deletions(-) diff --git a/crates/blockifier/src/blockifier/transaction_executor.rs b/crates/blockifier/src/blockifier/transaction_executor.rs index bec2bd5b8c..56dde1e623 100644 --- a/crates/blockifier/src/blockifier/transaction_executor.rs +++ b/crates/blockifier/src/blockifier/transaction_executor.rs @@ -16,6 +16,7 @@ use crate::context::BlockContext; use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult}; +use crate::state::stateful_compression::allocate_aliases; use crate::transaction::errors::TransactionExecutionError; use crate::transaction::objects::TransactionExecutionInfo; use crate::transaction::transaction_execution::Transaction; @@ -166,13 +167,12 @@ impl TransactionExecutor { .collect::>()?; log::debug!("Final block weights: {:?}.", self.bouncer.get_accumulated_weights()); + let mut block_state = self.block_state.take().expect(BLOCK_STATE_ACCESS_ERR); + if self.block_context.versioned_constants.enable_stateful_compression { + allocate_aliases(&mut block_state)?; + } Ok(( - self.block_state - .as_mut() - .expect(BLOCK_STATE_ACCESS_ERR) - .to_state_diff()? - .state_maps - .into(), + block_state.to_state_diff()?.state_maps.into(), visited_segments, *self.bouncer.get_accumulated_weights(), )) diff --git a/crates/blockifier/src/state/stateful_compression.rs b/crates/blockifier/src/state/stateful_compression.rs index 34e24609c2..8a34cff00e 100644 --- a/crates/blockifier/src/state/stateful_compression.rs +++ b/crates/blockifier/src/state/stateful_compression.rs @@ -1,3 +1,5 @@ +use std::collections::{BTreeSet, HashMap}; + use starknet_api::core::{ContractAddress, PatriciaKey}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; @@ -16,6 +18,9 @@ type AliasKey = StorageKey; const ALIAS_CONTRACT_ADDRESS: u8 = 2; // The storage key of the alias counter in the alias contract. const ALIAS_COUNTER_STORAGE_KEY: u8 = 0; +// The maximal contract address for which aliases are not used and all keys are serialized as is, +// without compression. +const MAX_NON_COMPRESSED_CONTRACT_ADDRESS: u8 = 15; // The minimal value for a key to be allocated an alias. Smaller keys are serialized as is (their // alias is identical to the key). const MIN_VALUE_FOR_ALIAS_ALLOC: Felt = Felt::from_hex_unchecked("0x80"); @@ -26,6 +31,39 @@ pub fn get_alias_contract_address() -> ContractAddress { pub fn get_alias_counter_storage_key() -> StorageKey { StorageKey::from(ALIAS_COUNTER_STORAGE_KEY) } +pub fn get_max_non_compressed_contract_address() -> ContractAddress { + ContractAddress::from(MAX_NON_COMPRESSED_CONTRACT_ADDRESS) +} + +/// Allocates aliases for the new addresses and storage keys in the alias contract. +/// Iterates over the addresses in ascending order. For each address, sets an alias for the new +/// storage keys (in ascending order) and for the address itself. +pub fn allocate_aliases(state: &mut CachedState) -> StateResult<()> { + let writes = state.borrow_updated_state_cache()?.clone().writes; + + // Collect the addresses and the storage keys that need aliases. + let mut addresses = BTreeSet::new(); + let mut sorted_storage_keys = HashMap::new(); + addresses.extend(writes.class_hashes.keys().chain(writes.nonces.keys())); + for (address, storage_key) in writes.storage.keys() { + addresses.insert(address); + if address > &get_max_non_compressed_contract_address() { + sorted_storage_keys.entry(address).or_insert_with(BTreeSet::new).insert(storage_key); + } + } + + // Iterate over the addresses and the storage keys and update the aliases. + let mut alias_updater = AliasUpdater::new(state)?; + for address in addresses { + if let Some(storage_keys) = sorted_storage_keys.get(address) { + for key in storage_keys { + alias_updater.set_alias(key)?; + } + } + alias_updater.set_alias(&StorageKey(address.0))?; + } + alias_updater.finalize_updates() +} /// Updates the alias contract with the new keys. struct AliasUpdater<'a, S: StateReader> { diff --git a/crates/blockifier/src/state/stateful_compression_test.rs b/crates/blockifier/src/state/stateful_compression_test.rs index 66f90b8f7d..1ee20c3eb6 100644 --- a/crates/blockifier/src/state/stateful_compression_test.rs +++ b/crates/blockifier/src/state/stateful_compression_test.rs @@ -1,17 +1,20 @@ use std::collections::HashMap; use rstest::rstest; -use starknet_api::core::ContractAddress; +use starknet_api::core::{ClassHash, ContractAddress}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; use super::{ + allocate_aliases, get_alias_contract_address, get_alias_counter_storage_key, + get_max_non_compressed_contract_address, AliasUpdater, MIN_VALUE_FOR_ALIAS_ALLOC, }; use crate::state::cached_state::CachedState; +use crate::state::state_api::{State, StateReader}; use crate::test_utils::dict_state_reader::DictStateReader; fn insert_to_alias_contract( @@ -98,3 +101,53 @@ fn test_alias_updater( assert_eq!(storage_diff, expected_storage_diff); } + +#[test] +fn test_iterate_aliases() { + let mut state = initial_state(0); + state + .set_storage_at(ContractAddress::from(0x201_u16), StorageKey::from(0x300_u16), Felt::ONE) + .unwrap(); + state + .set_storage_at( + get_max_non_compressed_contract_address(), + StorageKey::from(0x301_u16), + Felt::ONE, + ) + .unwrap(); + state.get_class_hash_at(ContractAddress::from(0x202_u16)).unwrap(); + state.set_class_hash_at(ContractAddress::from(0x202_u16), ClassHash::default()).unwrap(); + state.increment_nonce(ContractAddress::from(0x200_u16)).unwrap(); + + allocate_aliases(&mut state).unwrap(); + let storage_diff = state.to_state_diff().unwrap().state_maps.storage; + assert_eq!( + storage_diff, + vec![ + ( + (get_alias_contract_address(), get_alias_counter_storage_key()), + MIN_VALUE_FOR_ALIAS_ALLOC + Felt::from(4_u8) + ), + ( + (get_alias_contract_address(), StorageKey::from(0x200_u16)), + MIN_VALUE_FOR_ALIAS_ALLOC + ), + ( + (get_alias_contract_address(), StorageKey::from(0x300_u16)), + MIN_VALUE_FOR_ALIAS_ALLOC + Felt::ONE + ), + ( + (get_alias_contract_address(), StorageKey::from(0x201_u16)), + MIN_VALUE_FOR_ALIAS_ALLOC + Felt::TWO + ), + ( + (get_alias_contract_address(), StorageKey::from(0x202_u16)), + MIN_VALUE_FOR_ALIAS_ALLOC + Felt::THREE + ), + ((ContractAddress::from(0x201_u16), StorageKey::from(0x300_u16)), Felt::ONE), + ((get_max_non_compressed_contract_address(), StorageKey::from(0x301_u16)), Felt::ONE), + ] + .into_iter() + .collect() + ); +}