From 83ff6e8d5b70052ad1dacd759ce94932315a641c Mon Sep 17 00:00:00 2001 From: Leonidas Maroulis Date: Wed, 27 Nov 2024 10:08:50 +0200 Subject: [PATCH] Update PVM host functions to 0.5.0 (#136) --- .../polkavm/host_call/accumulate_functions.go | 58 +++++--------- .../host_call/accumulate_functions_test.go | 80 +++++++------------ internal/polkavm/host_call/common.go | 6 +- .../polkavm/host_call/general_functions.go | 37 ++++----- .../host_call/general_functions_test.go | 8 +- internal/statetransition/accumulate.go | 22 ++--- internal/statetransition/on_transfer.go | 5 +- pkg/serialization/codec/jam/decode.go | 10 +-- pkg/serialization/codec/jam/encode.go | 10 +-- .../codec/jam/general_natural.go | 8 +- .../codec/jam/general_natural_test.go | 9 ++- .../codec/jam/trivial_natural.go | 4 +- .../codec/jam/trivial_natural_test.go | 16 ++-- 13 files changed, 116 insertions(+), 157 deletions(-) diff --git a/internal/polkavm/host_call/accumulate_functions.go b/internal/polkavm/host_call/accumulate_functions.go index 7f04d6f..57a6093 100644 --- a/internal/polkavm/host_call/accumulate_functions.go +++ b/internal/polkavm/host_call/accumulate_functions.go @@ -13,12 +13,12 @@ import ( "github.com/eigerco/strawberry/internal/state" ) -// Empower ΩE(ϱ, ω, μ, (x, y)) -func Empower(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Gas, Registers, Memory, AccumulateContextPair, error) { - if gas < EmpowerCost { +// Bless ΩB(ϱ, ω, μ, (x, y)) +func Bless(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Gas, Registers, Memory, AccumulateContextPair, error) { + if gas < BlessCost { return gas, regs, mem, ctxPair, ErrOutOfGas } - gas -= EmpowerCost + gas -= BlessCost // let [m, a, v, o, n] = ω7...12 managerServiceId, assignServiceId, designateServiceId, addr, servicesNr := regs[A0], regs[A1], regs[A2], regs[A3], regs[A4] @@ -108,9 +108,8 @@ func Checkpoint(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPa ctxPair.ExceptionalCtx = ctxPair.RegularCtx - // Split the new ϱ' value into its lower and upper parts. + // Set the new ϱ' value into ω′7 regs[A0] = uint32(gas & ((1 << 32) - 1)) - regs[A1] = uint32(gas >> 32) return gas, regs, mem, ctxPair, nil } @@ -122,19 +121,14 @@ func New(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Ga } gas -= NewCost - // let [o, l, gl, gh, ml, mh] = ω7..13 - addr, preimageLength, gl, gh, ml, mh := regs[A0], regs[A1], regs[A2], regs[A3], regs[A4], regs[A5] + // let [o, l, g, m] = ω7..11 + addr, preimageLength, gasLimitAccumulator, gasLimitTransfer := regs[A0], regs[A1], regs[A2], regs[A3] // c = μo⋅⋅⋅+32 if No⋅⋅⋅+32 ⊂ Vμ otherwise ∇ codeHashBytes := make([]byte, 32) if err := mem.Read(addr, codeHashBytes); err != nil { return gas, withCode(regs, OOB), mem, ctxPair, nil } - // let g = 2^32 ⋅ gh + gl - gasLimitAccumulator := uint64(gh)<<32 | uint64(gl) - - // let m = 2^32 ⋅ mh + ml - gasLimitTransfer := uint64(mh)<<32 | uint64(ml) codeHash := crypto.Hash(codeHashBytes) @@ -145,8 +139,8 @@ func New(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Ga {Hash: codeHash, Length: service.PreimageLength(preimageLength)}: {}, }, CodeHash: codeHash, - GasLimitForAccumulator: gasLimitAccumulator, - GasLimitOnTransfer: gasLimitTransfer, + GasLimitForAccumulator: uint64(gasLimitAccumulator), + GasLimitOnTransfer: uint64(gasLimitTransfer), } account.Balance = account.ThresholdBalance() @@ -177,8 +171,8 @@ func Upgrade(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) return gas, regs, mem, ctxPair, ErrOutOfGas } gas -= UpgradeCost - // let [o, gh, gl, mh, ml] = ω7...12 - addr, gl, gh, ml, mh := regs[A0], regs[A1], regs[A2], regs[A3], regs[A4] + // let [o, g, m] = ω7...10 + addr, gasLimitAccumulator, gasLimitTransfer := regs[A0], regs[A1], regs[A2] // c = μo⋅⋅⋅+32 if No⋅⋅⋅+32 ⊂ Vμ otherwise ∇ codeHash := make([]byte, 32) @@ -186,28 +180,19 @@ func Upgrade(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) return gas, withCode(regs, OOB), mem, ctxPair, nil } - // let g = 2^32 ⋅ gh + gl - gasLimitAccumulator := uint64(gh)<<32 | uint64(gl) - - // let m = 2^32 ⋅ mh + ml - gasLimitTransfer := uint64(mh)<<32 | uint64(ml) - // (ω′7, (X′s)c, (X′s)g , (X′s)m) = (OK, c, g, m) if c ≠ ∇ currentService := ctxPair.RegularCtx.ServiceAccount() currentService.CodeHash = crypto.Hash(codeHash) - currentService.GasLimitForAccumulator = gasLimitAccumulator - currentService.GasLimitOnTransfer = gasLimitTransfer + currentService.GasLimitForAccumulator = uint64(gasLimitAccumulator) + currentService.GasLimitOnTransfer = uint64(gasLimitTransfer) ctxPair.RegularCtx.ServiceState[ctxPair.RegularCtx.ServiceId] = currentService return gas, withCode(regs, OK), mem, ctxPair, nil } // Transfer ΩT(ϱ, ω, μ, (x, y)) func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (Gas, Registers, Memory, AccumulateContextPair, error) { - // let (d, al, ah, gl, gh, o) = ω7..13 - receiverId, al, ah, gl, gh, o := regs[A0], regs[A1], regs[A2], regs[A3], regs[A4], regs[A5] - - // let a = 2^32 ⋅ ah + al - newBalance := uint64(ah)<<32 | uint64(al) + // let (d, a, g, o) = ω7..11 + receiverId, newBalance, gasLimit, o := regs[A0], regs[A1], regs[A2], regs[A3] transferCost := TransferBaseCost + Gas(newBalance) if gas < transferCost { @@ -215,9 +200,6 @@ func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair } gas -= transferCost - // let g = 2^32 ⋅ gh + gl - gasLimit := uint64(gh)<<32 | uint64(gl) - // m = μo⋅⋅⋅+M if No⋅⋅⋅+M ⊂ Vμ otherwise ∇ m := make([]byte, service.TransferMemoSizeBytes) if err := mem.Read(o, m); err != nil { @@ -228,9 +210,9 @@ func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair deferredTransfer := service.DeferredTransfer{ SenderServiceIndex: ctxPair.RegularCtx.ServiceId, ReceiverServiceIndex: block.ServiceId(receiverId), - Balance: newBalance, + Balance: uint64(newBalance), Memo: service.Memo(m), - GasLimit: gasLimit, + GasLimit: uint64(gasLimit), } // let d = xd ∪ (xu)d @@ -244,7 +226,7 @@ func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair } // if g < (δ ∪ xn)[d]m - if gasLimit < receiverService.GasLimitOnTransfer { + if uint64(gasLimit) < receiverService.GasLimitOnTransfer { return gas, withCode(regs, LOW), mem, ctxPair, nil } @@ -255,7 +237,7 @@ func Transfer(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair // let b = (xs)b − a // if b < (xs)t - if ctxPair.RegularCtx.ServiceAccount().Balance-newBalance < ctxPair.RegularCtx.ServiceAccount().ThresholdBalance() { + if ctxPair.RegularCtx.ServiceAccount().Balance-uint64(newBalance) < ctxPair.RegularCtx.ServiceAccount().ThresholdBalance() { return gas, withCode(regs, CASH), mem, ctxPair, nil } @@ -286,7 +268,7 @@ func Quit(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair) (G } // if d ∈ {s, 2^32 − 1} - if block.ServiceId(receiverId) == ctxPair.RegularCtx.ServiceId || receiverId == math.MaxUint32 { + if block.ServiceId(receiverId) == ctxPair.RegularCtx.ServiceId || uint64(receiverId) == math.MaxUint64 { delete(ctxPair.RegularCtx.AccumulationState.ServiceState, ctxPair.RegularCtx.ServiceId) return gas, withCode(regs, OK), mem, ctxPair, ErrHalt } diff --git a/internal/polkavm/host_call/accumulate_functions_test.go b/internal/polkavm/host_call/accumulate_functions_test.go index e0933bd..b91021a 100644 --- a/internal/polkavm/host_call/accumulate_functions_test.go +++ b/internal/polkavm/host_call/accumulate_functions_test.go @@ -1,11 +1,9 @@ package host_call import ( - "maps" "math" "slices" "testing" - "unsafe" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -75,15 +73,15 @@ func TestAccumulate(t *testing.T) { }{ { name: "empower", - fn: fnStd(Empower), + fn: fnStd(Bless), alloc: alloc{ A3: slices.Concat( - encodeNumber(uint32(123)), - encodeNumber(uint64(12341234)), - encodeNumber(uint32(234)), - encodeNumber(uint64(23452345)), - encodeNumber(uint32(345)), - encodeNumber(uint64(34563456)), + encodeNumber(t, uint32(123)), + encodeNumber(t, uint64(12341234)), + encodeNumber(t, uint32(234)), + encodeNumber(t, uint64(23452345)), + encodeNumber(t, uint32(345)), + encodeNumber(t, uint64(34563456)), ), }, initialRegs: deltaRegs{ @@ -159,17 +157,14 @@ func TestAccumulate(t *testing.T) { expectedDeltaRegs: checkUint64(t, 89), expectedX: checkpointCtx, expectedY: checkpointCtx, - }, { + }, + { name: "new", fn: fnStd(New), alloc: alloc{ A0: hash2bytes(randomHash), }, - initialRegs: merge( - deltaRegs{A1: 123123}, - storeUint64(123124123, A2, A3), - storeUint64(756846353, A4, A5), - ), + initialRegs: deltaRegs{A1: 123123, A2: 123124123, A3: 756846353}, expectedDeltaRegs: deltaRegs{ A0: uint32(currentServiceID), }, @@ -210,16 +205,14 @@ func TestAccumulate(t *testing.T) { }, }, }, - }, { + }, + { name: "upgrade", fn: fnStd(Upgrade), alloc: alloc{ A0: hash2bytes(randomHash), }, - initialRegs: merge( - storeUint64(345345345345, A1, A2), - storeUint64(456456456456, A3, A4), - ), + initialRegs: deltaRegs{A1: 3453453453, A2: 456456456}, expectedDeltaRegs: deltaRegs{ A0: uint32(OK), }, @@ -233,33 +226,31 @@ func TestAccumulate(t *testing.T) { ServiceId: currentServiceID, ServiceState: service.ServiceState{currentServiceID: { CodeHash: randomHash, - GasLimitForAccumulator: 345345345345, - GasLimitOnTransfer: 456456456456, + GasLimitForAccumulator: 3453453453, + GasLimitOnTransfer: 456456456, }}, }, }, { name: "transfer", fn: fnStd(Transfer), alloc: alloc{ - A5: fixedSizeBytes(service.TransferMemoSizeBytes, []byte("memo message")), + A3: fixedSizeBytes(service.TransferMemoSizeBytes, []byte("memo message")), + }, + initialRegs: deltaRegs{ + A0: 1234, // d: receiver + A1: 1000000000, // a + A2: 80, // g }, - initialRegs: merge( - deltaRegs{ - A0: 1234, // d: receiver - }, - storeUint64(100000000000, A1, A2), // a - storeUint64(80, A3, A4), // g - ), expectedDeltaRegs: deltaRegs{ A0: uint32(OK), }, - initialGas: 100000000100, + initialGas: 1000000100, expectedGas: 88, X: AccumulateContext{ ServiceId: block.ServiceId(123123123), ServiceState: service.ServiceState{ block.ServiceId(123123123): { - Balance: 100000000100, + Balance: 1000000100, }, }, AccumulationState: state.AccumulationState{ @@ -281,13 +272,13 @@ func TestAccumulate(t *testing.T) { ServiceId: block.ServiceId(123123123), ServiceState: service.ServiceState{ block.ServiceId(123123123): { - Balance: 100000000100, + Balance: 1000000100, }, }, DeferredTransfers: []service.DeferredTransfer{{ SenderServiceIndex: block.ServiceId(123123123), ReceiverServiceIndex: 1234, - Balance: 100000000000, + Balance: 1000000000, Memo: service.Memo(fixedSizeBytes(service.TransferMemoSizeBytes, []byte("memo message"))), GasLimit: 80, }}, @@ -724,21 +715,6 @@ func checkUint64(t *testing.T, gas uint64) deltaRegs { } } -func storeUint64(i uint64, reg1, reg2 Reg) deltaRegs { - return deltaRegs{ - reg1: uint32(math.Mod(float64(i), 1<<32)), - reg2: uint32(math.Floor(float64(i) / (1 << 32))), - } -} - -func merge[M ~map[K]V, K comparable, V any](dd ...M) M { - result := make(M) - for _, d := range dd { - maps.Copy(result, d) - } - return result -} - func fnStd(fn func(Gas, Registers, Memory, AccumulateContextPair) (Gas, Registers, Memory, AccumulateContextPair, error)) hostCall { return func(gas Gas, regs Registers, mem Memory, ctxPair AccumulateContextPair, timeslot jamtime.Timeslot) (Gas, Registers, Memory, AccumulateContextPair, error) { return fn(gas, regs, mem, ctxPair) @@ -775,6 +751,8 @@ func transform[S, S2 any](slice1 []S, fn func(S) S2) (slice []S2) { return slice } -func encodeNumber[T ~uint8 | ~uint16 | ~uint32 | ~uint64](v T) []byte { - return jam.SerializeTrivialNatural(v, uint8(unsafe.Sizeof(v))) +func encodeNumber[T ~uint8 | ~uint16 | ~uint32 | ~uint64](t *testing.T, v T) []byte { + res, err := jam.Marshal(v) + require.NoError(t, err) + return res } diff --git a/internal/polkavm/host_call/common.go b/internal/polkavm/host_call/common.go index 8f09b3d..545552d 100644 --- a/internal/polkavm/host_call/common.go +++ b/internal/polkavm/host_call/common.go @@ -13,7 +13,7 @@ const ( ReadCost WriteCost InfoCost - EmpowerCost + BlessCost AssignCost DesignateCost CheckpointCost @@ -31,7 +31,7 @@ const ( ReadID = 2 WriteID = 3 InfoID = 4 - EmpowerID = 5 + BlessID = 5 AssignID = 6 DesignateID = 7 CheckpointID = 8 @@ -93,7 +93,7 @@ func readNumber[U interface{ ~uint32 | ~uint64 }](mem Memory, addr uint32, lengt return } - jam.DeserializeTrivialNatural(b, &u) + err = jam.Unmarshal(b, &u) return } diff --git a/internal/polkavm/host_call/general_functions.go b/internal/polkavm/host_call/general_functions.go index aa42a8b..8b711d0 100644 --- a/internal/polkavm/host_call/general_functions.go +++ b/internal/polkavm/host_call/general_functions.go @@ -1,15 +1,13 @@ package host_call import ( - "github.com/eigerco/strawberry/pkg/serialization/codec/jam" "math" - "golang.org/x/crypto/blake2b" - "github.com/eigerco/strawberry/internal/block" "github.com/eigerco/strawberry/internal/crypto" "github.com/eigerco/strawberry/internal/polkavm" "github.com/eigerco/strawberry/internal/service" + "github.com/eigerco/strawberry/pkg/serialization/codec/jam" ) type AccountInfo struct { @@ -29,9 +27,8 @@ func GasRemaining(gas polkavm.Gas, regs polkavm.Registers) (polkavm.Gas, polkavm } gas -= GasRemainingCost - // Split the new ϱ' value into its lower and upper parts. - regs[polkavm.A0] = uint32(gas & ((1 << 32) - 1)) - regs[polkavm.A1] = uint32(gas >> 32) + // Set the new ϱ' value into ω′7 + regs[polkavm.A0] = uint32(gas) return gas, regs, nil } @@ -43,11 +40,11 @@ func Lookup(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s servi } gas -= LookupCost - sID := regs[polkavm.A0] + omega7 := regs[polkavm.A0] // Determine the lookup key 'a' a := s - if sID != math.MaxUint32 && sID != uint32(serviceId) { + if uint64(omega7) != math.MaxUint64 && omega7 != uint32(serviceId) { var exists bool // Lookup service account by serviceId in the serviceState a, exists = serviceState[serviceId] @@ -68,7 +65,7 @@ func Lookup(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s servi } // Compute the hash H(µho..ho+32) - hash := blake2b.Sum256(memorySlice) + hash := crypto.HashData(memorySlice) // Lookup value in storage (v) using the hash v, exists := a.Storage[hash] @@ -102,16 +99,16 @@ func Read(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s service } gas -= ReadCost - sID := regs[polkavm.A0] + omega7 := regs[polkavm.A0] ko := regs[polkavm.A1] kz := regs[polkavm.A2] bo := regs[polkavm.A3] bz := regs[polkavm.A4] a := s - if sID != math.MaxUint32 && sID != uint32(serviceId) { + if uint64(omega7) != math.MaxUint64 && omega7 != uint32(serviceId) { var exists bool - a, exists = serviceState[block.ServiceId(sID)] + a, exists = serviceState[block.ServiceId(omega7)] if !exists { return gas, regs, mem, polkavm.ErrAccountNotFound } @@ -125,7 +122,7 @@ func Read(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s service return gas, regs, mem, nil } - serviceIdBytes, err := jam.Marshal(sID) + serviceIdBytes, err := jam.Marshal(omega7) if err != nil { return gas, regs, mem, polkavm.ErrPanicf(err.Error()) } @@ -136,7 +133,7 @@ func Read(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s service hashInput = append(hashInput, keyData...) // Compute the hash H(E4(s) + keyData) - k := blake2b.Sum256(hashInput) + k := crypto.HashData(hashInput) v, exists := a.Storage[k] if !exists { @@ -183,7 +180,7 @@ func Write(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, s servic return gas, regs, mem, s, err } hashInput := append(serviceIdBytes, keyData...) - k := blake2b.Sum256(hashInput) + k := crypto.HashData(hashInput) a := s if vz == 0 { @@ -222,12 +219,12 @@ func Info(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, serviceId } gas -= InfoCost - sID := regs[polkavm.A0] - omega1 := regs[polkavm.A1] + omega7 := regs[polkavm.A0] + omega8 := regs[polkavm.A1] t, exists := serviceState[serviceId] - if sID != math.MaxUint32 { - t, exists = serviceState[block.ServiceId(sID)] + if uint64(omega7) != math.MaxUint64 { + t, exists = serviceState[block.ServiceId(omega7)] } if !exists { return gas, withCode(regs, NONE), mem, nil @@ -249,7 +246,7 @@ func Info(gas polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, serviceId return gas, regs, mem, polkavm.ErrPanicf(err.Error()) } - if err := mem.Write(omega1, m); err != nil { + if err := mem.Write(omega8, m); err != nil { regs[polkavm.A0] = uint32(OOB) return gas, regs, mem, nil } diff --git a/internal/polkavm/host_call/general_functions_test.go b/internal/polkavm/host_call/general_functions_test.go index d0044f0..bd0c3ca 100644 --- a/internal/polkavm/host_call/general_functions_test.go +++ b/internal/polkavm/host_call/general_functions_test.go @@ -5,7 +5,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/crypto/blake2b" "github.com/eigerco/strawberry/internal/block" "github.com/eigerco/strawberry/internal/crypto" @@ -31,7 +30,6 @@ func TestGasRemaining(t *testing.T) { initialRegs := polkavm.Registers{ polkavm.RA: polkavm.VmAddressReturnToHost, - polkavm.SP: memoryMap.StackAddressHigh, } initialGas := uint64(100) hostCall := func(hostCall uint32, gasCounter polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, x struct{}) (polkavm.Gas, polkavm.Registers, polkavm.Memory, struct{}, error) { @@ -92,7 +90,7 @@ func TestLookup(t *testing.T) { bo := memoryMap.RWDataAddress + 100 dataToHash := make([]byte, 32) copy(dataToHash, "hash") - hash := blake2b.Sum256(dataToHash) + hash := crypto.HashData(dataToHash) err := mem.Write(ho, dataToHash) require.NoError(t, err) @@ -153,7 +151,7 @@ func TestRead(t *testing.T) { hashInput := make([]byte, 0, len(serviceIdBytes)+len(keyData)) hashInput = append(hashInput, serviceIdBytes...) hashInput = append(hashInput, keyData...) - k := blake2b.Sum256(hashInput) + k := crypto.HashData(hashInput) sa := service.ServiceAccount{ Storage: map[crypto.Hash][]byte{ @@ -221,7 +219,7 @@ func TestWrite(t *testing.T) { require.NoError(t, err) hashInput := append(serviceIdBytes, keyData...) - k := blake2b.Sum256(hashInput) + k := crypto.HashData(hashInput) sa := service.ServiceAccount{ Balance: 200, diff --git a/internal/statetransition/accumulate.go b/internal/statetransition/accumulate.go index 5df6613..97b4464 100644 --- a/internal/statetransition/accumulate.go +++ b/internal/statetransition/accumulate.go @@ -32,9 +32,9 @@ type Accumulator struct { state *state.State } -// InvokePVM ΨA(U, N_S , N_G, ⟦O⟧) → (U, ⟦T⟧, H?, N_G) Equation 280 +// InvokePVM ΨA(U, N_S , N_G, ⟦O⟧) → (U, ⟦T⟧, H?, N_G) Equation (B.8) func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex block.ServiceId, gas uint64, accOperand []state.AccumulationOperand) (state.AccumulationState, []service.DeferredTransfer, *crypto.Hash, uint64) { - // if d[s]c = ∅ + // if ud[s]c = ∅ if accState.ServiceState[serviceIndex].Code() == nil { ctx, err := a.newCtx(accState, serviceIndex) if err != nil { @@ -62,7 +62,7 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b return ctx.AccumulationState, []service.DeferredTransfer{}, nil, 0 } - // F (equation 283) + // F (equation B.10) hostCallFunc := func(hostCall uint32, gasCounter polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, ctx polkavm.AccumulateContextPair) (polkavm.Gas, polkavm.Registers, polkavm.Memory, polkavm.AccumulateContextPair, error) { // s currentService := accState.ServiceState[serviceIndex] @@ -70,7 +70,6 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b switch hostCall { case host_call.GasID: gasCounter, regs, err = host_call.GasRemaining(gasCounter, regs) - ctx.RegularCtx.AccumulationState.ServiceState[ctx.RegularCtx.ServiceId] = currentService case host_call.LookupID: gasCounter, regs, mem, err = host_call.Lookup(gasCounter, regs, mem, currentService, serviceIndex, ctx.RegularCtx.AccumulationState.ServiceState) ctx.RegularCtx.AccumulationState.ServiceState[ctx.RegularCtx.ServiceId] = currentService @@ -82,9 +81,8 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b ctx.RegularCtx.AccumulationState.ServiceState[ctx.RegularCtx.ServiceId] = currentService case host_call.InfoID: gasCounter, regs, mem, err = host_call.Info(gasCounter, regs, mem, serviceIndex, ctx.RegularCtx.AccumulationState.ServiceState) - ctx.RegularCtx.AccumulationState.ServiceState[ctx.RegularCtx.ServiceId] = currentService - case host_call.EmpowerID: - gasCounter, regs, mem, ctx, err = host_call.Empower(gasCounter, regs, mem, ctx) + case host_call.BlessID: + gasCounter, regs, mem, ctx, err = host_call.Bless(gasCounter, regs, mem, ctx) case host_call.AssignID: gasCounter, regs, mem, ctx, err = host_call.Assign(gasCounter, regs, mem, ctx) case host_call.DesignateID: @@ -110,7 +108,7 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b return gasCounter, regs, mem, ctx, err } - remainingGas, ret, newCtxPair, err := interpreter.InvokeWholeProgram(accState.ServiceState[serviceIndex].Code(), 10, gas, args, hostCallFunc, newCtxPair) + remainingGas, ret, newCtxPair, err := interpreter.InvokeWholeProgram(accState.ServiceState[serviceIndex].Code(), 5, gas, args, hostCallFunc, newCtxPair) if err != nil { errPanic := &polkavm.ErrPanic{} if errors.Is(err, polkavm.ErrOutOfGas) || errors.As(err, &errPanic) { @@ -128,7 +126,7 @@ func (a *Accumulator) InvokePVM(accState state.AccumulationState, serviceIndex b return newCtxPair.RegularCtx.AccumulationState, newCtxPair.RegularCtx.DeferredTransfers, nil, uint64(remainingGas) } -// newCtx (281) +// newCtx (B.9) func (a *Accumulator) newCtx(u state.AccumulationState, serviceIndex block.ServiceId) (polkavm.AccumulateContext, error) { serviceState := maps.Clone(u.ServiceState) delete(serviceState, serviceIndex) @@ -176,6 +174,10 @@ func (a *Accumulator) newServiceID(serviceIndex block.ServiceId) (block.ServiceI hashData := crypto.HashData(hashBytes) newId := block.ServiceId(0) - jam.DeserializeTrivialNatural(hashData[:], &newId) + err = jam.Unmarshal(hashData[:], &newId) + if err != nil { + return 0, err + } + return newId, nil } diff --git a/internal/statetransition/on_transfer.go b/internal/statetransition/on_transfer.go index 29be45d..dcf867f 100644 --- a/internal/statetransition/on_transfer.go +++ b/internal/statetransition/on_transfer.go @@ -1,13 +1,14 @@ package statetransition import ( + "log" + "github.com/eigerco/strawberry/internal/block" "github.com/eigerco/strawberry/internal/polkavm" "github.com/eigerco/strawberry/internal/polkavm/host_call" "github.com/eigerco/strawberry/internal/polkavm/interpreter" "github.com/eigerco/strawberry/internal/service" "github.com/eigerco/strawberry/pkg/serialization/codec/jam" - "log" ) // InvokePVMOnTransfer On-Transfer service-account invocation (ΨT). @@ -48,7 +49,7 @@ func InvokePVMOnTransfer(serviceState service.ServiceState, serviceIndex block.S return gasCounter, regs, mem, serviceAccount, err } - _, _, newServiceAccount, err := interpreter.InvokeWholeProgram(serviceCode, 15, gas, args, hostCallFunc, serviceAccount) + _, _, newServiceAccount, err := interpreter.InvokeWholeProgram(serviceCode, 10, gas, args, hostCallFunc, serviceAccount) if err != nil { // TODO handle errors appropriately log.Println("the virtual machine exited with an error", err) diff --git a/pkg/serialization/codec/jam/decode.go b/pkg/serialization/codec/jam/decode.go index ef7bc7d..9a762ad 100644 --- a/pkg/serialization/codec/jam/decode.go +++ b/pkg/serialization/codec/jam/decode.go @@ -341,7 +341,7 @@ func (br *byteReader) decodeUint(value reflect.Value) error { } var v uint64 - err = DeserializeUint64WithLength(serialized, l, &v) + err = deserializeUint64WithLength(serialized, l, &v) if err != nil { return fmt.Errorf(ErrDecodingUint, err) } @@ -419,19 +419,19 @@ func (br *byteReader) decodeFixedWidthInt(dstv reflect.Value) error { switch in.(type) { case uint8: var temp uint8 - DeserializeTrivialNatural(buf, &temp) + deserializeTrivialNatural(buf, &temp) dstv.Set(reflect.ValueOf(temp)) case uint16: var temp uint16 - DeserializeTrivialNatural(buf, &temp) + deserializeTrivialNatural(buf, &temp) dstv.Set(reflect.ValueOf(temp)) case uint32: var temp uint32 - DeserializeTrivialNatural(buf, &temp) + deserializeTrivialNatural(buf, &temp) dstv.Set(reflect.ValueOf(temp)) case uint64: var temp uint64 - DeserializeTrivialNatural(buf, &temp) + deserializeTrivialNatural(buf, &temp) dstv.Set(reflect.ValueOf(temp)) } diff --git a/pkg/serialization/codec/jam/encode.go b/pkg/serialization/codec/jam/encode.go index 7001d06..17a08d4 100644 --- a/pkg/serialization/codec/jam/encode.go +++ b/pkg/serialization/codec/jam/encode.go @@ -292,13 +292,13 @@ func (bw *byteWriter) encodeFixedWidthUint(i interface{}) error { switch v := i.(type) { case uint8: - data = SerializeTrivialNatural(v, 1) + data = serializeTrivialNatural(v, 1) case uint16: - data = SerializeTrivialNatural(v, 2) + data = serializeTrivialNatural(v, 2) case uint32: - data = SerializeTrivialNatural(v, 4) + data = serializeTrivialNatural(v, 4) case uint64: - data = SerializeTrivialNatural(v, 8) + data = serializeTrivialNatural(v, 8) default: return fmt.Errorf(ErrUnsupportedType, i) } @@ -341,7 +341,7 @@ func (bw *byteWriter) encodeLength(l int) error { } func (bw *byteWriter) encodeUint(i uint) error { - encodedBytes := SerializeUint64(uint64(i)) + encodedBytes := serializeUint64(uint64(i)) _, err := bw.Write(encodedBytes) diff --git a/pkg/serialization/codec/jam/general_natural.go b/pkg/serialization/codec/jam/general_natural.go index cff50ac..72d28dc 100644 --- a/pkg/serialization/codec/jam/general_natural.go +++ b/pkg/serialization/codec/jam/general_natural.go @@ -5,8 +5,8 @@ import ( "math" ) -// SerializeUint64 implements the general formula (able to encode naturals of up to 2^64) -func SerializeUint64(x uint64) []byte { +// serializeUint64 implements the general formula (able to encode naturals of up to 2^64) +func serializeUint64(x uint64) []byte { var l uint8 // Determine the length needed to represent the value for l = 0; l < 8; l++ { @@ -30,8 +30,8 @@ func SerializeUint64(x uint64) []byte { return bytes } -// DeserializeUint64WithLength deserializes a byte slice into a uint64 value, with length `l`. -func DeserializeUint64WithLength(serialized []byte, l uint8, u *uint64) error { +// deserializeUint64WithLength deserializes a byte slice into a uint64 value, with length `l`. +func deserializeUint64WithLength(serialized []byte, l uint8, u *uint64) error { *u = 0 n := len(serialized) diff --git a/pkg/serialization/codec/jam/general_natural_test.go b/pkg/serialization/codec/jam/general_natural_test.go index 958824b..218660d 100644 --- a/pkg/serialization/codec/jam/general_natural_test.go +++ b/pkg/serialization/codec/jam/general_natural_test.go @@ -2,11 +2,12 @@ package jam import ( "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "math" "math/bits" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEncodeDecodeUint64(t *testing.T) { @@ -53,7 +54,7 @@ func TestEncodeDecodeUint64(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("uint64(%d)", tc.input), func(t *testing.T) { // Marshal the x value - serialized := SerializeUint64(tc.input) + serialized := serializeUint64(tc.input) // Check if the serialized output matches the expected output assert.Equal(t, tc.expected, serialized, "serialized output mismatch for x %d", tc.input) @@ -64,7 +65,7 @@ func TestEncodeDecodeUint64(t *testing.T) { } // Unmarshal the serialized data back into a uint64 var deserialized uint64 - err := DeserializeUint64WithLength(serialized, l, &deserialized) + err := deserializeUint64WithLength(serialized, l, &deserialized) require.NoError(t, err, "unmarshal(%v) returned an unexpected error", serialized) // Check if the deserialized value matches the original x diff --git a/pkg/serialization/codec/jam/trivial_natural.go b/pkg/serialization/codec/jam/trivial_natural.go index aa98250..1b3b6fe 100644 --- a/pkg/serialization/codec/jam/trivial_natural.go +++ b/pkg/serialization/codec/jam/trivial_natural.go @@ -4,7 +4,7 @@ import ( "math" ) -func SerializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint8) []byte { +func serializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint8) []byte { bytes := make([]byte, l) for i := uint8(0); i < l; i++ { bytes[i] = byte((x >> (8 * i)) & T(math.MaxUint8)) @@ -12,7 +12,7 @@ func SerializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint return bytes } -func DeserializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](serialized []byte, u *T) { +func deserializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](serialized []byte, u *T) { *u = 0 // Iterate over each byte in the serialized array diff --git a/pkg/serialization/codec/jam/trivial_natural_test.go b/pkg/serialization/codec/jam/trivial_natural_test.go index 10a77e9..9f8e78d 100644 --- a/pkg/serialization/codec/jam/trivial_natural_test.go +++ b/pkg/serialization/codec/jam/trivial_natural_test.go @@ -49,13 +49,13 @@ func TestSerializationTrivialNatural(t *testing.T) { var serialized []byte switch v := tc.x.(type) { case uint8: - serialized = SerializeTrivialNatural(v, tc.l) + serialized = serializeTrivialNatural(v, tc.l) case uint16: - serialized = SerializeTrivialNatural(v, tc.l) + serialized = serializeTrivialNatural(v, tc.l) case uint32: - serialized = SerializeTrivialNatural(v, tc.l) + serialized = serializeTrivialNatural(v, tc.l) case uint64: - serialized = SerializeTrivialNatural(v, tc.l) + serialized = serializeTrivialNatural(v, tc.l) } assert.Equal(t, tc.expected, serialized, "serialized output mismatch") @@ -63,19 +63,19 @@ func TestSerializationTrivialNatural(t *testing.T) { switch v := tc.x.(type) { case uint8: var deserialized uint8 - DeserializeTrivialNatural(serialized, &deserialized) + deserializeTrivialNatural(serialized, &deserialized) assert.Equal(t, v, deserialized, "deserialized value mismatch") case uint16: var deserialized uint16 - DeserializeTrivialNatural(serialized, &deserialized) + deserializeTrivialNatural(serialized, &deserialized) assert.Equal(t, v, deserialized, "deserialized value mismatch") case uint32: var deserialized uint32 - DeserializeTrivialNatural(serialized, &deserialized) + deserializeTrivialNatural(serialized, &deserialized) assert.Equal(t, v, deserialized, "deserialized value mismatch") case uint64: var deserialized uint64 - DeserializeTrivialNatural(serialized, &deserialized) + deserializeTrivialNatural(serialized, &deserialized) assert.Equal(t, v, deserialized, "deserialized value mismatch") } })