Skip to content

Commit

Permalink
[v16] Allow athenaevents to bypass SNS (#42463)
Browse files Browse the repository at this point in the history
* Allow athenaevents to bypass SNS

* Double up the athena tests that hit AWS
  • Loading branch information
espadolini authored Jun 5, 2024
1 parent 4dbb5ce commit fb46701
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 110 deletions.
5 changes: 2 additions & 3 deletions examples/dynamoathenamigration/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,14 @@ func newMigrateTask(ctx context.Context, cfg Config, awsCfg aws.Config) (*task,
dynamoClient: dynamodb.NewFromConfig(awsCfg),
s3Downloader: manager.NewDownloader(s3Client),
eventsEmitter: athena.NewPublisher(athena.PublisherConfig{
TopicARN: cfg.TopicARN,
SNSPublisher: sns.NewFromConfig(awsCfg, func(o *sns.Options) {
MessagePublisher: athena.SNSPublisherFunc(cfg.TopicARN, sns.NewFromConfig(awsCfg, func(o *sns.Options) {
o.Retryer = retry.NewStandard(func(so *retry.StandardOptions) {
so.MaxAttempts = 30
so.MaxBackoff = 1 * time.Minute
// Use bigger rate limit to handle default sdk throttling: https://github.com/aws/aws-sdk-go-v2/issues/1665
so.RateLimiter = ratelimit.NewTokenRateLimit(1000000)
})
}),
})),
Uploader: manager.NewUploader(s3Client),
PayloadBucket: cfg.LargePayloadBucket,
PayloadPrefix: cfg.LargePayloadPrefix,
Expand Down
13 changes: 11 additions & 2 deletions lib/events/athena/athena.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ const (
defaultBatchItems = 20000
// defaultBatchInterval defines default batch interval.
defaultBatchInterval = 1 * time.Minute

// topicARNBypass is a magic value for TopicARN that signifies that the
// Athena audit log should send messages directly to SQS instead of going
// through a SNS topic.
topicARNBypass = "bypass"
)

// Config structure represents Athena configuration.
Expand All @@ -62,7 +67,9 @@ type Config struct {

// Publisher settings.

// TopicARN where to emit events in SNS (required).
// TopicARN where to emit events in SNS (required). If TopicARN is "bypass"
// (i.e. [topicArnBypass]) then the events should be emitted directly to the
// SQS queue reachable at QueryURL.
TopicARN string
// LargeEventsS3 is location on S3 where temporary large events (>256KB)
// are stored before converting it to Parquet and moving to long term
Expand Down Expand Up @@ -106,7 +113,9 @@ type Config struct {

// Batcher settings.

// QueueURL is URL of SQS, which is set as subscriber to SNS topic (required).
// QueueURL is URL of SQS, which is set as subscriber to SNS topic if we're
// emitting to SNS, or used directly to send messages if we're bypassing SNS
// (required).
QueueURL string
// BatchMaxItems defines how many items can be stored in single Parquet
// batch (optional).
Expand Down
6 changes: 3 additions & 3 deletions lib/events/athena/athena_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func TestPublisherConsumer(t *testing.T) {
ID: uuid.NewString(),
Time: time.Now().UTC(),
Type: events.AppCreateEvent,
Code: strings.Repeat("d", 2*maxSNSMessageSize),
Code: strings.Repeat("d", 2*maxDirectMessageSize),
},
AppMetadata: apievents.AppMetadata{
AppName: "app-large",
Expand Down Expand Up @@ -418,8 +418,8 @@ func TestPublisherConsumer(t *testing.T) {
fq := newFakeQueue()
p := &publisher{
PublisherConfig: PublisherConfig{
SNSPublisher: fq,
Uploader: fS3,
MessagePublisher: fq,
Uploader: fS3,
},
}
cfg := validCollectCfgForTests(t)
Expand Down
56 changes: 28 additions & 28 deletions lib/events/athena/fakequeue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ import (
"sync"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sns"
snsTypes "github.com/aws/aws-sdk-go-v2/service/sns/types"
"github.com/aws/aws-sdk-go-v2/service/sqs"
sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
"github.com/google/uuid"
)

Expand All @@ -42,39 +40,55 @@ type fakeQueue struct {
}

type fakeQueueMessage struct {
payload string
attributes map[string]snsTypes.MessageAttributeValue
payload string
s3Based bool
}

func newFakeQueue() *fakeQueue {
return &fakeQueue{}
}

func (f *fakeQueue) Publish(ctx context.Context, params *sns.PublishInput, optFns ...func(*sns.Options)) (*sns.PublishOutput, error) {
func (f *fakeQueue) Publish(ctx context.Context, base64Body string, s3Based bool) error {
f.mu.Lock()
defer f.mu.Unlock()
if len(f.publishErrors) > 0 {
err := f.publishErrors[0]
f.publishErrors = f.publishErrors[1:]
return nil, err
return err
}
f.msgs = append(f.msgs, fakeQueueMessage{
payload: *params.Message,
attributes: params.MessageAttributes,
payload: base64Body,
s3Based: s3Based,
})
return nil, nil
return nil
}

func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
msgs := f.dequeue()
if len(msgs) == 0 {
return &sqs.ReceiveMessageOutput{}, nil
}
out := make([]sqsTypes.Message, 0, 10)
out := make([]sqstypes.Message, 0, len(msgs))
for _, msg := range msgs {
out = append(out, sqsTypes.Message{
Body: aws.String(msg.payload),
MessageAttributes: snsToSqsAttributes(msg.attributes),
var messageAttributes map[string]sqstypes.MessageAttributeValue
if msg.s3Based {
messageAttributes = map[string]sqstypes.MessageAttributeValue{
payloadTypeAttr: {
DataType: aws.String("String"),
StringValue: aws.String(payloadTypeS3Based),
},
}
} else {
messageAttributes = map[string]sqstypes.MessageAttributeValue{
payloadTypeAttr: {
DataType: aws.String("String"),
StringValue: aws.String(payloadTypeRawProtoEvent),
},
}
}
out = append(out, sqstypes.Message{
Body: &msg.payload,
MessageAttributes: messageAttributes,
ReceiptHandle: aws.String(uuid.NewString()),
})
}
Expand All @@ -83,20 +97,6 @@ func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessa
}, nil
}

func snsToSqsAttributes(in map[string]snsTypes.MessageAttributeValue) map[string]sqsTypes.MessageAttributeValue {
if in == nil {
return nil
}
out := map[string]sqsTypes.MessageAttributeValue{}
for k, v := range in {
out[k] = sqsTypes.MessageAttributeValue{
DataType: v.DataType,
StringValue: v.StringValue,
}
}
return out
}

func (f *fakeQueue) dequeue() []fakeQueueMessage {
f.mu.Lock()
defer f.mu.Unlock()
Expand Down
55 changes: 51 additions & 4 deletions lib/events/athena/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,20 @@ import (
)

func TestIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T) {
t.Run("sns", func(t *testing.T) {
const bypassSNSFalse = false
testIntegrationAthenaSearchSessionEventsBySessionID(t, bypassSNSFalse)
})
t.Run("sqs", func(t *testing.T) {
const bypassSNSTrue = true
testIntegrationAthenaSearchSessionEventsBySessionID(t, bypassSNSTrue)
})
}

func testIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T, bypassSNS bool) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{})
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{BypassSNS: bypassSNS})
auditLogger := &EventuallyConsistentAuditLogger{
Inner: ac.log,
// Additional 5s is used to compensate for uploading parquet on s3.
Expand All @@ -55,9 +66,20 @@ func TestIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T) {
}

func TestIntegrationAthenaSessionEventsCRUD(t *testing.T) {
t.Run("sns", func(t *testing.T) {
const bypassSNSFalse = false
testIntegrationAthenaSessionEventsCRUD(t, bypassSNSFalse)
})
t.Run("sqs", func(t *testing.T) {
const bypassSNSTrue = true
testIntegrationAthenaSessionEventsCRUD(t, bypassSNSTrue)
})
}

func testIntegrationAthenaSessionEventsCRUD(t *testing.T, bypassSNS bool) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{})
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{BypassSNS: bypassSNS})
auditLogger := &EventuallyConsistentAuditLogger{
Inner: ac.log,
// Additional 5s is used to compensate for uploading parquet on s3.
Expand All @@ -72,9 +94,20 @@ func TestIntegrationAthenaSessionEventsCRUD(t *testing.T) {
}

func TestIntegrationAthenaEventPagination(t *testing.T) {
t.Run("sns", func(t *testing.T) {
const bypassSNSFalse = false
testIntegrationAthenaEventPagination(t, bypassSNSFalse)
})
t.Run("sqs", func(t *testing.T) {
const bypassSNSTrue = true
testIntegrationAthenaEventPagination(t, bypassSNSTrue)
})
}

func testIntegrationAthenaEventPagination(t *testing.T, bypassSNS bool) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{})
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{BypassSNS: bypassSNS})
auditLogger := &EventuallyConsistentAuditLogger{
Inner: ac.log,
// Additional 5s is used to compensate for uploading parquet on s3.
Expand All @@ -89,10 +122,24 @@ func TestIntegrationAthenaEventPagination(t *testing.T) {
}

func TestIntegrationAthenaLargeEvents(t *testing.T) {
t.Run("sns", func(t *testing.T) {
const bypassSNSFalse = false
testIntegrationAthenaLargeEvents(t, bypassSNSFalse)
})
t.Run("sqs", func(t *testing.T) {
const bypassSNSTrue = true
testIntegrationAthenaLargeEvents(t, bypassSNSTrue)
})
}

func testIntegrationAthenaLargeEvents(t *testing.T, bypassSNS bool) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()

ac := SetupAthenaContext(t, ctx, AthenaContextConfig{MaxBatchSize: 1})
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{
MaxBatchSize: 1,
BypassSNS: bypassSNS,
})
in := &apievents.SessionStart{
Metadata: apievents.Metadata{
Index: 2,
Expand Down
Loading

0 comments on commit fb46701

Please sign in to comment.