Skip to content

Commit

Permalink
Allow configuring max number of steps to be executed for starknet_call
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Jan 23, 2024
1 parent a99d71a commit 5b76ca0
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 20 deletions.
4 changes: 4 additions & 0 deletions cmd/juno/juno.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ const (
cnL2ChainIDF = "cn-l2-chain-id"
cnCoreContractAddressF = "cn-core-contract-address"
cnUnverifiableRangeF = "cn-unverifiable-range"
callMaxStepsF = "rpc-call-max-steps"

defaultConfig = ""
defaulHost = "localhost"
Expand Down Expand Up @@ -105,6 +106,7 @@ const (
defaultCNL1ChainID = ""
defaultCNL2ChainID = ""
defaultCNCoreContractAddressStr = ""
defaultCallMaxSteps = 4_000_000

configFlagUsage = "The yaml configuration file."
logLevelFlagUsage = "Options: debug, info, warn, error."
Expand Down Expand Up @@ -146,6 +148,7 @@ const (
dbCacheSizeUsage = "Determines the amount of memory (in megabytes) allocated for caching data in the database."
dbMaxHandlesUsage = "A soft limit on the number of open files that can be used by the DB"
gwAPIKeyUsage = "API key for gateway endpoints to avoid throttling" //nolint: gosec
callMaxStepsUsage = "Maximum number of steps to be executed in starknet_call requests"
)

var Version string
Expand Down Expand Up @@ -320,6 +323,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr
junoCmd.Flags().Int(dbMaxHandlesF, defaultMaxHandles, dbMaxHandlesUsage)
junoCmd.MarkFlagsRequiredTogether(cnNameF, cnFeederURLF, cnGatewayURLF, cnL1ChainIDF, cnL2ChainIDF, cnCoreContractAddressF, cnUnverifiableRangeF) //nolint:lll
junoCmd.MarkFlagsMutuallyExclusive(networkF, cnNameF)
junoCmd.Flags().Uint(callMaxStepsF, defaultCallMaxSteps, callMaxStepsUsage)

return junoCmd
}
16 changes: 16 additions & 0 deletions cmd/juno/juno_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func TestConfigPrecedence(t *testing.T) {
defaultRPCMaxBlockScan := uint(math.MaxUint)
defaultMaxCacheSize := uint(8)
defaultMaxHandles := 1024
defaultCallMaxSteps := uint(4_000_000)

tests := map[string]struct {
cfgFile bool
Expand Down Expand Up @@ -106,6 +107,7 @@ func TestConfigPrecedence(t *testing.T) {
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"custom network config file": {
Expand Down Expand Up @@ -149,6 +151,7 @@ cn-unverifiable-range: [0,10]
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"default config with no flags": {
Expand Down Expand Up @@ -179,6 +182,7 @@ cn-unverifiable-range: [0,10]
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"config file path is empty string": {
Expand Down Expand Up @@ -209,6 +213,7 @@ cn-unverifiable-range: [0,10]
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"config file doesn't exist": {
Expand Down Expand Up @@ -244,6 +249,7 @@ cn-unverifiable-range: [0,10]
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"config file with all settings but without any other flags": {
Expand Down Expand Up @@ -281,6 +287,7 @@ pprof: true
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"config file with some settings but without any other flags": {
Expand Down Expand Up @@ -315,6 +322,7 @@ http-port: 4576
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"all flags without config file": {
Expand Down Expand Up @@ -347,6 +355,7 @@ http-port: 4576
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"some flags without config file": {
Expand Down Expand Up @@ -380,6 +389,7 @@ http-port: 4576
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"all setting set in both config file and flags": {
Expand Down Expand Up @@ -437,6 +447,7 @@ db-cache-size: 8
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: 9,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"some setting set in both config file and flags": {
Expand Down Expand Up @@ -473,6 +484,7 @@ network: goerli
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"some setting set in default, config file and flags": {
Expand Down Expand Up @@ -505,6 +517,7 @@ network: goerli
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"only set env variables": {
Expand Down Expand Up @@ -535,6 +548,7 @@ network: goerli
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"some setting set in both env variables and flags": {
Expand Down Expand Up @@ -566,6 +580,7 @@ network: goerli
RPCMaxBlockScan: defaultRPCMaxBlockScan,
DBCacheSize: defaultMaxCacheSize,
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
"some setting set in both env variables and config file": {
Expand Down Expand Up @@ -598,6 +613,7 @@ network: goerli
DBCacheSize: defaultMaxCacheSize,
GatewayAPIKey: "apikey",
DBMaxHandles: defaultMaxHandles,
RPCCallMaxSteps: defaultCallMaxSteps,
},
},
}
Expand Down
8 changes: 4 additions & 4 deletions mocks/mock_vm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ type Config struct {
MaxVMs uint `mapstructure:"max-vms"`
MaxVMQueue uint `mapstructure:"max-vm-queue"`
RPCMaxBlockScan uint `mapstructure:"rpc-max-block-scan"`
RPCCallMaxSteps uint `mapstructure:"rpc-call-max-steps"`

DBCacheSize uint `mapstructure:"db-cache-size"`
DBMaxHandles int `mapstructure:"db-max-handles"`
Expand Down Expand Up @@ -145,7 +146,7 @@ func New(cfg *Config, version string) (*Node, error) { //nolint:gocyclo,funlen

throttledVM := NewThrottledVM(vm.New(log), cfg.MaxVMs, int32(cfg.MaxVMQueue))
rpcHandler := rpc.New(chain, synchronizer, throttledVM, version, log).WithGateway(gatewayClient).WithFeeder(client)
rpcHandler = rpcHandler.WithFilterLimit(cfg.RPCMaxBlockScan)
rpcHandler = rpcHandler.WithFilterLimit(cfg.RPCMaxBlockScan).WithCallMaxSteps(uint64(cfg.RPCCallMaxSteps))
services = append(services, rpcHandler)
// to improve RPC throughput we double GOMAXPROCS
maxGoroutines := 2 * runtime.GOMAXPROCS(0)
Expand Down
5 changes: 3 additions & 2 deletions node/throttled_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen int32) *Thrott
}

func (tvm *ThrottledVM) Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.Felt, blockNumber,
blockTimestamp uint64, state core.StateReader, network *utils.Network,
blockTimestamp uint64, state core.StateReader, network *utils.Network, maxSteps uint64,
) ([]*felt.Felt, error) {
var ret []*felt.Felt
throttler := (*utils.Throttler[vm.VM])(tvm)
return ret, throttler.Do(func(vm *vm.VM) error {
var err error
ret, err = (*vm).Call(contractAddr, classHash, selector, calldata, blockNumber, blockTimestamp, state, network)
ret, err = (*vm).Call(contractAddr, classHash, selector, calldata, blockNumber, blockTimestamp,
state, network, maxSteps)
return err
})
}
Expand Down
10 changes: 8 additions & 2 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ type Handler struct {

blockTraceCache *lru.Cache[traceCacheKey, []TracedBlockTransaction]

filterLimit uint
filterLimit uint
callMaxSteps uint64
}

type subscription struct {
Expand Down Expand Up @@ -132,6 +133,11 @@ func (h *Handler) WithFilterLimit(limit uint) *Handler {
return h
}

func (h *Handler) WithCallMaxSteps(maxSteps uint64) *Handler {
h.callMaxSteps = maxSteps
return h
}

func (h *Handler) WithIDGen(idgen func() uint64) *Handler {
h.idgen = idgen
return h
Expand Down Expand Up @@ -1253,7 +1259,7 @@ func (h *Handler) Call(call FunctionCall, id BlockID) ([]*felt.Felt, *jsonrpc.Er
}

res, err := h.vm.Call(&call.ContractAddress, classHash, &call.EntryPointSelector,
call.Calldata, header.Number, header.Timestamp, state, h.bcReader.Network())
call.Calldata, header.Number, header.Timestamp, state, h.bcReader.Network(), h.callMaxSteps)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrInternal.CloneWithData(err.Error())
Expand Down
37 changes: 36 additions & 1 deletion rpc/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2973,7 +2973,8 @@ func TestCall(t *testing.T) {
t.Cleanup(mockCtrl.Finish)

mockReader := mocks.NewMockReader(mockCtrl)
handler := rpc.New(mockReader, nil, nil, "", utils.NewNopZapLogger())
mockVM := mocks.NewMockVM(mockCtrl)
handler := rpc.New(mockReader, nil, mockVM, "", utils.NewNopZapLogger())

t.Run("empty blockchain", func(t *testing.T) {
mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound)
Expand Down Expand Up @@ -3010,6 +3011,40 @@ func TestCall(t *testing.T) {
require.Nil(t, res)
assert.Equal(t, rpc.ErrContractNotFound, rpcErr)
})

t.Run("ok", func(t *testing.T) {
handler = handler.WithCallMaxSteps(1337)

contractAddr := new(felt.Felt).SetUint64(1)
selector := new(felt.Felt).SetUint64(2)
classHash := new(felt.Felt).SetUint64(3)
calldata := []felt.Felt{
*new(felt.Felt).SetUint64(4),
*new(felt.Felt).SetUint64(5),
}
expectedRes := []*felt.Felt{
new(felt.Felt).SetUint64(6),
new(felt.Felt).SetUint64(7),
}

mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil)
mockReader.EXPECT().HeadsHeader().Return(&core.Header{
Number: 100,
Timestamp: 101,
}, nil)
mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil)
mockReader.EXPECT().Network().Return(&utils.Mainnet)
mockVM.EXPECT().Call(contractAddr, classHash, selector, calldata, uint64(100),
uint64(101), gomock.Any(), &utils.Mainnet, uint64(1337)).Return(expectedRes, nil)

res, rpcErr := handler.Call(rpc.FunctionCall{
ContractAddress: *contractAddr,
EntryPointSelector: *selector,
Calldata: calldata,
}, rpc.BlockID{Latest: true})
require.Nil(t, rpcErr)
require.Equal(t, expectedRes, res)
})
}

func TestEstimateMessageFee(t *testing.T) {
Expand Down
10 changes: 7 additions & 3 deletions vm/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
};

use blockifier::{
abi::constants::{INITIAL_GAS_COST, N_STEPS_RESOURCE},
abi::constants::{INITIAL_GAS_COST, N_STEPS_RESOURCE, MAX_STEPS_PER_TX, MAX_VALIDATE_STEPS_PER_TX},
block_context::{BlockContext, GasPrices, FeeTokenAddresses},
execution::{
common_hints::ExecutionMode,
Expand Down Expand Up @@ -68,6 +68,7 @@ pub extern "C" fn cairoVMCall(
block_number: c_ulonglong,
block_timestamp: c_ulonglong,
chain_id: *const c_char,
max_steps: c_ulonglong,
) {
let reader = JunoStateReader::new(reader_handle, block_number);
let contract_addr_felt = ptr_to_felt(contract_address);
Expand Down Expand Up @@ -113,6 +114,7 @@ pub extern "C" fn cairoVMCall(
block_timestamp,
StarkFelt::default(),
GAS_PRICES,
Some(max_steps),
),
&AccountTransactionContext::Deprecated(DeprecatedAccountTransactionContext::default()),
ExecutionMode::Execute,
Expand Down Expand Up @@ -204,6 +206,7 @@ pub extern "C" fn cairoVMExecute(
eth_l1_gas_price: felt_to_u128(gas_price_wei_felt),
strk_l1_gas_price: felt_to_u128(gas_price_strk_felt),
},
None
);
let mut state = CachedState::new(reader, GlobalContractCache::default());
let charge_fee = skip_charge_fee == 0;
Expand Down Expand Up @@ -396,6 +399,7 @@ fn build_block_context(
block_timestamp: c_ulonglong,
sequencer_address: StarkFelt,
gas_prices: GasPrices,
max_steps: Option<c_ulonglong>,
) -> BlockContext {
BlockContext {
chain_id: ChainId(chain_id_str.into()),
Expand Down Expand Up @@ -432,8 +436,8 @@ fn build_block_context(
(KECCAK_BUILTIN_NAME.to_string(), N_STEPS_FEE_WEIGHT * 2048.0),
])
.into(),
invoke_tx_max_n_steps: 3_000_000,
validate_max_n_steps: 3_000_000,
invoke_tx_max_n_steps: max_steps.unwrap_or(MAX_STEPS_PER_TX as u64).try_into().unwrap(),
validate_max_n_steps: MAX_VALIDATE_STEPS_PER_TX as u32,
max_recursion_depth: 50,
}
}
7 changes: 4 additions & 3 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package vm
//#include <stddef.h>
// extern void cairoVMCall(char* contract_address, char* class_hash, char* entry_point_selector, char** calldata,
// size_t len_calldata, uintptr_t readerHandle, unsigned long long block_number,
// unsigned long long block_timestamp, char* chain_id);
// unsigned long long block_timestamp, char* chain_id, unsigned long long max_steps);
//
// extern void cairoVMExecute(char* txns_json, char* classes_json, uintptr_t readerHandle, unsigned long long block_number,
// unsigned long long block_timestamp, char* chain_id, char* sequencer_address, char* paid_fees_on_l1_json,
Expand All @@ -31,7 +31,7 @@ import (
//go:generate mockgen -destination=../mocks/mock_vm.go -package=mocks github.com/NethermindEth/juno/vm VM
type VM interface {
Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.Felt, blockNumber,
blockTimestamp uint64, state core.StateReader, network *utils.Network,
blockTimestamp uint64, state core.StateReader, network *utils.Network, maxSteps uint64,
) ([]*felt.Felt, error)
Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network *utils.Network, paidFeesOnL1 []*felt.Felt,
Expand Down Expand Up @@ -111,7 +111,7 @@ func makePtrFromFelt(val *felt.Felt) unsafe.Pointer {
}

func (v *vm) Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.Felt, blockNumber,
blockTimestamp uint64, state core.StateReader, network *utils.Network,
blockTimestamp uint64, state core.StateReader, network *utils.Network, maxSteps uint64,
) ([]*felt.Felt, error) {
context := &callContext{
state: state,
Expand Down Expand Up @@ -149,6 +149,7 @@ func (v *vm) Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.
C.ulonglong(blockNumber),
C.ulonglong(blockTimestamp),
chainID,
C.ulonglong(maxSteps),
)

for _, ptr := range calldataPtrs {
Expand Down
Loading

0 comments on commit 5b76ca0

Please sign in to comment.