diff --git a/pkg/util/azure/credential.go b/pkg/util/azure/credential.go index 52df1798f60..d4cbb8888b9 100644 --- a/pkg/util/azure/credential.go +++ b/pkg/util/azure/credential.go @@ -26,54 +26,36 @@ import ( "github.com/pkg/errors" ) -// NewCredential chains the config credential , workload identity credential , managed identity credential +// NewCredential constructs a Credential that tries the config credential, workload identity credential +// and managed identity credential according to the provided creds. func NewCredential(creds map[string]string, options policy.ClientOptions) (azcore.TokenCredential, error) { - var ( - credential []azcore.TokenCredential - errMsgs []string - ) - additionalTenants := []string{} if tenants := creds[CredentialKeyAdditionallyAllowedTenants]; tenants != "" { additionalTenants = strings.Split(tenants, ";") } // config credential - cfgCred, err := newConfigCredential(creds, configCredentialOptions{ - ClientOptions: options, - AdditionallyAllowedTenants: additionalTenants, - }) - if err == nil { - credential = append(credential, cfgCred) - } else { - errMsgs = append(errMsgs, err.Error()) + if len(creds[CredentialKeyClientSecret]) > 0 || + len(creds[CredentialKeyClientCertificate]) > 0 || + len(creds[CredentialKeyClientCertificatePath]) > 0 || + len(creds[CredentialKeyUsername]) > 0 { + return newConfigCredential(creds, configCredentialOptions{ + ClientOptions: options, + AdditionallyAllowedTenants: additionalTenants, + }) } // workload identity credential - wic, err := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ - AdditionallyAllowedTenants: additionalTenants, - ClientOptions: options, - }) - if err == nil { - credential = append(credential, wic) - } else { - errMsgs = append(errMsgs, err.Error()) + if len(os.Getenv("AZURE_FEDERATED_TOKEN_FILE")) > 0 { + return azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ + AdditionallyAllowedTenants: additionalTenants, + ClientOptions: options, + }) } - //managed identity credential + // managed identity credential o := &azidentity.ManagedIdentityCredentialOptions{ClientOptions: options, ID: azidentity.ClientID(creds[CredentialKeyClientID])} - msi, err := azidentity.NewManagedIdentityCredential(o) - if err == nil { - credential = append(credential, msi) - } else { - errMsgs = append(errMsgs, err.Error()) - } - - if len(credential) == 0 { - return nil, errors.Errorf("failed to create Azure credential: %s", strings.Join(errMsgs, "\n\t")) - } - - return azidentity.NewChainedTokenCredential(credential, nil) + return azidentity.NewManagedIdentityCredential(o) } type configCredentialOptions struct { diff --git a/pkg/util/azure/credential_test.go b/pkg/util/azure/credential_test.go index bdd33eadb02..35e34550d87 100644 --- a/pkg/util/azure/credential_test.go +++ b/pkg/util/azure/credential_test.go @@ -17,36 +17,63 @@ limitations under the License. package azure import ( - "context" + "os" "testing" - "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewCredential(t *testing.T) { options := policy.ClientOptions{} - // no credentials - creds := map[string]string{} - tokenCredential, _ := NewCredential(creds, options) - - var scopes []string - scopes = append(scopes, "https://management.core.windows.net//.default") - ctx, _ := context.WithTimeout(context.Background(), time.Second*2) - _, err := tokenCredential.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) + // invalid client secret credential (missing tenant ID) + creds := map[string]string{ + CredentialKeyClientID: "clientid", + CredentialKeyClientSecret: "secret", + } + _, err := NewCredential(creds, options) require.NotNil(t, err) - // config credential + // valid client secret credential creds = map[string]string{ CredentialKeyTenantID: "tenantid", CredentialKeyClientID: "clientid", CredentialKeyClientSecret: "secret", } - _, err = NewCredential(creds, options) + tokenCredential, err := NewCredential(creds, options) require.Nil(t, err) + assert.IsType(t, &azidentity.ClientSecretCredential{}, tokenCredential) + + // client certificate credential + certData, err := readCertData() + require.Nil(t, err) + creds = map[string]string{ + CredentialKeyTenantID: "tenantid", + CredentialKeyClientID: "clientid", + CredentialKeyClientCertificate: certData, + } + tokenCredential, err = NewCredential(creds, options) + require.Nil(t, err) + assert.IsType(t, &azidentity.ClientCertificateCredential{}, tokenCredential) + + // workload identity credential + os.Setenv(CredentialKeyTenantID, "tenantid") + os.Setenv(CredentialKeyClientID, "clientid") + os.Setenv("AZURE_FEDERATED_TOKEN_FILE", "/tmp/token") + creds = map[string]string{} + tokenCredential, err = NewCredential(creds, options) + require.Nil(t, err) + assert.IsType(t, &azidentity.WorkloadIdentityCredential{}, tokenCredential) + os.Clearenv() + + // managed identity credential + creds = map[string]string{} + tokenCredential, err = NewCredential(creds, options) + require.Nil(t, err) + assert.IsType(t, &azidentity.ManagedIdentityCredential{}, tokenCredential) } func Test_newConfigCredential(t *testing.T) { @@ -77,10 +104,12 @@ func Test_newConfigCredential(t *testing.T) { require.True(t, ok) // client certificate + certData, err := readCertData() + require.Nil(t, err) creds = map[string]string{ - CredentialKeyTenantID: "clientid", - CredentialKeyClientID: "clientid", - CredentialKeyClientCertificatePath: "testdata/certificate.pem", + CredentialKeyTenantID: "clientid", + CredentialKeyClientID: "clientid", + CredentialKeyClientCertificate: certData, } credential, err = newConfigCredential(creds, options) require.Nil(t, err) @@ -101,3 +130,11 @@ func Test_newConfigCredential(t *testing.T) { _, ok = credential.(*azidentity.UsernamePasswordCredential) require.True(t, ok) } + +func readCertData() (string, error) { + data, err := os.ReadFile("testdata/certificate.pem") + if err != nil { + return "", err + } + return string(data), nil +}