Skip to content

Commit

Permalink
bfgd: fix loops unconditionally exited after one interation (SA4004) (#…
Browse files Browse the repository at this point in the history
…108)

Co-authored-by: ClaytonNorthey92 <clayton.northey@gmail.com>
  • Loading branch information
joshuasing and ClaytonNorthey92 authored May 8, 2024
1 parent 29f116f commit ea1a1c9
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 88 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ lint-deps:
GOBIN=$(shell go env GOPATH)/bin go install mvdan.cc/gofumpt@latest
GOBIN=$(shell go env GOPATH)/bin go install github.com/google/addlicense@latest

staticcheck:
$(shell go env GOPATH)/bin/staticcheck ./...

staticcheck-deps:
GOBIN=$(shell go env GOPATH)/bin go install honnef.co/go/tools/cmd/staticcheck@latest

tidy:
go mod tidy

Expand Down
116 changes: 28 additions & 88 deletions database/bfgd/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,25 @@ func New(ctx context.Context, uri string) (*pgdb, error) {
return p, nil
}

func (pg *pgdb) Version(ctx context.Context) (int, error) {
func (p *pgdb) Version(ctx context.Context) (int, error) {
log.Tracef("Version")
defer log.Tracef("Version exit")

const selectVersion = `SELECT * FROM version LIMIT 1;`
var dbVersion int
if err := pg.db.QueryRowContext(ctx, selectVersion).Scan(&dbVersion); err != nil {
if err := p.db.QueryRowContext(ctx, selectVersion).Scan(&dbVersion); err != nil {
return -1, err
}
return dbVersion, nil
}

func (pg *pgdb) L2KeystonesCount(ctx context.Context) (int, error) {
func (p *pgdb) L2KeystonesCount(ctx context.Context) (int, error) {
log.Tracef("L2KeystonesCount")
defer log.Tracef("L2KeystonesCount exit")

const selectCount = `SELECT COUNT(*) FROM l2_keystones;`
var count int
if err := pg.db.QueryRowContext(ctx, selectCount).Scan(&count); err != nil {
if err := p.db.QueryRowContext(ctx, selectCount).Scan(&count); err != nil {
return 0, err
}

Expand Down Expand Up @@ -890,33 +890,14 @@ func (p *pgdb) BtcBlockCanonicalHeight(ctx context.Context) (uint64, error) {
log.Tracef("BtcBlockCanonicalHeight")
defer log.Tracef("BtcBlockCanonicalHeight exit")

sql := `
SELECT COALESCE(MAX(height),0)
FROM btc_blocks_can
`
const q = `SELECT COALESCE(MAX(height),0) FROM btc_blocks_can LIMIT 1`

rows, err := p.db.QueryContext(ctx, sql)
if err != nil {
var result uint64
if err := p.db.QueryRowContext(ctx, q).Scan(&result); err != nil {
return 0, err
}

defer rows.Close()

for rows.Next() {
var result uint64
err = rows.Scan(&result)
if err != nil {
return 0, err
}

return result, nil
}

if err = rows.Err(); err != nil {
return 0, err
}

return 0, errors.New("should not get here")
return result, nil
}

func (p *pgdb) AccessPublicKeyInsert(ctx context.Context, publicKey *bfgd.AccessPublicKey) error {
Expand Down Expand Up @@ -946,64 +927,40 @@ func (p *pgdb) AccessPublicKeyExists(ctx context.Context, publicKey *bfgd.Access
log.Tracef("AccessPublicKeyExists")
defer log.Tracef("AccessPublicKeyExists exit")

const sql = `
const q = `
SELECT EXISTS (
SELECT * FROM access_public_keys WHERE public_key = $1
)
`

rows, err := p.db.QueryContext(ctx, sql, publicKey.PublicKey)
if err != nil {
var exists bool
if err := p.db.QueryRowContext(ctx, q, publicKey.PublicKey).Scan(&exists); err != nil {
return false, err
}

defer rows.Close()

for rows.Next() {
var exists bool
err = rows.Scan(&exists)
if err != nil {
return false, err
}

return exists, nil
}

if err = rows.Err(); err != nil {
return false, err
}

return false, errors.New("should not get here")
return exists, nil
}

func (p *pgdb) AccessPublicKeyDelete(ctx context.Context, publicKey *bfgd.AccessPublicKey) error {
log.Tracef("AccessPublicKeyDelete")
log.Tracef("AccessPublicKeyDelete exit")
defer log.Tracef("AccessPublicKeyDelete exit")

sql := fmt.Sprintf(`
WITH deleted AS (
DELETE FROM access_public_keys WHERE public_key = $1
RETURNING *
) SELECT count(*) FROM deleted;
`)
const q = `
DELETE FROM access_public_keys WHERE public_key = $1
`

rows, err := p.db.QueryContext(ctx, sql, publicKey.PublicKey)
res, err := p.db.ExecContext(ctx, q, publicKey.PublicKey)
if err != nil {
return err
}

for rows.Next() {
var count int
if err := rows.Scan(&count); err != nil {
return err
}

return database.NotFoundError("public key not found")
}

if err := rows.Err(); err != nil {
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return database.NotFoundError("public key not found")
}

return nil
}
Expand All @@ -1015,40 +972,23 @@ func (p *pgdb) canonicalChainTipL2BlockNumber(ctx context.Context) (*uint32, err
log.Tracef("canonicalChainTipL2BlockNumber")
defer log.Tracef("canonicalChainTipL2BlockNumber exit")

sql := fmt.Sprintf(`
const q = `
SELECT l2_keystones.l2_block_number
FROM btc_blocks_can
INNER JOIN pop_basis ON pop_basis.btc_block_hash = btc_blocks_can.hash
INNER JOIN l2_keystones ON l2_keystones.l2_keystone_abrev_hash
= pop_basis.l2_keystone_abrev_hash
ORDER BY l2_block_number DESC LIMIT 1
`)
`

rows, err := p.db.QueryContext(ctx, sql)
if err != nil {
var l2BlockNumber uint32
if err := p.db.QueryRowContext(ctx, q).Scan(&l2BlockNumber); err != nil {
return nil, err
}

defer rows.Close()

for rows.Next() {
var l2BlockNumber uint32
err := rows.Scan(&l2BlockNumber)
if err != nil {
return nil, err
}

return &l2BlockNumber, nil
}

if rows.Err() != nil {
return nil, rows.Err()
}

return nil, nil
return &l2BlockNumber, nil
}

func (p *pgdb) refreshBTCBlocksCanonical(ctx context.Context) error {
Expand Down
68 changes: 68 additions & 0 deletions e2e/e2e_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3622,6 +3622,74 @@ func TestDeleteAccessPublicKeyThatDoesNotExist(t *testing.T) {
}
}

func TestDeleteAccessPublicKey(t *testing.T) {
db, pgUri, sdb, cleanup := createTestDB(context.Background(), t)
defer func() {
db.Close()
sdb.Close()
cleanup()
}()

privateKeyOne, err := dcrsecp256k1.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}

publicKeyOne := hex.EncodeToString(privateKeyOne.PubKey().SerializeCompressed())

ctx, cancel := defaultTestContext()
defer cancel()

_, _, bfgPrivateWsUrl, _ := createBfgServer(ctx, t, pgUri, "", 1)

c, _, err := websocket.Dial(ctx, bfgPrivateWsUrl, nil)
if err != nil {
t.Fatal(err)
}
defer c.CloseNow()

bws := &bfgWs{
conn: protocol.NewWSConn(c),
}

assertPing(ctx, t, c, bfgapi.CmdPingRequest)

if err := bfgapi.Write(ctx, bws.conn, "someid", &bfgapi.AccessPublicKeyCreateRequest{
PublicKey: publicKeyOne,
}); err != nil {
t.Fatal(err)
}

command, _, _, err := bfgapi.Read(ctx, bws.conn)
if err != nil {
t.Fatal(err)
}

if command != bfgapi.CmdAccessPublicKeyCreateResponse {
t.Fatalf("unexpected command %s", command)
}

if err := bfgapi.Write(ctx, bws.conn, "someid", &bfgapi.AccessPublicKeyDeleteRequest{
PublicKey: publicKeyOne,
}); err != nil {
t.Fatal(err)
}

command, _, v, err := bfgapi.Read(ctx, bws.conn)
if err != nil {
t.Fatal(err)
}

if command != bfgapi.CmdAccessPublicKeyDeleteResponse {
t.Fatalf("unexpected command %s", command)
}

resp := v.(*bfgapi.AccessPublicKeyDeleteResponse)
if resp.Error != nil {
t.Fatalf("unexpected error: %s", resp.Error.Message)
}
}

func createBtcBlock(ctx context.Context, t *testing.T, db bfgd.Database, count int, height int, lastHash []byte, l2BlockNumber uint32) bfgd.BtcBlock {
header := make([]byte, 80)
hash := make([]byte, 32)
Expand Down

0 comments on commit ea1a1c9

Please sign in to comment.