Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor initializer to be simpler to use and race free #84

Merged
merged 12 commits into from
Nov 9, 2022
231 changes: 112 additions & 119 deletions initialization/initialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"path/filepath"
"sync"
"sync/atomic"
"time"

"golang.org/x/sync/errgroup"

Expand All @@ -29,11 +28,19 @@ type (
ComputeProvider = gpu.ComputeProvider
)

type Status int

const (
StatusNotStarted Status = iota
StatusStarted
StatusInitializing
StatusCompleted
StatusError
)

var (
ErrNotInitializing = errors.New("not initializing")
ErrAlreadyInitializing = errors.New("already initializing")
ErrCannotResetWhileInitializing = errors.New("cannot reset while initializing")
ErrStopped = errors.New("gpu-post: stopped")
ErrStateMetadataFileMissing = errors.New("metadata file is missing")
)

Expand All @@ -45,67 +52,105 @@ func CPUProviderID() int {
return gpu.CPUProviderID()
}

type initializeOption struct {
commitment []byte
cfg *Config
initOpts *config.InitOpts
logger Logger
}

func (opts *initializeOption) verify() error {
if opts.cfg == nil {
return errors.New("no config provided")
}

if opts.initOpts == nil {
return errors.New("no init options provided")
}

return config.Validate(*opts.cfg, *opts.initOpts)
}

type initializeOptionFunc func(*initializeOption) error

func WithCommitment(commitment []byte) initializeOptionFunc {
return func(opts *initializeOption) error {
if len(commitment) != 32 {
return fmt.Errorf("invalid `id` length; expected: 32, given: %v", len(commitment))
}
opts.commitment = commitment
return nil
}
}

func WithInitOpts(initOpts config.InitOpts) initializeOptionFunc {
return func(opts *initializeOption) error {
opts.initOpts = &initOpts
return nil
}
}

func WithConfig(cfg Config) initializeOptionFunc {
return func(opts *initializeOption) error {
opts.cfg = &cfg
return nil
}
}

func WithLogger(logger Logger) initializeOptionFunc {
return func(opts *initializeOption) error {
if logger == nil {
return errors.New("logger is nil")
}
opts.logger = logger
return nil
}
}

type Initializer struct {
numLabelsWritten atomic.Uint64
numLabelsWrittenChan chan uint64
numLabelsWritten atomic.Uint64

cfg Config
opts InitOpts
commitment []byte

diskState *DiskState
initializing bool
mtx sync.RWMutex

stopChan chan struct{}
doneChan chan struct{}
diskState *DiskState
mtx sync.RWMutex

logger Logger
}

