Skip to content

Commit

Permalink
Add shuffling function
Browse files Browse the repository at this point in the history
  • Loading branch information
pantrif committed Nov 14, 2024
1 parent 1382b2a commit d67f0ab
Show file tree
Hide file tree
Showing 4 changed files with 904 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fmt:
.PHONY: lint
## lint: Runs golangci-lint run
lint:
golangci-lint run
golangci-lint run --timeout=5m

.PHONY: build-bandersnatch
## build-bandersnatch: Builds the bandersnatch library
Expand Down
72 changes: 72 additions & 0 deletions internal/common/shuffling.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package common

import (
"github.com/eigerco/strawberry/internal/crypto"
"github.com/eigerco/strawberry/pkg/serialization"
"github.com/eigerco/strawberry/pkg/serialization/codec"
"golang.org/x/crypto/blake2b"
)

var serializer = serialization.NewSerializer(codec.NewJamCodec())

// DeterministicShuffle performs a deterministic shuffle of the sequence s based on the hash h (appendix F)
func DeterministicShuffle(length uint32, h crypto.Hash) ([]uint32, error) {
s := make([]uint32, length)
for i := uint32(0); i < length; i++ {
s[i] = i
}

r, err := generateRandomNumbers(h, length)
if err != nil {
return nil, err
}
return recursiveShuffle(s, r), nil
}

// recursiveShuffle recursively shuffles the sequence s using the random numbers r
func recursiveShuffle(s []uint32, r []uint32) []uint32 {
l := len(s)
if l == 0 {
return []uint32{}
}

index := r[0] % uint32(l)
head := s[index]

sPost := make([]uint32, l)
copy(sPost, s)

sPost[index] = sPost[l-1]
sPost = sPost[:l-1]

return append([]uint32{head}, recursiveShuffle(sPost, r[1:])...)
}

// generateRandomNumbers (Q_l(h)) generates a sequence of l uint32 numbers from the hash h
func generateRandomNumbers(h crypto.Hash, l uint32) ([]uint32, error) {
r := make([]uint32, l)
for i := uint32(0); i < l; i++ {
k := i / 8
kBytes, err := serializer.Encode(k)
if err != nil {
return nil, err
}

input := append(h[:], kBytes...)
hash := blake2b.Sum256(input)

p := (4 * i) % 32
var b [4]byte
for j := uint32(0); j < 4; j++ {
b[j] = hash[(p+j)%32]
}

var rI uint32
err = serializer.Decode(b[:], &rI)
if err != nil {
return nil, err
}
r[i] = rI
}
return r, nil
}
47 changes: 47 additions & 0 deletions tests/integration/shuffling_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package integration

import (
"encoding/hex"
"fmt"
"os"
"testing"

"github.com/eigerco/strawberry/internal/common"
"github.com/eigerco/strawberry/internal/crypto"
"github.com/eigerco/strawberry/pkg/serialization"
"github.com/eigerco/strawberry/pkg/serialization/codec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type TestCase struct {
Input uint32
Entropy string
ExpectedOutput []uint32 `json:"output"`
}

func TestShuffleVectors(t *testing.T) {
filePath := "vectors/shuffling/shuffle_tests.json"
fileData, err := os.ReadFile(filePath)
require.NoError(t, err)

var testCases []TestCase
s := serialization.NewSerializer(&codec.JSONCodec{})
err = s.Decode(fileData, &testCases)
require.NoError(t, err)

for idx, testCase := range testCases {
t.Run(
fmt.Sprintf("Test case %d: Input=%d", idx+1, testCase.Input),
func(t *testing.T) {
entropyBytes, err := hex.DecodeString(testCase.Entropy)
require.NoError(t, err)

shuffledSequence, err := common.DeterministicShuffle(testCase.Input, crypto.Hash(entropyBytes))
require.NoError(t, err)

assert.Equal(t, testCase.ExpectedOutput, shuffledSequence)
},
)
}
}
Loading

0 comments on commit d67f0ab

Please sign in to comment.