diff --git a/api/utils/keys/privatekey.go b/api/utils/keys/privatekey.go index bd325a16d1d02..16e0e0b7b642b 100644 --- a/api/utils/keys/privatekey.go +++ b/api/utils/keys/privatekey.go @@ -220,28 +220,9 @@ func ParsePrivateKey(keyPEM []byte) (*PrivateKey, error) { } switch block.Type { - case PKCS1PrivateKeyType: - cryptoSigner, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return nil, trace.Wrap(err) - } - return NewPrivateKey(cryptoSigner, keyPEM) - case ECPrivateKeyType: - cryptoSigner, err := x509.ParseECPrivateKey(block.Bytes) - if err != nil { - return nil, trace.Wrap(err) - } - return NewPrivateKey(cryptoSigner, keyPEM) - case PKCS8PrivateKeyType: - priv, err := x509.ParsePKCS8PrivateKey(block.Bytes) - if err != nil { - return nil, trace.Wrap(err) - } - cryptoSigner, ok := priv.(crypto.Signer) - if !ok { - return nil, trace.BadParameter("x509.ParsePKCS8PrivateKey returned an invalid private key of type %T", priv) - } - return NewPrivateKey(cryptoSigner, keyPEM) + case pivYubiKeyPrivateKeyType: + priv, err := parseYubiKeyPrivateKeyData(block.Bytes) + return priv, trace.Wrap(err, "parsing YubiKey private key") case OpenSSHPrivateKeyType: priv, err := ssh.ParseRawPrivateKey(keyPEM) if err != nil { @@ -258,12 +239,35 @@ func ParsePrivateKey(keyPEM []byte) (*PrivateKey, error) { cryptoSigner = *pEdwards } return NewPrivateKey(cryptoSigner, keyPEM) - case pivYubiKeyPrivateKeyType: - priv, err := parseYubiKeyPrivateKeyData(block.Bytes) - if err != nil { - return nil, trace.Wrap(err, "failed to parse YubiKey private key") + case PKCS1PrivateKeyType, PKCS8PrivateKeyType, ECPrivateKeyType: + // The DER format doesn't always exactly match the PEM header, various + // versions of Teleport and OpenSSL have been guilty of writing PKCS#8 + // data into an "RSA PRIVATE KEY" block or vice-versa, so we just try + // parsing every DER format. This matches the behavior of [tls.X509KeyPair]. + var preferredErr error + if priv, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + signer, ok := priv.(crypto.Signer) + if !ok { + return nil, trace.BadParameter("x509.ParsePKCS8PrivateKey returned an invalid private key of type %T", priv) + } + return NewPrivateKey(signer, keyPEM) + } else if block.Type == PKCS8PrivateKeyType { + preferredErr = err + } + if signer, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return NewPrivateKey(signer, keyPEM) + } else if block.Type == PKCS1PrivateKeyType { + preferredErr = err + } + if signer, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return NewPrivateKey(signer, keyPEM) + } else if block.Type == ECPrivateKeyType { + preferredErr = err } - return priv, nil + // If all three parse functions returned an error, preferedErr is + // guaranteed to be set to the error from the parse function that + // usually matches the PEM block type. + return nil, trace.Wrap(preferredErr, "parsing private key PEM") default: return nil, trace.BadParameter("unexpected private key PEM type %q", block.Type) } diff --git a/api/utils/keys/privatekey_test.go b/api/utils/keys/privatekey_test.go index 41b52b169c739..3d3eb1305f28b 100644 --- a/api/utils/keys/privatekey_test.go +++ b/api/utils/keys/privatekey_test.go @@ -35,7 +35,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestMarshalAndParsePrivateKey(t *testing.T) { +func TestMarshalAndParseKey(t *testing.T) { rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) require.NoError(t, err) ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -54,6 +54,125 @@ func TestMarshalAndParsePrivateKey(t *testing.T) { gotKey, err := ParsePrivateKey(keyPEM) require.NoError(t, err) require.Equal(t, key, gotKey.Signer) + + pubKeyPEM, err := MarshalPublicKey(key.Public()) + require.NoError(t, err) + gotPubKey, err := ParsePublicKey(pubKeyPEM) + require.NoError(t, err) + require.Equal(t, key.Public(), gotPubKey) + }) + } +} + +func TestParseMismatchedPEMHeader(t *testing.T) { + rsaKey, err := ParsePrivateKey(rsaKeyPEM) + require.NoError(t, err) + rsaPKCS1DER := x509.MarshalPKCS1PrivateKey(rsaKey.Signer.(*rsa.PrivateKey)) + rsaPKCS8DER, err := x509.MarshalPKCS8PrivateKey(rsaKey.Signer) + require.NoError(t, err) + rsaPublicPKCS1DER := x509.MarshalPKCS1PublicKey(rsaKey.Public().(*rsa.PublicKey)) + rsaPublicPKIXDER, err := x509.MarshalPKIXPublicKey(rsaKey.Public()) + require.NoError(t, err) + + ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + ecdsaPKCS8DER, err := x509.MarshalPKCS8PrivateKey(ecdsaKey) + require.NoError(t, err) + ecdsaECDER, err := x509.MarshalECPrivateKey(ecdsaKey) + require.NoError(t, err) + + for desc, tc := range map[string]struct { + pem []byte + expectKey crypto.Signer + }{ + "PKCS1 DER in PRIVATE KEY PEM": { + pem: pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: rsaPKCS1DER, + }), + expectKey: rsaKey.Signer, + }, + "RSA PKCS8 DER in RSA PRIVATE KEY PEM": { + pem: pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: rsaPKCS8DER, + }), + expectKey: rsaKey.Signer, + }, + "ECDSA PKCS8 DER in EC PRIVATE KEY PEM": { + pem: pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: ecdsaPKCS8DER, + }), + expectKey: ecdsaKey, + }, + "EC DER in PRIVATE KEY PEM": { + pem: pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: ecdsaECDER, + }), + expectKey: ecdsaKey, + }, + } { + t.Run(desc, func(t *testing.T) { + key, err := ParsePrivateKey(tc.pem) + require.NoError(t, err) + require.Equal(t, tc.expectKey, key.Signer) + }) + } + + for desc, tc := range map[string]struct { + pem []byte + expectKey crypto.PublicKey + }{ + "PKCS1 DER in PUBLIC KEY PEM": { + pem: pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: rsaPublicPKCS1DER, + }), + expectKey: rsaKey.Public(), + }, + "PKIX DER in RSA PUBLIC KEY PEM": { + pem: pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: rsaPublicPKIXDER, + }), + expectKey: rsaKey.Public(), + }, + } { + t.Run(desc, func(t *testing.T) { + pubKey, err := ParsePublicKey(tc.pem) + require.NoError(t, err) + require.Equal(t, tc.expectKey, pubKey) + }) + } +} + +// TestParseCorruptedKey tests that we actually return an error and don't panic +// when parsing some trivially corrupted key PEMs. This is mostly to validate +// that the preferredErr logic in Parse(Private|Public)Key returns an error for +// each PEM type. +func TestParseCorruptedKey(t *testing.T) { + for _, tc := range []string{ + "RSA PRIVATE KEY", + "PRIVATE KEY", + "EC PRIVATE KEY", + } { + t.Run(tc, func(t *testing.T) { + b := pem.EncodeToMemory(&pem.Block{Type: tc, Bytes: []byte("foo")}) + _, err := ParsePrivateKey(b) + require.Error(t, err) + }) + } + + for _, tc := range []string{ + "RSA PUBLIC KEY", + "PUBLIC KEY", + } { + t.Run(tc, func(t *testing.T) { + b := pem.EncodeToMemory(&pem.Block{Type: tc, Bytes: []byte("foo")}) + _, err := ParsePublicKey(b) + require.Error(t, err) }) } } diff --git a/api/utils/keys/publickey.go b/api/utils/keys/publickey.go index 26ef9192814d9..0979caf266c60 100644 --- a/api/utils/keys/publickey.go +++ b/api/utils/keys/publickey.go @@ -68,24 +68,26 @@ func ParsePublicKey(keyPEM []byte) (crypto.PublicKey, error) { } switch block.Type { - case PKCS1PublicKeyType: - pub, pkcs1Err := x509.ParsePKCS1PublicKey(block.Bytes) - if pkcs1Err != nil { - // Failed to parse as PKCS#1. We have been known to stuff PKIX DER encoded RSA public keys into - // "RSA PUBLIC KEY" PEM blocks, so try to parse as PKIX. - pub, pkixErr := x509.ParsePKIXPublicKey(block.Bytes) - if pkixErr != nil { - // Parsing as both formats failed. We really should expect PKCS#1 in this PEM block, so return - // that error. - return nil, trace.Wrap(pkcs1Err) - } - return pub, nil - } - return pub, nil - case PKIXPublicKeyType: - pub, err := x509.ParsePKIXPublicKey(block.Bytes) - return pub, trace.Wrap(err) + case PKCS1PublicKeyType, PKIXPublicKeyType: default: return nil, trace.BadParameter("unsupported public key type %q", block.Type) } + + // We have been known to stuff PKIX DER encoded RSA public keys into + // "RSA PUBLIC KEY" PEM blocks, so just try to parse either. + var preferredErr error + if pub, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil { + return pub, nil + } else if block.Type == PKIXPublicKeyType { + preferredErr = err + } + if pub, err := x509.ParsePKCS1PublicKey(block.Bytes); err == nil { + return pub, nil + } else if block.Type == PKCS1PublicKeyType { + preferredErr = err + } + // If both parse functions returned an error, preferedErr is guaranteed to + // be set to the error from the parse function that usually matches the PEM + // block type. + return nil, trace.Wrap(preferredErr, "parsing public key PEM") }