Skip to content

Commit

Permalink
chore: smol refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rkdud007 committed Jul 16, 2024
1 parent f463d59 commit a94ed6f
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 73 deletions.
26 changes: 5 additions & 21 deletions crates/pre-processor/src/compile/module.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
//! Preprocessor is reponsible for identifying the required values.
//! This will be most abstract layer of the preprocessor.

use alloy::primitives::ChainId;
use core::panic;
use hdp_cairo_runner::dry_run::DryRunResult;
use hdp_cairo_runner::{cairo_dry_run, input::dry_run::DryRunnerProgramInput};
use hdp_primitives::constant::DRY_CAIRO_RUN_OUTPUT_FILE;
use hdp_primitives::processed_types::cairo_format;
use hdp_primitives::task::ExtendedModule;
use hdp_provider::{evm::provider::EvmProvider, key::FetchKeyEnvelope};
use std::collections::{HashMap, HashSet};
use hdp_provider::evm::from_keys::categorize_fetch_keys;
use hdp_provider::evm::provider::EvmProvider;
use std::collections::HashMap;
use std::path::PathBuf;
use tracing::info;

Expand Down Expand Up @@ -54,7 +54,7 @@ impl Compilable for ModuleVec {
);

// 3. call provider using keys
let keys_maps_chain = categorize_fetch_keys_by_chain_id(dry_runned_module.fetch_keys);
let keys_maps_chain = categorize_fetch_keys(dry_runned_module.fetch_keys);
if keys_maps_chain.len() > 1 {
// TODO: This is temporary solution. Need to handle multiple chain id in future
panic!("Multiple chain id is not supported yet");
Expand All @@ -67,9 +67,7 @@ impl Compilable for ModuleVec {
// But as this have not used, for now we can just follow batch's chain id
info!("3. Fetching proofs from provider...");
let provider = EvmProvider::new(compile_config.provider_config.clone());
let results = provider
.fetch_proofs_from_keys(keys.into_iter().collect())
.await?;
let results = provider.fetch_proofs_from_keys(keys).await?;

Ok(CompilationResult::new(
true,
Expand All @@ -84,20 +82,6 @@ impl Compilable for ModuleVec {
}
}

/// Categorize fetch keys by chain id
/// This is require to initiate multiple provider for different chain id
fn categorize_fetch_keys_by_chain_id(
fetch_keys: Vec<FetchKeyEnvelope>,
) -> Vec<(ChainId, HashSet<FetchKeyEnvelope>)> {
let mut chain_id_map = std::collections::HashMap::new();
for key in fetch_keys {
let chain_id = key.get_chain_id();
let keys = chain_id_map.entry(chain_id).or_insert_with(HashSet::new);
keys.insert(key);
}
chain_id_map.into_iter().collect()
}

/// Generate input structure for preprocessor that need to pass to runner
async fn generate_input(
extended_modules: Vec<ExtendedModule>,
Expand Down
157 changes: 105 additions & 52 deletions crates/provider/src/evm/from_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::key::{
AccountMemorizerKey, FetchKeyEnvelope, HeaderMemorizerKey, StorageMemorizerKey, TxMemorizerKey,
TxReceiptMemorizerKey,
};
use alloy::primitives::{Address, BlockNumber, Bytes, TxIndex, B256};
use alloy::primitives::{Address, BlockNumber, Bytes, ChainId, TxIndex, B256};
use alloy::transports::{RpcError, TransportErrorKind};
use eth_trie_proofs::tx_receipt_trie::TxReceiptsMptHandler;
use eth_trie_proofs::tx_trie::TxsMptHandler;
Expand All @@ -21,74 +21,119 @@ use std::collections::{HashMap, HashSet};
use std::time::Instant;
use tracing::info;

impl EvmProvider {
/// This is the public entry point of provider.
pub async fn fetch_proofs_from_keys(
&self,
fetch_keys: Vec<FetchKeyEnvelope>,
) -> Result<ProcessedBlockProofs, ProviderError> {
let mut target_keys_for_header = Vec::new();
let mut target_keys_for_account = Vec::new();
let mut target_keys_for_storage = Vec::new();
let mut target_keys_for_tx = Vec::new();
let mut target_keys_for_tx_receipt = Vec::new();
for key in fetch_keys {
match key {
FetchKeyEnvelope::Header(header_key) => {
target_keys_for_header.push(header_key);
}
FetchKeyEnvelope::Account(account_key) => {
target_keys_for_header.push(HeaderMemorizerKey::new(
#[derive(Debug, Default)]
/// This is keys that are categorized into different subsets of keys.
pub struct CategorizedFetchKeys {
pub headers: HashSet<HeaderMemorizerKey>,
pub accounts: HashSet<AccountMemorizerKey>,
pub storage: HashSet<StorageMemorizerKey>,
pub txs: HashSet<TxMemorizerKey>,
pub tx_receipts: HashSet<TxReceiptMemorizerKey>,
}

impl CategorizedFetchKeys {
pub fn new(
headers: HashSet<HeaderMemorizerKey>,
accounts: HashSet<AccountMemorizerKey>,
storage: HashSet<StorageMemorizerKey>,
txs: HashSet<TxMemorizerKey>,
tx_receipts: HashSet<TxReceiptMemorizerKey>,
) -> Self {
Self {
headers,
accounts,
storage,
txs,
tx_receipts,
}
}
}

/// Categorize fetch keys
/// This is require to initiate multiple provider for different chain and fetch keys types
pub fn categorize_fetch_keys(
fetch_keys: Vec<FetchKeyEnvelope>,
) -> Vec<(ChainId, CategorizedFetchKeys)> {
let mut chain_id_map: HashMap<u64, CategorizedFetchKeys> = std::collections::HashMap::new();

for key in fetch_keys {
let chain_id = key.get_chain_id();
let target_categorized_fetch_keys = chain_id_map.entry(chain_id).or_default();

match key {
FetchKeyEnvelope::Header(header_key) => {
target_categorized_fetch_keys.headers.insert(header_key);
}
FetchKeyEnvelope::Account(account_key) => {
target_categorized_fetch_keys
.headers
.insert(HeaderMemorizerKey::new(
account_key.chain_id,
account_key.block_number,
));
target_keys_for_account.push(account_key);
}
FetchKeyEnvelope::Storage(storage_key) => {
target_keys_for_header.push(HeaderMemorizerKey::new(
target_categorized_fetch_keys.accounts.insert(account_key);
}
FetchKeyEnvelope::Storage(storage_key) => {
target_categorized_fetch_keys
.headers
.insert(HeaderMemorizerKey::new(
storage_key.chain_id,
storage_key.block_number,
));
target_keys_for_storage.push(storage_key);
}
FetchKeyEnvelope::Tx(tx_key) => {
target_keys_for_header.push(HeaderMemorizerKey::new(
target_categorized_fetch_keys.storage.insert(storage_key);
}
FetchKeyEnvelope::Tx(tx_key) => {
target_categorized_fetch_keys
.headers
.insert(HeaderMemorizerKey::new(
tx_key.chain_id,
tx_key.block_number,
));
target_keys_for_tx.push(tx_key);
}
FetchKeyEnvelope::TxReceipt(tx_receipt_key) => {
target_keys_for_header.push(HeaderMemorizerKey::new(
target_categorized_fetch_keys.txs.insert(tx_key);
}
FetchKeyEnvelope::TxReceipt(tx_receipt_key) => {
target_categorized_fetch_keys
.headers
.insert(HeaderMemorizerKey::new(
tx_receipt_key.chain_id,
tx_receipt_key.block_number,
));
target_keys_for_tx_receipt.push(tx_receipt_key);
}
target_categorized_fetch_keys
.tx_receipts
.insert(tx_receipt_key);
}
}
}
chain_id_map.into_iter().collect()
}

impl EvmProvider {
/// This is the public entry point of provider.
pub async fn fetch_proofs_from_keys(
&self,
fetch_keys: CategorizedFetchKeys,
) -> Result<ProcessedBlockProofs, ProviderError> {
// fetch proofs using keys and construct result
let (headers, mmr_metas) = self.get_headers_from_keys(target_keys_for_header).await?;
let mut accounts = if target_keys_for_account.is_empty() {
let (headers, mmr_metas) = self.get_headers_from_keys(fetch_keys.headers).await?;
let mut accounts = if fetch_keys.accounts.is_empty() {
HashSet::new()
} else {
self.get_accounts_from_keys(target_keys_for_account).await?
self.get_accounts_from_keys(fetch_keys.accounts).await?
};
let (accounts_from_storage_key, storages) = if target_keys_for_storage.is_empty() {
let (accounts_from_storage_key, storages) = if fetch_keys.storage.is_empty() {
(HashSet::new(), HashSet::new())
} else {
self.get_storages_from_keys(target_keys_for_storage).await?
self.get_storages_from_keys(fetch_keys.storage).await?
};
let transactions = if target_keys_for_tx.is_empty() {
let transactions = if fetch_keys.txs.is_empty() {
vec![]
} else {
self.get_txs_from_keys(target_keys_for_tx).await?
self.get_txs_from_keys(fetch_keys.txs).await?
};
let transaction_receipts = if target_keys_for_tx_receipt.is_empty() {
let transaction_receipts = if fetch_keys.tx_receipts.is_empty() {
vec![]
} else {
self.get_tx_receipts_from_keys(target_keys_for_tx_receipt)
self.get_tx_receipts_from_keys(fetch_keys.tx_receipts)
.await?
};
accounts.extend(accounts_from_storage_key);
Expand All @@ -105,7 +150,7 @@ impl EvmProvider {

async fn get_headers_from_keys(
&self,
keys: Vec<HeaderMemorizerKey>,
keys: HashSet<HeaderMemorizerKey>,
) -> Result<(HashSet<ProcessedHeader>, Vec<MMRMeta>), ProviderError> {
let start_fetch = Instant::now();

Expand All @@ -121,7 +166,7 @@ impl EvmProvider {
self._chunk_vec_blocks_for_indexer(block_range)
};

let chain_id = keys.first().unwrap().chain_id;
let chain_id = keys.iter().next().unwrap().chain_id;
let mut fetched_headers_proofs: HashSet<ProcessedHeader> = HashSet::new();
let mut mmrs = HashSet::new();

Expand Down Expand Up @@ -167,7 +212,7 @@ impl EvmProvider {

async fn get_accounts_from_keys(
&self,
keys: Vec<AccountMemorizerKey>,
keys: HashSet<AccountMemorizerKey>,
) -> Result<HashSet<ProcessedAccount>, ProviderError> {
let mut fetched_accounts_proofs: HashSet<ProcessedAccount> = HashSet::new();
let start_fetch = Instant::now();
Expand Down Expand Up @@ -218,7 +263,7 @@ impl EvmProvider {

async fn get_storages_from_keys(
&self,
keys: Vec<StorageMemorizerKey>,
keys: HashSet<StorageMemorizerKey>,
) -> Result<(HashSet<ProcessedAccount>, HashSet<ProcessedStorage>), ProviderError> {
let mut fetched_accounts_proofs: HashSet<ProcessedAccount> = HashSet::new();
let mut fetched_storage_proofs: HashSet<ProcessedStorage> = HashSet::new();
Expand Down Expand Up @@ -284,7 +329,7 @@ impl EvmProvider {

pub async fn get_txs_from_keys(
&self,
keys: Vec<TxMemorizerKey>,
keys: HashSet<TxMemorizerKey>,
) -> Result<Vec<ProcessedTransaction>, ProviderError> {
let mut fetched_transactions = vec![];
let start_fetch = Instant::now();
Expand Down Expand Up @@ -333,7 +378,7 @@ impl EvmProvider {

pub async fn get_tx_receipts_from_keys(
&self,
keys: Vec<TxReceiptMemorizerKey>,
keys: HashSet<TxReceiptMemorizerKey>,
) -> Result<Vec<ProcessedReceipt>, ProviderError> {
let mut fetched_transaction_receipts = vec![];
let start_fetch = Instant::now();
Expand Down Expand Up @@ -401,7 +446,9 @@ mod tests {
FetchKeyEnvelope::Header(HeaderMemorizerKey::new(target_chain_id, 2)),
FetchKeyEnvelope::Header(HeaderMemorizerKey::new(target_chain_id, 3)),
];
let proofs = provider.fetch_proofs_from_keys(keys).await.unwrap();
let (chain_id, fetched_keys) = categorize_fetch_keys(keys).into_iter().next().unwrap();
assert_eq!(chain_id, target_chain_id);
let proofs = provider.fetch_proofs_from_keys(fetched_keys).await.unwrap();
assert_eq!(proofs.headers.len(), 3);
}

Expand All @@ -423,7 +470,9 @@ mod tests {
target_address,
)),
];
let proofs = provider.fetch_proofs_from_keys(keys).await.unwrap();
let (chain_id, fetched_keys) = categorize_fetch_keys(keys).into_iter().next().unwrap();
assert_eq!(chain_id, target_chain_id);
let proofs = provider.fetch_proofs_from_keys(fetched_keys).await.unwrap();
assert_eq!(proofs.accounts[0].proofs.len(), 3);
assert_eq!(proofs.headers.len(), 3);
}
Expand Down Expand Up @@ -473,7 +522,9 @@ mod tests {
target_slot,
)),
];
let proofs = provider.fetch_proofs_from_keys(keys).await.unwrap();
let (chain_id, fetched_keys) = categorize_fetch_keys(keys).into_iter().next().unwrap();
assert_eq!(chain_id, target_chain_id);
let proofs = provider.fetch_proofs_from_keys(fetched_keys).await.unwrap();
let duration = start_fetch.elapsed();
println!("Time taken (Total Proofs Fetch): {:?}", duration);
assert_eq!(proofs.headers.len(), 6);
Expand All @@ -490,7 +541,9 @@ mod tests {
FetchKeyEnvelope::Tx(TxMemorizerKey::new(target_chain_id, 1001, 1)),
FetchKeyEnvelope::Tx(TxMemorizerKey::new(target_chain_id, 1000, 2)),
];
let proofs = provider.fetch_proofs_from_keys(keys).await.unwrap();
let (chain_id, fetched_keys) = categorize_fetch_keys(keys).into_iter().next().unwrap();
assert_eq!(chain_id, target_chain_id);
let proofs = provider.fetch_proofs_from_keys(fetched_keys).await.unwrap();
assert_eq!(proofs.headers.len(), 2);
assert_eq!(proofs.transactions.len(), 3);
}
Expand Down

0 comments on commit a94ed6f

Please sign in to comment.