func NewInitializer(cfg Config, opts config.InitOpts, commitment []byte) (*Initializer, error) {
if len(commitment) != 32 {
return nil, fmt.Errorf("invalid `id` length; expected: 32, given: %v", len(commitment))
func NewInitializer(opts ...initializeOptionFunc) (*Initializer, error) {
options := &initializeOption{
logger: shared.DisabledLogger{},
}

for _, opt := range opts {
if err := opt(options); err != nil {
return nil, err
}
}

if err := config.Validate(cfg, opts); err != nil {
if err := options.verify(); err != nil {
return nil, err
}

return &Initializer{
cfg: cfg,
opts: opts,
commitment: commitment,
diskState: NewDiskState(opts.DataDir, uint(cfg.BitsPerLabel)),
logger: shared.DisabledLogger{},
cfg: *options.cfg,
opts: *options.initOpts,
commitment: options.commitment,
diskState: NewDiskState(options.initOpts.DataDir, uint(options.cfg.BitsPerLabel)),
logger: options.logger,
}, nil
}

// Initialize is the process in which the prover commits to store some data, by having its storage filled with
// pseudo-random data with respect to a specific id. This data is the result of a computationally-expensive operation.
func (init *Initializer) Initialize() error {
init.mtx.Lock()

if init.initializing {
init.mtx.Unlock()
func (init *Initializer) Initialize(ctx context.Context) error {
if !init.mtx.TryLock() {
return ErrAlreadyInitializing
}

init.stopChan = make(chan struct{})
init.doneChan = make(chan struct{})
init.numLabelsWrittenChan = make(chan uint64)

init.initializing = true
init.mtx.Unlock()

defer func() {
init.mtx.Lock()
defer init.mtx.Unlock()
init.initializing = false

close(init.doneChan)
close(init.numLabelsWrittenChan)
}()
defer init.mtx.Unlock()

if numLabelsWritten, err := init.diskState.NumLabelsWritten(); err != nil {
return err
Expand All @@ -130,57 +175,23 @@ func (init *Initializer) Initialize() error {
init.opts.NumFiles, init.opts.NumUnits, init.cfg.LabelsPerUnit, init.cfg.BitsPerLabel, init.opts.DataDir)

for i := 0; i < int(init.opts.NumFiles); i++ {
if err := init.initFile(uint(init.opts.ComputeProviderID), i, numLabels, fileNumLabels); err != nil {
if err := init.initFile(ctx, uint(init.opts.ComputeProviderID), i, numLabels, fileNumLabels); err != nil {
return err
}
}

return nil
}

func (init *Initializer) isInitializing() bool {
init.mtx.RLock()
defer init.mtx.RUnlock()
return init.initializing
}

func (init *Initializer) Stop() error {
if !init.isInitializing() {
return ErrNotInitializing
}

close(init.stopChan)
if res := gpu.Stop(); res != gpu.StopResultOk {
return fmt.Errorf("gpu stop error: %s", res)
}

select {
case <-init.doneChan:
case <-time.After(5 * time.Second):
return errors.New("stop timeout")
}

return nil
}

func (init *Initializer) SessionNumLabelsWrittenChan() <-chan uint64 {
init.mtx.RLock()
defer init.mtx.RUnlock()
return init.numLabelsWrittenChan
}

func (init *Initializer) SessionNumLabelsWritten() uint64 {
return init.numLabelsWritten.Load()
}

func (init *Initializer) Reset() error {
if init.isInitializing() {
if !init.mtx.TryLock() {
return ErrCannotResetWhileInitializing
}

if err := init.VerifyStarted(); err != nil {
return err
}
defer init.mtx.Unlock()

files, err := os.ReadDir(init.opts.DataDir)
if err != nil {
Expand All @@ -203,38 +214,30 @@ func (init *Initializer) Reset() error {
return nil
}

func (init *Initializer) Started() (bool, error) {
numLabelsWritten, err := init.diskState.NumLabelsWritten()
if err != nil {
return false, err
func (init *Initializer) Status() Status {
if !init.mtx.TryLock() {
return StatusInitializing
}
defer init.mtx.Unlock()

return numLabelsWritten > 0, nil
}

func (init *Initializer) Completed() (bool, error) {
numLabelsWritten, err := init.diskState.NumLabelsWritten()
if err != nil {
return false, err
return StatusError
}

target := uint64(init.opts.NumUnits) * uint64(init.cfg.LabelsPerUnit)
return numLabelsWritten == target, nil
}

func (init *Initializer) VerifyStarted() error {
ok, err := init.Started()
if err != nil {
return err
if numLabelsWritten == target {
return StatusCompleted
}
if !ok {
return shared.ErrInitNotStarted

if numLabelsWritten > 0 {
return StatusStarted
}

return nil
return StatusNotStarted
fasmat marked this conversation as resolved.
Show resolved Hide resolved
}

func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabels uint64, fileNumLabels uint64) error {
func (init *Initializer) initFile(ctx context.Context, computeProviderID uint, fileIndex int, numLabels uint64, fileNumLabels uint64) error {
fileOffset := uint64(fileIndex) * fileNumLabels
fileTargetPosition := fileOffset + fileNumLabels
batchSize := uint64(config.DefaultComputeBatchSize)
Expand All @@ -254,7 +257,7 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
if numLabelsWritten > 0 {
if numLabelsWritten == fileNumLabels {
init.logger.Info("initialization: file #%v already initialized; number of labels: %v, start position: %v", fileIndex, numLabelsWritten, fileOffset)
init.updateSessionNumLabelsWritten(fileTargetPosition)
init.numLabelsWritten.Store(fileTargetPosition)
return nil
}

Expand All @@ -263,7 +266,7 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
if err := writer.Truncate(fileNumLabels); err != nil {
return err
}
init.updateSessionNumLabelsWritten(fileTargetPosition)
init.numLabelsWritten.Store(fileTargetPosition)
return nil
}

Expand All @@ -275,17 +278,22 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
currentPosition := numLabelsWritten
outputChan := make(chan []byte, 1024)

errGroup, ctx := errgroup.WithContext(context.Background())
errGroup, ctx := errgroup.WithContext(ctx)

// Start compute worker.
errGroup.Go(func() error {
defer close(outputChan)

for currentPosition < fileNumLabels {
select {
case <-init.stopChan:
case <-ctx.Done():
init.logger.Info("initialization: stopped")
return ErrStopped

if res := gpu.Stop(); res != gpu.StopResultOk {
fasmat marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("gpu stop error: %s", res)
}

return ctx.Err()
default:
}

Expand All @@ -307,7 +315,7 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
outputChan <- output
currentPosition += uint64(batchSize)

init.updateSessionNumLabelsWritten(fileOffset + currentPosition)
init.numLabelsWritten.Store(fileOffset + currentPosition)
}
return nil
})
Expand Down Expand Up @@ -344,17 +352,6 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
return nil
}

func (init *Initializer) updateSessionNumLabelsWritten(numLabelsWritten uint64) {
init.numLabelsWritten.Store(numLabelsWritten)

select {
case init.numLabelsWrittenChan <- numLabelsWritten:
default:
// if no one listens for the update, we just drop it
// otherwise Initializer would eventually stop working until someone reads from the channel
}
}

func (init *Initializer) verifyMetadata(m *Metadata) error {
if !bytes.Equal(init.commitment, m.Commitment) {
return ConfigMismatchError{
Expand Down Expand Up @@ -419,7 +416,3 @@ func (init *Initializer) saveMetadata() error {
func (init *Initializer) loadMetadata() (*Metadata, error) {
return LoadMetadata(init.opts.DataDir)
}

func (init *Initializer) SetLogger(logger Logger) {
init.logger = logger
}
Loading