diff --git a/activation/post.go b/activation/post.go index fa3cd97a445..d273e2f08bc 100644 --- a/activation/post.go +++ b/activation/post.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "fmt" + "math" "runtime" "sync" @@ -82,7 +83,7 @@ type PostSetupOpts struct { DataDir string `mapstructure:"smeshing-opts-datadir"` NumUnits uint32 `mapstructure:"smeshing-opts-numunits"` MaxFileSize uint64 `mapstructure:"smeshing-opts-maxfilesize"` - ProviderID int `mapstructure:"smeshing-opts-provider"` + ProviderID uint64 `mapstructure:"smeshing-opts-provider"` Throttle bool `mapstructure:"smeshing-opts-throttle"` Scrypt config.ScryptParams `mapstructure:"smeshing-opts-scrypt"` ComputeBatchSize uint64 `mapstructure:"smeshing-opts-compute-batch-size"` @@ -169,23 +170,32 @@ func DefaultPostSetupOpts() PostSetupOpts { DataDir: opts.DataDir, NumUnits: opts.NumUnits, MaxFileSize: opts.MaxFileSize, - ProviderID: opts.ProviderID, + ProviderID: math.MaxUint64, Throttle: opts.Throttle, Scrypt: opts.Scrypt, ComputeBatchSize: opts.ComputeBatchSize, } } -func (o PostSetupOpts) ToInitOpts() config.InitOpts { - return config.InitOpts{ +func (o PostSetupOpts) ToInitOpts() (config.InitOpts, error) { + if o.ProviderID == math.MaxUint64 { + return config.InitOpts{}, fmt.Errorf("no provider ID specified: %d", o.ProviderID) + } + if o.ProviderID > math.MaxUint32 { + return config.InitOpts{}, fmt.Errorf("invalid provider id: %d", o.ProviderID) + } + + providerID := uint32(o.ProviderID) + initOpts := config.InitOpts{ DataDir: o.DataDir, NumUnits: o.NumUnits, MaxFileSize: o.MaxFileSize, - ProviderID: o.ProviderID, + ProviderID: &providerID, Throttle: o.Throttle, Scrypt: o.Scrypt, ComputeBatchSize: o.ComputeBatchSize, } + return initOpts, nil } // PostSetupManager implements the PostProvider interface. @@ -368,27 +378,21 @@ func (mgr *PostSetupManager) PrepareInitializer(ctx context.Context, opts PostSe return fmt.Errorf("post setup session in progress") } - if opts.ProviderID == config.BestProviderID { - p, err := mgr.BestProvider() - if err != nil { - return err - } - - mgr.logger.Info("found best compute provider: id: %d, model: %v, device type: %v", p.ID, p.Model, p.DeviceType) - opts.ProviderID = int(p.ID) - } - var err error mgr.commitmentAtxId, err = mgr.commitmentAtx(ctx, opts.DataDir) if err != nil { return err } + initOpts, err := opts.ToInitOpts() + if err != nil { + return err + } newInit, err := initialization.NewInitializer( initialization.WithNodeId(mgr.id.Bytes()), initialization.WithCommitmentAtxId(mgr.commitmentAtxId.Bytes()), initialization.WithConfig(mgr.cfg.ToConfig()), - initialization.WithInitOpts(opts.ToInitOpts()), + initialization.WithInitOpts(initOpts), initialization.WithLogger(mgr.logger.Zap()), ) if err != nil { diff --git a/activation/post_test.go b/activation/post_test.go index c9402bea1ca..47daea5fd7f 100644 --- a/activation/post_test.go +++ b/activation/post_test.go @@ -108,18 +108,6 @@ func TestPostSetupManager_PrepareInitializer(t *testing.T) { req.Error(mgr.PrepareInitializer(ctx, opts)) } -func TestPostSetupManager_PrepareInitializer_BestProvider(t *testing.T) { - req := require.New(t) - - mgr := newTestPostManager(t) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - - mgr.opts.ProviderID = config.BestProviderID - req.NoError(mgr.PrepareInitializer(ctx, mgr.opts)) -} - // Checks that the sequence of calls for initialization (first // PrepareInitializer and then StartSession) is enforced. func TestPostSetupManager_InitializationCallSequence(t *testing.T) { @@ -442,7 +430,7 @@ func newTestPostManager(tb testing.TB, o ...newPostSetupMgrOptionFunc) *testPost opts := DefaultPostSetupOpts() opts.DataDir = tb.TempDir() - opts.ProviderID = int(initialization.CPUProviderID()) + opts.ProviderID = uint64(initialization.CPUProviderID()) opts.Scrypt.N = 2 // Speedup initialization in tests. goldenATXID := types.ATXID{2, 3, 4} diff --git a/activation/validation_test.go b/activation/validation_test.go index 96e9e3e0eea..88d944ee8ad 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -29,18 +29,20 @@ func Test_Validation_VRFNonce(t *testing.T) { LabelsPerUnit: postCfg.LabelsPerUnit, } - initOpts := DefaultPostSetupOpts() - initOpts.DataDir = t.TempDir() - initOpts.ProviderID = int(initialization.CPUProviderID()) + opts := DefaultPostSetupOpts() + opts.DataDir = t.TempDir() + opts.ProviderID = uint64(initialization.CPUProviderID()) nodeId := types.BytesToNodeID(make([]byte, 32)) commitmentAtxId := types.EmptyATXID + initOpts, err := opts.ToInitOpts() + r.NoError(err) init, err := initialization.NewInitializer( initialization.WithNodeId(nodeId.Bytes()), initialization.WithCommitmentAtxId(commitmentAtxId.Bytes()), initialization.WithConfig(postCfg.ToConfig()), - initialization.WithInitOpts(initOpts.ToInitOpts()), + initialization.WithInitOpts(initOpts), ) r.NoError(err) r.NoError(init.Initialize(context.Background())) @@ -54,27 +56,27 @@ func Test_Validation_VRFNonce(t *testing.T) { t.Run("valid vrf nonce", func(t *testing.T) { t.Parallel() - require.NoError(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, initOpts.NumUnits)) + require.NoError(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, opts.NumUnits)) }) t.Run("invalid vrf nonce", func(t *testing.T) { t.Parallel() nonce := types.VRFPostIndex(1) - require.Error(t, v.VRFNonce(nodeId, commitmentAtxId, &nonce, meta, initOpts.NumUnits)) + require.Error(t, v.VRFNonce(nodeId, commitmentAtxId, &nonce, meta, opts.NumUnits)) }) t.Run("wrong commitmentAtxId", func(t *testing.T) { t.Parallel() commitmentAtxId := types.ATXID{1, 2, 3} - require.Error(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, initOpts.NumUnits)) + require.Error(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, opts.NumUnits)) }) t.Run("numUnits can be smaller", func(t *testing.T) { t.Parallel() - require.NoError(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, initOpts.NumUnits-1)) + require.NoError(t, v.VRFNonce(nodeId, commitmentAtxId, nonce, meta, opts.NumUnits-1)) }) } diff --git a/api/grpcserver/smesher_service.go b/api/grpcserver/smesher_service.go index db0eeb84ccc..ff3a1fe72b0 100644 --- a/api/grpcserver/smesher_service.go +++ b/api/grpcserver/smesher_service.go @@ -82,7 +82,7 @@ func (s SmesherService) StartSmeshing(ctx context.Context, in *pb.StartSmeshingR opts.DataDir = in.Opts.DataDir opts.NumUnits = in.Opts.NumUnits opts.MaxFileSize = in.Opts.MaxFileSize - opts.ProviderID = int(in.Opts.ProviderId) + opts.ProviderID = uint64(in.Opts.ProviderId) opts.Throttle = in.Opts.Throttle coinbaseAddr, err := types.StringToAddress(in.Coinbase.Address) diff --git a/cmd/root.go b/cmd/root.go index 73807aa34f7..50329ac9e30 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -239,7 +239,7 @@ func AddCommands(cmd *cobra.Command) { cfg.SMESHING.Opts.NumUnits, "") cmd.PersistentFlags().Uint64Var(&cfg.SMESHING.Opts.MaxFileSize, "smeshing-opts-maxfilesize", cfg.SMESHING.Opts.MaxFileSize, "") - cmd.PersistentFlags().IntVar(&cfg.SMESHING.Opts.ProviderID, "smeshing-opts-provider", + cmd.PersistentFlags().Uint64Var(&cfg.SMESHING.Opts.ProviderID, "smeshing-opts-provider", cfg.SMESHING.Opts.ProviderID, "") cmd.PersistentFlags().BoolVar(&cfg.SMESHING.Opts.Throttle, "smeshing-opts-throttle", cfg.SMESHING.Opts.Throttle, "") diff --git a/config/presets/fastnet.go b/config/presets/fastnet.go index 4f4dc7f6356..523570452d5 100644 --- a/config/presets/fastnet.go +++ b/config/presets/fastnet.go @@ -56,7 +56,7 @@ func fastnet() config.Config { conf.SMESHING.CoinbaseAccount = types.GenerateAddress([]byte("1")).String() conf.SMESHING.Start = false - conf.SMESHING.Opts.ProviderID = int(initialization.CPUProviderID()) + conf.SMESHING.Opts.ProviderID = uint64(initialization.CPUProviderID()) conf.SMESHING.Opts.NumUnits = 2 conf.SMESHING.Opts.Throttle = true // Override proof of work flags to use light mode (less memory intensive) diff --git a/config/presets/standalone.go b/config/presets/standalone.go index 018f48d73da..a409391d727 100644 --- a/config/presets/standalone.go +++ b/config/presets/standalone.go @@ -58,7 +58,7 @@ func standalone() config.Config { conf.SMESHING.CoinbaseAccount = types.GenerateAddress([]byte("1")).String() conf.SMESHING.Start = true - conf.SMESHING.Opts.ProviderID = int(initialization.CPUProviderID()) + conf.SMESHING.Opts.ProviderID = uint64(initialization.CPUProviderID()) conf.SMESHING.Opts.NumUnits = 1 conf.SMESHING.Opts.Throttle = true conf.SMESHING.Opts.DataDir = conf.DataDirParent diff --git a/config/presets/testnet.go b/config/presets/testnet.go index 7f3dfb3411c..19e96ac1377 100644 --- a/config/presets/testnet.go +++ b/config/presets/testnet.go @@ -56,7 +56,7 @@ func testnet() config.Config { conf.SMESHING.CoinbaseAccount = types.GenerateAddress([]byte("1")).String() conf.SMESHING.Start = false - conf.SMESHING.Opts.ProviderID = int(initialization.CPUProviderID()) + conf.SMESHING.Opts.ProviderID = uint64(initialization.CPUProviderID()) conf.SMESHING.Opts.NumUnits = 2 conf.SMESHING.Opts.Throttle = true diff --git a/go.mod b/go.mod index 8613e46ee40..ecef59ffe6f 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,7 @@ require ( github.com/spacemeshos/go-scale v1.1.10 github.com/spacemeshos/merkle-tree v0.2.2 github.com/spacemeshos/poet v0.8.7 - github.com/spacemeshos/post v0.8.11 + github.com/spacemeshos/post v0.8.12-0.20230804123744-7e32021db1c1 github.com/spf13/afero v1.9.5 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 diff --git a/go.sum b/go.sum index 6036a224e21..510eb907781 100644 --- a/go.sum +++ b/go.sum @@ -647,8 +647,8 @@ github.com/spacemeshos/merkle-tree v0.2.2 h1:+zF+17CwVebq9UzShunUBXv16rEVKIJHh2C github.com/spacemeshos/merkle-tree v0.2.2/go.mod h1:0Q/z4S5Kt9qz/c3qWa7hKA1yT7n7odyysbPIUTRu/xo= github.com/spacemeshos/poet v0.8.7 h1:Q85SDIlV0Asn7AJ9TaIbE1upimcXSUixYVeSY2bFqXI= github.com/spacemeshos/poet v0.8.7/go.mod h1:eRT87JENlVfItdJvKtnFCnSFcceeYTRrsXfF5u5sE8k= -github.com/spacemeshos/post v0.8.11 h1:uZo8n5wF0fGOxRzXHabU/Buu9aLyPT4YWi6ZFms9mWg= -github.com/spacemeshos/post v0.8.11/go.mod h1:LDj6XQht1ZvTZurBJ+LZNf17t92qIQymWgU2sY6V2Zs= +github.com/spacemeshos/post v0.8.12-0.20230804123744-7e32021db1c1 h1:+T+Pg/fHVBS7rPS4TBSUDRUo7DhLHreSFhjRlzcGpr0= +github.com/spacemeshos/post v0.8.12-0.20230804123744-7e32021db1c1/go.mod h1:LDj6XQht1ZvTZurBJ+LZNf17t92qIQymWgU2sY6V2Zs= github.com/spacemeshos/sha256-simd v0.1.0 h1:G7Mfu5RYdQiuE+wu4ZyJ7I0TI74uqLhFnKblEnSpjYI= github.com/spacemeshos/sha256-simd v0.1.0/go.mod h1:O8CClVIilId7RtuCMV2+YzMj6qjVn75JsxOxaE8vcfM= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= diff --git a/node/node_test.go b/node/node_test.go index 6e5b44b36de..b0c46d061a2 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -999,7 +999,7 @@ func getTestDefaultConfig(tb testing.TB) *config.Config { cfg.SMESHING = config.DefaultSmeshingConfig() cfg.SMESHING.Start = true cfg.SMESHING.Opts.NumUnits = cfg.POST.MinNumUnits + 1 - cfg.SMESHING.Opts.ProviderID = int(initialization.CPUProviderID()) + cfg.SMESHING.Opts.ProviderID = uint64(initialization.CPUProviderID()) // note: these need to be set sufficiently low enough that turbohare finishes well before the LayerDurationSec cfg.HARE.RoundDuration = 2