Skip to content

Commit

Permalink
feat: add encryption support (#627)
Browse files Browse the repository at this point in the history
Co-authored-by: Leonidas Vrachnis <leo.al.vra@gmail.com>
  • Loading branch information
mihir20 and lvrach authored Sep 11, 2024
1 parent c2d5275 commit e26c3e4
Show file tree
Hide file tree
Showing 4 changed files with 427 additions and 0 deletions.
66 changes: 66 additions & 0 deletions encrypt/aes_gcm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package encrypt

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"io"
)

type encryptionAESGCM struct {
level int
}

func (e *encryptionAESGCM) Encrypt(src []byte, key string) ([]byte, error) {
if len(key) != e.level/8 {
return nil, fmt.Errorf("key length must be %d bytes", e.level/8)
}

block, err := aes.NewCipher([]byte(key))
if err != nil {
return nil, err
}

aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}

nonce := make([]byte, aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}

ciphertext := aesGCM.Seal(nonce, nonce, src, nil)
return ciphertext, nil
}

func (e *encryptionAESGCM) Decrypt(src []byte, key string) ([]byte, error) {
if len(key) != e.level/8 {
return nil, fmt.Errorf("key length must be %d bytes", e.level/8)
}

block, err := aes.NewCipher([]byte(key))
if err != nil {
return nil, err
}

aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}

nonceSize := aesGCM.NonceSize()
if len(src) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}

nonce, ciphertext := src[:nonceSize], src[nonceSize:]
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}

return plaintext, nil
}
60 changes: 60 additions & 0 deletions encrypt/benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package encrypt

import (
"testing"

"github.com/stretchr/testify/require"
)

/*
BenchmarkEncryptDecrypt
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES128
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES128-12 803842 1444 ns/op 1616 B/op 13 allocs/op
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES192
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES192-12 805350 1443 ns/op 1744 B/op 13 allocs/op
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES256
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES256-12 744871 1516 ns/op 1872 B/op 13 allocs/op
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES128
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES128-12 1900 614516 ns/op 4204053 B/op 13 allocs/op
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES192
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES192-12 1755 672776 ns/op 4204180 B/op 13 allocs/op
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES256
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES256-12 1624 723403 ns/op 4204308 B/op 13 allocs/op
*/
func BenchmarkEncryptDecrypt(b *testing.B) {
tests := []struct {
payload []byte
name string
algo EncryptionAlgorithm
level EncryptionLevel
}{
{[]byte("small payload"), "SMALL_AESGCM_AES128", EncryptionAlgoAESGCM, EncryptionLevelAES128},
{[]byte("small payload"), "SMALL_AESGCM_AES192", EncryptionAlgoAESGCM, EncryptionLevelAES192},
{[]byte("small payload"), "SMALL_AESGCM_AES256", EncryptionAlgoAESGCM, EncryptionLevelAES256},
{make([]byte, 2*1024*1024), "LARGE_AESGCM_AES128", EncryptionAlgoAESGCM, EncryptionLevelAES128},
{make([]byte, 2*1024*1024), "LARGE_AESGCM_AES192", EncryptionAlgoAESGCM, EncryptionLevelAES192},
{make([]byte, 2*1024*1024), "LARGE_AESGCM_AES256", EncryptionAlgoAESGCM, EncryptionLevelAES256},
}

for _, tt := range tests {
b.Run(tt.name, func(b *testing.B) {
b.ReportAllocs()
encrypter, err := New(tt.algo, tt.level)
require.NoError(b, err)

key, err := generateRandomString(int(tt.level / 8))
require.NoError(b, err)

plaintext := tt.payload

b.ResetTimer()
for i := 0; i < b.N; i++ {
ciphertext, err := encrypter.Encrypt(plaintext, key)
require.NoError(b, err)

_, err = encrypter.Decrypt(ciphertext, key)
require.NoError(b, err)
}
})
}
}
102 changes: 102 additions & 0 deletions encrypt/encrypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package encrypt

import (
"fmt"
"strings"
)

// EncryptionAlgorithm is the interface that wraps the encryption algorithm method.
type EncryptionAlgorithm int

func (e EncryptionAlgorithm) String() string {
switch e {
case EncryptionAlgoAESGCM:
return "aes-gcm"
default:
return ""
}
}

// EncryptionLevel is the interface that wraps the encryption level method.
type EncryptionLevel int

func (e EncryptionLevel) String() string {
switch e {
case EncryptionLevelAES128, EncryptionLevelAES192, EncryptionLevelAES256:
return fmt.Sprintf("%d", e)
default:
return ""
}
}

func NewSettings(algo, level string) (EncryptionAlgorithm, EncryptionLevel, error) {
switch algo {
case "aes-gcm":
switch level {
case "128":
return EncryptionAlgoAESGCM, EncryptionLevelAES128, nil
case "192":
return EncryptionAlgoAESGCM, EncryptionLevelAES192, nil
case "256":
return EncryptionAlgoAESGCM, EncryptionLevelAES256, nil
default:
return 0, 0, fmt.Errorf("unknown encryption level for %s: %s", algo, level)
}
default:
return 0, 0, fmt.Errorf("unknown encryption algorithm: %s", algo)
}
}

var (
EncryptionAlgoAESGCM = EncryptionAlgorithm(1)
EncryptionLevelAES128 = EncryptionLevel(128)
EncryptionLevelAES192 = EncryptionLevel(192)
EncryptionLevelAES256 = EncryptionLevel(256)
)

func New(algo EncryptionAlgorithm, level EncryptionLevel) (*Encryptor, error) {
var err error
algo, level, err = NewSettings(algo.String(), level.String())
if err != nil {
return nil, err
}

switch algo {
case EncryptionAlgoAESGCM:
return &Encryptor{encryptionAESGCM: &encryptionAESGCM{level: int(level)}}, nil
default:
return nil, fmt.Errorf("unknown encryption algorithm: %d", algo)
}
}

type Encryptor struct {
*encryptionAESGCM
}

func (e *Encryptor) Encrypt(src []byte, key string) ([]byte, error) {
if e.encryptionAESGCM != nil {
return e.encryptionAESGCM.Encrypt(src, key)
}
return nil, fmt.Errorf("no encryption method available")
}

func (e *Encryptor) Decrypt(src []byte, key string) ([]byte, error) {
if e.encryptionAESGCM != nil {
return e.encryptionAESGCM.Decrypt(src, key)
}
return nil, fmt.Errorf("no decryption method available")
}

// SerializeSettings converts the EncryptionAlgorithm and EncryptionLevel to a string.
func SerializeSettings(algo EncryptionAlgorithm, level EncryptionLevel) string {
return fmt.Sprintf("%s:%s", algo.String(), level.String())
}

// DeserializeSettings converts a string to EncryptionAlgorithm and EncryptionLevel.
func DeserializeSettings(settings string) (EncryptionAlgorithm, EncryptionLevel, error) {
parts := strings.Split(settings, ":")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid settings format")
}
return NewSettings(parts[0], parts[1])
}
Loading

0 comments on commit e26c3e4

Please sign in to comment.