Skip to content

Commit

Permalink
Remove best provider option for OpenCL provider ID
Browse files Browse the repository at this point in the history
  • Loading branch information
fasmat committed Aug 4, 2023
1 parent 366da6d commit 7f8438f
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 46 deletions.
36 changes: 20 additions & 16 deletions activation/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"math"
"runtime"
"sync"

Expand Down Expand Up @@ -82,7 +83,7 @@ type PostSetupOpts struct {
DataDir string `mapstructure:"smeshing-opts-datadir"`
NumUnits uint32 `mapstructure:"smeshing-opts-numunits"`
MaxFileSize uint64 `mapstructure:"smeshing-opts-maxfilesize"`
ProviderID int `mapstructure:"smeshing-opts-provider"`
ProviderID uint64 `mapstructure:"smeshing-opts-provider"`
Throttle bool `mapstructure:"smeshing-opts-throttle"`
Scrypt config.ScryptParams `mapstructure:"smeshing-opts-scrypt"`
ComputeBatchSize uint64 `mapstructure:"smeshing-opts-compute-batch-size"`
Expand Down Expand Up @@ -169,23 +170,32 @@ func DefaultPostSetupOpts() PostSetupOpts {
DataDir: opts.DataDir,
NumUnits: opts.NumUnits,
MaxFileSize: opts.MaxFileSize,
ProviderID: opts.ProviderID,
ProviderID: math.MaxUint64,
Throttle: opts.Throttle,
Scrypt: opts.Scrypt,
ComputeBatchSize: opts.ComputeBatchSize,
}
}

func (o PostSetupOpts) ToInitOpts() config.InitOpts {
return config.InitOpts{
func (o PostSetupOpts) ToInitOpts() (config.InitOpts, error) {
if o.ProviderID == math.MaxUint64 {
return config.InitOpts{}, fmt.Errorf("no provider ID specified: %d", o.ProviderID)
}
if o.ProviderID > math.MaxUint32 {
return config.InitOpts{}, fmt.Errorf("invalid provider id: %d", o.ProviderID)
}

providerID := uint32(o.ProviderID)
initOpts := config.InitOpts{
DataDir: o.DataDir,
NumUnits: o.NumUnits,
MaxFileSize: o.MaxFileSize,
ProviderID: o.ProviderID,
ProviderID: &providerID,
Throttle: o.Throttle,
Scrypt: o.Scrypt,
ComputeBatchSize: o.ComputeBatchSize,
}
return initOpts, nil
}

// PostSetupManager implements the PostProvider interface.
Expand Down Expand Up @@ -368,27 +378,21 @@ func (mgr *PostSetupManager) PrepareInitializer(ctx context.Context, opts PostSe
return fmt.Errorf("post setup session in progress")
}

if opts.ProviderID == config.BestProviderID {
p, err := mgr.BestProvider()
if err != nil {
return err
}

mgr.logger.Info("found best compute provider: id: %d, model: %v, device type: %v", p.ID, p.Model, p.DeviceType)
opts.ProviderID = int(p.ID)
}

var err error
mgr.commitmentAtxId, err = mgr.commitmentAtx(ctx, opts.DataDir)
if err != nil {
return err
}

initOpts, err := opts.ToInitOpts()
if err != nil {
return err
}
newInit, err := initialization.NewInitializer(
initialization.WithNodeId(mgr.id.Bytes()),
initialization.WithCommitmentAtxId(mgr.commitmentAtxId.Bytes()),
initialization.WithConfig(mgr.cfg.ToConfig()),
initialization.WithInitOpts(opts.ToInitOpts()),
initialization.WithInitOpts(initOpts),
initialization.WithLogger(mgr.logger.Zap()),
)
if err != nil {
Expand Down
14 changes: 1 addition & 13 deletions activation/post_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,6 @@ func TestPostSetupManager_PrepareInitializer(t *testing.T) {
req.Error(mgr.PrepareInitializer(ctx, opts))
}

func TestPostSetupManager_PrepareInitializer_BestProvider(t *testing.T) {
req := require.New(t)

mgr := newTestPostManager(t)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

mgr.opts.ProviderID = config.BestProviderID
req.NoError(mgr.PrepareInitializer(ctx, mgr.opts))
}

// Checks that the sequence of calls for initialization (first
// PrepareInitializer and then StartSession) is enforced.
func TestPostSetupManager_InitializationCallSequence(t *testing.T) {
Expand Down Expand Up @@ -442,7 +430,7 @@ func newTestPostManager(tb testing.TB, o ...newPostSetupMgrOptionFunc) *testPost

opts := DefaultPostSetupOpts()
opts.DataDir = tb.TempDir()
opts.ProviderID = int(initialization.CPUProviderID())
opts.ProviderID = uint64(initialization.CPUProviderID())
opts.Scrypt.N = 2 // Speedup initialization in tests.

goldenATXID := types.ATXID{2, 3, 4}
Expand Down
18 changes: 10 additions & 8 deletions activation/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,20 @@ func Test_Validation_VRFNonce(t *testing.T) {
LabelsPerUnit: postCfg.LabelsPerUnit,
}

initOpts := DefaultPostSetupOpts()
initOpts.DataDir = t.TempDir()
initOpts.ProviderID = int(initialization.CPUProviderID())
opts := DefaultPostSetupOpts()
opts.DataDir = t.TempDir()
opts.ProviderID = uint64(initialization.CPUProviderID())

nodeId := types.BytesToNodeID(make([]byte, 32))
commitmentAtxId := types.EmptyATXID

initOpts, err := opts.ToInitOpts()
r.NoError(err)
init, err := initialization.NewInitializer(
initialization.WithNodeId(nodeId.Bytes()),
initialization.WithCommitmentAtxId(commitmentAtxId.Bytes()),
initialization.WithConfig(postCfg.ToConfig()),
initialization.WithInitOpts(initOpts.ToInitOpts()),
initialization.WithInitOpts(initOpts),
)
r.NoError(err)
r.NoError(init.Initialize(context.Background()))
Expand All @@ -54,27 +56,27 @@ func Test_Validation_VRFNonce(t *testing.T) {
t.Run("valid vrf nonce", func(t *testing.T) {
t.Parallel()

require.NoError(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, initOpts.NumUnits))
require.NoError(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, opts.NumUnits))
})

t.Run("invalid vrf nonce", func(t *testing.T) {
t.Parallel()

nonce := types.VRFPostIndex(1)
require.Error(t, v.VRFNonce(nodeId, commitmentAtxId, &nonce, meta, initOpts.NumUnits))
require.Error(t, v.VRFNonce(nodeId, commitmentAtxId, &nonce, meta, opts.NumUnits))
})

t.Run("wrong commitmentAtxId", func(t *testing.T) {
t.Parallel()

commitmentAtxId := types.ATXID{1, 2, 3}
require.Error(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, initOpts.NumUnits))
require.Error(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, opts.NumUnits))
})

t.Run("numUnits can be smaller", func(t *testing.T) {
t.Parallel()

require.NoError(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, initOpts.NumUnits-1))
require.NoError(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, opts.NumUnits-1))
})
}

