Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] Allow athenaevents to bypass SNS #42463

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/dynamoathenamigration/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,14 @@ func newMigrateTask(ctx context.Context, cfg Config, awsCfg aws.Config) (*task,
dynamoClient: dynamodb.NewFromConfig(awsCfg),
s3Downloader: manager.NewDownloader(s3Client),
eventsEmitter: athena.NewPublisher(athena.PublisherConfig{
TopicARN: cfg.TopicARN,
SNSPublisher: sns.NewFromConfig(awsCfg, func(o *sns.Options) {
MessagePublisher: athena.SNSPublisherFunc(cfg.TopicARN, sns.NewFromConfig(awsCfg, func(o *sns.Options) {
o.Retryer = retry.NewStandard(func(so *retry.StandardOptions) {
so.MaxAttempts = 30
so.MaxBackoff = 1 * time.Minute
// Use bigger rate limit to handle default sdk throttling: https://github.com/aws/aws-sdk-go-v2/issues/1665
so.RateLimiter = ratelimit.NewTokenRateLimit(1000000)
})
}),
})),
Uploader: manager.NewUploader(s3Client),
PayloadBucket: cfg.LargePayloadBucket,
PayloadPrefix: cfg.LargePayloadPrefix,
Expand Down
13 changes: 11 additions & 2 deletions lib/events/athena/athena.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ const (
defaultBatchItems = 20000
// defaultBatchInterval defines default batch interval.
defaultBatchInterval = 1 * time.Minute

// topicARNBypass is a magic value for TopicARN that signifies that the
// Athena audit log should send messages directly to SQS instead of going
// through a SNS topic.
topicARNBypass = "bypass"
)

// Config structure represents Athena configuration.
Expand All @@ -62,7 +67,9 @@ type Config struct {

// Publisher settings.

// TopicARN where to emit events in SNS (required).
// TopicARN where to emit events in SNS (required). If TopicARN is "bypass"
// (i.e. [topicArnBypass]) then the events should be emitted directly to the
// SQS queue reachable at QueryURL.
TopicARN string
// LargeEventsS3 is location on S3 where temporary large events (>256KB)
// are stored before converting it to Parquet and moving to long term
Expand Down Expand Up @@ -106,7 +113,9 @@ type Config struct {

// Batcher settings.

// QueueURL is URL of SQS, which is set as subscriber to SNS topic (required).
// QueueURL is URL of SQS, which is set as subscriber to SNS topic if we're
// emitting to SNS, or used directly to send messages if we're bypassing SNS
// (required).
QueueURL string
// BatchMaxItems defines how many items can be stored in single Parquet
// batch (optional).
Expand Down
6 changes: 3 additions & 3 deletions lib/events/athena/athena_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func TestPublisherConsumer(t *testing.T) {
ID: uuid.NewString(),
Time: time.Now().UTC(),
Type: events.AppCreateEvent,
Code: strings.Repeat("d", 2*maxSNSMessageSize),
Code: strings.Repeat("d", 2*maxDirectMessageSize),
},
AppMetadata: apievents.AppMetadata{
AppName: "app-large",
Expand Down Expand Up @@ -418,8 +418,8 @@ func TestPublisherConsumer(t *testing.T) {
fq := newFakeQueue()
p := &publisher{
PublisherConfig: PublisherConfig{
SNSPublisher: fq,
Uploader: fS3,
MessagePublisher: fq,
Uploader: fS3,
},
}
cfg := validCollectCfgForTests(t)
Expand Down
56 changes: 28 additions & 28 deletions lib/events/athena/fakequeue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ import (
"sync"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sns"
snsTypes "github.com/aws/aws-sdk-go-v2/service/sns/types"
"github.com/aws/aws-sdk-go-v2/service/sqs"
sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
"github.com/google/uuid"
)

Expand All @@ -42,39 +40,55 @@ type fakeQueue struct {
}

type fakeQueueMessage struct {
payload string
attributes map[string]snsTypes.MessageAttributeValue
payload string
s3Based bool
}

func newFakeQueue() *fakeQueue {
return &fakeQueue{}
}

func (f *fakeQueue) Publish(ctx context.Context, params *sns.PublishInput, optFns ...func(*sns.Options)) (*sns.PublishOutput, error) {
func (f *fakeQueue) Publish(ctx context.Context, base64Body string, s3Based bool) error {
f.mu.Lock()
defer f.mu.Unlock()
if len(f.publishErrors) > 0 {
err := f.publishErrors[0]
f.publishErrors = f.publishErrors[1:]
return nil, err
return err
}
f.msgs = append(f.msgs, fakeQueueMessage{
payload: *params.Message,
attributes: params.MessageAttributes,
payload: base64Body,
s3Based: s3Based,
})
return nil, nil
return nil
}

func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
msgs := f.dequeue()
if len(msgs) == 0 {
return &sqs.ReceiveMessageOutput{}, nil
}
out := make([]sqsTypes.Message, 0, 10)
out := make([]sqstypes.Message, 0, len(msgs))
for _, msg := range msgs {
out = append(out, sqsTypes.Message{
Body: aws.String(msg.payload),
MessageAttributes: snsToSqsAttributes(msg.attributes),
var messageAttributes map[string]sqstypes.MessageAttributeValue
if msg.s3Based {
messageAttributes = map[string]sqstypes.MessageAttributeValue{
payloadTypeAttr: {
DataType: aws.String("String"),
StringValue: aws.String(payloadTypeS3Based),
},
}
} else {
messageAttributes = map[string]sqstypes.MessageAttributeValue{
payloadTypeAttr: {
DataType: aws.String("String"),
StringValue: aws.String(payloadTypeRawProtoEvent),
},
}
}
out = append(out, sqstypes.Message{
Body: &msg.payload,
MessageAttributes: messageAttributes,
ReceiptHandle: aws.String(uuid.NewString()),
})
}
Expand All @@ -83,20 +97,6 @@ func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessa
}, nil
}

func snsToSqsAttributes(in map[string]snsTypes.MessageAttributeValue) map[string]sqsTypes.MessageAttributeValue {
if in == nil {
return nil
}
out := map[string]sqsTypes.MessageAttributeValue{}
for k, v := range in {
out[k] = sqsTypes.MessageAttributeValue{
DataType: v.DataType,
StringValue: v.StringValue,
}
}
return out
}

func (f *fakeQueue) dequeue() []fakeQueueMessage {
f.mu.Lock()
defer f.mu.Unlock()
Expand Down
55 changes: 51 additions & 4 deletions lib/events/athena/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,20 @@ import (
)

func TestIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T) {
t.Run("sns", func(t *testing.T) {
const bypassSNSFalse = false
testIntegrationAthenaSearchSessionEventsBySessionID(t, bypassSNSFalse)
})
t.Run("sqs", func(t *testing.T) {
const bypassSNSTrue = true
testIntegrationAthenaSearchSessionEventsBySessionID(t, bypassSNSTrue)
})
}

func testIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T, bypassSNS bool) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{})
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{BypassSNS: bypassSNS})
auditLogger := &EventuallyConsistentAuditLogger{
Inner: ac.log,
// Additional 5s is used to compensate for uploading parquet on s3.
Expand All @@ -55,9 +66,20 @@ func TestIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T) {
}

func TestIntegrationAthenaSessionEventsCRUD(t *testing.T) {
t.Run("sns", func(t *testing.T) {
const bypassSNSFalse = false
testIntegrationAthenaSessionEventsCRUD(t, bypassSNSFalse)
})
t.Run("sqs", func(t *testing.T) {
const bypassSNSTrue = true
testIntegrationAthenaSessionEventsCRUD(t, bypassSNSTrue)
})
}

func testIntegrationAthenaSessionEventsCRUD(t *testing.T, bypassSNS bool) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{})
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{BypassSNS: bypassSNS})
auditLogger := &EventuallyConsistentAuditLogger{
Inner: ac.log,
// Additional 5s is used to compensate for uploading parquet on s3.
Expand All @@ -72,9 +94,20 @@ func TestIntegrationAthenaSessionEventsCRUD(t *testing.T) {
}

func TestIntegrationAthenaEventPagination(t *testing.T) {
t.Run("sns", func(t *testing.T) {
const bypassSNSFalse = false
testIntegrationAthenaEventPagination(t, bypassSNSFalse)
})
t.Run("sqs", func(t *testing.T) {
const bypassSNSTrue = true
testIntegrationAthenaEventPagination(t, bypassSNSTrue)
})
}

func testIntegrationAthenaEventPagination(t *testing.T, bypassSNS bool) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{})
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{BypassSNS: bypassSNS})
auditLogger := &EventuallyConsistentAuditLogger{
Inner: ac.log,
// Additional 5s is used to compensate for uploading parquet on s3.
Expand All @@ -89,10 +122,24 @@ func TestIntegrationAthenaEventPagination(t *testing.T) {
}

func TestIntegrationAthenaLargeEvents(t *testing.T) {
t.Run("sns", func(t *testing.T) {
const bypassSNSFalse = false
testIntegrationAthenaLargeEvents(t, bypassSNSFalse)
})
t.Run("sqs", func(t *testing.T) {
const bypassSNSTrue = true
testIntegrationAthenaLargeEvents(t, bypassSNSTrue)
})
}

func testIntegrationAthenaLargeEvents(t *testing.T, bypassSNS bool) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()

ac := SetupAthenaContext(t, ctx, AthenaContextConfig{MaxBatchSize: 1})
ac := SetupAthenaContext(t, ctx, AthenaContextConfig{
MaxBatchSize: 1,
BypassSNS: bypassSNS,
})
in := &apievents.SessionStart{
Metadata: apievents.Metadata{
Index: 2,
Expand Down
Loading
Loading