diff --git a/wallet/rescan.go b/wallet/rescan.go index 36e97bb9c..ea1ea28fc 100644 --- a/wallet/rescan.go +++ b/wallet/rescan.go @@ -497,9 +497,13 @@ func (w *Wallet) rescanPoint(dbtx walletdb.ReadTx) (*chainhash.Hash, error) { return &rescanPoint, nil } -// SetBirthState sets the birthday state in the database. +// SetBirthState sets the birthday state in the database. This should be called +// before initial sync has happened. func (w *Wallet) SetBirthState(ctx context.Context, bs *udb.BirthdayState) error { const op errors.Op = "wallet.SetBirthState" + if bs == nil { + return errors.E(op, errors.Invalid, "nil birthday state") + } err := walletdb.Update(ctx, w.db, func(dbtx walletdb.ReadWriteTx) error { return udb.SetBirthState(dbtx, bs) }) @@ -509,7 +513,8 @@ func (w *Wallet) SetBirthState(ctx context.Context, bs *udb.BirthdayState) error return nil } -// BirthState returns the birthday state. +// BirthState returns the birthday state. Will return a nil state if none has +// been set. func (w *Wallet) BirthState(ctx context.Context) (bs *udb.BirthdayState, err error) { const op errors.Op = "wallet.BirthState" err = walletdb.View(ctx, w.db, func(dbtx walletdb.ReadTx) error { @@ -521,3 +526,42 @@ func (w *Wallet) BirthState(ctx context.Context) (bs *udb.BirthdayState, err err } return bs, nil } + +// SetBirthStateAndScan sets the wallet birthstate. This version should be used +// to change the birth state after initial sync has happened. It will find an +// adequate block from time or height. +func (w *Wallet) SetBirthStateAndScan(ctx context.Context, bs *udb.BirthdayState) error { + if bs == nil { + return errors.New("nil birthday state") + } + if !(bs.SetFromTime || bs.SetFromHeight) { + return errors.New("nothing to set") + } + tipHash, _ := w.MainChainTip(ctx) + syncedHeader, err := w.BlockHeader(ctx, &tipHash) + if err != nil { + return err + } + h := syncedHeader + for { + if (bs.SetFromTime && (h.Height == 0 || h.Timestamp.After(bs.Time))) || + bs.SetFromHeight && h.Height <= bs.Height { + bh := h.BlockHash() + height := h.Height + bs.Hash = bh + bs.Height = height + bs.SetFromTime = false + bs.SetFromHeight = false + if err := w.SetBirthState(ctx, bs); err != nil { + return err + } + log.Infof("Set wallet birthday to block %d (%v).", height, bh) + break + } + h, err = w.BlockHeader(ctx, &h.PrevBlock) + if err != nil { + return err + } + } + return nil +} diff --git a/wallet/udb/txmined.go b/wallet/udb/txmined.go index efaf23196..06a42d755 100644 --- a/wallet/udb/txmined.go +++ b/wallet/udb/txmined.go @@ -331,7 +331,8 @@ type BirthdayState struct { SetFromHeight, SetFromTime bool } -// SetBirthState sets the birthday state in the database. +// SetBirthState sets the birthday state in the database. *BirthdayState must +// not be nil. // // [0:1] Options (1 byte) // [1:33] Birthblock block header hash (32 bytes) @@ -361,7 +362,8 @@ func SetBirthState(dbtx walletdb.ReadWriteTx, bs *BirthdayState) error { return ns.Put(rootBirthState, v) } -// BirthState returns the current birthday state. +// BirthState returns the current birthday state. Will return nil if none has +// been set. func BirthState(dbtx walletdb.ReadTx) *BirthdayState { ns := dbtx.ReadBucket(wtxmgrBucketKey) const ( diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go index 017a7ae48..dff46d861 100644 --- a/wallet/wallet_test.go +++ b/wallet/wallet_test.go @@ -5,11 +5,16 @@ package wallet import ( + "context" "encoding/hex" + "fmt" "math" "testing" + "time" "decred.org/dcrwallet/v5/errors" + "decred.org/dcrwallet/v5/wallet/udb" + "github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/dcrd/chaincfg/v3" ) @@ -138,3 +143,91 @@ func TestVotingXprivFromSeed(t *testing.T) { } } } + +func TestSetBirthStateAndScan(t *testing.T) { + t.Parallel() + ctx := context.Background() + + cfg := basicWalletConfig + w, teardown := testWallet(ctx, t, &cfg, nil) + defer teardown() + + tg := maketg(t, cfg.Params) + tw := &tw{t, w} + forest := new(SidechainForest) + + for i := 1; i < 10; i++ { + name := fmt.Sprintf("%va", i) + b := tg.nextBlock(name, nil, nil) + mustAddBlockNode(t, forest, b.BlockNode) + t.Logf("Generated block %v name %q", b.Hash, name) + } + b9aHash := tg.blockHashByName("9a") + bestChain := tw.evaluateBestChain(ctx, forest, 9, b9aHash) + tw.chainSwitch(ctx, forest, bestChain) + tw.assertNoBetterChain(ctx, forest) + + tests := []struct { + name string + bs *udb.BirthdayState + wantBHash *chainhash.Hash + wantErr bool + }{{ + name: "ok middle", + bs: &udb.BirthdayState{ + SetFromHeight: true, + Height: 6, + }, + wantBHash: tg.blockHashByName("6a"), + }, { + name: "ok past tip", + bs: &udb.BirthdayState{ + SetFromHeight: true, + Height: 20, + }, + wantBHash: b9aHash, + }, { + name: "ok genesis", + bs: &udb.BirthdayState{ + SetFromHeight: true, + Height: 0, + }, + wantBHash: &cfg.Params.GenesisHash, + }, { + name: "ok from time", + bs: &udb.BirthdayState{ + SetFromTime: true, + Time: time.Now(), + }, + wantBHash: b9aHash, + }, { + name: "nil birthday passed", + wantErr: true, + }, { + name: "nothing to set", + bs: new(udb.BirthdayState), + wantErr: true, + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := w.SetBirthStateAndScan(ctx, test.bs) + if test.wantErr { + if err == nil { + t.Fatalf("expected error: %v", err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + bs, err := w.BirthState(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if bs.Hash != *test.wantBHash { + t.Fatalf("wanted birthday hash %v but got %v", test.wantBHash, bs.Hash) + } + }) + } +}