diff --git a/sszgen/generator/generator.go b/sszgen/generator/generator.go index 6dd06fb..9422e6f 100644 --- a/sszgen/generator/generator.go +++ b/sszgen/generator/generator.go @@ -1313,3 +1313,21 @@ func uintVToName(v *Value) string { panic(fmt.Sprintf("unknown uint size, %d bytes. field name=%s", v.s, v.name)) } } + +func uintVToLowerCaseName(v *Value) string { + if v.t != TypeUint { + panic(fmt.Sprintf("type %v for %s not expected", v.t, v.name)) + } + switch v.s { + case 8: + return "uint64" + case 4: + return "uint32" + case 2: + return "uint16" + case 1: + return "uint8" + default: + panic(fmt.Sprintf("unknown uint size, %d bytes. field name=%s", v.s, v.name)) + } +} diff --git a/sszgen/generator/hash.go b/sszgen/generator/hash.go index 9e4752c..318d5a9 100755 --- a/sszgen/generator/hash.go +++ b/sszgen/generator/hash.go @@ -161,8 +161,8 @@ func (v *Value) hashTreeRoot(name string, appendBytes bool) string { case TypeUint: if v.ref != "" || v.obj != "" { - // alias to Uint64 - name = fmt.Sprintf("uint64(%s)", name) + // alias to uint* + name = fmt.Sprintf("%s(%s)", uintVToLowerCaseName(v), name) } bitLen := v.fixedSize() * 8 return fmt.Sprintf("hh.PutUint%d(%s)", bitLen, name) diff --git a/sszgen/generator/marshal.go b/sszgen/generator/marshal.go index 57b9e8e..ea0881a 100644 --- a/sszgen/generator/marshal.go +++ b/sszgen/generator/marshal.go @@ -55,8 +55,8 @@ func (v *Value) marshal() string { case TypeUint: var name string if v.ref != "" || v.obj != "" { - // alias to Uint64 - name = fmt.Sprintf("uint64(::.%s)", v.name) + // alias to uint* + name = fmt.Sprintf("%s(::.%s)", uintVToLowerCaseName(v), v.name) } else { name = "::." + v.name } diff --git a/sszgen/testcases/uint.go b/sszgen/testcases/uint.go new file mode 100644 index 0000000..c7949af --- /dev/null +++ b/sszgen/testcases/uint.go @@ -0,0 +1,15 @@ +package testcases + +//go:generate go run ../main.go --path uint.go + +type Uint8 uint8 +type Uint16 uint16 +type Uint32 uint32 +type Uint64 uint64 + +type Uints struct { + Uint8 Uint8 + Uint16 Uint16 + Uint32 Uint32 + Uint64 Uint64 +} diff --git a/sszgen/testcases/uint_encoding.go b/sszgen/testcases/uint_encoding.go new file mode 100644 index 0000000..d50cd26 --- /dev/null +++ b/sszgen/testcases/uint_encoding.go @@ -0,0 +1,91 @@ +// Code generated by fastssz. DO NOT EDIT. +// Hash: a58c2659c9995524aa00aaa7e97600338e9a844b877948cbb37e3fb620c5e8a6 +// Version: 0.1.3 +package testcases + +import ( + ssz "github.com/ferranbt/fastssz" +) + +// MarshalSSZ ssz marshals the Uints object +func (u *Uints) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(u) +} + +// MarshalSSZTo ssz marshals the Uints object to a target array +func (u *Uints) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'Uint8' + dst = ssz.MarshalUint8(dst, uint8(u.Uint8)) + + // Field (1) 'Uint16' + dst = ssz.MarshalUint16(dst, uint16(u.Uint16)) + + // Field (2) 'Uint32' + dst = ssz.MarshalUint32(dst, uint32(u.Uint32)) + + // Field (3) 'Uint64' + dst = ssz.MarshalUint64(dst, uint64(u.Uint64)) + + return +} + +// UnmarshalSSZ ssz unmarshals the Uints object +func (u *Uints) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size != 15 { + return ssz.ErrSize + } + + // Field (0) 'Uint8' + u.Uint8 = Uint8(ssz.UnmarshallUint8(buf[0:1])) + + // Field (1) 'Uint16' + u.Uint16 = Uint16(ssz.UnmarshallUint16(buf[1:3])) + + // Field (2) 'Uint32' + u.Uint32 = Uint32(ssz.UnmarshallUint32(buf[3:7])) + + // Field (3) 'Uint64' + u.Uint64 = Uint64(ssz.UnmarshallUint64(buf[7:15])) + + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Uints object +func (u *Uints) SizeSSZ() (size int) { + size = 15 + return +} + +// HashTreeRoot ssz hashes the Uints object +func (u *Uints) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(u) +} + +// HashTreeRootWith ssz hashes the Uints object with a hasher +func (u *Uints) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Uint8' + hh.PutUint8(uint8(u.Uint8)) + + // Field (1) 'Uint16' + hh.PutUint16(uint16(u.Uint16)) + + // Field (2) 'Uint32' + hh.PutUint32(uint32(u.Uint32)) + + // Field (3) 'Uint64' + hh.PutUint64(uint64(u.Uint64)) + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Uints object +func (u *Uints) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(u) +} diff --git a/sszgen/testcases/uint_test.go b/sszgen/testcases/uint_test.go new file mode 100644 index 0000000..dfd2733 --- /dev/null +++ b/sszgen/testcases/uint_test.go @@ -0,0 +1,35 @@ +package testcases + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestUint(t *testing.T) { + s := Uints{ + Uint8: Uint8(123), + Uint16: Uint16(12345), + Uint32: Uint32(1234567890), + Uint64: Uint64(123456789000), + } + expectedHash := [32]byte{ + 0xea, 0xfc, 0xf7, 0xa2, 0x41, 0x8, 0x51, 0xa2, + 0xa0, 0xb0, 0x23, 0x68, 0xff, 0x4, 0x44, 0xbd, + 0x24, 0xc9, 0x9b, 0xff, 0xe7, 0x81, 0xca, 0x49, + 0xb6, 0xf7, 0xd4, 0x99, 0x28, 0xf3, 0xee, 0xeb, + } + + bytes, err := s.MarshalSSZ() + assert.NoError(t, err) + + var s2 Uints + err = s2.UnmarshalSSZ(bytes) + assert.NoError(t, err) + + assert.Equal(t, s, s2) + + h, err := s.HashTreeRoot() + assert.NoError(t, err) + + assert.Equal(t, h, expectedHash) +}