Expand Down
2 changes: 1 addition & 1 deletion api/grpcserver/smesher_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (s SmesherService) StartSmeshing(ctx context.Context, in *pb.StartSmeshingR
opts.DataDir = in.Opts.DataDir
opts.NumUnits = in.Opts.NumUnits
opts.MaxFileSize = in.Opts.MaxFileSize
opts.ProviderID = int(in.Opts.ProviderId)
opts.ProviderID = uint64(in.Opts.ProviderId)
opts.Throttle = in.Opts.Throttle

coinbaseAddr, err := types.StringToAddress(in.Coinbase.Address)
Expand Down
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func AddCommands(cmd *cobra.Command) {
cfg.SMESHING.Opts.NumUnits, "")
cmd.PersistentFlags().Uint64Var(&cfg.SMESHING.Opts.MaxFileSize, "smeshing-opts-maxfilesize",
cfg.SMESHING.Opts.MaxFileSize, "")
cmd.PersistentFlags().IntVar(&cfg.SMESHING.Opts.ProviderID, "smeshing-opts-provider",
cmd.PersistentFlags().Uint64Var(&cfg.SMESHING.Opts.ProviderID, "smeshing-opts-provider",
cfg.SMESHING.Opts.ProviderID, "")
cmd.PersistentFlags().BoolVar(&cfg.SMESHING.Opts.Throttle, "smeshing-opts-throttle",
cfg.SMESHING.Opts.Throttle, "")
Expand Down
2 changes: 1 addition & 1 deletion config/presets/fastnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func fastnet() config.Config {

conf.SMESHING.CoinbaseAccount = types.GenerateAddress([]byte("1")).String()
conf.SMESHING.Start = false
conf.SMESHING.Opts.ProviderID = int(initialization.CPUProviderID())
conf.SMESHING.Opts.ProviderID = uint64(initialization.CPUProviderID())
conf.SMESHING.Opts.NumUnits = 2
conf.SMESHING.Opts.Throttle = true
// Override proof of work flags to use light mode (less memory intensive)
Expand Down
2 changes: 1 addition & 1 deletion config/presets/standalone.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func standalone() config.Config {

conf.SMESHING.CoinbaseAccount = types.GenerateAddress([]byte("1")).String()
conf.SMESHING.Start = true
conf.SMESHING.Opts.ProviderID = int(initialization.CPUProviderID())
conf.SMESHING.Opts.ProviderID = uint64(initialization.CPUProviderID())
conf.SMESHING.Opts.NumUnits = 1
conf.SMESHING.Opts.Throttle = true
conf.SMESHING.Opts.DataDir = conf.DataDirParent
Expand Down
2 changes: 1 addition & 1 deletion config/presets/testnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func testnet() config.Config {

conf.SMESHING.CoinbaseAccount = types.GenerateAddress([]byte("1")).String()
conf.SMESHING.Start = false
conf.SMESHING.Opts.ProviderID = int(initialization.CPUProviderID())
conf.SMESHING.Opts.ProviderID = uint64(initialization.CPUProviderID())
conf.SMESHING.Opts.NumUnits = 2
conf.SMESHING.Opts.Throttle = true

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ require (
github.com/spacemeshos/go-scale v1.1.10
github.com/spacemeshos/merkle-tree v0.2.2
github.com/spacemeshos/poet v0.8.7
github.com/spacemeshos/post v0.8.11
github.com/spacemeshos/post v0.8.12-0.20230804123744-7e32021db1c1
github.com/spf13/afero v1.9.5
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,8 @@ github.com/spacemeshos/merkle-tree v0.2.2 h1:+zF+17CwVebq9UzShunUBXv16rEVKIJHh2C
github.com/spacemeshos/merkle-tree v0.2.2/go.mod h1:0Q/z4S5Kt9qz/c3qWa7hKA1yT7n7odyysbPIUTRu/xo=
github.com/spacemeshos/poet v0.8.7 h1:Q85SDIlV0Asn7AJ9TaIbE1upimcXSUixYVeSY2bFqXI=
github.com/spacemeshos/poet v0.8.7/go.mod h1:eRT87JENlVfItdJvKtnFCnSFcceeYTRrsXfF5u5sE8k=
github.com/spacemeshos/post v0.8.11 h1:uZo8n5wF0fGOxRzXHabU/Buu9aLyPT4YWi6ZFms9mWg=
github.com/spacemeshos/post v0.8.11/go.mod h1:LDj6XQht1ZvTZurBJ+LZNf17t92qIQymWgU2sY6V2Zs=
github.com/spacemeshos/post v0.8.12-0.20230804123744-7e32021db1c1 h1:+T+Pg/fHVBS7rPS4TBSUDRUo7DhLHreSFhjRlzcGpr0=
github.com/spacemeshos/post v0.8.12-0.20230804123744-7e32021db1c1/go.mod h1:LDj6XQht1ZvTZurBJ+LZNf17t92qIQymWgU2sY6V2Zs=
github.com/spacemeshos/sha256-simd v0.1.0 h1:G7Mfu5RYdQiuE+wu4ZyJ7I0TI74uqLhFnKblEnSpjYI=
github.com/spacemeshos/sha256-simd v0.1.0/go.mod h1:O8CClVIilId7RtuCMV2+YzMj6qjVn75JsxOxaE8vcfM=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
Expand Down
2 changes: 1 addition & 1 deletion node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ func getTestDefaultConfig(tb testing.TB) *config.Config {
cfg.SMESHING = config.DefaultSmeshingConfig()
cfg.SMESHING.Start = true
cfg.SMESHING.Opts.NumUnits = cfg.POST.MinNumUnits + 1
cfg.SMESHING.Opts.ProviderID = int(initialization.CPUProviderID())
cfg.SMESHING.Opts.ProviderID = uint64(initialization.CPUProviderID())

// note: these need to be set sufficiently low enough that turbohare finishes well before the LayerDurationSec
cfg.HARE.RoundDuration = 2
Expand Down

0 comments on commit 7f8438f

Please sign in to comment.