Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] Azure join method - use subscription ID from attested data #49157

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
t.Run("Azure method", func(t *testing.T) {
subID := uuid.NewString()
resourceGroup := "rg"
rsID := resourceID(subID, resourceGroup, "test-vm")
rsID := vmResourceID(subID, resourceGroup, "test-vm")
vmID := "vmID"

accessToken, err := makeToken(rsID, a.clock.Now())
Expand All @@ -618,13 +618,20 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
require.NoError(t, err)
require.NoError(t, a.UpsertToken(ctx, azureToken))

vmClient := &mockAzureVMClient{vm: &azure.VirtualMachine{
ID: rsID,
Name: "test-vm",
Subscription: subID,
ResourceGroup: resourceGroup,
VMID: vmID,
}}
vmClient := &mockAzureVMClient{
vms: map[string]*azure.VirtualMachine{
rsID: {
ID: rsID,
Name: "test-vm",
Subscription: subID,
ResourceGroup: resourceGroup,
VMID: vmID,
},
},
}
getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{
subID: vmClient,
})

tlsConfig, err := fixtures.LocalTLSConfig()
require.NoError(t, err)
Expand Down Expand Up @@ -666,7 +673,7 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
AccessToken: accessToken,
}
return req, nil
}, withCerts([]*x509.Certificate{tlsConfig.Certificate}), withVerifyFunc(mockVerifyToken(nil)), withVMClient(vmClient))
}, withCerts([]*x509.Certificate{tlsConfig.Certificate}), withVerifyFunc(mockVerifyToken(nil)), withVMClientGetter(getVMClient))
require.NoError(t, err)
checkCertLoginIP(t, certs.TLS, remoteAddr)
})
Expand Down
57 changes: 28 additions & 29 deletions lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ type accessTokenClaims struct {

type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error)

type vmClientGetter func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error)

type azureRegisterConfig struct {
clock clockwork.Clock
certificateAuthorities []*x509.Certificate
verify azureVerifyTokenFunc
vmClient azure.VirtualMachinesClient
getVMClient vmClientGetter
}

func azureVerifyFuncFromOIDCVerifier(cfg *oidc.Config) azureVerifyTokenFunc {
Expand Down Expand Up @@ -140,6 +142,12 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error {
}
cfg.certificateAuthorities = certs
}
if cfg.getVMClient == nil {
cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) {
client, err := azure.NewVirtualMachinesClient(subscriptionID, token, nil)
return client, trace.Wrap(err)
}
}
return nil
}

Expand All @@ -148,42 +156,42 @@ type azureRegisterOption func(cfg *azureRegisterConfig)
// parseAndVeryAttestedData verifies that an attested data document was signed
// by Azure. If verification is successful, it returns the ID of the VM that
// produced the document.
func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (string, error) {
func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (subscriptionID, vmID string, err error) {
var signedAD signedAttestedData
if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil {
return "", trace.Wrap(err)
return "", "", trace.Wrap(err)
}
if signedAD.Encoding != "pkcs7" {
return "", trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding)
return "", "", trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding)
}

sigPEM := "-----BEGIN PKCS7-----\n" + signedAD.Signature + "\n-----END PKCS7-----"
sigBER, _ := pem.Decode([]byte(sigPEM))
if sigBER == nil {
return "", trace.AccessDenied("unable to decode attested data document")
return "", "", trace.AccessDenied("unable to decode attested data document")
}

p7, err := pkcs7.Parse(sigBER.Bytes)
if err != nil {
return "", trace.Wrap(err)
return "", "", trace.Wrap(err)
}
var ad attestedData
if err := utils.FastUnmarshal(p7.Content, &ad); err != nil {
return "", trace.Wrap(err)
return "", "", trace.Wrap(err)
}
if ad.Nonce != challenge {
return "", trace.AccessDenied("challenge is missing or does not match")
return "", "", trace.AccessDenied("challenge is missing or does not match")
}

if len(p7.Certificates) == 0 {
return "", trace.AccessDenied("no certificates for signature")
return "", "", trace.AccessDenied("no certificates for signature")
}
fixAzureSigningAlgorithm(p7)

// Azure only sends the leaf cert, so we have to fetch the intermediate.
intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0])
if err != nil {
return "", trace.Wrap(err)
return "", "", trace.Wrap(err)
}
if intermediate != nil {
p7.Certificates = append(p7.Certificates, intermediate)
Expand All @@ -195,15 +203,15 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s
}

if err := p7.VerifyWithChain(pool); err != nil {
return "", trace.Wrap(err)
return "", "", trace.Wrap(err)
}

return ad.ID, nil
return ad.SubscriptionID, ad.ID, nil
}

// verifyVMIdentity verifies that the provided access token came from the
// correct Azure VM.
func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken, vmID string, requestStart time.Time) (*azure.VirtualMachine, error) {
func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken, subscriptionID, vmID string, requestStart time.Time) (*azure.VirtualMachine, error) {
tokenClaims, err := cfg.verify(ctx, accessToken)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -231,24 +239,15 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}

rsID, err := arm.ParseResourceID(tokenClaims.ResourceID)
tokenCredential := azure.NewStaticCredential(azcore.AccessToken{
Token: accessToken,
ExpiresOn: tokenClaims.Expiry.Time(),
})
vmClient, err := cfg.getVMClient(subscriptionID, tokenCredential)
if err != nil {
return nil, trace.Wrap(err)
}

vmClient := cfg.vmClient
if vmClient == nil {
tokenCredential := azure.NewStaticCredential(azcore.AccessToken{
Token: accessToken,
ExpiresOn: tokenClaims.Expiry.Time(),
})
var err error
vmClient, err = azure.NewVirtualMachinesClient(rsID.SubscriptionID, tokenCredential, nil)
if err != nil {
return nil, trace.Wrap(err)
}
}

resourceID, err := arm.ParseResourceID(tokenClaims.ResourceID)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -324,12 +323,12 @@ func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *p
return trace.AccessDenied("this token does not support the Azure join method")
}

vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities)
subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities)
if err != nil {
return trace.Wrap(err)
}

vm, err := verifyVMIdentity(ctx, cfg, req.AccessToken, vmID, requestStart)
vm, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart)
if err != nil {
return trace.Wrap(err)
}
Expand Down
Loading
Loading