Skip to content

Commit

Permalink
[DECO-2483] Handle Azure authentication when WorkspaceResourceID is p…
Browse files Browse the repository at this point in the history
…rovided
  • Loading branch information
hectorcast-db committed Aug 31, 2023
1 parent 3a8494e commit 7588861
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.databricks.sdk.core;

import com.databricks.sdk.core.oauth.Token;
import com.databricks.sdk.core.oauth.TokenSource;
import com.databricks.sdk.core.utils.AzureUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.*;
Expand All @@ -27,6 +28,15 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv);
}

@Override
public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource, String subscription) {
List<String> cmd =
new ArrayList<>(
Arrays.asList(
"az", "account", "get-access-token", "--subscription", subscription, "--resource", resource, "--output", "json"));
return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv);
}

@Override
public HeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()) {
Expand All @@ -35,20 +45,44 @@ public HeaderFactory configure(DatabricksConfig config) {

try {
ensureHostPresent(config, mapper);
CliTokenSource tokenSource;
CliTokenSource mgmtTokenSource;
String resource = config.getEffectiveAzureLoginAppId();
CliTokenSource tokenSource = tokenSourceFor(config, resource);
CliTokenSource mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
tokenSource.getToken(); // We need this for checking if Azure CLI is installed.
Optional<String> subscription = getSubscription(config);

if (subscription.isPresent()) {
try {
// This will fail if the user has access to the workspace, but not to the subscription itself.
// In such case, we fall back to not using the subscription.
tokenSource = tokenSourceFor(config, resource, subscription.get());
tokenSource.getToken();
mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get());
} catch (DatabricksException e) {
LOG.warn("Failed to get token for subscription. Using resource only token.");
tokenSource = tokenSourceFor(config, resource);
mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
}
} else {
LOG.warn("azure_workspace_resource_id field not provided. " +
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors.");
tokenSource = tokenSourceFor(config, resource);
mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
}

tokenSource.getToken(); // We need this for checking if Azure CLI is installed
try {
mgmtTokenSource.getToken();
} catch (Exception e) {
LOG.debug("Not including service management token in headers", e);
mgmtTokenSource = null;
}
TokenSource finalToken = tokenSource;
CliTokenSource finalMgmtTokenSource = mgmtTokenSource;
return () -> {
Token token = tokenSource.getToken();
Token token = finalToken.getToken();
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken());
if (finalMgmtTokenSource != null) {
Expand All @@ -67,3 +101,5 @@ public HeaderFactory configure(DatabricksConfig config) {
}
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public class DatabricksConfig {
sensitive = true)
private String googleCredentials;

/** Azure Resource Manager ID for Azure Databricks workspace, which is exhanged for a Host */
/** Azure Resource Manager ID for Azure Databricks workspace, which is exchanged for a Host */
@ConfigAttribute(
value = "azure_workspace_resource_id",
env = "DATABRICKS_AZURE_RESOURCE_ID",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

/**
* Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request,
Expand All @@ -27,9 +28,31 @@ public HeaderFactory configure(DatabricksConfig config) {
return null;
}
ensureHostPresent(config, mapper);
RefreshableTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
RefreshableTokenSource cloud =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
RefreshableTokenSource innerToken;
RefreshableTokenSource cloudToken;
Optional<String> subscription = getSubscription(config);
if (subscription.isPresent()) {
try {
// This will fail if the service principal has access to the workspace, but not to the subscription itself.
// In such case, we fall back to not using the subscription.
innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId(), subscription.get());
cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get());
innerToken.getToken();
cloudToken.getToken();
} catch (DatabricksException e) {
LOG.warn("Failed to get token for subscription. Using resource only token.");
innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
}
} else {
LOG.warn("azure_workspace_resource_id field not provided. " +
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors.");
innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
}

RefreshableTokenSource inner = innerToken;
RefreshableTokenSource cloud = cloudToken;

return () -> {
Map<String, String> headers = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.databricks.sdk.core.utils;

import com.databricks.sdk.core.AzureCliCredentialsProvider;
import com.databricks.sdk.core.DatabricksConfig;
import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.http.Request;
Expand All @@ -11,11 +12,14 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.*;

public interface AzureUtils {
static final Logger LOG = LoggerFactory.getLogger(AzureUtils.class);

/**
* Creates a RefreshableTokenSource for the specified Azure resource.
Expand All @@ -30,20 +34,46 @@ public interface AzureUtils {
* Azure resource.
*/
default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint();
String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token";
Map<String, String> endpointParams = new HashMap<>();
endpointParams.put("resource", resource);
return tokenSourceFor(config, endpointParams);
}

default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource, String subscription) {
Map<String, String> endpointParams = new HashMap<>();
endpointParams.put("resource", resource);
endpointParams.put("subscription", subscription);
return tokenSourceFor(config, endpointParams);
}

default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, Map<String, String> endpointParams) {
String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint();
String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token";
return new ClientCredentials.Builder()
.withHttpClient(config.getHttpClient())
.withClientId(config.getAzureClientId())
.withClientSecret(config.getAzureClientSecret())
.withTokenUrl(tokenUrl)
.withEndpointParameters(endpointParams)
.withAuthParameterPosition(AuthParameterPosition.BODY)
.build();
.withHttpClient(config.getHttpClient())
.withClientId(config.getAzureClientId())
.withClientSecret(config.getAzureClientSecret())
.withTokenUrl(tokenUrl)
.withEndpointParameters(endpointParams)
.withAuthParameterPosition(AuthParameterPosition.BODY)
.build();
}

default Optional<String> getSubscription(DatabricksConfig config) {
String resourceId = config.getAzureWorkspaceResourceId();
if (resourceId == null || resourceId.equals("")) {
return Optional.empty();
}
String[] components = resourceId.split("/");
if (components.length < 3) {
LOG.warn("Invalid azure workspace resource ID");
return Optional.empty();
}
return Optional.of(components[2]);

}


default String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException {
JsonNode properties = jsonResponse.get("properties");
if (properties == null) {
Expand All @@ -69,7 +99,7 @@ default void ensureHostPresent(DatabricksConfig config, ObjectMapper mapper) {
}

String armEndpoint = config.getAzureEnvironment().getResourceManagerEndpoint();
Token token = tokenSourceFor(config, armEndpoint).getToken();
Token token = tokenSourceFor(config, "resource", armEndpoint).getToken();
String requestUrl =
armEndpoint + config.getAzureWorkspaceResourceId() + "?api-version=2018-04-01";
Request req = new Request("GET", requestUrl);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package com.databricks.sdk.core;

import com.databricks.sdk.core.oauth.Token;
import com.databricks.sdk.core.oauth.TokenSource;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import java.time.LocalDateTime;

import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;

class AzureCliCredentialsProviderTest {

private static final String WORKSPACE_RESOURCE_ID = "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws";
private static final String SUBSCRIPTION = "2a2345f8";
private static final String TOKEN = "t-123";
private static final String TOKEN_TYPE = "token-type";
public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/";


private static CliTokenSource mockTokenSource() {
CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class);
Mockito.when(tokenSource.getToken()).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now()));
return tokenSource;
}
private static AzureCliCredentialsProvider getAzureCliCredentialsProvider(TokenSource tokenSource) {

AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider());
Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION));
Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT));
Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION));

return provider;
}


@Test
void testWorkSpaceIDUsage() {
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource());
DatabricksConfig config = new DatabricksConfig()
.setHost(".azuredatabricks.")
.setCredentialsProvider(provider)
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID);

HeaderFactory header = provider.configure(config);

String token = header.headers().get("Authorization");
assertEquals(token, TOKEN_TYPE + " " + TOKEN);
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION));
Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
}

@Test
void testFallbackWhenTailsToGetTokenForSubscription() {
CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class);
Mockito.when(tokenSource.getToken()).thenThrow(new DatabricksException("error")).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now()));

AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(tokenSource);

DatabricksConfig config = new DatabricksConfig()
.setHost(".azuredatabricks.")
.setCredentialsProvider(provider)
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID);


HeaderFactory header = provider.configure(config);


String token = header.headers().get("Authorization");
assertEquals(token, TOKEN_TYPE + " " + TOKEN);

Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION));
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
}

@Test
void testGetTokenWithoutWorkspaceResourceID() {
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource());
DatabricksConfig config = new DatabricksConfig()
.setHost(".azuredatabricks.")
.setCredentialsProvider(provider);

HeaderFactory header = provider.configure(config);

String token = header.headers().get("Authorization");
assertEquals(token, TOKEN_TYPE + " " + TOKEN);
Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION));
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
}


}
Loading

0 comments on commit 7588861

Please sign in to comment.