diff --git a/sa/sa_test.go b/sa/sa_test.go index 1e281a95f21..bdbb4a3a2c5 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -22,6 +22,7 @@ import ( "os" "reflect" "slices" + "sort" "strings" "sync" "testing" @@ -3353,40 +3354,69 @@ func TestGetRevokedCerts(t *testing.T) { test.AssertEquals(t, count, 0) } -func TestGetRevokedCertsByShard(t *testing.T) { - sa, _, cleanUp := initSA(t) +func TestGetRevokedCertsWithShard(t *testing.T) { + sa, fc, cleanUp := initSA(t) defer cleanUp() - // Add a cert to the DB to test with. We use AddPrecertificate because it sets - // up the certificateStatus row we need. This particular cert has a notAfter - // date of Mar 6 2023, and we lie about its IssuerNameID to make things easy. reg := createWorkingRegistration(t, sa) - eeCert, err := core.LoadCert("../test/hierarchy/ee-e1.cert.pem") - test.AssertNotError(t, err, "failed to load test cert") - _, err = sa.AddSerial(ctx, &sapb.AddSerialRequest{ - RegID: reg.Id, - Serial: core.SerialToString(eeCert.SerialNumber), - Created: timestamppb.New(eeCert.NotBefore), - Expires: timestamppb.New(eeCert.NotAfter), - }) - test.AssertNotError(t, err, "failed to add test serial") - _, err = sa.AddPrecertificate(ctx, &sapb.AddCertificateRequest{ - Der: eeCert.Raw, - RegID: reg.Id, - Issued: timestamppb.New(eeCert.NotBefore), - IssuerNameID: 1, - }) - test.AssertNotError(t, err, "failed to add test cert") - // Check that it worked. - status, err := sa.GetCertificateStatus( - ctx, &sapb.Serial{Serial: core.SerialToString(eeCert.SerialNumber)}) - test.AssertNotError(t, err, "GetCertificateStatus failed") - test.AssertEquals(t, core.OCSPStatus(status.Status), core.OCSPStatusGood) + fc.Set(mustTime("2023-03-01 00:00")) - // Here's a little helper func we'll use to call GetRevokedCerts and count - // how many results it returned. - countRevokedCerts := func(req *sapb.GetRevokedCertsRequest) (int, error) { + // Make up an IssuerNameID to make things simpler. + issuerNameID := int64(834) + + // Create a certificate and add it to the tables we need. + makeCert := func() *x509.Certificate { + _, cert := test.ThrowAwayCert(t, fc) + // We depend on specifics of the lifetime set by test.ThrowAwayCert, so verify. + lifetime := cert.NotAfter.Sub(cert.NotBefore) + if lifetime != 6*24*time.Hour { + t.Fatalf("cert lifetime: got %s, want 6 days", lifetime) + } + _, err := sa.AddSerial(ctx, &sapb.AddSerialRequest{ + RegID: reg.Id, + Serial: core.SerialToString(cert.SerialNumber), + Created: timestamppb.New(cert.NotBefore), + Expires: timestamppb.New(cert.NotAfter), + }) + if err != nil { + t.Fatalf("adding serial: %s", err) + } + _, err = sa.AddPrecertificate(ctx, &sapb.AddCertificateRequest{ + Der: cert.Raw, + RegID: reg.Id, + Issued: timestamppb.New(cert.NotBefore), + IssuerNameID: issuerNameID, + }) + if err != nil { + t.Fatalf("adding cert: %s", err) + } + status, err := sa.GetCertificateStatus(ctx, &sapb.Serial{Serial: core.SerialToString(cert.SerialNumber)}) + if err != nil { + t.Fatalf("GetCertificateStatus: %s", err) + } + if status.Status != string(core.OCSPStatusGood) { + t.Fatalf("GetCertificateStatus for new cert: got %s, want %s", status.Status, core.OCSPStatusGood) + } + return cert + } + + // Two certs issued at the same time, with the same expiration. + // eeCert1 will be revoked without an explicit ShardIdx. + // eeCert2 will be revoked _with_ an explicit ShardIdx. + eeCert1 := makeCert() + eeCert2 := makeCert() + + // eeCert3 is issued two days after the others and will be revoked + // with the same explicit ShardIdx. It will show up in a different + // temporal shard than eeCert1 and eeCert2, because we are querying + // as if the shard width for CRLs is one day. + fc.Add(2 * 24 * time.Hour) + eeCert3 := makeCert() + + // Here's a little helper func we'll use to call GetRevokedCerts and return + // a sorted list of serials. + getRevokedCerts := func(req *sapb.GetRevokedCertsRequest) []string { stream := make(chan *corepb.CRLEntry) mockServerStream := &fakeServerStream[corepb.CRLEntry]{output: stream} var err error @@ -3394,79 +3424,109 @@ func TestGetRevokedCertsByShard(t *testing.T) { err = sa.GetRevokedCerts(req, mockServerStream) close(stream) }() - entriesReceived := 0 - for range stream { - entriesReceived++ + var serials []string + for e := range stream { + serials = append(serials, e.Serial) } - return entriesReceived, err + if err != nil { + t.Fatalf("GetRevokedCerts(%+v): %s", req, err) + } + return serials } - // The basic request covers a time range and shard that should include this certificate. + // The basic request covers a time range that includes eeCert1's and eeCert2's NotAfter, + // but excludes eeCert3's NotAfter. + // The ExpiresBefore field is set based on the 6-day lifetime of certs from test.ThrowAwayCert basicRequest := &sapb.GetRevokedCertsRequest{ - IssuerNameID: 1, - ShardIdx: 9, + IssuerNameID: issuerNameID, + ShardIdx: 97, ExpiresAfter: mustTimestamp("2023-03-01 00:00"), + ExpiresBefore: mustTimestamp("2023-03-08 00:00"), RevokedBefore: mustTimestamp("2023-04-01 00:00"), } - // Nothing's been revoked yet. Count should be zero. - count, err := countRevokedCerts(basicRequest) - test.AssertNotError(t, err, "zero rows shouldn't result in error") - test.AssertEquals(t, count, 0) - - // Revoke the certificate, providing the ShardIdx so it gets written into - // both the certificateStatus and revokedCertificates tables. - _, err = sa.RevokeCertificate(context.Background(), &sapb.RevokeCertificateRequest{ - IssuerID: 1, - Serial: core.SerialToString(eeCert.SerialNumber), - Date: mustTimestamp("2023-01-01 00:00"), - Reason: 1, - Response: []byte{1, 2, 3}, - ShardIdx: 9, - }) - test.AssertNotError(t, err, "failed to revoke test cert") + // Nothing's been revoked yet. Should get no results. + serials := getRevokedCerts(basicRequest) + if len(serials) > 0 { + t.Errorf("GetRevokedCerts (before revocations) = %s, want []", serials) + } - // Check that it worked in the most basic way. - c, err := sa.dbMap.SelectNullInt( - ctx, "SELECT count(*) FROM revokedCertificates") - test.AssertNotError(t, err, "SELECT from revokedCertificates failed") - test.Assert(t, c.Valid, "SELECT from revokedCertificates got no result") - test.AssertEquals(t, c.Int64, int64(1)) + revoke := func(cert *x509.Certificate, shardIdx int64) { + t.Logf("revoking %x with shardIdx %d", cert.SerialNumber, shardIdx) + _, err := sa.RevokeCertificate(context.Background(), &sapb.RevokeCertificateRequest{ + IssuerID: issuerNameID, + Serial: core.SerialToString(cert.SerialNumber), + Date: mustTimestamp("2023-03-04 00:00"), + Reason: 1, + Response: []byte{1, 2, 3}, + ShardIdx: shardIdx, + }) + if err != nil { + t.Fatalf("sa.RevokeCertificate %s", err) + } + } - // Asking for revoked certs now should return one result. - count, err = countRevokedCerts(basicRequest) - test.AssertNotError(t, err, "normal usage shouldn't result in error") - test.AssertEquals(t, count, 1) + // First certificate: revoke without ShardIdx + revoke(eeCert1, 0) + // Second certificate: revoke with ShardIdx = 97. + revoke(eeCert2, 97) + // Third certificate: revoke with ShardIdx = 97. + // But note that the temporal shard is different from the other two. + revoke(eeCert3, 97) + + // expectSerials registers an error if the provided serials don't match the serials + // of the provided certs (after sorting). + expectSerials := func(message string, serials []string, certs ...*x509.Certificate) { + t.Helper() + var expectedSerials []string + for _, c := range certs { + expectedSerials = append(expectedSerials, core.SerialToString(c.SerialNumber)) + } + sort.Strings(expectedSerials) + sort.Strings(serials) + if !reflect.DeepEqual(serials, expectedSerials) { + t.Errorf("%s: want %s, got %s", message, expectedSerials, serials) + } + } + serials = getRevokedCerts(basicRequest) + expectSerials("GetRevokedCerts (after revocations)", serials, eeCert1, eeCert2, eeCert3) - // Asking for revoked certs from a different issuer should return zero results. - count, err = countRevokedCerts(&sapb.GetRevokedCertsRequest{ + serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{ IssuerNameID: 5678, ShardIdx: basicRequest.ShardIdx, ExpiresAfter: basicRequest.ExpiresAfter, + ExpiresBefore: basicRequest.ExpiresBefore, RevokedBefore: basicRequest.RevokedBefore, }) - test.AssertNotError(t, err, "zero rows shouldn't result in error") - test.AssertEquals(t, count, 0) + expectSerials("GetRevokedCerts with nonexistent issuer", serials) - // Asking for revoked certs from a different shard should return zero results. - count, err = countRevokedCerts(&sapb.GetRevokedCertsRequest{ + serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{ + IssuerNameID: basicRequest.IssuerNameID, + ShardIdx: 0, + ExpiresAfter: basicRequest.ExpiresAfter, + ExpiresBefore: basicRequest.ExpiresBefore, + RevokedBefore: basicRequest.RevokedBefore, + }) + expectSerials("GetRevokedCerts with no shardIdx specified (temporal sharding only)", serials, eeCert1, eeCert2) + + serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{ IssuerNameID: basicRequest.IssuerNameID, ShardIdx: 8, ExpiresAfter: basicRequest.ExpiresAfter, + ExpiresBefore: basicRequest.ExpiresBefore, RevokedBefore: basicRequest.RevokedBefore, }) - test.AssertNotError(t, err, "zero rows shouldn't result in error") - test.AssertEquals(t, count, 0) + expectSerials("GetRevokedCerts for explicit shard with no revocations (temporal sharding only)", serials, eeCert1, eeCert2) // Asking for revoked certs with an old RevokedBefore should return no results. - count, err = countRevokedCerts(&sapb.GetRevokedCertsRequest{ + serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{ IssuerNameID: basicRequest.IssuerNameID, ShardIdx: basicRequest.ShardIdx, ExpiresAfter: basicRequest.ExpiresAfter, + ExpiresBefore: basicRequest.ExpiresBefore, RevokedBefore: mustTimestamp("2020-03-01 00:00"), }) - test.AssertNotError(t, err, "zero rows shouldn't result in error") - test.AssertEquals(t, count, 0) + expectSerials("GetRevokedCerts for old RevokedBefore", serials) } func TestGetMaxExpiration(t *testing.T) { diff --git a/sa/saro.go b/sa/saro.go index 6860cb9657a..7b53888d278 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -1050,18 +1050,54 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden }) } -// GetRevokedCerts gets a request specifying an issuer and a period of time, -// and writes to the output stream the set of all certificates issued by that -// issuer which expire during that period of time and which have been revoked. +// crlDeduper implements grpc.ServerStreamingServer[corepb.CRLEntry]. +// +// It passes CRLEntry's to the inner ServerStreamingServer, with the +// exception that it omits any CRLEntry with the same serial as a previously +// sent one. +type crlDeduper struct { + grpc.ServerStreamingServer[corepb.CRLEntry] + + seen map[string]bool +} + +func (cd crlDeduper) Send(crl *corepb.CRLEntry) error { + if !cd.seen[crl.Serial] { + cd.seen[crl.Serial] = true + return cd.ServerStreamingServer.Send(crl) + } + return nil +} + +// GetRevokedCerts returns a stream of revoked certificates for a single CRL shard. +// +// If ShardIdx is zero, GetRevokedCerts calculates shard membership based +// solely on temporal sharding. +// +// If ShardIdx is nonzero, GetRevokedCerts calculates shard membership based +// on temporal sharding _and_ explicit sharding (that is, sharding based on +// the shardIdx field of the revokedCertificates table). Most revoked certificates +// will be present in two shards: one based on explicit sharding and one based +// on temporal sharding (a few will have the same shard for both). +// // The starting timestamp is treated as inclusive (certs with exactly that // notAfter date are included), but the ending timestamp is exclusive (certs // with exactly that notAfter date are *not* included). func (ssa *SQLStorageAuthorityRO) GetRevokedCerts(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { + if core.IsAnyNilOrZero(req.IssuerNameID, req.ExpiresAfter, req.ExpiresBefore, req.RevokedBefore) { + return errors.New("incomplete request for GetRevokedCerts") + } + crlDeduper := crlDeduper{ + ServerStreamingServer: stream, + seen: make(map[string]bool), + } if req.ShardIdx != 0 { - return ssa.getRevokedCertsFromRevokedCertificatesTable(req, stream) - } else { - return ssa.getRevokedCertsFromCertificateStatusTable(req, stream) + err := ssa.getRevokedCertsFromRevokedCertificatesTable(req, crlDeduper) + if err != nil { + return err + } } + return ssa.getRevokedCertsFromCertificateStatusTable(req, crlDeduper) } // getRevokedCertsFromRevokedCertificatesTable uses the new revokedCertificates diff --git a/test/certs.go b/test/certs.go index 6dd1ce5a239..2bcd7c5796a 100644 --- a/test/certs.go +++ b/test/certs.go @@ -64,14 +64,18 @@ func LoadSigner(filename string) (crypto.Signer, error) { // ThrowAwayCert is a small test helper function that creates a self-signed // certificate with one SAN. It returns the parsed certificate and its serial // in string form for convenience. +// // The certificate returned from this function is the bare minimum needed for // most tests and isn't a robust example of a complete end entity certificate. +// +// Returned certificates have NotBefore == clk.Now(), and NotBefore 6 days in the +// future. func ThrowAwayCert(t *testing.T, clk clock.Clock) (string, *x509.Certificate) { var nameBytes [3]byte _, _ = rand.Read(nameBytes[:]) name := fmt.Sprintf("%s.example.com", hex.EncodeToString(nameBytes[:])) - var serialBytes [16]byte + var serialBytes [18]byte _, _ = rand.Read(serialBytes[:]) serial := big.NewInt(0).SetBytes(serialBytes[:])