diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java index ea2dcbd2e..5a2a7f709 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java @@ -24,7 +24,45 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) { new ArrayList<>( Arrays.asList( "az", "account", "get-access-token", "--resource", resource, "--output", "json")); - return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); + Optional subscription = getSubscription(config); + if (subscription.isPresent()) { + // This will fail if the user has access to the workspace, but not to the subscription + // itself. + // In such case, we fall back to not using the subscription. + List extendedCmd = new ArrayList<>(cmd); + extendedCmd.addAll(Arrays.asList("--subscription", subscription.get())); + try { + return getToken(config, extendedCmd); + } catch (DatabricksException ex) { + LOG.warn("Failed to get token for subscription. Using resource only token."); + } + } else { + LOG.warn( + "azure_workspace_resource_id field not provided. " + + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); + } + + return getToken(config, cmd); + } + + protected CliTokenSource getToken(DatabricksConfig config, List cmd) { + CliTokenSource token = + new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); + token.getToken(); // We need this to check if the CLI is installed and to validate the config. + return token; + } + + private Optional getSubscription(DatabricksConfig config) { + String resourceId = config.getAzureWorkspaceResourceId(); + if (resourceId == null || resourceId.equals("")) { + return Optional.empty(); + } + String[] components = resourceId.split("/"); + if (components.length < 3) { + LOG.warn("Invalid azure workspace resource ID"); + return Optional.empty(); + } + return Optional.of(components[2]); } @Override @@ -37,11 +75,10 @@ public HeaderFactory configure(DatabricksConfig config) { ensureHostPresent(config, mapper); String resource = config.getEffectiveAzureLoginAppId(); CliTokenSource tokenSource = tokenSourceFor(config, resource); - CliTokenSource mgmtTokenSource = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); - tokenSource.getToken(); // We need this for checking if Azure CLI is installed. + CliTokenSource mgmtTokenSource; try { - mgmtTokenSource.getToken(); + mgmtTokenSource = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); } catch (Exception e) { LOG.debug("Not including service management token in headers", e); mgmtTokenSource = null; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index 2fedbcf82..60feac6d5 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -74,7 +74,7 @@ public class DatabricksConfig { sensitive = true) private String googleCredentials; - /** Azure Resource Manager ID for Azure Databricks workspace, which is exhanged for a Host */ + /** Azure Resource Manager ID for Azure Databricks workspace, which is exchanged for a Host */ @ConfigAttribute( value = "azure_workspace_resource_id", env = "DATABRICKS_AZURE_RESOURCE_ID", diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java new file mode 100644 index 000000000..6b617f643 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java @@ -0,0 +1,103 @@ +package com.databricks.sdk.core; + +import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.times; + +import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.TokenSource; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +class AzureCliCredentialsProviderTest { + + private static final String WORKSPACE_RESOURCE_ID = + "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; + private static final String SUBSCRIPTION = "2a2345f8"; + private static final String TOKEN = "t-123"; + private static final String TOKEN_TYPE = "token-type"; + + private static CliTokenSource mockTokenSource() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()) + .thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + return tokenSource; + } + + private static AzureCliCredentialsProvider getAzureCliCredentialsProvider( + TokenSource tokenSource) { + + AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); + Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList()); + + return provider; + } + + @Test + void testWorkSpaceIDUsage() { + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); + DatabricksConfig config = + new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + ArgumentCaptor> argument = ArgumentCaptor.forClass(List.class); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + Mockito.verify(provider, times(2)).getToken(any(), argument.capture()); + + List value = argument.getValue(); + value = value.subList(value.size() - 2, value.size()); + List expected = Arrays.asList("--subscription", SUBSCRIPTION); + assertEquals(expected, value); + } + + @Test + void testFallbackWhenTailsToGetTokenForSubscription() { + CliTokenSource tokenSource = mockTokenSource(); + + AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); + Mockito.doThrow(new DatabricksException("error")).when(provider).getToken(any(), anyList()); + Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList()); + + DatabricksConfig config = + new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } + + @Test + void testGetTokenWithoutWorkspaceResourceID() { + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); + DatabricksConfig config = + new DatabricksConfig().setHost(".azuredatabricks.").setCredentialsProvider(provider); + + ArgumentCaptor> argument = ArgumentCaptor.forClass(List.class); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + Mockito.verify(provider, times(2)).getToken(any(), argument.capture()); + + List value = argument.getValue(); + assertFalse(value.contains("--subscription")); + assertFalse(value.contains(SUBSCRIPTION)); + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java new file mode 100644 index 000000000..e06683308 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java @@ -0,0 +1,60 @@ +package com.databricks.sdk.core.oauth; + +import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; + +import com.databricks.sdk.core.*; +import java.time.LocalDateTime; +import java.time.temporal.IsoFields; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +class AzureServicePrincipalCredentialsProviderTest { + private static final String TOKEN = "t-123"; + private static final String TOKEN_TYPE = "token-type"; + public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; + + private static RefreshableTokenSource mockTokenSource() { + RefreshableTokenSource tokenSource = Mockito.mock(RefreshableTokenSource.class); + Mockito.when(tokenSource.getToken()) + .thenReturn( + new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now().plus(1, IsoFields.WEEK_BASED_YEARS))); + return tokenSource; + } + + private static AzureServicePrincipalCredentialsProvider + getAzureServicePrincipalCredentialsProvider(RefreshableTokenSource tokenSource) { + AzureServicePrincipalCredentialsProvider provider = + Mockito.spy(new AzureServicePrincipalCredentialsProvider()); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + return provider; + } + + @Test + void testGetToken() { + AzureServicePrincipalCredentialsProvider provider = + getAzureServicePrincipalCredentialsProvider(mockTokenSource()); + DatabricksConfig config = + new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID"); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } +}