diff --git a/rsa.go b/rsa.go index 564ccdcf..4e45b02d 100644 --- a/rsa.go +++ b/rsa.go @@ -396,15 +396,24 @@ func newRSAKey3(isPriv bool, n, e, d, p, q, dp, dq, qinv BigInt) (C.GO_EVP_PKEY_ } comps = append(comps, required[:]...) - // OpenSSL 3.0 and 3.1 required all the precomputed values if - // P and Q are present. See: - // https://github.com/openssl/openssl/pull/22334 - if p != nil && q != nil && dp != nil && dq != nil && qinv != nil { - precomputed := [...]bigIntParam{ - {OSSL_PKEY_PARAM_RSA_FACTOR1, p}, {OSSL_PKEY_PARAM_RSA_FACTOR2, q}, - {OSSL_PKEY_PARAM_RSA_EXPONENT1, dp}, {OSSL_PKEY_PARAM_RSA_EXPONENT2, dq}, {OSSL_PKEY_PARAM_RSA_COEFFICIENT1, qinv}, + if p != nil && q != nil { + allPrecomputedExists := dp != nil && dq != nil && qinv != nil + // The precomputed values should only be passed if P and Q are present + // and every precomputed value is present. (If any precomputed value is + // missing, don't pass any of them.) + // + // In OpenSSL 3.0 and 3.1, we must also omit P and Q if any precomputed + // value is missing. See https://github.com/openssl/openssl/pull/22334 + if vMinor >= 2 || allPrecomputedExists { + comps = append(comps, bigIntParam{OSSL_PKEY_PARAM_RSA_FACTOR1, p}, bigIntParam{OSSL_PKEY_PARAM_RSA_FACTOR2, q}) + } + if allPrecomputedExists { + comps = append(comps, + bigIntParam{OSSL_PKEY_PARAM_RSA_EXPONENT1, dp}, + bigIntParam{OSSL_PKEY_PARAM_RSA_EXPONENT2, dq}, + bigIntParam{OSSL_PKEY_PARAM_RSA_COEFFICIENT1, qinv}, + ) } - comps = append(comps, precomputed[:]...) } for _, comp := range comps { diff --git a/rsa_test.go b/rsa_test.go index 35de951a..94775332 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto" "crypto/rsa" + "fmt" "math/big" "strconv" "testing" @@ -13,11 +14,80 @@ import ( ) func TestRSAKeyGeneration(t *testing.T) { + for _, size := range []int{2048, 3072} { + size := size + t.Run(strconv.Itoa(size), func(t *testing.T) { + t.Parallel() + _, _, _, _, _, _, _, _, err := openssl.GenerateKeyRSA(size) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func testRSAEncryptDecryptPKCS1(t *testing.T, priv *openssl.PrivateKeyRSA, pub *openssl.PublicKeyRSA) { + msg := []byte("hi!") + enc, err := openssl.EncryptRSAPKCS1(pub, msg) + if err != nil { + t.Fatalf("EncryptPKCS1v15: %v", err) + } + dec, err := openssl.DecryptRSAPKCS1(priv, enc) + if err != nil { + t.Fatalf("DecryptPKCS1v15: %v", err) + } + if !bytes.Equal(dec, msg) { + t.Fatalf("got:%x want:%x", dec, msg) + } +} + +func TestRSAEncryptDecryptPKCS1(t *testing.T) { for _, size := range []int{2048, 3072} { size := size t.Run(strconv.Itoa(size), func(t *testing.T) { t.Parallel() priv, pub := newRSAKey(t, size) + testRSAEncryptDecryptPKCS1(t, priv, pub) + }) + } +} + +func TestRSAEncryptDecryptPKCS1_MissingPrecomputedValues(t *testing.T) { + n, e, d, p, q, dp, dq, qinv, err := openssl.GenerateKeyRSA(2048) + if err != nil { + t.Fatalf("GenerateKeyRSA: %v", err) + } + tt := []struct { + withDp bool + withDq bool + withQinv bool + }{ + {true, true, false}, + {true, false, true}, + {false, true, true}, + {false, false, false}, + {false, false, true}, + {false, true, false}, + {true, false, false}, + {true, true, true}, + } + for _, tt := range tt { + tt := tt + t.Run(fmt.Sprintf("dp=%v,dq=%v,qinv=%v", tt.withDp, tt.withDq, tt.withQinv), func(t *testing.T) { + t.Parallel() + dp1, dq1, qinv1 := dp, dq, qinv + if !tt.withDp { + dp1 = nil + } + if !tt.withDq { + dq1 = nil + } + if !tt.withQinv { + qinv1 = nil + } + + priv, pub := newRSAKeyFromParams(t, n, e, d, p, q, dp1, dq1, qinv1) + testRSAEncryptDecryptPKCS1(t, priv, pub) msg := []byte("hi!") enc, err := openssl.EncryptRSAPKCS1(pub, msg) if err != nil { @@ -34,7 +104,7 @@ func TestRSAKeyGeneration(t *testing.T) { } } -func TestEncryptDecryptOAEP(t *testing.T) { +func TestRSAEncryptDecryptOAEP(t *testing.T) { sha256 := openssl.NewSHA256() msg := []byte("hi!") label := []byte("ho!") @@ -57,7 +127,7 @@ func TestEncryptDecryptOAEP(t *testing.T) { } } -func TestEncryptDecryptOAEP_EmptyLabel(t *testing.T) { +func TestRSAEncryptDecryptOAEP_EmptyLabel(t *testing.T) { sha256 := openssl.NewSHA256() msg := []byte("hi!") label := []byte("") @@ -80,7 +150,7 @@ func TestEncryptDecryptOAEP_EmptyLabel(t *testing.T) { } } -func TestEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) { +func TestRSAEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) { if openssl.SymCryptProviderAvailable() { t.Skip("SymCrypt provider does not support MGF1 hash") } @@ -107,7 +177,7 @@ func TestEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) { } } -func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) { +func TestRSAEncryptDecryptOAEP_WrongLabel(t *testing.T) { sha256 := openssl.NewSHA256() msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) @@ -124,7 +194,7 @@ func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) { } } -func TestSignVerifyPKCS1v15(t *testing.T) { +func TestRSASignVerifyPKCS1v15(t *testing.T) { sha256 := openssl.NewSHA256() priv, pub := newRSAKey(t, 2048) msg := []byte("hi!") @@ -151,7 +221,7 @@ func TestSignVerifyPKCS1v15(t *testing.T) { } } -func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) { +func TestRSASignVerifyPKCS1v15_Unhashed(t *testing.T) { if openssl.SymCryptProviderAvailable() { t.Skip("SymCrypt provider does not support unhashed PKCS1v15") } @@ -168,7 +238,7 @@ func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) { } } -func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { +func TestRSASignVerifyPKCS1v15_Invalid(t *testing.T) { sha256 := openssl.NewSHA256() msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) @@ -184,7 +254,7 @@ func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { } } -func TestSignVerifyRSAPSS(t *testing.T) { +func TestRSASignVerifyRSAPSS(t *testing.T) { // Test cases taken from // https://github.com/golang/go/blob/54182ff54a687272dd7632c3a963e036ce03cb7c/src/crypto/rsa/pss_test.go#L200. const keyBits = 2048 @@ -225,15 +295,18 @@ func newRSAKey(t *testing.T, size int) (*openssl.PrivateKeyRSA, *openssl.PublicK if err != nil { t.Fatalf("GenerateKeyRSA(%d): %v", size, err) } - // Exercise omission of precomputed value - Dp = nil + return newRSAKeyFromParams(t, N, E, D, P, Q, Dp, Dq, Qinv) +} + +func newRSAKeyFromParams(t *testing.T, N, E, D, P, Q, Dp, Dq, Qinv openssl.BigInt) (*openssl.PrivateKeyRSA, *openssl.PublicKeyRSA) { + t.Helper() priv, err := openssl.NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) if err != nil { - t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err) + t.Fatalf("NewPrivateKeyRSA: %v", err) } pub, err := openssl.NewPublicKeyRSA(N, E) if err != nil { - t.Fatalf("NewPublicKeyRSA(%d): %v", size, err) + t.Fatalf("NewPublicKeyRSA: %v", err) } return priv, pub }