From 758886102fec6364ef9a39f8946a3b118ecfaa4b Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Thu, 31 Aug 2023 11:40:24 +0200 Subject: [PATCH] [DECO-2483] Handle Azure authentication when WorkspaceResourceID is provided --- .../sdk/core/AzureCliCredentialsProvider.java | 46 +++++++- .../databricks/sdk/core/DatabricksConfig.java | 2 +- ...reServicePrincipalCredentialsProvider.java | 29 ++++- .../databricks/sdk/core/utils/AzureUtils.java | 54 +++++++-- .../core/AzureCliCredentialsProviderTest.java | 98 ++++++++++++++++ ...rvicePrincipalCredentialsProviderTest.java | 108 ++++++++++++++++++ 6 files changed, 316 insertions(+), 21 deletions(-) create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java index ea2dcbd2e..b32a3124c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java @@ -1,6 +1,7 @@ package com.databricks.sdk.core; import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.AzureUtils; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.*; @@ -27,6 +28,15 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) { return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); } + @Override + public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource, String subscription) { + List cmd = + new ArrayList<>( + Arrays.asList( + "az", "account", "get-access-token", "--subscription", subscription, "--resource", resource, "--output", "json")); + return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); + } + @Override public HeaderFactory configure(DatabricksConfig config) { if (!config.isAzure()) { @@ -35,20 +45,44 @@ public HeaderFactory configure(DatabricksConfig config) { try { ensureHostPresent(config, mapper); + CliTokenSource tokenSource; + CliTokenSource mgmtTokenSource; String resource = config.getEffectiveAzureLoginAppId(); - CliTokenSource tokenSource = tokenSourceFor(config, resource); - CliTokenSource mgmtTokenSource = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); - tokenSource.getToken(); // We need this for checking if Azure CLI is installed. + Optional subscription = getSubscription(config); + + if (subscription.isPresent()) { + try { + // This will fail if the user has access to the workspace, but not to the subscription itself. + // In such case, we fall back to not using the subscription. + tokenSource = tokenSourceFor(config, resource, subscription.get()); + tokenSource.getToken(); + mgmtTokenSource = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get()); + } catch (DatabricksException e) { + LOG.warn("Failed to get token for subscription. Using resource only token."); + tokenSource = tokenSourceFor(config, resource); + mgmtTokenSource = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + } + } else { + LOG.warn("azure_workspace_resource_id field not provided. " + + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); + tokenSource = tokenSourceFor(config, resource); + mgmtTokenSource = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + } + + tokenSource.getToken(); // We need this for checking if Azure CLI is installed try { mgmtTokenSource.getToken(); } catch (Exception e) { LOG.debug("Not including service management token in headers", e); mgmtTokenSource = null; } + TokenSource finalToken = tokenSource; CliTokenSource finalMgmtTokenSource = mgmtTokenSource; return () -> { - Token token = tokenSource.getToken(); + Token token = finalToken.getToken(); Map headers = new HashMap<>(); headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken()); if (finalMgmtTokenSource != null) { @@ -67,3 +101,5 @@ public HeaderFactory configure(DatabricksConfig config) { } } } + + diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index 2fedbcf82..60feac6d5 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -74,7 +74,7 @@ public class DatabricksConfig { sensitive = true) private String googleCredentials; - /** Azure Resource Manager ID for Azure Databricks workspace, which is exhanged for a Host */ + /** Azure Resource Manager ID for Azure Databricks workspace, which is exchanged for a Host */ @ConfigAttribute( value = "azure_workspace_resource_id", env = "DATABRICKS_AZURE_RESOURCE_ID", diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index ea5a5dee8..e867409fa 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.util.HashMap; import java.util.Map; +import java.util.Optional; /** * Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, @@ -27,9 +28,31 @@ public HeaderFactory configure(DatabricksConfig config) { return null; } ensureHostPresent(config, mapper); - RefreshableTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); - RefreshableTokenSource cloud = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + RefreshableTokenSource innerToken; + RefreshableTokenSource cloudToken; + Optional subscription = getSubscription(config); + if (subscription.isPresent()) { + try { + // This will fail if the service principal has access to the workspace, but not to the subscription itself. + // In such case, we fall back to not using the subscription. + innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId(), subscription.get()); + cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get()); + innerToken.getToken(); + cloudToken.getToken(); + } catch (DatabricksException e) { + LOG.warn("Failed to get token for subscription. Using resource only token."); + innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); + cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + } + } else { + LOG.warn("azure_workspace_resource_id field not provided. " + + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); + innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); + cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + } + + RefreshableTokenSource inner = innerToken; + RefreshableTokenSource cloud = cloudToken; return () -> { Map headers = new HashMap<>(); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java index 1a73ea630..42d3cd332 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java @@ -1,5 +1,6 @@ package com.databricks.sdk.core.utils; +import com.databricks.sdk.core.AzureCliCredentialsProvider; import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.Request; @@ -11,11 +12,14 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.io.IOException; -import java.util.HashMap; -import java.util.Map; +import java.util.*; public interface AzureUtils { + static final Logger LOG = LoggerFactory.getLogger(AzureUtils.class); /** * Creates a RefreshableTokenSource for the specified Azure resource. @@ -30,20 +34,46 @@ public interface AzureUtils { * Azure resource. */ default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource) { - String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); - String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; Map endpointParams = new HashMap<>(); endpointParams.put("resource", resource); + return tokenSourceFor(config, endpointParams); + } + + default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource, String subscription) { + Map endpointParams = new HashMap<>(); + endpointParams.put("resource", resource); + endpointParams.put("subscription", subscription); + return tokenSourceFor(config, endpointParams); + } + + default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, Map endpointParams) { + String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); + String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; return new ClientCredentials.Builder() - .withHttpClient(config.getHttpClient()) - .withClientId(config.getAzureClientId()) - .withClientSecret(config.getAzureClientSecret()) - .withTokenUrl(tokenUrl) - .withEndpointParameters(endpointParams) - .withAuthParameterPosition(AuthParameterPosition.BODY) - .build(); + .withHttpClient(config.getHttpClient()) + .withClientId(config.getAzureClientId()) + .withClientSecret(config.getAzureClientSecret()) + .withTokenUrl(tokenUrl) + .withEndpointParameters(endpointParams) + .withAuthParameterPosition(AuthParameterPosition.BODY) + .build(); + } + + default Optional getSubscription(DatabricksConfig config) { + String resourceId = config.getAzureWorkspaceResourceId(); + if (resourceId == null || resourceId.equals("")) { + return Optional.empty(); + } + String[] components = resourceId.split("/"); + if (components.length < 3) { + LOG.warn("Invalid azure workspace resource ID"); + return Optional.empty(); + } + return Optional.of(components[2]); + } + default String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException { JsonNode properties = jsonResponse.get("properties"); if (properties == null) { @@ -69,7 +99,7 @@ default void ensureHostPresent(DatabricksConfig config, ObjectMapper mapper) { } String armEndpoint = config.getAzureEnvironment().getResourceManagerEndpoint(); - Token token = tokenSourceFor(config, armEndpoint).getToken(); + Token token = tokenSourceFor(config, "resource", armEndpoint).getToken(); String requestUrl = armEndpoint + config.getAzureWorkspaceResourceId() + "?api-version=2018-04-01"; Request req = new Request("GET", requestUrl); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java new file mode 100644 index 000000000..82dedb68d --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java @@ -0,0 +1,98 @@ +package com.databricks.sdk.core; + +import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.TokenSource; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.time.LocalDateTime; + +import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; + +class AzureCliCredentialsProviderTest { + + private static final String WORKSPACE_RESOURCE_ID = "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; + private static final String SUBSCRIPTION = "2a2345f8"; + private static final String TOKEN = "t-123"; + private static final String TOKEN_TYPE = "token-type"; + public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; + + + private static CliTokenSource mockTokenSource() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + return tokenSource; + } + private static AzureCliCredentialsProvider getAzureCliCredentialsProvider(TokenSource tokenSource) { + + AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + + return provider; + } + + + @Test + void testWorkSpaceIDUsage() { + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } + + @Test + void testFallbackWhenTailsToGetTokenForSubscription() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()).thenThrow(new DatabricksException("error")).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(tokenSource); + + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + + HeaderFactory header = provider.configure(config); + + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } + + @Test + void testGetTokenWithoutWorkspaceResourceID() { + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } + + +} \ No newline at end of file diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java new file mode 100644 index 000000000..3619bd7cc --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java @@ -0,0 +1,108 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.*; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.time.LocalDateTime; +import java.time.temporal.IsoFields; + +import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; + +class AzureServicePrincipalCredentialsProviderTest { + + private static final String WORKSPACE_RESOURCE_ID = "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; + private static final String SUBSCRIPTION = "2a2345f8"; + private static final String TOKEN = "t-123"; + private static final String TOKEN_TYPE = "token-type"; + public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; + + private static RefreshableTokenSource mockTokenSource() { + RefreshableTokenSource tokenSource = Mockito.mock(RefreshableTokenSource.class); + Mockito.when(tokenSource.getToken()).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now().plus(1, IsoFields.WEEK_BASED_YEARS))); + return tokenSource; + } + private static AzureServicePrincipalCredentialsProvider getAzureServicePrincipalCredentialsProvider(RefreshableTokenSource tokenSource) { + AzureServicePrincipalCredentialsProvider provider = Mockito.spy(new AzureServicePrincipalCredentialsProvider()); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + return provider; + } + + + @Test + void testWorkSpaceIDUsage() { + AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(mockTokenSource()); + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID") + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } + + @Test + void testFallbackWhenTailsToGetTokenForSubscription() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()).thenThrow(new DatabricksException("error")).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + + AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(tokenSource); + + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID") + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } + + @Test + void testGetTokenWithoutWorkspaceResourceID() { + AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(mockTokenSource()); + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID"); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } + + +} \ No newline at end of file