From ea1a1c94ca04b445fc2e797c2593707a7eb190fb Mon Sep 17 00:00:00 2001 From: Joshua Sing Date: Thu, 9 May 2024 01:59:44 +1000 Subject: [PATCH] bfgd: fix loops unconditionally exited after one interation (SA4004) (#108) Co-authored-by: ClaytonNorthey92 --- Makefile | 6 ++ database/bfgd/postgres/postgres.go | 116 +++++++---------------------- e2e/e2e_ext_test.go | 68 +++++++++++++++++ 3 files changed, 102 insertions(+), 88 deletions(-) diff --git a/Makefile b/Makefile index b0c26cb8..3fb41ec3 100644 --- a/Makefile +++ b/Makefile @@ -66,6 +66,12 @@ lint-deps: GOBIN=$(shell go env GOPATH)/bin go install mvdan.cc/gofumpt@latest GOBIN=$(shell go env GOPATH)/bin go install github.com/google/addlicense@latest +staticcheck: + $(shell go env GOPATH)/bin/staticcheck ./... + +staticcheck-deps: + GOBIN=$(shell go env GOPATH)/bin go install honnef.co/go/tools/cmd/staticcheck@latest + tidy: go mod tidy diff --git a/database/bfgd/postgres/postgres.go b/database/bfgd/postgres/postgres.go index 91552b71..3c363349 100644 --- a/database/bfgd/postgres/postgres.go +++ b/database/bfgd/postgres/postgres.go @@ -83,25 +83,25 @@ func New(ctx context.Context, uri string) (*pgdb, error) { return p, nil } -func (pg *pgdb) Version(ctx context.Context) (int, error) { +func (p *pgdb) Version(ctx context.Context) (int, error) { log.Tracef("Version") defer log.Tracef("Version exit") const selectVersion = `SELECT * FROM version LIMIT 1;` var dbVersion int - if err := pg.db.QueryRowContext(ctx, selectVersion).Scan(&dbVersion); err != nil { + if err := p.db.QueryRowContext(ctx, selectVersion).Scan(&dbVersion); err != nil { return -1, err } return dbVersion, nil } -func (pg *pgdb) L2KeystonesCount(ctx context.Context) (int, error) { +func (p *pgdb) L2KeystonesCount(ctx context.Context) (int, error) { log.Tracef("L2KeystonesCount") defer log.Tracef("L2KeystonesCount exit") const selectCount = `SELECT COUNT(*) FROM l2_keystones;` var count int - if err := pg.db.QueryRowContext(ctx, selectCount).Scan(&count); err != nil { + if err := p.db.QueryRowContext(ctx, selectCount).Scan(&count); err != nil { return 0, err } @@ -890,33 +890,14 @@ func (p *pgdb) BtcBlockCanonicalHeight(ctx context.Context) (uint64, error) { log.Tracef("BtcBlockCanonicalHeight") defer log.Tracef("BtcBlockCanonicalHeight exit") - sql := ` - SELECT COALESCE(MAX(height),0) - FROM btc_blocks_can - ` + const q = `SELECT COALESCE(MAX(height),0) FROM btc_blocks_can LIMIT 1` - rows, err := p.db.QueryContext(ctx, sql) - if err != nil { + var result uint64 + if err := p.db.QueryRowContext(ctx, q).Scan(&result); err != nil { return 0, err } - defer rows.Close() - - for rows.Next() { - var result uint64 - err = rows.Scan(&result) - if err != nil { - return 0, err - } - - return result, nil - } - - if err = rows.Err(); err != nil { - return 0, err - } - - return 0, errors.New("should not get here") + return result, nil } func (p *pgdb) AccessPublicKeyInsert(ctx context.Context, publicKey *bfgd.AccessPublicKey) error { @@ -946,64 +927,40 @@ func (p *pgdb) AccessPublicKeyExists(ctx context.Context, publicKey *bfgd.Access log.Tracef("AccessPublicKeyExists") defer log.Tracef("AccessPublicKeyExists exit") - const sql = ` + const q = ` SELECT EXISTS ( SELECT * FROM access_public_keys WHERE public_key = $1 ) ` - rows, err := p.db.QueryContext(ctx, sql, publicKey.PublicKey) - if err != nil { + var exists bool + if err := p.db.QueryRowContext(ctx, q, publicKey.PublicKey).Scan(&exists); err != nil { return false, err } - defer rows.Close() - - for rows.Next() { - var exists bool - err = rows.Scan(&exists) - if err != nil { - return false, err - } - - return exists, nil - } - - if err = rows.Err(); err != nil { - return false, err - } - - return false, errors.New("should not get here") + return exists, nil } func (p *pgdb) AccessPublicKeyDelete(ctx context.Context, publicKey *bfgd.AccessPublicKey) error { log.Tracef("AccessPublicKeyDelete") - log.Tracef("AccessPublicKeyDelete exit") + defer log.Tracef("AccessPublicKeyDelete exit") - sql := fmt.Sprintf(` - WITH deleted AS ( - DELETE FROM access_public_keys WHERE public_key = $1 - RETURNING * - ) SELECT count(*) FROM deleted; - `) + const q = ` + DELETE FROM access_public_keys WHERE public_key = $1 + ` - rows, err := p.db.QueryContext(ctx, sql, publicKey.PublicKey) + res, err := p.db.ExecContext(ctx, q, publicKey.PublicKey) if err != nil { return err } - for rows.Next() { - var count int - if err := rows.Scan(&count); err != nil { - return err - } - - return database.NotFoundError("public key not found") - } - - if err := rows.Err(); err != nil { + rows, err := res.RowsAffected() + if err != nil { return err } + if rows == 0 { + return database.NotFoundError("public key not found") + } return nil } @@ -1015,40 +972,23 @@ func (p *pgdb) canonicalChainTipL2BlockNumber(ctx context.Context) (*uint32, err log.Tracef("canonicalChainTipL2BlockNumber") defer log.Tracef("canonicalChainTipL2BlockNumber exit") - sql := fmt.Sprintf(` + const q = ` SELECT l2_keystones.l2_block_number - FROM btc_blocks_can INNER JOIN pop_basis ON pop_basis.btc_block_hash = btc_blocks_can.hash INNER JOIN l2_keystones ON l2_keystones.l2_keystone_abrev_hash = pop_basis.l2_keystone_abrev_hash - + ORDER BY l2_block_number DESC LIMIT 1 - `) + ` - rows, err := p.db.QueryContext(ctx, sql) - if err != nil { + var l2BlockNumber uint32 + if err := p.db.QueryRowContext(ctx, q).Scan(&l2BlockNumber); err != nil { return nil, err } - defer rows.Close() - - for rows.Next() { - var l2BlockNumber uint32 - err := rows.Scan(&l2BlockNumber) - if err != nil { - return nil, err - } - - return &l2BlockNumber, nil - } - - if rows.Err() != nil { - return nil, rows.Err() - } - - return nil, nil + return &l2BlockNumber, nil } func (p *pgdb) refreshBTCBlocksCanonical(ctx context.Context) error { diff --git a/e2e/e2e_ext_test.go b/e2e/e2e_ext_test.go index 371e3962..63c2823c 100644 --- a/e2e/e2e_ext_test.go +++ b/e2e/e2e_ext_test.go @@ -3622,6 +3622,74 @@ func TestDeleteAccessPublicKeyThatDoesNotExist(t *testing.T) { } } +func TestDeleteAccessPublicKey(t *testing.T) { + db, pgUri, sdb, cleanup := createTestDB(context.Background(), t) + defer func() { + db.Close() + sdb.Close() + cleanup() + }() + + privateKeyOne, err := dcrsecp256k1.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + + publicKeyOne := hex.EncodeToString(privateKeyOne.PubKey().SerializeCompressed()) + + ctx, cancel := defaultTestContext() + defer cancel() + + _, _, bfgPrivateWsUrl, _ := createBfgServer(ctx, t, pgUri, "", 1) + + c, _, err := websocket.Dial(ctx, bfgPrivateWsUrl, nil) + if err != nil { + t.Fatal(err) + } + defer c.CloseNow() + + bws := &bfgWs{ + conn: protocol.NewWSConn(c), + } + + assertPing(ctx, t, c, bfgapi.CmdPingRequest) + + if err := bfgapi.Write(ctx, bws.conn, "someid", &bfgapi.AccessPublicKeyCreateRequest{ + PublicKey: publicKeyOne, + }); err != nil { + t.Fatal(err) + } + + command, _, _, err := bfgapi.Read(ctx, bws.conn) + if err != nil { + t.Fatal(err) + } + + if command != bfgapi.CmdAccessPublicKeyCreateResponse { + t.Fatalf("unexpected command %s", command) + } + + if err := bfgapi.Write(ctx, bws.conn, "someid", &bfgapi.AccessPublicKeyDeleteRequest{ + PublicKey: publicKeyOne, + }); err != nil { + t.Fatal(err) + } + + command, _, v, err := bfgapi.Read(ctx, bws.conn) + if err != nil { + t.Fatal(err) + } + + if command != bfgapi.CmdAccessPublicKeyDeleteResponse { + t.Fatalf("unexpected command %s", command) + } + + resp := v.(*bfgapi.AccessPublicKeyDeleteResponse) + if resp.Error != nil { + t.Fatalf("unexpected error: %s", resp.Error.Message) + } +} + func createBtcBlock(ctx context.Context, t *testing.T, db bfgd.Database, count int, height int, lastHash []byte, l2BlockNumber uint32) bfgd.BtcBlock { header := make([]byte, 80) hash := make([]byte, 32)