Skip to content

Commit

Permalink
feat(awslambda): supports bucket upload.
Browse files Browse the repository at this point in the history
feat(awslambda): wires up bucket upload.

fixup! feat(awslambda): awaits version publish completion.
  • Loading branch information
outofcoffee committed Aug 2, 2024
1 parent b052339 commit 25a51d2
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 27 deletions.
126 changes: 126 additions & 0 deletions remote/awslambda/bucket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package awslambda

import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
awssession "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/google/uuid"
"os"
"path"
"strings"
)

func (m LambdaRemote) uploadBundleToBucket(zipContents *[]byte) (bucketName string, localBundlePath string, err error) {
localBundlePath, err = m.writeBundleToTempFile(zipContents)
if err != nil {
return "", "", err
}
bucketName, err = m.getBucketName()
if err != nil {
return "", "", err
}
if err = m.uploadToBucket(localBundlePath, bucketName); err != nil {
return "", "", fmt.Errorf("failed to upload file %v to bucket %v: %v", localBundlePath, bucketName, err)
}
return bucketName, localBundlePath, nil
}

func (m LambdaRemote) writeBundleToTempFile(zipContents *[]byte) (localBundlePath string, err error) {
temp, err := os.CreateTemp(os.TempDir(), "imposter-bundle-*.zip")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %v", err)
}
defer temp.Close()

localBundlePath = temp.Name()
if err = os.WriteFile(localBundlePath, *zipContents, 0644); err != nil {
return "", fmt.Errorf("failed to write bundle to temp file %v: %v", temp, err)
}
logger.Tracef("wrote bundle to temp file: %v", localBundlePath)
return localBundlePath, nil
}

func (m LambdaRemote) getBucketName() (bucketName string, err error) {
bucketName = m.Config[configKeyBucketName]
if bucketName == "" {
bucketName = "imposter-mock-" + strings.ReplaceAll(uuid.New().String(), "-", "")
m.Config[configKeyBucketName] = bucketName
if err = m.SaveConfig(); err != nil {
return "", fmt.Errorf("failed to save bucket name %v in config: %v", bucketName, err)
}
}
return bucketName, nil
}

func (m LambdaRemote) uploadToBucket(localPath string, bucketName string) error {
region, _, svc, err := m.initS3Client()
if err != nil {
return fmt.Errorf("failed to initialise S3 client: %v", err)
}
if err = ensureBucket(svc, bucketName, region); err != nil {
return fmt.Errorf("failed to ensure bucket %v exists: %v", bucketName, err)
}
if err = upload(svc, bucketName, localPath); err != nil {
return fmt.Errorf("failed to upload file %v to bucket %v: %v", localPath, bucketName, err)
}
return nil
}

func ensureBucket(svc *s3.S3, bucketName string, region string) error {
logger.Tracef("checking for bucket %v in region %v", bucketName, region)

if _, err := svc.HeadBucket(&s3.HeadBucketInput{Bucket: aws.String(bucketName)}); err != nil {
if err = createBucket(svc, bucketName, region); err != nil {
return err
}
}
logger.Debugf("bucket %v exists", bucketName)
return nil
}

func createBucket(svc *s3.S3, bucketName string, region string) error {
logger.Tracef("creating bucket %v in region %v", bucketName, region)

_, err := svc.CreateBucket(&s3.CreateBucketInput{
Bucket: aws.String(bucketName),
CreateBucketConfiguration: &s3.CreateBucketConfiguration{
LocationConstraint: aws.String(region),
},
})
if err != nil {
return fmt.Errorf("failed to create bucket %v in region %v: %v", bucketName, region, err)
}
logger.Debugf("created bucket %v in region %v", bucketName, region)
return nil
}

func upload(svc *s3.S3, bucketName string, localPath string) error {
logger.Tracef("uploading file %v to bucket %v", localPath, bucketName)

file, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("failed to read file: %v: %v", localPath, err)
}
defer file.Close()

_, err = svc.PutObject(&s3.PutObjectInput{
Body: aws.ReadSeekCloser(file),
Bucket: aws.String(bucketName),
Key: aws.String(path.Base(localPath)),
})
if err != nil {
return fmt.Errorf("failed to upload file %v to bucket %v: %v", localPath, bucketName, err)
}
logger.Debugf("uploaded file %v to bucket %v", localPath, bucketName)
return nil
}

func (m LambdaRemote) initS3Client() (region string, sess *awssession.Session, svc *s3.S3, err error) {
if m.Config[configKeyRegion] == "" {
return "", nil, nil, fmt.Errorf("region cannot be null")
}
region, sess = m.startAwsSession()
svc = s3.New(sess)
return region, sess, svc, nil
}
4 changes: 4 additions & 0 deletions remote/awslambda/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const defaultMemory = 768

