Skip to content

Commit

Permalink
admin: use SA to get metadata before blocking key (#7377)
Browse files Browse the repository at this point in the history
Use two existing SA methods, KeyBlocked and GetSerialsByKey, to replace
the direct database access previously used by the blockSPKIHash method.
This is less efficient than before -- it now streams the whole set of
affected serials rather than just counting them -- but doing so prevents
us from needing an additional SA method just for counting.

Also update the default mock StorageAuthority and
StorageAuthorityReadOnly provided by the mocks package to return actual
stream objects (which stream zero results) instead of nil, so that tests
can attempt to read from the resulting stream without getting a nil
pointer exception.

Part of #7350
  • Loading branch information
aarongable authored Mar 11, 2024
1 parent 7e5c1ca commit ffef10a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 36 deletions.
26 changes: 19 additions & 7 deletions cmd/admin/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"flag"
"fmt"
"io"
"os/user"

"google.golang.org/protobuf/types/known/timestamppb"
Expand Down Expand Up @@ -53,21 +54,32 @@ func (a *admin) spkiHashFromPrivateKey(keyFile string) ([]byte, error) {
}

func (a *admin) blockSPKIHash(ctx context.Context, spkiHash []byte, comment string) error {
var exists bool
err := a.dbMap.SelectOne(ctx, &exists, "SELECT EXISTS(SELECT 1 FROM blockedKeys WHERE keyHash = ? LIMIT 1)", spkiHash[:])
exists, err := a.saroc.KeyBlocked(ctx, &sapb.SPKIHash{KeyHash: spkiHash})
if err != nil {
return fmt.Errorf("checking if key is already blocked: %w", err)
}
if exists {
if exists.Exists {
return errors.New("the provided key already exists in the 'blockedKeys' table")
}

var count int
err = a.dbMap.SelectOne(ctx, &count, "SELECT COUNT(*) as count FROM keyHashToSerial WHERE keyHash = ? AND certNotAfter > NOW()", spkiHash[:])
stream, err := a.saroc.GetSerialsByKey(ctx, &sapb.SPKIHash{KeyHash: spkiHash})
if err != nil {
return fmt.Errorf("counting affected certificates: %w", err)
return fmt.Errorf("setting up stream of serials from SA: %s", err)
}
a.log.Infof("Found %d certificates matching the provided key", count)

var count int
for {
_, err := stream.Recv()
if err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("streaming serials from SA: %s", err)
}
count++
}

a.log.Infof("Found %d unexpired certificates matching the provided key", count)

u, err := user.Current()
if err != nil {
Expand Down
25 changes: 4 additions & 21 deletions cmd/admin/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ import (

"github.com/letsencrypt/boulder/core"
blog "github.com/letsencrypt/boulder/log"
"github.com/letsencrypt/boulder/sa"
"github.com/letsencrypt/boulder/mocks"
sapb "github.com/letsencrypt/boulder/sa/proto"
"github.com/letsencrypt/boulder/test"
"github.com/letsencrypt/boulder/test/vars"
)

func TestSPKIHashFromPrivateKey(t *testing.T) {
Expand Down Expand Up @@ -73,30 +72,15 @@ func TestBlockSPKIHash(t *testing.T) {
keyHash, err := core.KeyDigest(privKey.Public())
test.AssertNotError(t, err, "computing test SPKI hash")

dbMap, err := sa.DBMapForTest(vars.DBConnSA)
test.AssertNotError(t, err, "creating test dbMap")
defer test.ResetBoulderTestDatabase(t)

for _, serial := range []string{"foo", "bar", "baz"} {
_, err = dbMap.ExecContext(
context.Background(),
"INSERT INTO keyHashToSerial(keyHash, certNotAfter, certSerial) VALUES (?, ?, ?)",
keyHash[:],
fc.Now().Add(24*time.Hour),
serial,
)
test.AssertNotError(t, err, "inserting fake serial into test db")
}

a := admin{sac: &msa, dbMap: dbMap, clk: fc, log: log}
a := admin{saroc: &mocks.StorageAuthorityReadOnly{}, sac: &msa, clk: fc, log: log}

// A full run should result in one request with the right fields.
msa.reset()
log.Clear()
a.dryRun = false
err = a.blockSPKIHash(context.Background(), keyHash[:], "hello world")
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(log.GetAllMatching("Found 3 certificates")), 1)
test.AssertEquals(t, len(log.GetAllMatching("Found 0 unexpired certificates")), 1)
test.AssertEquals(t, len(msa.blockRequests), 1)
test.AssertByteEquals(t, msa.blockRequests[0].KeyHash, keyHash[:])
test.AssertContains(t, msa.blockRequests[0].Comment, "hello world")
Expand All @@ -108,8 +92,7 @@ func TestBlockSPKIHash(t *testing.T) {
a.sac = dryRunSAC{log: log}
err = a.blockSPKIHash(context.Background(), keyHash[:], "")
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(log.GetAllMatching("Found 3 certificates")), 1)
test.AssertEquals(t, len(log.GetAllMatching("Found 0 unexpired certificates")), 1)
test.AssertEquals(t, len(log.GetAllMatching("dry-run:")), 1)
test.AssertEquals(t, len(msa.blockRequests), 0)

}
29 changes: 21 additions & 8 deletions mocks/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/pem"
"errors"
"fmt"
"io"
"math/rand"
"net"
"os"
Expand Down Expand Up @@ -53,6 +54,18 @@ func NewStorageAuthority(clk clock.Clock) *StorageAuthority {
return &StorageAuthority{StorageAuthorityReadOnly{clk}}
}

// serverStreamClient is a mock which satisfies the grpc.ClientStream interface,
// allowing it to be returned by methods where the server returns a stream of
// results. This simple mock will always return zero results.
type serverStreamClient[T any] struct {
grpc.ClientStream
}

// Recv immediately returns the EOF error, indicating that the stream is done.
func (c *serverStreamClient[T]) Recv() (*T, error) {
return nil, io.EOF
}

const (
test1KeyPublicJSON = `{"kty":"RSA","n":"yNWVhtYEKJR21y9xsHV-PD_bYwbXSeNuFal46xYxVfRL5mqha7vttvjB_vc7Xg2RvgCxHPCqoxgMPTzHrZT75LjCwIW2K_klBYN8oYvTwwmeSkAz6ut7ZxPv-nZaT5TJhGk0NT2kh_zSpdriEJ_3vW-mqxYbbBmpvHqsa1_zx9fSuHYctAZJWzxzUZXykbWMWQZpEiE0J4ajj51fInEzVn7VxV-mzfMyboQjujPh7aNJxAWSq4oQEJJDgWwSh9leyoJoPpONHxh5nEE5AjE01FkGICSxjpZsF-w8hOTI3XXohUdu29Se26k2B0PolDSuj0GIQU6-W9TdLXSjBb2SpQ","e":"AQAB"}`
test2KeyPublicJSON = `{"kty":"RSA","n":"qnARLrT7Xz4gRcKyLdydmCr-ey9OuPImX4X40thk3on26FkMznR3fRjs66eLK7mmPcBZ6uOJseURU6wAaZNmemoYx1dMvqvWWIyiQleHSD7Q8vBrhR6uIoO4jAzJZR-ChzZuSDt7iHN-3xUVspu5XGwXU_MVJZshTwp4TaFx5elHIT_ObnTvTOU3Xhish07AbgZKmWsVbXh5s-CrIicU4OexJPgunWZ_YJJueOKmTvnLlTV4MzKR2oZlBKZ27S0-SfdV_QDx_ydle5oMAyKVtlAV35cyPMIsYNwgUGBCdY_2Uzi5eX0lTc7MPRwz6qR1kip-i59VcGcUQgqHV6Fyqw","e":"AQAB"}`
Expand Down Expand Up @@ -270,22 +283,22 @@ func (sa *StorageAuthorityReadOnly) GetRevocationStatus(_ context.Context, req *

// SerialsForIncident is a mock
func (sa *StorageAuthorityReadOnly) SerialsForIncident(ctx context.Context, _ *sapb.SerialsForIncidentRequest, _ ...grpc.CallOption) (sapb.StorageAuthorityReadOnly_SerialsForIncidentClient, error) {
return nil, nil
return &serverStreamClient[sapb.IncidentSerial]{}, nil
}

// SerialsForIncident is a mock
func (sa *StorageAuthority) SerialsForIncident(ctx context.Context, _ *sapb.SerialsForIncidentRequest, _ ...grpc.CallOption) (sapb.StorageAuthority_SerialsForIncidentClient, error) {
return nil, nil
return &serverStreamClient[sapb.IncidentSerial]{}, nil
}

// GetRevokedCerts is a mock
func (sa *StorageAuthorityReadOnly) GetRevokedCerts(ctx context.Context, _ *sapb.GetRevokedCertsRequest, _ ...grpc.CallOption) (sapb.StorageAuthorityReadOnly_GetRevokedCertsClient, error) {
return nil, nil
return &serverStreamClient[corepb.CRLEntry]{}, nil
}

// GetRevokedCerts is a mock
func (sa *StorageAuthority) GetRevokedCerts(ctx context.Context, _ *sapb.GetRevokedCertsRequest, _ ...grpc.CallOption) (sapb.StorageAuthority_GetRevokedCertsClient, error) {
return nil, nil
return &serverStreamClient[corepb.CRLEntry]{}, nil
}

// GetMaxExpiration is a mock
Expand Down Expand Up @@ -579,22 +592,22 @@ func (sa *StorageAuthorityReadOnly) GetAuthorization2(ctx context.Context, id *s

// GetSerialsByKey is a mock
func (sa *StorageAuthorityReadOnly) GetSerialsByKey(ctx context.Context, _ *sapb.SPKIHash, _ ...grpc.CallOption) (sapb.StorageAuthorityReadOnly_GetSerialsByKeyClient, error) {
return nil, nil
return &serverStreamClient[sapb.Serial]{}, nil
}

// GetSerialsByKey is a mock
func (sa *StorageAuthority) GetSerialsByKey(ctx context.Context, _ *sapb.SPKIHash, _ ...grpc.CallOption) (sapb.StorageAuthority_GetSerialsByKeyClient, error) {
return nil, nil
return &serverStreamClient[sapb.Serial]{}, nil
}

// GetSerialsByAccount is a mock
func (sa *StorageAuthorityReadOnly) GetSerialsByAccount(ctx context.Context, _ *sapb.RegistrationID, _ ...grpc.CallOption) (sapb.StorageAuthorityReadOnly_GetSerialsByAccountClient, error) {
return nil, nil
return &serverStreamClient[sapb.Serial]{}, nil
}

// GetSerialsByAccount is a mock
func (sa *StorageAuthority) GetSerialsByAccount(ctx context.Context, _ *sapb.RegistrationID, _ ...grpc.CallOption) (sapb.StorageAuthority_GetSerialsByAccountClient, error) {
return nil, nil
return &serverStreamClient[sapb.Serial]{}, nil
}

// RevokeCertificate is a mock
Expand Down

0 comments on commit ffef10a

Please sign in to comment.