Skip to content

Commit

Permalink
refactor: gateway compiler handle declare tx (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArniStarkware authored Jul 23, 2024
1 parent 5b08d60 commit fc5813a
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 65 deletions.
100 changes: 55 additions & 45 deletions crates/gateway/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl
use cairo_lang_starknet_classes::casm_contract_class::{
CasmContractClass, CasmContractEntryPoints,
};
use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass;
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::RPCDeclareTransaction;
use starknet_sierra_compile::compile::compile_sierra_to_casm;
Expand All @@ -19,6 +20,7 @@ use crate::utils::is_subsequence;
#[path = "compilation_test.rs"]
mod compilation_test;

// TODO(Arni): Pass the compiler with dependancy injection.
#[derive(Clone)]
pub struct GatewayCompiler {
#[allow(dead_code)]
Expand All @@ -29,64 +31,56 @@ impl GatewayCompiler {
/// Formats the contract class for compilation, compiles it, and returns the compiled contract
/// class wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(
pub fn process_declare_tx(
&self,
declare_tx: &RPCDeclareTransaction,
) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);
let rpc_contract_class = &tx.contract_class;
let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class);

// Compile Sierra to Casm.
let casm_contract_class = self.compile(cairo_lang_contract_class)?;

validate_compiled_class_hash(&casm_contract_class, &tx.compiled_class_hash)?;
validate_casm_class(&casm_contract_class)?;

Ok(ClassInfo::new(
&ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?),
rpc_contract_class.sierra_program.len(),
rpc_contract_class.abi.len(),
)?)
}

// TODO(Arni): Pass the compilation args from the config.
fn compile(
&self,
cairo_lang_contract_class: CairoLangContractClass,
) -> Result<CasmContractClass, GatewayError> {
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
}
};
self.validate_casm_class(&casm_contract_class)?;
let casm_contract_class =
catch_unwind_result.map_err(|_| CompilationUtilError::CompilationPanic)??;

let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
Ok(casm_contract_class)
}
}

// TODO(Arni): Add test.
fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator =
external.iter().chain(l1_handler.iter()).chain(constructor.iter());
// TODO(Arni): Add test.
fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter());

for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, supported_builtins()) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: supported_builtins().to_vec(),
});
}
for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, supported_builtins()) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: supported_builtins().to_vec(),
});
}
Ok(())
}
Ok(())
}

// TODO(Arni): Add to a config.
Expand All @@ -101,3 +95,19 @@ fn supported_builtins() -> &'static Vec<String> {
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
})
}

/// Validates that the compiled class hash of the compiled contract class matches the supplied
/// compiled class hash.
fn validate_compiled_class_hash(
casm_contract_class: &CasmContractClass,
supplied_compiled_class_hash: &CompiledClassHash,
) -> Result<(), GatewayError> {
let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash());
if compiled_class_hash != *supplied_compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: *supplied_compiled_class_hash,
hash_result: compiled_class_hash,
});
}
Ok(())
}
20 changes: 10 additions & 10 deletions crates/gateway/src/compilation_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ fn gateway_compiler() -> GatewayCompiler {
GatewayCompiler { config: Default::default() }
}

// TODO(Arni): Redesign this test once the compiler is passed with dependancy injection.
#[rstest]
fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) {
fn test_compile_contract_class_compiled_class_hash_mismatch(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
let expected_hash_result = tx.compiled_class_hash;
let supplied_hash = CompiledClassHash::default();

tx.compiled_class_hash = supplied_hash;
let expected_hash = tx.compiled_class_hash;
let wrong_supplied_hash = CompiledClassHash::default();
tx.compiled_class_hash = wrong_supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = gateway_compiler.compile_contract_class(&declare_tx);
let result = gateway_compiler.process_declare_tx(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
if supplied == supplied_hash && hash_result == expected_hash_result
if supplied == wrong_supplied_hash && hash_result == expected_hash
);
}

Expand All @@ -45,7 +45,7 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = gateway_compiler.compile_contract_class(&declare_tx);
let result = gateway_compiler.process_declare_tx(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
Expand All @@ -55,15 +55,15 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
}

#[rstest]
fn test_compile_contract_class(gateway_compiler: GatewayCompiler) {
fn test_process_declare_tx_success(gateway_compiler: GatewayCompiler) {
let declare_tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(declare_tx) => declare_tx
);
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap();
let class_info = gateway_compiler.process_declare_tx(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn process_tx(
// Compile Sierra to Casm.
let optional_class_info = match &tx {
RPCTransaction::Declare(declare_tx) => {
Some(gateway_compiler.compile_contract_class(declare_tx)?)
Some(gateway_compiler.process_declare_tx(declare_tx)?)
}
_ => None,
};
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fn test_stateful_tx_validator(
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => Some(
GatewayCompiler { config: GatewayCompilerConfig {} }
.compile_contract_class(declare_tx)
.process_declare_tx(declare_tx)
.unwrap(),
),
_ => None,
Expand Down
16 changes: 13 additions & 3 deletions crates/mempool_test_utils/src/starknet_api_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,21 @@ pub fn executable_resource_bounds_mapping() -> ResourceBoundsMapping {
)
}

pub fn declare_tx() -> RPCTransaction {
/// Get the contract class used for testing.
pub fn contract_class() -> ContractClass {
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
let json_file_path = Path::new(CONTRACT_CLASS_FILE);
let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap();
let compiled_class_hash = CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS));
serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap()
}

/// Get the compiled class hash corresponding to the contract class used for testing.
pub fn compiled_class_hash() -> CompiledClassHash {
CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS))
}

pub fn declare_tx() -> RPCTransaction {
let contract_class = contract_class();
let compiled_class_hash = compiled_class_hash();

let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
let account_address = account_contract.get_instance_address(0);
Expand Down
10 changes: 5 additions & 5 deletions crates/starknet_sierra_compile/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@ use cairo_lang_starknet_classes::contract_class::{
};
use cairo_lang_utils::bigint::BigUintAsHex;
use starknet_api::rpc_transaction::{
ContractClass as StarknetApiContractClass, EntryPointByType as StarknetApiEntryPointByType,
ContractClass as RpcContractClass, EntryPointByType as StarknetApiEntryPointByType,
};
use starknet_api::state::EntryPoint as StarknetApiEntryPoint;
use starknet_types_core::felt::Felt;

/// Retruns a [`CairoLangContractClass`] struct ready for Sierra to Casm compilation. Note the `abi`
/// field is None as it is not relevant for the compilation.
pub fn into_contract_class_for_compilation(
starknet_api_contract_class: &StarknetApiContractClass,
rpc_contract_class: &RpcContractClass,
) -> CairoLangContractClass {
let sierra_program =
starknet_api_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
rpc_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
let entry_points_by_type =
into_cairo_lang_contract_entry_points(&starknet_api_contract_class.entry_points_by_type);
into_cairo_lang_contract_entry_points(&rpc_contract_class.entry_points_by_type);

CairoLangContractClass {
sierra_program,
sierra_program_debug_info: None,
contract_class_version: starknet_api_contract_class.contract_class_version.clone(),
contract_class_version: rpc_contract_class.contract_class_version.clone(),
entry_points_by_type,
abi: None,
}
Expand Down

0 comments on commit fc5813a

Please sign in to comment.