const configKeyAnonAccess = "anonAccess"
const configKeyArchitecture = "architecture"
const configKeyBucketName = "bucketName"
const configKeyCreateAlias = "createAlias"
const configKeyEngineVersion = "engineVersion"
const configKeyFuncName = "functionName"
Expand All @@ -34,10 +35,12 @@ const configKeyMemory = "memory"
const configKeyPublishVersion = "publishVersion"
const configKeyRegion = "region"
const configKeySnapStart = "snapStart"
const configKeyUploadToBucket = "uploadToBucket"

var configKeys = []string{
configKeyAnonAccess,
configKeyArchitecture,
configKeyBucketName,
configKeyCreateAlias,
configKeyEngineVersion,
configKeyFuncName,
Expand All @@ -46,6 +49,7 @@ var configKeys = []string{
configKeyPublishVersion,
configKeyRegion,
configKeySnapStart,
configKeyUploadToBucket,
}

var logger = logging.GetLogger()
Expand Down
89 changes: 63 additions & 26 deletions remote/awslambda/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ const liveAliasName = "live"
const readyTimeoutSeconds = 360

func (m LambdaRemote) Deploy() error {
region, sess, svc, err := m.initAws()
region, sess, svc, err := m.initLambdaClient()
if err != nil {
return err
return fmt.Errorf("failed to initialise lambda client: %v", err)
}

roleName := stringutil.GetFirstNonEmpty(m.Config[configKeyIamRoleName], defaultIamRoleName)
Expand All @@ -39,6 +39,22 @@ func (m LambdaRemote) Deploy() error {
logger.Fatal(err)
}

var location codeLocation
if stringutil.ToBoolWithDefault(m.Config[configKeyUploadToBucket], true) {
bucketName, localBundlePath, err := m.uploadBundleToBucket(zipContents)
if err != nil {
return err
}
location = codeLocation{
bucket: bucketName,
objectKey: path.Base(localBundlePath),
}
} else {
location = codeLocation{
zipContents: zipContents,
}
}

snapStart := stringutil.ToBool(m.Config[configKeySnapStart])
funcArn, err := ensureFunctionExists(
svc,
Expand All @@ -47,7 +63,7 @@ func (m LambdaRemote) Deploy() error {
roleArn,
m.getMemorySize(),
m.getArchitecture(),
zipContents,
location,
snapStart,
)
if err != nil {
Expand All @@ -65,9 +81,7 @@ func (m LambdaRemote) Deploy() error {
}

var arnForUrl string

createAlias := stringutil.ToBool(m.Config[configKeyCreateAlias])
if createAlias {
if stringutil.ToBool(m.Config[configKeyCreateAlias]) {
aliasArn, err := createOrUpdateAlias(svc, funcArn, versionId, liveAliasName)
if err != nil {
return err
Expand Down Expand Up @@ -156,15 +170,18 @@ func awaitReady(svc *lambda.Lambda, funcArn string, checkVersion string) error {
if err != nil {
return err
}
logger.Tracef("function %v [version: %v] config %v", funcArn, checkVersion, *configuration)

var lastUpdateInProgress bool
if configuration.LastUpdateStatus != nil {
lastUpdateInProgress = *configuration.LastUpdateStatus != lambda.LastUpdateStatusSuccessful
lastUpdateStatus := *configuration.LastUpdateStatus
logger.Tracef("function %v [version: %v] lastUpdateStatus=%v", funcArn, checkVersion, lastUpdateStatus)
lastUpdateInProgress = lastUpdateStatus != lambda.LastUpdateStatusSuccessful
}
var stateIsPending bool
if configuration.State != nil {
stateIsPending = *configuration.State == lambda.StatePending
currentState := *configuration.State
logger.Tracef("function %v [version: %v] state=%v", funcArn, checkVersion, currentState)
stateIsPending = currentState == lambda.StatePending
}
if !lastUpdateInProgress && !stateIsPending {
logger.Debugf("function %v [version: %v] is ready", funcArn, checkVersion)
Expand Down Expand Up @@ -226,9 +243,9 @@ func createAlias(svc *lambda.Lambda, funcArn string, versionId string, aliasName
}

func (m LambdaRemote) Undeploy() error {
region, _, svc, err := m.initAws()
region, _, svc, err := m.initLambdaClient()
if err != nil {
return err
return fmt.Errorf("failed to initialise lambda client: %v", err)
}

funcName := m.getFunctionName()
Expand Down Expand Up @@ -259,9 +276,9 @@ func (m LambdaRemote) Undeploy() error {
}

func (m LambdaRemote) GetEndpoint() (*remote.EndpointDetails, error) {
_, _, svc, err := m.initAws()
_, _, svc, err := m.initLambdaClient()
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to initialise lambda client: %v", err)
}

var funcArn string
Expand Down Expand Up @@ -292,7 +309,7 @@ func (m LambdaRemote) GetEndpoint() (*remote.EndpointDetails, error) {
return details, nil
}

func (m LambdaRemote) initAws() (region string, sess *awssession.Session, svc *lambda.Lambda, err error) {
func (m LambdaRemote) initLambdaClient() (region string, sess *awssession.Session, svc *lambda.Lambda, err error) {
if m.Config[configKeyRegion] == "" {
return "", nil, nil, fmt.Errorf("region cannot be null")
}
Expand Down Expand Up @@ -335,10 +352,11 @@ func ensureFunctionExists(
roleArn string,
memoryMb int64,
arch LambdaArchitecture,
zipContents *[]byte,
location codeLocation,
snapStart bool,
) (string, error) {
var funcArn string

result, err := checkFunctionExists(svc, funcName)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
Expand All @@ -350,7 +368,7 @@ func ensureFunctionExists(
roleArn,
memoryMb,
arch,
zipContents,
location,
snapStart,
)
if err != nil {
Expand All @@ -366,11 +384,10 @@ func ensureFunctionExists(

} else {
funcArn = *result.Configuration.FunctionArn

if err = ensureSnapStart(svc, funcArn, snapStart); err != nil {
return "", err
}
if err = updateFunctionCode(svc, funcArn, zipContents); err != nil {
if err = updateFunctionCode(svc, funcArn, location); err != nil {
return "", err
}
}
Expand All @@ -384,14 +401,20 @@ func checkFunctionExists(svc *lambda.Lambda, functionName string) (*lambda.GetFu
return result, err
}

type codeLocation struct {
bucket string
objectKey string
zipContents *[]byte
}

func createFunction(
svc *lambda.Lambda,
region string,
funcName string,
roleArn string,
memoryMb int64,
arch LambdaArchitecture,
zipContents *[]byte,
location codeLocation,
snapStart bool,
) (arn string, err error) {
logger.Debugf("creating function: %s in region: %s", funcName, region)
Expand All @@ -404,9 +427,6 @@ func createFunction(
}

input := &lambda.CreateFunctionInput{
Code: &lambda.FunctionCode{
ZipFile: *zipContents,
},
FunctionName: aws.String(funcName),
Handler: aws.String("io.gatehill.imposter.awslambda.HandlerV2"),
MemorySize: aws.Int64(memoryMb),
Expand All @@ -416,6 +436,17 @@ func createFunction(
Environment: buildEnv(),
}

if location.bucket != "" {
input.SetCode(&lambda.FunctionCode{
S3Bucket: aws.String(location.bucket),
S3Key: aws.String(location.objectKey),
})
} else {
input.SetCode(&lambda.FunctionCode{
ZipFile: *location.zipContents,
})
}

input.SetSnapStart(&lambda.SnapStart{
ApplyOn: aws.String(desiredConfig),
})
Expand All @@ -434,12 +465,18 @@ func createFunction(
return *result.FunctionArn, nil
}

func updateFunctionCode(svc *lambda.Lambda, funcArn string, zipContents *[]byte) error {
func updateFunctionCode(svc *lambda.Lambda, funcArn string, location codeLocation) error {
logger.Debugf("updating function code for: %s", funcArn)
_, err := svc.UpdateFunctionCode(&lambda.UpdateFunctionCodeInput{
input := &lambda.UpdateFunctionCodeInput{
FunctionName: aws.String(funcArn),
ZipFile: *zipContents,
})
}
if location.bucket != "" {
input.S3Bucket = aws.String(location.bucket)
input.S3Key = aws.String(location.objectKey)
} else {
input.ZipFile = *location.zipContents
}
_, err := svc.UpdateFunctionCode(input)
if err != nil {
return err
}
Expand Down
6 changes: 5 additions & 1 deletion stringutil/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ func Sha1hash(input []byte) string {
}

func ToBool(input string) bool {
return ToBoolWithDefault(input, false)
}

func ToBoolWithDefault(input string, defaultValue bool) bool {
parsed, err := strconv.ParseBool(input)
if err != nil {
return false
return defaultValue
}
return parsed
}

0 comments on commit 25a51d2

Please sign in to comment.