diff --git a/cmd/gossamer/flags.go b/cmd/gossamer/flags.go index e723c4b712..ee401050ed 100644 --- a/cmd/gossamer/flags.go +++ b/cmd/gossamer/flags.go @@ -43,10 +43,10 @@ var ( Name: "roles", Usage: "Roles of the gossamer node", } - // RewindFlag rewinds the head of the chain by the given number of blocks. Useful for development + // RewindFlag rewinds the head of the chain to the given block number. Useful for development RewindFlag = cli.IntFlag{ Name: "rewind", - Usage: "Rewind head of chain by given number of blocks", + Usage: "Rewind head of chain to the given block number", } ) diff --git a/dot/services.go b/dot/services.go index 4d649182c1..3fa1335a8a 100644 --- a/dot/services.go +++ b/dot/services.go @@ -64,7 +64,7 @@ func createStateService(cfg *Config) (*state.Service, error) { } if cfg.State.Rewind != 0 { - err = stateSrvc.Rewind(cfg.State.Rewind) + err = stateSrvc.Rewind(int64(cfg.State.Rewind)) if err != nil { return nil, fmt.Errorf("failed to rewind state: %w", err) } diff --git a/dot/state/service.go b/dot/state/service.go index 77c60a19dd..fb544ab12e 100644 --- a/dot/state/service.go +++ b/dot/state/service.go @@ -19,6 +19,7 @@ package state import ( "bytes" "fmt" + "math/big" "os" "path/filepath" @@ -321,13 +322,22 @@ func (s *Service) Start() error { return nil } -// Rewind rewinds the chain by the given number of blocks. +// Rewind rewinds the chain to the given block number. // If the given number of blocks is greater than the chain height, it will rewind to genesis. -func (s *Service) Rewind(numBlocks int) error { +func (s *Service) Rewind(toBlock int64) error { num, _ := s.Block.BestBlockNumber() + if toBlock > num.Int64() { + return fmt.Errorf("cannot rewind, given height is higher than our current height") + } + + logger.Info("rewinding state...", "current height", num, "desired height", toBlock) + + root, err := s.Block.GetBlockByNumber(big.NewInt(toBlock)) + if err != nil { + return err + } - logger.Info("rewinding state...", "current height", num, "to rewind", numBlocks) - s.Block.bt.Rewind(numBlocks) + s.Block.bt = blocktree.NewBlockTreeFromRoot(root.Header, s.db) newHead := s.Block.BestBlockHash() header, _ := s.Block.BestBlockHeader()