Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only omit RSA primes if precomputed values are missing in OpenSSL 3.0 and 3.1 #163

Merged
merged 6 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,24 @@ func newRSAKey3(isPriv bool, n, e, d, p, q, dp, dq, qinv BigInt) (C.GO_EVP_PKEY_
}
comps = append(comps, required[:]...)

// OpenSSL 3.0 and 3.1 required all the precomputed values if
// P and Q are present. See:
// https://github.com/openssl/openssl/pull/22334
if p != nil && q != nil && dp != nil && dq != nil && qinv != nil {
precomputed := [...]bigIntParam{
{OSSL_PKEY_PARAM_RSA_FACTOR1, p}, {OSSL_PKEY_PARAM_RSA_FACTOR2, q},
{OSSL_PKEY_PARAM_RSA_EXPONENT1, dp}, {OSSL_PKEY_PARAM_RSA_EXPONENT2, dq}, {OSSL_PKEY_PARAM_RSA_COEFFICIENT1, qinv},
if p != nil && q != nil {
allPrecomputedExists := dp != nil && dq != nil && qinv != nil
// The precomputed values should only be passed if P and Q are present
// and every precomputed value is present. (If any precomputed value is
// missing, don't pass any of them.)
//
// In OpenSSL 3.0 and 3.1, we must also omit P and Q if any precomputed
// value is missing. See https://github.com/openssl/openssl/pull/22334
if vMinor >= 2 || allPrecomputedExists {
comps = append(comps, bigIntParam{OSSL_PKEY_PARAM_RSA_FACTOR1, p}, bigIntParam{OSSL_PKEY_PARAM_RSA_FACTOR2, q})
}
if allPrecomputedExists {
comps = append(comps,
bigIntParam{OSSL_PKEY_PARAM_RSA_EXPONENT1, dp},
bigIntParam{OSSL_PKEY_PARAM_RSA_EXPONENT2, dq},
bigIntParam{OSSL_PKEY_PARAM_RSA_COEFFICIENT1, qinv},
)
}
comps = append(comps, precomputed[:]...)
}

for _, comp := range comps {
Expand Down
97 changes: 85 additions & 12 deletions rsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"crypto"
"crypto/rsa"
"fmt"
"math/big"
"strconv"
"testing"
Expand All @@ -13,11 +14,80 @@ import (
)

func TestRSAKeyGeneration(t *testing.T) {
for _, size := range []int{2048, 3072} {
size := size
t.Run(strconv.Itoa(size), func(t *testing.T) {
t.Parallel()
_, _, _, _, _, _, _, _, err := openssl.GenerateKeyRSA(size)
if err != nil {
t.Fatal(err)
}
})
}
}

func testRSAEncryptDecryptPKCS1(t *testing.T, priv *openssl.PrivateKeyRSA, pub *openssl.PublicKeyRSA) {
msg := []byte("hi!")
enc, err := openssl.EncryptRSAPKCS1(pub, msg)
if err != nil {
t.Fatalf("EncryptPKCS1v15: %v", err)
}
dec, err := openssl.DecryptRSAPKCS1(priv, enc)
if err != nil {
t.Fatalf("DecryptPKCS1v15: %v", err)
}
if !bytes.Equal(dec, msg) {
t.Fatalf("got:%x want:%x", dec, msg)
}
}

func TestRSAEncryptDecryptPKCS1(t *testing.T) {
for _, size := range []int{2048, 3072} {
size := size
t.Run(strconv.Itoa(size), func(t *testing.T) {
t.Parallel()
priv, pub := newRSAKey(t, size)
testRSAEncryptDecryptPKCS1(t, priv, pub)
})
}
}

func TestRSAEncryptDecryptPKCS1_MissingPrecomputedValues(t *testing.T) {
n, e, d, p, q, dp, dq, qinv, err := openssl.GenerateKeyRSA(2048)
if err != nil {
t.Fatalf("GenerateKeyRSA: %v", err)
}
tt := []struct {
withDp bool
withDq bool
withQinv bool
}{
{true, true, false},
{true, false, true},
{false, true, true},
{false, false, false},
{false, false, true},
{false, true, false},
{true, false, false},
{true, true, true},
}
for _, tt := range tt {
tt := tt
t.Run(fmt.Sprintf("dp=%v,dq=%v,qinv=%v", tt.withDp, tt.withDq, tt.withQinv), func(t *testing.T) {
t.Parallel()
dp1, dq1, qinv1 := dp, dq, qinv
if !tt.withDp {
dp1 = nil
}
if !tt.withDq {
dq1 = nil
}
if !tt.withQinv {
qinv1 = nil
}

priv, pub := newRSAKeyFromParams(t, n, e, d, p, q, dp1, dq1, qinv1)
testRSAEncryptDecryptPKCS1(t, priv, pub)
msg := []byte("hi!")
enc, err := openssl.EncryptRSAPKCS1(pub, msg)
if err != nil {
Expand All @@ -34,7 +104,7 @@ func TestRSAKeyGeneration(t *testing.T) {
}
}

func TestEncryptDecryptOAEP(t *testing.T) {
func TestRSAEncryptDecryptOAEP(t *testing.T) {
sha256 := openssl.NewSHA256()
msg := []byte("hi!")
label := []byte("ho!")
Expand All @@ -57,7 +127,7 @@ func TestEncryptDecryptOAEP(t *testing.T) {
}
}

func TestEncryptDecryptOAEP_EmptyLabel(t *testing.T) {
func TestRSAEncryptDecryptOAEP_EmptyLabel(t *testing.T) {
sha256 := openssl.NewSHA256()
msg := []byte("hi!")
label := []byte("")
Expand All @@ -80,7 +150,7 @@ func TestEncryptDecryptOAEP_EmptyLabel(t *testing.T) {
}
}

func TestEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) {
func TestRSAEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) {
if openssl.SymCryptProviderAvailable() {
t.Skip("SymCrypt provider does not support MGF1 hash")
}
Expand All @@ -107,7 +177,7 @@ func TestEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) {
}
}

func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) {
func TestRSAEncryptDecryptOAEP_WrongLabel(t *testing.T) {
sha256 := openssl.NewSHA256()
msg := []byte("hi!")
priv, pub := newRSAKey(t, 2048)
Expand All @@ -124,7 +194,7 @@ func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) {
}
}

func TestSignVerifyPKCS1v15(t *testing.T) {
func TestRSASignVerifyPKCS1v15(t *testing.T) {
sha256 := openssl.NewSHA256()
priv, pub := newRSAKey(t, 2048)
msg := []byte("hi!")
Expand All @@ -151,7 +221,7 @@ func TestSignVerifyPKCS1v15(t *testing.T) {
}
}

func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) {
func TestRSASignVerifyPKCS1v15_Unhashed(t *testing.T) {
if openssl.SymCryptProviderAvailable() {
t.Skip("SymCrypt provider does not support unhashed PKCS1v15")
}
Expand All @@ -168,7 +238,7 @@ func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) {
}
}

func TestSignVerifyPKCS1v15_Invalid(t *testing.T) {
func TestRSASignVerifyPKCS1v15_Invalid(t *testing.T) {
sha256 := openssl.NewSHA256()
msg := []byte("hi!")
priv, pub := newRSAKey(t, 2048)
Expand All @@ -184,7 +254,7 @@ func TestSignVerifyPKCS1v15_Invalid(t *testing.T) {
}
}

func TestSignVerifyRSAPSS(t *testing.T) {
func TestRSASignVerifyRSAPSS(t *testing.T) {
// Test cases taken from
// https://github.com/golang/go/blob/54182ff54a687272dd7632c3a963e036ce03cb7c/src/crypto/rsa/pss_test.go#L200.
const keyBits = 2048
Expand Down Expand Up @@ -225,15 +295,18 @@ func newRSAKey(t *testing.T, size int) (*openssl.PrivateKeyRSA, *openssl.PublicK
if err != nil {
t.Fatalf("GenerateKeyRSA(%d): %v", size, err)
}
// Exercise omission of precomputed value
Dp = nil
return newRSAKeyFromParams(t, N, E, D, P, Q, Dp, Dq, Qinv)
}

func newRSAKeyFromParams(t *testing.T, N, E, D, P, Q, Dp, Dq, Qinv openssl.BigInt) (*openssl.PrivateKeyRSA, *openssl.PublicKeyRSA) {
t.Helper()
priv, err := openssl.NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv)
if err != nil {
t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err)
t.Fatalf("NewPrivateKeyRSA: %v", err)
}
pub, err := openssl.NewPublicKeyRSA(N, E)
if err != nil {
t.Fatalf("NewPublicKeyRSA(%d): %v", size, err)
t.Fatalf("NewPublicKeyRSA: %v", err)
}
return priv, pub
}
Expand Down
Loading