diff --git a/blockchain/standalone/error.go b/blockchain/standalone/error.go index 4b2929c4ed..0ebec05923 100644 --- a/blockchain/standalone/error.go +++ b/blockchain/standalone/error.go @@ -4,58 +4,42 @@ package standalone -import ( - "fmt" -) - -// ErrorCode identifies a kind of error. -type ErrorCode int +// ErrorKind identifies a kind of error. It has full support for errors.Is and +// errors.As, so the caller can directly check against an error kind when +// determining the reason for an error. +type ErrorKind string // These constants are used to identify a specific RuleError. const ( // ErrUnexpectedDifficulty indicates specified bits do not align with // the expected value either because it doesn't match the calculated // value based on difficulty rules or it is out of the valid range. - ErrUnexpectedDifficulty ErrorCode = iota + ErrUnexpectedDifficulty = ErrorKind("ErrUnexpectedDifficulty") // ErrHighHash indicates the block does not hash to a value which is // lower than the required target difficultly. - ErrHighHash + ErrHighHash = ErrorKind("ErrHighHash") // ErrTSpendStartInvalidExpiry indicates that an invalid expiry was // provided to calculate the start of a treasury spend vote window. - ErrTSpendStartInvalidExpiry + ErrTSpendStartInvalidExpiry = ErrorKind("ErrTSpendStartInvalidExpiry") // ErrTSpendEndInvalidExpiry indicates that an invalid expiry was // provided to calculate the end of a treasury spend vote window. - ErrTSpendEndInvalidExpiry - - // numErrorCodes is the maximum error code number used in tests. - numErrorCodes + ErrTSpendEndInvalidExpiry = ErrorKind("ErrTSpendEndInvalidExpiry") ) -// Map of ErrorCode values back to their constant names for pretty printing. -var errorCodeStrings = map[ErrorCode]string{ - ErrUnexpectedDifficulty: "ErrUnexpectedDifficulty", - ErrHighHash: "ErrHighHash", - ErrTSpendStartInvalidExpiry: "ErrTSpendStartInvalidExpiry", - ErrTSpendEndInvalidExpiry: "ErrTSpendEndInvalidExpiry", -} - -// String returns the ErrorCode as a human-readable name. -func (e ErrorCode) String() string { - if s := errorCodeStrings[e]; s != "" { - return s - } - return fmt.Sprintf("Unknown ErrorCode (%d)", int(e)) +// Error satisfies the error interface and prints human-readable errors. +func (e ErrorKind) Error() string { + return string(e) } -// RuleError identifies a rule violation. The caller can use type assertions to -// determine if a failure was specifically due to a rule violation and access -// the ErrorCode field to ascertain the specific reason for the rule violation. +// RuleError identifies a rule violation. It has full support for errors.Is +// and errors.As, so the caller can ascertain the specific reason for the +// error by checking the underlying error. type RuleError struct { - ErrorCode ErrorCode // Describes the kind of error - Description string // Human readable description of the issue + Description string + Err error } // Error satisfies the error interface and prints human-readable errors. @@ -63,14 +47,12 @@ func (e RuleError) Error() string { return e.Description } -// ruleError creates an RuleError given a set of arguments. -func ruleError(c ErrorCode, desc string) RuleError { - return RuleError{ErrorCode: c, Description: desc} +// Unwrap returns the underlying wrapped error. +func (e RuleError) Unwrap() error { + return e.Err } -// IsErrorCode returns whether or not the provided error is a rule error with -// the provided error code. -func IsErrorCode(err error, c ErrorCode) bool { - e, ok := err.(RuleError) - return ok && e.ErrorCode == c +// ruleError creates a RuleError given a set of arguments. +func ruleError(kind ErrorKind, desc string) RuleError { + return RuleError{Err: kind, Description: desc} } diff --git a/blockchain/standalone/error_test.go b/blockchain/standalone/error_test.go index d2448f9ccf..1ed1598a3a 100644 --- a/blockchain/standalone/error_test.go +++ b/blockchain/standalone/error_test.go @@ -5,33 +5,28 @@ package standalone import ( + "errors" + "io" "testing" ) -// TestErrorCodeStringer tests the stringized output for the ErrorCode type. -func TestErrorCodeStringer(t *testing.T) { +// TestErrorKindStringer tests the stringized output for the ErrorKind type. +func TestErrorKindStringer(t *testing.T) { tests := []struct { - in ErrorCode + in ErrorKind want string }{ {ErrUnexpectedDifficulty, "ErrUnexpectedDifficulty"}, {ErrHighHash, "ErrHighHash"}, {ErrTSpendStartInvalidExpiry, "ErrTSpendStartInvalidExpiry"}, {ErrTSpendEndInvalidExpiry, "ErrTSpendEndInvalidExpiry"}, - {0xffff, "Unknown ErrorCode (65535)"}, - } - - // Detect additional error codes that don't have the stringer added. - if len(tests)-1 != int(numErrorCodes) { - t.Errorf("It appears an error code was added without adding an " + - "associated stringer test") } t.Logf("Running %d tests", len(tests)) for i, test := range tests { - result := test.in.String() + result := test.in.Error() if result != test.want { - t.Errorf("String #%d\n got: %s want: %s", i, result, test.want) + t.Errorf("%d: got: %s want: %s", i, result, test.want) continue } } @@ -43,8 +38,8 @@ func TestRuleError(t *testing.T) { in RuleError want string }{{ - RuleError{Description: "duplicate block"}, - "duplicate block", + RuleError{Description: "unexpected difficulty"}, + "unexpected difficulty", }, { RuleError{Description: "human-readable error"}, "human-readable error", @@ -55,50 +50,91 @@ func TestRuleError(t *testing.T) { for i, test := range tests { result := test.in.Error() if result != test.want { - t.Errorf("Error #%d\n got: %s want: %s", i, result, test.want) + t.Errorf("%d: got: %s want: %s", i, result, test.want) continue } } } -// TestIsErrorCode ensures IsErrorCode works as intended. -func TestIsErrorCode(t *testing.T) { +// TestRuleErrorKindIsAs ensures both ErrorKind and RuleError can be +// identified as being a specific error kind via errors.Is and unwrapped +// via errors.As. +func TestRuleErrorKindIsAs(t *testing.T) { tests := []struct { - name string - err error - code ErrorCode - want bool + name string + err error + target error + wantMatch bool + wantAs ErrorKind }{{ - name: "ErrUnexpectedDifficulty testing for ErrUnexpectedDifficulty", - err: ruleError(ErrUnexpectedDifficulty, ""), - code: ErrUnexpectedDifficulty, - want: true, + name: "ErrUnexpectedDifficulty == ErrUnexpectedDifficulty", + err: ErrUnexpectedDifficulty, + target: ErrUnexpectedDifficulty, + wantMatch: true, + wantAs: ErrUnexpectedDifficulty, + }, { + name: "RuleError.ErrUnexpectedDifficulty == ErrUnexpectedDifficulty", + err: ruleError(ErrUnexpectedDifficulty, ""), + target: ErrUnexpectedDifficulty, + wantMatch: true, + wantAs: ErrUnexpectedDifficulty, }, { - name: "ErrHighHash testing for ErrHighHash", - err: ruleError(ErrHighHash, ""), - code: ErrHighHash, - want: true, + name: "RuleError.ErrUnexpectedDifficulty == RuleError.ErrUnexpectedDifficulty", + err: ruleError(ErrUnexpectedDifficulty, ""), + target: ruleError(ErrUnexpectedDifficulty, ""), + wantMatch: true, + wantAs: ErrUnexpectedDifficulty, }, { - name: "ErrHighHash error testing for ErrUnexpectedDifficulty", - err: ruleError(ErrHighHash, ""), - code: ErrUnexpectedDifficulty, - want: false, + name: "ErrUnexpectedDifficulty != ErrHighHash", + err: ErrUnexpectedDifficulty, + target: ErrHighHash, + wantMatch: false, + wantAs: ErrUnexpectedDifficulty, }, { - name: "ErrHighHash error testing for unknown error code", - err: ruleError(ErrHighHash, ""), - code: 0xffff, - want: false, + name: "RuleError.ErrUnexpectedDifficulty != ErrHighHash", + err: ruleError(ErrUnexpectedDifficulty, ""), + target: ErrHighHash, + wantMatch: false, + wantAs: ErrUnexpectedDifficulty, }, { - name: "nil error testing for ErrUnexpectedDifficulty", - err: nil, - code: ErrUnexpectedDifficulty, - want: false, + name: "ErrUnexpectedDifficulty != RuleError.ErrHighHash", + err: ErrUnexpectedDifficulty, + target: ruleError(ErrHighHash, ""), + wantMatch: false, + wantAs: ErrUnexpectedDifficulty, + }, { + name: "RuleError.ErrUnexpectedDifficulty != RuleError.ErrHighHash", + err: ruleError(ErrUnexpectedDifficulty, ""), + target: ruleError(ErrHighHash, ""), + wantMatch: false, + wantAs: ErrUnexpectedDifficulty, + }, { + name: "RuleError.ErrUnexpectedDifficulty != io.EOF", + err: ruleError(ErrUnexpectedDifficulty, ""), + target: io.EOF, + wantMatch: false, + wantAs: ErrUnexpectedDifficulty, }} + for _, test := range tests { - result := IsErrorCode(test.err, test.code) - if result != test.want { - t.Errorf("%s: unexpected result -- got: %v want: %v", test.name, - result, test.want) + // Ensure the error matches or not depending on the expected result. + result := errors.Is(test.err, test.target) + if result != test.wantMatch { + t.Errorf("%s: incorrect error identification -- got %v, want %v", + test.name, result, test.wantMatch) + continue + } + + // Ensure the underlying error kind can be unwrapped is and is the + // expected kind. + var kind ErrorKind + if !errors.As(test.err, &kind) { + t.Errorf("%s: unable to unwrap to error kind", test.name) + continue + } + if kind != test.wantAs { + t.Errorf("%s: unexpected unwrapped error kind -- got %v, want %v", + test.name, kind, test.wantAs) continue } } diff --git a/blockchain/standalone/pow_test.go b/blockchain/standalone/pow_test.go index cdc73de13a..fa9c55f7f1 100644 --- a/blockchain/standalone/pow_test.go +++ b/blockchain/standalone/pow_test.go @@ -5,6 +5,7 @@ package standalone import ( + "errors" "math/big" "testing" @@ -248,17 +249,17 @@ func TestCheckProofOfWorkRange(t *testing.T) { name: "zero", bits: 0, powLimit: mockMainNetPowLimit(), - err: ruleError(ErrUnexpectedDifficulty, ""), + err: ErrUnexpectedDifficulty, }, { name: "negative", bits: 0x1810000, powLimit: mockMainNetPowLimit(), - err: ruleError(ErrUnexpectedDifficulty, ""), + err: ErrUnexpectedDifficulty, }, { name: "pow limit + 1", bits: 0x1d010000, powLimit: mockMainNetPowLimit(), - err: ruleError(ErrUnexpectedDifficulty, ""), + err: ErrUnexpectedDifficulty, }} for _, test := range tests { @@ -269,15 +270,9 @@ func TestCheckProofOfWorkRange(t *testing.T) { } err := CheckProofOfWorkRange(test.bits, powLimit) - if test.err == nil && err != nil { - t.Errorf("%q: unexpected err -- got %v, want nil", test.name, err) - continue - } else if test.err != nil { - if !IsErrorCode(err, test.err.(RuleError).ErrorCode) { - t.Errorf("%q: unexpected err -- got %v, want %v", - test.name, err, test.err.(RuleError).ErrorCode) - continue - } + if !errors.Is(err, test.err) { + t.Errorf("%q: unexpected err -- got %v, want %v", test.name, err, + test.err) continue } } @@ -316,25 +311,25 @@ func TestCheckProofOfWork(t *testing.T) { hash: "000000000001ffff000000000000000000000000000000000000000000000001", bits: 0x1b01ffff, powLimit: mockMainNetPowLimit(), - err: ruleError(ErrHighHash, ""), + err: ErrHighHash, }, { name: "hash satisfies target, but target too high at pow limit + 1", hash: "0000000000000000000000000000000000000000000000000000000000000001", bits: 0x1d010000, powLimit: mockMainNetPowLimit(), - err: ruleError(ErrUnexpectedDifficulty, ""), + err: ErrUnexpectedDifficulty, }, { name: "zero target difficulty", hash: "0000000000000000000000000000000000000000000000000000000000000001", bits: 0, powLimit: mockMainNetPowLimit(), - err: ruleError(ErrUnexpectedDifficulty, ""), + err: ErrUnexpectedDifficulty, }, { name: "negative target difficulty", hash: "0000000000000000000000000000000000000000000000000000000000000001", bits: 0x1810000, powLimit: mockMainNetPowLimit(), - err: ruleError(ErrUnexpectedDifficulty, ""), + err: ErrUnexpectedDifficulty, }} for _, test := range tests { @@ -351,15 +346,9 @@ func TestCheckProofOfWork(t *testing.T) { } err = CheckProofOfWork(hash, test.bits, powLimit) - if test.err == nil && err != nil { - t.Errorf("%q: unexpected err -- got %v, want nil", test.name, err) - continue - } else if test.err != nil { - if !IsErrorCode(err, test.err.(RuleError).ErrorCode) { - t.Errorf("%q: unexpected err -- got %v, want %v", - test.name, err, test.err.(RuleError).ErrorCode) - continue - } + if !errors.Is(err, test.err) { + t.Errorf("%q: unexpected err -- got %v, want %v", test.name, err, + test.err) continue } } diff --git a/blockchain/standalone/treasury.go b/blockchain/standalone/treasury.go index 5cc55b5912..70bcd153e2 100644 --- a/blockchain/standalone/treasury.go +++ b/blockchain/standalone/treasury.go @@ -29,8 +29,8 @@ func IsTreasuryVoteInterval(height, tvi uint64) bool { // this function is only called with an expiry that *IS* on a TVI. func CalculateTSpendWindowStart(expiry uint32, tvi, multiplier uint64) (uint32, error) { if !IsTreasuryVoteInterval(uint64(expiry-2), tvi) { - return 0, RuleError{ErrTSpendStartInvalidExpiry, - fmt.Sprintf("invalid start expiry: %v", expiry)} + return 0, ruleError(ErrTSpendStartInvalidExpiry, + fmt.Sprintf("invalid start expiry: %v", expiry)) } return expiry - uint32(tvi*multiplier) - 2, nil } @@ -40,8 +40,8 @@ func CalculateTSpendWindowStart(expiry uint32, tvi, multiplier uint64) (uint32, // this function is only called with an expiry that *IS* on a TVI. func CalculateTSpendWindowEnd(expiry uint32, tvi uint64) (uint32, error) { if !IsTreasuryVoteInterval(uint64(expiry-2), tvi) { - return 0, RuleError{ErrTSpendEndInvalidExpiry, - fmt.Sprintf("invalid end expiry: %v", expiry)} + return 0, ruleError(ErrTSpendEndInvalidExpiry, + fmt.Sprintf("invalid end expiry: %v", expiry)) } return expiry - 2, nil } diff --git a/blockchain/standalone/treasury_test.go b/blockchain/standalone/treasury_test.go index 389ce8df6b..9f727cf9aa 100644 --- a/blockchain/standalone/treasury_test.go +++ b/blockchain/standalone/treasury_test.go @@ -4,25 +4,24 @@ package standalone -import "testing" +import ( + "errors" + "testing" +) func TestTSpendExpiryNegative(t *testing.T) { // 5 is not a valid start for a tvi of 11 with a mul of 3. _, err := CalculateTSpendWindowStart(5, 11, 3) - if e, ok := err.(RuleError); !ok { - t.Fatal(err) - } else if e.ErrorCode != ErrTSpendStartInvalidExpiry { + if !errors.Is(err, ErrTSpendStartInvalidExpiry) { t.Fatalf("expected %v got %v", - ErrTSpendStartInvalidExpiry, e.ErrorCode) + ErrTSpendStartInvalidExpiry, err) } // 5 is not a valid end for a tvi of 11. _, err = CalculateTSpendWindowEnd(5, 11) - if e, ok := err.(RuleError); !ok { - t.Fatal(err) - } else if e.ErrorCode != ErrTSpendEndInvalidExpiry { + if !errors.Is(err, ErrTSpendEndInvalidExpiry) { t.Fatalf("expected %v got %v", - ErrTSpendEndInvalidExpiry, e.ErrorCode) + ErrTSpendEndInvalidExpiry, err) } } diff --git a/blockchain/validate.go b/blockchain/validate.go index 31c8b781cd..00c45a723e 100644 --- a/blockchain/validate.go +++ b/blockchain/validate.go @@ -529,15 +529,15 @@ func CheckProofOfStake(block *dcrutil.Block, posLimit int64) error { } // standaloneToChainRuleError attempts to convert the passed error from a -// standalone.RuleError to a blockchain.RuleError with the equivalent code. The -// error is simply passed through without modification if it is not a -// standalone.RuleError, not one of the specifically recognized error codes, or -// nil. +// standalone.RuleError to a blockchain.RuleError with the equivalent error +// kind. The error is simply passed through without modification if it is +// not a standalone.RuleError, not one of the specifically recognized +// error codes, or nil. func standaloneToChainRuleError(err error) error { // Convert standalone package rule errors to blockchain rule errors. var rErr standalone.RuleError if errors.As(err, &rErr) { - switch rErr.ErrorCode { + switch rErr.Err { case standalone.ErrUnexpectedDifficulty: return ruleError(ErrUnexpectedDifficulty, rErr.Description) case standalone.ErrHighHash: