From b5525e494e42ace92bea210c845d7fdc9a748705 Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Tue, 7 Jan 2025 09:48:09 -0800 Subject: [PATCH 1/6] sa: GetRevokedCerts returns explicit shards too Change GetRevokedCerts to return a combined list of certs for a given shard, calculating shard membership temporally _and_ by explicit assignment to a shard in the revokedCertificates table. This functionality is gated on the ShardIdx field of GetRevokedCertsRequest. If it is zero, revoked certs will only be returned from a given temporal shard (and we assume that no certs have been assigned to any explicit shard yet). After we start sending the ShardIdx field, and also start writing entries to the revokedCertificates table, this will result in CRL sizes doubling for several months until we retire the temporal sharding code, since most revoked certificates will be included in one shard based on their entry in revokedCertificates, and a different shard based on their issuance time. --- sa/saro.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/sa/saro.go b/sa/saro.go index 6860cb9657a..0046878d924 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -1050,18 +1050,26 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden }) } -// GetRevokedCerts gets a request specifying an issuer and a period of time, -// and writes to the output stream the set of all certificates issued by that -// issuer which expire during that period of time and which have been revoked. +// GetRevokedCerts returns a stream of revoked certificates for a single CRL shard. +// +// If ShardIdx is zero, GetRevokedCerts calculates shard membership based +// solely on temporal sharding. +// +// If ShardIdx is nonzero, GetRevokedCerts calculates shard membership based +// on temporal sharding _and_ explicit sharding (that is, sharding based on +// the shardIdx field of the revokedCertificates table). +// // The starting timestamp is treated as inclusive (certs with exactly that // notAfter date are included), but the ending timestamp is exclusive (certs // with exactly that notAfter date are *not* included). func (ssa *SQLStorageAuthorityRO) GetRevokedCerts(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { if req.ShardIdx != 0 { - return ssa.getRevokedCertsFromRevokedCertificatesTable(req, stream) - } else { - return ssa.getRevokedCertsFromCertificateStatusTable(req, stream) + err := ssa.getRevokedCertsFromRevokedCertificatesTable(req, stream) + if err != nil { + return err + } } + return ssa.getRevokedCertsFromCertificateStatusTable(req, stream) } // getRevokedCertsFromRevokedCertificatesTable uses the new revokedCertificates From ea93022fb4b1f12b4c4d64763776e7288ddf259b Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Tue, 7 Jan 2025 12:00:16 -0800 Subject: [PATCH 2/6] update comment --- sa/saro.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sa/saro.go b/sa/saro.go index 0046878d924..7a5a68d3382 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -1057,7 +1057,9 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden // // If ShardIdx is nonzero, GetRevokedCerts calculates shard membership based // on temporal sharding _and_ explicit sharding (that is, sharding based on -// the shardIdx field of the revokedCertificates table). +// the shardIdx field of the revokedCertificates table). Most revoked certificates +// will be present in two shards: one based on explicit sharding and one based +// on temporal sharding (a few will have the same shard for both). // // The starting timestamp is treated as inclusive (certs with exactly that // notAfter date are included), but the ending timestamp is exclusive (certs From 4a079846d4fb3f0e86dbdd8079638b5b33b2b9bc Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Thu, 9 Jan 2025 09:58:21 -0800 Subject: [PATCH 3/6] Make tests work --- sa/sa_test.go | 144 ++++++++++++++++++++++++++++++-------------------- sa/saro.go | 75 +++++++++++++++++++++----- 2 files changed, 149 insertions(+), 70 deletions(-) diff --git a/sa/sa_test.go b/sa/sa_test.go index 1e281a95f21..f9f6c7a9142 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -22,6 +22,7 @@ import ( "os" "reflect" "slices" + "sort" "strings" "sync" "testing" @@ -3354,39 +3355,54 @@ func TestGetRevokedCerts(t *testing.T) { } func TestGetRevokedCertsByShard(t *testing.T) { - sa, _, cleanUp := initSA(t) + sa, fc, cleanUp := initSA(t) defer cleanUp() - // Add a cert to the DB to test with. We use AddPrecertificate because it sets - // up the certificateStatus row we need. This particular cert has a notAfter - // date of Mar 6 2023, and we lie about its IssuerNameID to make things easy. reg := createWorkingRegistration(t, sa) - eeCert, err := core.LoadCert("../test/hierarchy/ee-e1.cert.pem") - test.AssertNotError(t, err, "failed to load test cert") - _, err = sa.AddSerial(ctx, &sapb.AddSerialRequest{ - RegID: reg.Id, - Serial: core.SerialToString(eeCert.SerialNumber), - Created: timestamppb.New(eeCert.NotBefore), - Expires: timestamppb.New(eeCert.NotAfter), - }) - test.AssertNotError(t, err, "failed to add test serial") - _, err = sa.AddPrecertificate(ctx, &sapb.AddCertificateRequest{ - Der: eeCert.Raw, - RegID: reg.Id, - Issued: timestamppb.New(eeCert.NotBefore), - IssuerNameID: 1, - }) - test.AssertNotError(t, err, "failed to add test cert") + + // Add two certs to the DB to test with. We use AddPrecertificate because it sets + // up the certificateStatus row we need. These certs have a notAfter + // date of Mar 7 2023, and we lie about their IssuerNameID to make things easy. + fc.Set(mustTime("2023-03-01 00:00")) + + // Create a certificate and add it to the tables we need. + makeCert := func() *x509.Certificate { + _, cert := test.ThrowAwayCert(t, fc) + _, err := sa.AddSerial(ctx, &sapb.AddSerialRequest{ + RegID: reg.Id, + Serial: core.SerialToString(cert.SerialNumber), + Created: timestamppb.New(cert.NotBefore), + Expires: timestamppb.New(cert.NotAfter), + }) + test.AssertNotError(t, err, "failed to add test serial") + _, err = sa.AddPrecertificate(ctx, &sapb.AddCertificateRequest{ + Der: cert.Raw, + RegID: reg.Id, + Issued: timestamppb.New(cert.NotBefore), + IssuerNameID: 1, + }) + test.AssertNotError(t, err, "failed to add test cert") + return cert + } + + eeCert1 := makeCert() + eeCert2 := makeCert() + t.Logf("eeCert1: %x", eeCert1.SerialNumber) + t.Logf("eeCert2: %x", eeCert2.SerialNumber) // Check that it worked. status, err := sa.GetCertificateStatus( - ctx, &sapb.Serial{Serial: core.SerialToString(eeCert.SerialNumber)}) + ctx, &sapb.Serial{Serial: core.SerialToString(eeCert1.SerialNumber)}) + test.AssertNotError(t, err, "GetCertificateStatus failed") + test.AssertEquals(t, core.OCSPStatus(status.Status), core.OCSPStatusGood) + status, err = sa.GetCertificateStatus( + ctx, &sapb.Serial{Serial: core.SerialToString(eeCert2.SerialNumber)}) test.AssertNotError(t, err, "GetCertificateStatus failed") test.AssertEquals(t, core.OCSPStatus(status.Status), core.OCSPStatusGood) - // Here's a little helper func we'll use to call GetRevokedCerts and count - // how many results it returned. - countRevokedCerts := func(req *sapb.GetRevokedCertsRequest) (int, error) { + // Here's a little helper func we'll use to call GetRevokedCerts and return + // a sorted list of serials. + getRevokedCerts := func(req *sapb.GetRevokedCertsRequest) []string { stream := make(chan *corepb.CRLEntry) mockServerStream := &fakeServerStream[corepb.CRLEntry]{output: stream} var err error @@ -3394,78 +3410,90 @@ func TestGetRevokedCertsByShard(t *testing.T) { err = sa.GetRevokedCerts(req, mockServerStream) close(stream) }() - entriesReceived := 0 - for range stream { - entriesReceived++ + var serials []string + for e := range stream { + serials = append(serials, e.Serial) } - return entriesReceived, err + if err != nil { + t.Fatalf("getting revoked certs: %s", err) + } + sort.Strings(serials) + return serials } - // The basic request covers a time range and shard that should include this certificate. + // The basic request covers a time range and shard that should include both certificates. basicRequest := &sapb.GetRevokedCertsRequest{ IssuerNameID: 1, - ShardIdx: 9, + ShardIdx: 97, ExpiresAfter: mustTimestamp("2023-03-01 00:00"), + ExpiresBefore: mustTimestamp("2023-04-01 00:00"), RevokedBefore: mustTimestamp("2023-04-01 00:00"), } + t.Logf("expires : %s, basicRequest: %+v", eeCert1.NotAfter, basicRequest) + // Nothing's been revoked yet. Count should be zero. - count, err := countRevokedCerts(basicRequest) - test.AssertNotError(t, err, "zero rows shouldn't result in error") - test.AssertEquals(t, count, 0) + serials := getRevokedCerts(basicRequest) + if len(serials) > 0 { + t.Errorf("before revoking, GetRevokedCerts(%+v) = %s, want []", basicRequest, serials) + } - // Revoke the certificate, providing the ShardIdx so it gets written into - // both the certificateStatus and revokedCertificates tables. - _, err = sa.RevokeCertificate(context.Background(), &sapb.RevokeCertificateRequest{ - IssuerID: 1, - Serial: core.SerialToString(eeCert.SerialNumber), - Date: mustTimestamp("2023-01-01 00:00"), - Reason: 1, - Response: []byte{1, 2, 3}, - ShardIdx: 9, - }) - test.AssertNotError(t, err, "failed to revoke test cert") + revoke := func(serial *big.Int, shardIdx int64) { + _, err = sa.RevokeCertificate(context.Background(), &sapb.RevokeCertificateRequest{ + IssuerID: 1, + Serial: core.SerialToString(serial), + Date: mustTimestamp("2023-03-04 00:00"), + Reason: 1, + Response: []byte{1, 2, 3}, + ShardIdx: shardIdx, + }) + test.AssertNotError(t, err, "failed to revoke test cert") + } + + // First certificate: revoke without ShardIdx + revoke(eeCert1.SerialNumber, 0) + // Second certificate: revoke with ShardIdx = 97. + revoke(eeCert2.SerialNumber, 97) // Check that it worked in the most basic way. c, err := sa.dbMap.SelectNullInt( - ctx, "SELECT count(*) FROM revokedCertificates") + ctx, "SELECT count(*) FROM revokedCertificates where shardIdx = 97;") test.AssertNotError(t, err, "SELECT from revokedCertificates failed") test.Assert(t, c.Valid, "SELECT from revokedCertificates got no result") test.AssertEquals(t, c.Int64, int64(1)) - // Asking for revoked certs now should return one result. - count, err = countRevokedCerts(basicRequest) - test.AssertNotError(t, err, "normal usage shouldn't result in error") - test.AssertEquals(t, count, 1) + // Asking for revoked certs now should return two results. + serials = getRevokedCerts(basicRequest) + if len(serials) != 2 { + t.Errorf("GetRevokedCerts(%+v) = %d, want %d", basicRequest, len(serials), 2) + } // Asking for revoked certs from a different issuer should return zero results. - count, err = countRevokedCerts(&sapb.GetRevokedCertsRequest{ + count := len(getRevokedCerts(&sapb.GetRevokedCertsRequest{ IssuerNameID: 5678, ShardIdx: basicRequest.ShardIdx, ExpiresAfter: basicRequest.ExpiresAfter, + ExpiresBefore: basicRequest.ExpiresBefore, RevokedBefore: basicRequest.RevokedBefore, - }) - test.AssertNotError(t, err, "zero rows shouldn't result in error") + })) test.AssertEquals(t, count, 0) // Asking for revoked certs from a different shard should return zero results. - count, err = countRevokedCerts(&sapb.GetRevokedCertsRequest{ + count = len(getRevokedCerts(&sapb.GetRevokedCertsRequest{ IssuerNameID: basicRequest.IssuerNameID, ShardIdx: 8, ExpiresAfter: basicRequest.ExpiresAfter, RevokedBefore: basicRequest.RevokedBefore, - }) - test.AssertNotError(t, err, "zero rows shouldn't result in error") + })) test.AssertEquals(t, count, 0) // Asking for revoked certs with an old RevokedBefore should return no results. - count, err = countRevokedCerts(&sapb.GetRevokedCertsRequest{ + count = len(getRevokedCerts(&sapb.GetRevokedCertsRequest{ IssuerNameID: basicRequest.IssuerNameID, ShardIdx: basicRequest.ShardIdx, ExpiresAfter: basicRequest.ExpiresAfter, RevokedBefore: mustTimestamp("2020-03-01 00:00"), - }) - test.AssertNotError(t, err, "zero rows shouldn't result in error") + })) test.AssertEquals(t, count, 0) } diff --git a/sa/saro.go b/sa/saro.go index 7a5a68d3382..1261ecb757d 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -1065,19 +1065,53 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden // notAfter date are included), but the ending timestamp is exclusive (certs // with exactly that notAfter date are *not* included). func (ssa *SQLStorageAuthorityRO) GetRevokedCerts(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { + // The two different methods of finding certs may return the same serial. + // We'd like to deduplicate them. + seen := make(map[string]bool) + send := func(e *corepb.CRLEntry) error { + if !seen[e.Serial] { + seen[e.Serial] = true + return stream.Send(e) + } + return nil + } + + // getterFunc is one of getRevokedCertsFromRevokedCertificatesTable or getRevokedCertsFromCertificateStatusTable. + // We assume that getterfunc closes the channel once it is done. + type getterFunc func(context.Context, *sapb.GetRevokedCertsRequest, chan<- *corepb.CRLEntry) error + + // Collect a bunch of CRLEntries, deduplicate them, and send them out on the stream. + sendAll := func(f getterFunc) error { + ch := make(chan *corepb.CRLEntry) + var err error + go func() { + err = f(stream.Context(), req, ch) + }() + for crlEntry := range ch { + err2 := send(crlEntry) + if err2 != nil { + return err2 + } + } + return err + } + + // If a shard index was specified, get entries by that explicit shard in addition + // to getting by temporal shard. if req.ShardIdx != 0 { - err := ssa.getRevokedCertsFromRevokedCertificatesTable(req, stream) + err := sendAll(ssa.getRevokedCertsFromRevokedCertificatesTable) if err != nil { return err } } - return ssa.getRevokedCertsFromCertificateStatusTable(req, stream) + + return sendAll(ssa.getRevokedCertsFromCertificateStatusTable) } // getRevokedCertsFromRevokedCertificatesTable uses the new revokedCertificates // table to implement GetRevokedCerts. It must only be called when the request // contains a non-zero ShardIdx. -func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { +func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(ctx context.Context, req *sapb.GetRevokedCertsRequest, stream chan<- *corepb.CRLEntry) error { if req.ShardIdx == 0 { return errors.New("can't select shard 0 from revokedCertificates table") } @@ -1101,12 +1135,12 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(re return fmt.Errorf("initializing db map: %w", err) } - rows, err := selector.QueryContext(stream.Context(), clauses, params...) + rows, err := selector.QueryContext(ctx, clauses, params...) if err != nil { return fmt.Errorf("reading db: %w", err) } - return rows.ForEach(func(row *revokedCertModel) error { + err = rows.ForEach(func(row *revokedCertModel) error { // Double-check that the cert wasn't revoked between the time at which we're // constructing this snapshot CRL and right now. If the cert was revoked // at-or-after the "atTime", we'll just include it in the next generation @@ -1115,17 +1149,25 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(re return nil } - return stream.Send(&corepb.CRLEntry{ + stream <- &corepb.CRLEntry{ Serial: row.Serial, Reason: int32(row.RevokedReason), RevokedAt: timestamppb.New(row.RevokedDate), - }) + } + + return nil }) + if err != nil { + return err + } + + close(stream) + return nil } // getRevokedCertsFromCertificateStatusTable uses the old certificateStatus // table to implement GetRevokedCerts. -func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { +func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(ctx context.Context, req *sapb.GetRevokedCertsRequest, stream chan<- *corepb.CRLEntry) error { atTime := req.RevokedBefore.AsTime() clauses := ` @@ -1145,12 +1187,13 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(req return fmt.Errorf("initializing db map: %w", err) } - rows, err := selector.QueryContext(stream.Context(), clauses, params...) + rows, err := selector.QueryContext(ctx, clauses, params...) if err != nil { return fmt.Errorf("reading db: %w", err) } - return rows.ForEach(func(row *crlEntryModel) error { + fmt.Printf("querying for notAfter >= %s, notAfter < %s\n", req.ExpiresAfter.AsTime(), req.ExpiresBefore.AsTime()) + err = rows.ForEach(func(row *crlEntryModel) error { // Double-check that the cert wasn't revoked between the time at which we're // constructing this snapshot CRL and right now. If the cert was revoked // at-or-after the "atTime", we'll just include it in the next generation @@ -1159,12 +1202,20 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(req return nil } - return stream.Send(&corepb.CRLEntry{ + stream <- &corepb.CRLEntry{ Serial: row.Serial, Reason: int32(row.RevokedReason), RevokedAt: timestamppb.New(row.RevokedDate), - }) + } + + return nil }) + if err != nil { + return nil + } + + close(stream) + return nil } // GetMaxExpiration returns the timestamp of the farthest-future notAfter date From 4708117f3c7f72f7dc905d34fa60f09134b2522f Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Thu, 9 Jan 2025 15:19:40 -0800 Subject: [PATCH 4/6] Update tests --- sa/sa_test.go | 147 +++++++++++++++++++++++++++++++++----------------- sa/saro.go | 96 ++++++++++++--------------------- test/certs.go | 6 ++- 3 files changed, 136 insertions(+), 113 deletions(-) diff --git a/sa/sa_test.go b/sa/sa_test.go index f9f6c7a9142..45afe52f711 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -3354,51 +3354,70 @@ func TestGetRevokedCerts(t *testing.T) { test.AssertEquals(t, count, 0) } -func TestGetRevokedCertsByShard(t *testing.T) { +func TestGetRevokedCertsWithShard(t *testing.T) { sa, fc, cleanUp := initSA(t) defer cleanUp() reg := createWorkingRegistration(t, sa) - // Add two certs to the DB to test with. We use AddPrecertificate because it sets - // up the certificateStatus row we need. These certs have a notAfter - // date of Mar 7 2023, and we lie about their IssuerNameID to make things easy. fc.Set(mustTime("2023-03-01 00:00")) + // Make up an IssuerNameID to make things simpler. + issuerNameID := int64(834) + // Create a certificate and add it to the tables we need. makeCert := func() *x509.Certificate { _, cert := test.ThrowAwayCert(t, fc) + // We depend on specifics of the lifetime set by test.ThrowAwayCert, so verify. + lifetime := cert.NotAfter.Sub(cert.NotBefore) + if lifetime != 6*24*time.Hour { + t.Fatalf("cert lifetime: got %s, want 6 days", lifetime) + } _, err := sa.AddSerial(ctx, &sapb.AddSerialRequest{ RegID: reg.Id, Serial: core.SerialToString(cert.SerialNumber), Created: timestamppb.New(cert.NotBefore), Expires: timestamppb.New(cert.NotAfter), }) - test.AssertNotError(t, err, "failed to add test serial") + if err != nil { + t.Fatalf("adding serial: %s", err) + } _, err = sa.AddPrecertificate(ctx, &sapb.AddCertificateRequest{ Der: cert.Raw, RegID: reg.Id, Issued: timestamppb.New(cert.NotBefore), - IssuerNameID: 1, + IssuerNameID: int64(issuerNameID), }) - test.AssertNotError(t, err, "failed to add test cert") + if err != nil { + t.Fatalf("adding cert: %s", err) + } return cert } + // Two certs issued at the same time, with the same expiration. + // eeCert1 will be revoked without an explicit ShardIdx. + // eeCert2 will be revoked _with_ an explicit ShardIdx. eeCert1 := makeCert() eeCert2 := makeCert() - t.Logf("eeCert1: %x", eeCert1.SerialNumber) - t.Logf("eeCert2: %x", eeCert2.SerialNumber) + + // eeCert3 is issued two days after the others and will be revoked + // with the same explicit ShardIdx. It will show up in a different + // temporal shard than eeCert1 and eeCert2, because we are querying + // as if the shard width for CRLs is one day. + fc.Add(2 * 24 * time.Hour) + eeCert3 := makeCert() // Check that it worked. - status, err := sa.GetCertificateStatus( - ctx, &sapb.Serial{Serial: core.SerialToString(eeCert1.SerialNumber)}) - test.AssertNotError(t, err, "GetCertificateStatus failed") - test.AssertEquals(t, core.OCSPStatus(status.Status), core.OCSPStatusGood) - status, err = sa.GetCertificateStatus( - ctx, &sapb.Serial{Serial: core.SerialToString(eeCert2.SerialNumber)}) - test.AssertNotError(t, err, "GetCertificateStatus failed") - test.AssertEquals(t, core.OCSPStatus(status.Status), core.OCSPStatusGood) + for _, c := range []*x509.Certificate{eeCert1, eeCert2, eeCert3} { + status, err := sa.GetCertificateStatus( + ctx, &sapb.Serial{Serial: core.SerialToString(c.SerialNumber)}) + if err != nil { + t.Fatalf("GetCertificateStatus: %s", err) + } + if status.Status != string(core.OCSPStatusGood) { + t.Fatalf("GetCertificateStatus for new cert: got %s, want %s", status.Status, core.OCSPStatusGood) + } + } // Here's a little helper func we'll use to call GetRevokedCerts and return // a sorted list of serials. @@ -3415,86 +3434,114 @@ func TestGetRevokedCertsByShard(t *testing.T) { serials = append(serials, e.Serial) } if err != nil { - t.Fatalf("getting revoked certs: %s", err) + t.Fatalf("GetRevokedCerts(%+v): %s", req, err) } sort.Strings(serials) return serials } // The basic request covers a time range and shard that should include both certificates. + // The ExpiresBefore field is set based on the 6-day lifetime of certs from test.ThrowAwayCert basicRequest := &sapb.GetRevokedCertsRequest{ - IssuerNameID: 1, + IssuerNameID: issuerNameID, ShardIdx: 97, ExpiresAfter: mustTimestamp("2023-03-01 00:00"), - ExpiresBefore: mustTimestamp("2023-04-01 00:00"), + ExpiresBefore: mustTimestamp("2023-03-08 00:00"), RevokedBefore: mustTimestamp("2023-04-01 00:00"), } - t.Logf("expires : %s, basicRequest: %+v", eeCert1.NotAfter, basicRequest) - // Nothing's been revoked yet. Count should be zero. serials := getRevokedCerts(basicRequest) if len(serials) > 0 { - t.Errorf("before revoking, GetRevokedCerts(%+v) = %s, want []", basicRequest, serials) + t.Errorf("GetRevokedCerts (before revocations) = %s, want []", serials) } - revoke := func(serial *big.Int, shardIdx int64) { - _, err = sa.RevokeCertificate(context.Background(), &sapb.RevokeCertificateRequest{ - IssuerID: 1, - Serial: core.SerialToString(serial), + revoke := func(cert *x509.Certificate, shardIdx int64) { + t.Logf("revoking %x with shardIdx %d", cert.SerialNumber, shardIdx) + _, err := sa.RevokeCertificate(context.Background(), &sapb.RevokeCertificateRequest{ + IssuerID: issuerNameID, + Serial: core.SerialToString(cert.SerialNumber), Date: mustTimestamp("2023-03-04 00:00"), Reason: 1, Response: []byte{1, 2, 3}, ShardIdx: shardIdx, }) - test.AssertNotError(t, err, "failed to revoke test cert") + if err != nil { + t.Fatalf("sa.RevokeCertificate %s", err) + } } // First certificate: revoke without ShardIdx - revoke(eeCert1.SerialNumber, 0) + revoke(eeCert1, 0) // Second certificate: revoke with ShardIdx = 97. - revoke(eeCert2.SerialNumber, 97) + revoke(eeCert2, 97) + // Third certificate: revoke with ShardIdx = 97. + // But note that the temporal shard is different from the other two. + revoke(eeCert3, 97) // Check that it worked in the most basic way. - c, err := sa.dbMap.SelectNullInt( - ctx, "SELECT count(*) FROM revokedCertificates where shardIdx = 97;") - test.AssertNotError(t, err, "SELECT from revokedCertificates failed") - test.Assert(t, c.Valid, "SELECT from revokedCertificates got no result") - test.AssertEquals(t, c.Int64, int64(1)) + query := "SELECT count(*) FROM revokedCertificates where shardIdx = 97;" + c, err := sa.dbMap.SelectNullInt(ctx, query) + if err != nil { + t.Fatalf("query %q: %s", query, err) + } + if !c.Valid { + t.Fatalf("query %q: no results", query) + } + if c.Int64 != 2 { + t.Fatalf("query %q: got %d results, want %d", query, c.Int64, 2) + } - // Asking for revoked certs now should return two results. - serials = getRevokedCerts(basicRequest) - if len(serials) != 2 { - t.Errorf("GetRevokedCerts(%+v) = %d, want %d", basicRequest, len(serials), 2) + expectSerials := func(message string, serials []string, certs ...*x509.Certificate) { + t.Helper() + var expectedSerials []string + for _, c := range certs { + expectedSerials = append(expectedSerials, core.SerialToString(c.SerialNumber)) + } + sort.Strings(expectedSerials) + if !reflect.DeepEqual(serials, expectedSerials) { + t.Errorf("%s: want %s, got %s", message, expectedSerials, serials) + } } + serials = getRevokedCerts(basicRequest) + expectSerials("GetRevokedCerts (after revocations)", serials, eeCert1, eeCert2, eeCert3) - // Asking for revoked certs from a different issuer should return zero results. - count := len(getRevokedCerts(&sapb.GetRevokedCertsRequest{ + serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{ IssuerNameID: 5678, ShardIdx: basicRequest.ShardIdx, ExpiresAfter: basicRequest.ExpiresAfter, ExpiresBefore: basicRequest.ExpiresBefore, RevokedBefore: basicRequest.RevokedBefore, - })) - test.AssertEquals(t, count, 0) + }) + expectSerials("GetRevokedCerts with nonexistent issuer", serials) - // Asking for revoked certs from a different shard should return zero results. - count = len(getRevokedCerts(&sapb.GetRevokedCertsRequest{ + serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{ + IssuerNameID: basicRequest.IssuerNameID, + ShardIdx: 0, + ExpiresAfter: basicRequest.ExpiresAfter, + ExpiresBefore: basicRequest.ExpiresBefore, + RevokedBefore: basicRequest.RevokedBefore, + }) + expectSerials("GetRevokedCerts with no shardIdx specified (temporal sharding only)", serials, eeCert1, eeCert2) + + serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{ IssuerNameID: basicRequest.IssuerNameID, ShardIdx: 8, ExpiresAfter: basicRequest.ExpiresAfter, + ExpiresBefore: basicRequest.ExpiresBefore, RevokedBefore: basicRequest.RevokedBefore, - })) - test.AssertEquals(t, count, 0) + }) + expectSerials("GetRevokedCerts for explicit shard with no revocations (temporal sharding only)", serials, eeCert1, eeCert2) // Asking for revoked certs with an old RevokedBefore should return no results. - count = len(getRevokedCerts(&sapb.GetRevokedCertsRequest{ + serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{ IssuerNameID: basicRequest.IssuerNameID, ShardIdx: basicRequest.ShardIdx, ExpiresAfter: basicRequest.ExpiresAfter, + ExpiresBefore: basicRequest.ExpiresBefore, RevokedBefore: mustTimestamp("2020-03-01 00:00"), - })) - test.AssertEquals(t, count, 0) + }) + expectSerials("GetRevokedCerts for old RevokedBefore", serials) } func TestGetMaxExpiration(t *testing.T) { diff --git a/sa/saro.go b/sa/saro.go index 1261ecb757d..ad2f83fd8cb 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -1050,6 +1050,25 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden }) } +// crlDeduper implements grpc.ServerStreamingServer[corepb.CRLEntry]. +// +// It passes CRLEntry's to the inner ServerStreamingServer, with the +// exception that it omits any CRLEntry with the same serial as a previously +// sent one. +type crlDeduper struct { + grpc.ServerStreamingServer[corepb.CRLEntry] + + seen map[string]bool +} + +func (cd crlDeduper) Send(crl *corepb.CRLEntry) error { + if !cd.seen[crl.Serial] { + cd.seen[crl.Serial] = true + return cd.ServerStreamingServer.Send(crl) + } + return nil +} + // GetRevokedCerts returns a stream of revoked certificates for a single CRL shard. // // If ShardIdx is zero, GetRevokedCerts calculates shard membership based @@ -1065,53 +1084,23 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden // notAfter date are included), but the ending timestamp is exclusive (certs // with exactly that notAfter date are *not* included). func (ssa *SQLStorageAuthorityRO) GetRevokedCerts(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { - // The two different methods of finding certs may return the same serial. - // We'd like to deduplicate them. - seen := make(map[string]bool) - send := func(e *corepb.CRLEntry) error { - if !seen[e.Serial] { - seen[e.Serial] = true - return stream.Send(e) - } - return nil - } - - // getterFunc is one of getRevokedCertsFromRevokedCertificatesTable or getRevokedCertsFromCertificateStatusTable. - // We assume that getterfunc closes the channel once it is done. - type getterFunc func(context.Context, *sapb.GetRevokedCertsRequest, chan<- *corepb.CRLEntry) error - - // Collect a bunch of CRLEntries, deduplicate them, and send them out on the stream. - sendAll := func(f getterFunc) error { - ch := make(chan *corepb.CRLEntry) - var err error - go func() { - err = f(stream.Context(), req, ch) - }() - for crlEntry := range ch { - err2 := send(crlEntry) - if err2 != nil { - return err2 - } - } - return err + crlDeduper := crlDeduper{ + ServerStreamingServer: stream, + seen: make(map[string]bool), } - - // If a shard index was specified, get entries by that explicit shard in addition - // to getting by temporal shard. if req.ShardIdx != 0 { - err := sendAll(ssa.getRevokedCertsFromRevokedCertificatesTable) + err := ssa.getRevokedCertsFromRevokedCertificatesTable(req, crlDeduper) if err != nil { return err } } - - return sendAll(ssa.getRevokedCertsFromCertificateStatusTable) + return ssa.getRevokedCertsFromCertificateStatusTable(req, crlDeduper) } // getRevokedCertsFromRevokedCertificatesTable uses the new revokedCertificates // table to implement GetRevokedCerts. It must only be called when the request // contains a non-zero ShardIdx. -func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(ctx context.Context, req *sapb.GetRevokedCertsRequest, stream chan<- *corepb.CRLEntry) error { +func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { if req.ShardIdx == 0 { return errors.New("can't select shard 0 from revokedCertificates table") } @@ -1135,12 +1124,12 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(ct return fmt.Errorf("initializing db map: %w", err) } - rows, err := selector.QueryContext(ctx, clauses, params...) + rows, err := selector.QueryContext(stream.Context(), clauses, params...) if err != nil { return fmt.Errorf("reading db: %w", err) } - err = rows.ForEach(func(row *revokedCertModel) error { + return rows.ForEach(func(row *revokedCertModel) error { // Double-check that the cert wasn't revoked between the time at which we're // constructing this snapshot CRL and right now. If the cert was revoked // at-or-after the "atTime", we'll just include it in the next generation @@ -1149,25 +1138,17 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromRevokedCertificatesTable(ct return nil } - stream <- &corepb.CRLEntry{ + return stream.Send(&corepb.CRLEntry{ Serial: row.Serial, Reason: int32(row.RevokedReason), RevokedAt: timestamppb.New(row.RevokedDate), - } - - return nil + }) }) - if err != nil { - return err - } - - close(stream) - return nil } // getRevokedCertsFromCertificateStatusTable uses the old certificateStatus // table to implement GetRevokedCerts. -func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(ctx context.Context, req *sapb.GetRevokedCertsRequest, stream chan<- *corepb.CRLEntry) error { +func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { atTime := req.RevokedBefore.AsTime() clauses := ` @@ -1187,13 +1168,12 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(ctx return fmt.Errorf("initializing db map: %w", err) } - rows, err := selector.QueryContext(ctx, clauses, params...) + rows, err := selector.QueryContext(stream.Context(), clauses, params...) if err != nil { return fmt.Errorf("reading db: %w", err) } - fmt.Printf("querying for notAfter >= %s, notAfter < %s\n", req.ExpiresAfter.AsTime(), req.ExpiresBefore.AsTime()) - err = rows.ForEach(func(row *crlEntryModel) error { + return rows.ForEach(func(row *crlEntryModel) error { // Double-check that the cert wasn't revoked between the time at which we're // constructing this snapshot CRL and right now. If the cert was revoked // at-or-after the "atTime", we'll just include it in the next generation @@ -1202,20 +1182,12 @@ func (ssa *SQLStorageAuthorityRO) getRevokedCertsFromCertificateStatusTable(ctx return nil } - stream <- &corepb.CRLEntry{ + return stream.Send(&corepb.CRLEntry{ Serial: row.Serial, Reason: int32(row.RevokedReason), RevokedAt: timestamppb.New(row.RevokedDate), - } - - return nil + }) }) - if err != nil { - return nil - } - - close(stream) - return nil } // GetMaxExpiration returns the timestamp of the farthest-future notAfter date diff --git a/test/certs.go b/test/certs.go index 6dd1ce5a239..2bcd7c5796a 100644 --- a/test/certs.go +++ b/test/certs.go @@ -64,14 +64,18 @@ func LoadSigner(filename string) (crypto.Signer, error) { // ThrowAwayCert is a small test helper function that creates a self-signed // certificate with one SAN. It returns the parsed certificate and its serial // in string form for convenience. +// // The certificate returned from this function is the bare minimum needed for // most tests and isn't a robust example of a complete end entity certificate. +// +// Returned certificates have NotBefore == clk.Now(), and NotBefore 6 days in the +// future. func ThrowAwayCert(t *testing.T, clk clock.Clock) (string, *x509.Certificate) { var nameBytes [3]byte _, _ = rand.Read(nameBytes[:]) name := fmt.Sprintf("%s.example.com", hex.EncodeToString(nameBytes[:])) - var serialBytes [16]byte + var serialBytes [18]byte _, _ = rand.Read(serialBytes[:]) serial := big.NewInt(0).SetBytes(serialBytes[:]) From 824e6e295e894da2a7a28b7377b6fdbbaa8330f1 Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Thu, 9 Jan 2025 15:57:30 -0800 Subject: [PATCH 5/6] More test fixes --- sa/sa_test.go | 41 +++++++++++++---------------------------- sa/saro.go | 3 +++ 2 files changed, 16 insertions(+), 28 deletions(-) diff --git a/sa/sa_test.go b/sa/sa_test.go index 45afe52f711..65bbd376740 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -3391,6 +3391,13 @@ func TestGetRevokedCertsWithShard(t *testing.T) { if err != nil { t.Fatalf("adding cert: %s", err) } + status, err := sa.GetCertificateStatus(ctx, &sapb.Serial{Serial: core.SerialToString(cert.SerialNumber)}) + if err != nil { + t.Fatalf("GetCertificateStatus: %s", err) + } + if status.Status != string(core.OCSPStatusGood) { + t.Fatalf("GetCertificateStatus for new cert: got %s, want %s", status.Status, core.OCSPStatusGood) + } return cert } @@ -3407,18 +3414,6 @@ func TestGetRevokedCertsWithShard(t *testing.T) { fc.Add(2 * 24 * time.Hour) eeCert3 := makeCert() - // Check that it worked. - for _, c := range []*x509.Certificate{eeCert1, eeCert2, eeCert3} { - status, err := sa.GetCertificateStatus( - ctx, &sapb.Serial{Serial: core.SerialToString(c.SerialNumber)}) - if err != nil { - t.Fatalf("GetCertificateStatus: %s", err) - } - if status.Status != string(core.OCSPStatusGood) { - t.Fatalf("GetCertificateStatus for new cert: got %s, want %s", status.Status, core.OCSPStatusGood) - } - } - // Here's a little helper func we'll use to call GetRevokedCerts and return // a sorted list of serials. getRevokedCerts := func(req *sapb.GetRevokedCertsRequest) []string { @@ -3436,11 +3431,11 @@ func TestGetRevokedCertsWithShard(t *testing.T) { if err != nil { t.Fatalf("GetRevokedCerts(%+v): %s", req, err) } - sort.Strings(serials) return serials } - // The basic request covers a time range and shard that should include both certificates. + // The basic request covers a time range that includes eeCert1's and eeCert2's NotAfter, + // but excludes eeCert3's NotAfter. // The ExpiresBefore field is set based on the 6-day lifetime of certs from test.ThrowAwayCert basicRequest := &sapb.GetRevokedCertsRequest{ IssuerNameID: issuerNameID, @@ -3450,7 +3445,7 @@ func TestGetRevokedCertsWithShard(t *testing.T) { RevokedBefore: mustTimestamp("2023-04-01 00:00"), } - // Nothing's been revoked yet. Count should be zero. + // Nothing's been revoked yet. Should get no results. serials := getRevokedCerts(basicRequest) if len(serials) > 0 { t.Errorf("GetRevokedCerts (before revocations) = %s, want []", serials) @@ -3479,19 +3474,8 @@ func TestGetRevokedCertsWithShard(t *testing.T) { // But note that the temporal shard is different from the other two. revoke(eeCert3, 97) - // Check that it worked in the most basic way. - query := "SELECT count(*) FROM revokedCertificates where shardIdx = 97;" - c, err := sa.dbMap.SelectNullInt(ctx, query) - if err != nil { - t.Fatalf("query %q: %s", query, err) - } - if !c.Valid { - t.Fatalf("query %q: no results", query) - } - if c.Int64 != 2 { - t.Fatalf("query %q: got %d results, want %d", query, c.Int64, 2) - } - + // expectSerials registers an error if the provided serials don't match the serials + // of the provded certs (after sorting). expectSerials := func(message string, serials []string, certs ...*x509.Certificate) { t.Helper() var expectedSerials []string @@ -3499,6 +3483,7 @@ func TestGetRevokedCertsWithShard(t *testing.T) { expectedSerials = append(expectedSerials, core.SerialToString(c.SerialNumber)) } sort.Strings(expectedSerials) + sort.Strings(serials) if !reflect.DeepEqual(serials, expectedSerials) { t.Errorf("%s: want %s, got %s", message, expectedSerials, serials) } diff --git a/sa/saro.go b/sa/saro.go index ad2f83fd8cb..7b53888d278 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -1084,6 +1084,9 @@ func (cd crlDeduper) Send(crl *corepb.CRLEntry) error { // notAfter date are included), but the ending timestamp is exclusive (certs // with exactly that notAfter date are *not* included). func (ssa *SQLStorageAuthorityRO) GetRevokedCerts(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { + if core.IsAnyNilOrZero(req.IssuerNameID, req.ExpiresAfter, req.ExpiresBefore, req.RevokedBefore) { + return errors.New("incomplete request for GetRevokedCerts") + } crlDeduper := crlDeduper{ ServerStreamingServer: stream, seen: make(map[string]bool), From aed3458d17031de53e56c061a167bd7aabb885a3 Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Thu, 9 Jan 2025 16:03:50 -0800 Subject: [PATCH 6/6] lints --- sa/sa_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sa/sa_test.go b/sa/sa_test.go index 65bbd376740..bdbb4a3a2c5 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -3386,7 +3386,7 @@ func TestGetRevokedCertsWithShard(t *testing.T) { Der: cert.Raw, RegID: reg.Id, Issued: timestamppb.New(cert.NotBefore), - IssuerNameID: int64(issuerNameID), + IssuerNameID: issuerNameID, }) if err != nil { t.Fatalf("adding cert: %s", err) @@ -3475,7 +3475,7 @@ func TestGetRevokedCertsWithShard(t *testing.T) { revoke(eeCert3, 97) // expectSerials registers an error if the provided serials don't match the serials - // of the provded certs (after sorting). + // of the provided certs (after sorting). expectSerials := func(message string, serials []string, certs ...*x509.Certificate) { t.Helper() var expectedSerials []string