Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DECO-2483] Handle Azure authentication when WorkspaceResourceID is provided #145

Merged
merged 3 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.databricks.sdk.core;

import com.databricks.sdk.core.oauth.Token;
import com.databricks.sdk.core.oauth.TokenSource;
import com.databricks.sdk.core.utils.AzureUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.*;
Expand All @@ -27,6 +28,24 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv);
}

@Override
public CliTokenSource tokenSourceFor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to somehow unify this with the other tokenSourceFor method. Not only do they share their implementation, but they probably must share their implementation (i.e. changes made to one will likely need to be made to the other in the future).

DatabricksConfig config, String resource, String subscription) {
List<String> cmd =
new ArrayList<>(
Arrays.asList(
"az",
"account",
"get-access-token",
"--subscription",
subscription,
"--resource",
resource,
"--output",
"json"));
return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv);
}

@Override
public HeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()) {
Expand All @@ -35,20 +54,49 @@ public HeaderFactory configure(DatabricksConfig config) {

try {
ensureHostPresent(config, mapper);
CliTokenSource tokenSource;
CliTokenSource mgmtTokenSource;
String resource = config.getEffectiveAzureLoginAppId();
CliTokenSource tokenSource = tokenSourceFor(config, resource);
CliTokenSource mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
tokenSource.getToken(); // We need this for checking if Azure CLI is installed.
Optional<String> subscription = getSubscription(config);

if (subscription.isPresent()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we refactor all of this into the tokenSourceFor method? I think that would prevent this configure() method from sprawling, and it seems to belong there in the first place.

try {
// This will fail if the user has access to the workspace, but not to the subscription
// itself.
// In such case, we fall back to not using the subscription.
tokenSource = tokenSourceFor(config, resource, subscription.get());
tokenSource.getToken();
mgmtTokenSource =
tokenSourceFor(
config,
config.getAzureEnvironment().getServiceManagementEndpoint(),
subscription.get());
} catch (DatabricksException e) {
LOG.warn("Failed to get token for subscription. Using resource only token.");
tokenSource = tokenSourceFor(config, resource);
mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
}
} else {
LOG.warn(
"azure_workspace_resource_id field not provided. "
+ "It is recommended to specify this field in the Databricks configuration to avoid authentication errors.");
tokenSource = tokenSourceFor(config, resource);
mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
}

tokenSource.getToken(); // We need this for checking if Azure CLI is installed
try {
mgmtTokenSource.getToken();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now get the token inside the "tokenSourceFor" function.

} catch (Exception e) {
LOG.debug("Not including service management token in headers", e);
mgmtTokenSource = null;
}
TokenSource finalToken = tokenSource;
CliTokenSource finalMgmtTokenSource = mgmtTokenSource;
return () -> {
Token token = tokenSource.getToken();
Token token = finalToken.getToken();
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken());
if (finalMgmtTokenSource != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public class DatabricksConfig {
sensitive = true)
private String googleCredentials;

/** Azure Resource Manager ID for Azure Databricks workspace, which is exhanged for a Host */
/** Azure Resource Manager ID for Azure Databricks workspace, which is exchanged for a Host */
@ConfigAttribute(
value = "azure_workspace_resource_id",
env = "DATABRICKS_AZURE_RESOURCE_ID",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

/**
* Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request,
Expand All @@ -27,9 +28,40 @@ public HeaderFactory configure(DatabricksConfig config) {
return null;
}
ensureHostPresent(config, mapper);
RefreshableTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
RefreshableTokenSource cloud =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
RefreshableTokenSource innerToken;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was confused when we were talking about this earlier. We don't need to change this provider at all: the tenant ID must be explicitly specified, see line 27. What I meant was that: if a user is logged into the Azure CLI with a service principal, in AzureCliCredentialsProvider, we still will take the same pathway.

RefreshableTokenSource cloudToken;
Optional<String> subscription = getSubscription(config);
if (subscription.isPresent()) {
try {
// This will fail if the service principal has access to the workspace, but not to the
// subscription itself.
// In such case, we fall back to not using the subscription.
innerToken =
tokenSourceFor(config, config.getEffectiveAzureLoginAppId(), subscription.get());
cloudToken =
tokenSourceFor(
config,
config.getAzureEnvironment().getServiceManagementEndpoint(),
subscription.get());
innerToken.getToken();
cloudToken.getToken();
} catch (DatabricksException e) {
LOG.warn("Failed to get token for subscription. Using resource only token.");
innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
cloudToken =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
}
} else {
LOG.warn(
"azure_workspace_resource_id field not provided. "
+ "It is recommended to specify this field in the Databricks configuration to avoid authentication errors.");
innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
cloudToken =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
}

RefreshableTokenSource inner = innerToken;
RefreshableTokenSource cloud = cloudToken;

return () -> {
Map<String, String> headers = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public interface AzureUtils {
Logger LOG = LoggerFactory.getLogger(AzureUtils.class);

/**
* Creates a RefreshableTokenSource for the specified Azure resource.
Expand All @@ -30,10 +32,23 @@ public interface AzureUtils {
* Azure resource.
*/
default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint();
String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token";
Map<String, String> endpointParams = new HashMap<>();
endpointParams.put("resource", resource);
return tokenSourceFor(config, endpointParams);
}

default RefreshableTokenSource tokenSourceFor(
DatabricksConfig config, String resource, String subscription) {
Map<String, String> endpointParams = new HashMap<>();
endpointParams.put("resource", resource);
endpointParams.put("subscription", subscription);
return tokenSourceFor(config, endpointParams);
}

default RefreshableTokenSource tokenSourceFor(
DatabricksConfig config, Map<String, String> endpointParams) {
String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint();
String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token";
return new ClientCredentials.Builder()
.withHttpClient(config.getHttpClient())
.withClientId(config.getAzureClientId())
Expand All @@ -44,6 +59,19 @@ default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String re
.build();
}

default Optional<String> getSubscription(DatabricksConfig config) {
String resourceId = config.getAzureWorkspaceResourceId();
if (resourceId == null || resourceId.equals("")) {
return Optional.empty();
}
String[] components = resourceId.split("/");
if (components.length < 3) {
LOG.warn("Invalid azure workspace resource ID");
return Optional.empty();
}
return Optional.of(components[2]);
}

default String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException {
JsonNode properties = jsonResponse.get("properties");
if (properties == null) {
Expand All @@ -69,7 +97,7 @@ default void ensureHostPresent(DatabricksConfig config, ObjectMapper mapper) {
}

String armEndpoint = config.getAzureEnvironment().getResourceManagerEndpoint();
Token token = tokenSourceFor(config, armEndpoint).getToken();
Token token = tokenSourceFor(config, "resource", armEndpoint).getToken();
String requestUrl =
armEndpoint + config.getAzureWorkspaceResourceId() + "?api-version=2018-04-01";
Request req = new Request("GET", requestUrl);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package com.databricks.sdk.core;

import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;

import com.databricks.sdk.core.oauth.Token;
import com.databricks.sdk.core.oauth.TokenSource;
import java.time.LocalDateTime;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

class AzureCliCredentialsProviderTest {

private static final String WORKSPACE_RESOURCE_ID =
"/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws";
private static final String SUBSCRIPTION = "2a2345f8";
private static final String TOKEN = "t-123";
private static final String TOKEN_TYPE = "token-type";
public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/";

private static CliTokenSource mockTokenSource() {
CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class);
Mockito.when(tokenSource.getToken())
.thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now()));
return tokenSource;
}

private static AzureCliCredentialsProvider getAzureCliCredentialsProvider(
TokenSource tokenSource) {

AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider());
Mockito.doReturn(tokenSource)
.when(provider)
.tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
Mockito.doReturn(tokenSource)
.when(provider)
.tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION));
Mockito.doReturn(tokenSource)
.when(provider)
.tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT));
Mockito.doReturn(tokenSource)
.when(provider)
.tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION));

return provider;
}

@Test
void testWorkSpaceIDUsage() {
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource());
DatabricksConfig config =
new DatabricksConfig()
.setHost(".azuredatabricks.")
.setCredentialsProvider(provider)
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID);

HeaderFactory header = provider.configure(config);

String token = header.headers().get("Authorization");
assertEquals(token, TOKEN_TYPE + " " + TOKEN);
Mockito.verify(provider, times(1))
.tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION));
Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
}

@Test
void testFallbackWhenTailsToGetTokenForSubscription() {
CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class);
Mockito.when(tokenSource.getToken())
.thenThrow(new DatabricksException("error"))
.thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now()));

AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(tokenSource);

DatabricksConfig config =
new DatabricksConfig()
.setHost(".azuredatabricks.")
.setCredentialsProvider(provider)
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID);

HeaderFactory header = provider.configure(config);

String token = header.headers().get("Authorization");
assertEquals(token, TOKEN_TYPE + " " + TOKEN);

Mockito.verify(provider, times(1))
.tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION));
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
}

@Test
void testGetTokenWithoutWorkspaceResourceID() {
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource());
DatabricksConfig config =
new DatabricksConfig().setHost(".azuredatabricks.").setCredentialsProvider(provider);

HeaderFactory header = provider.configure(config);

String token = header.headers().get("Authorization");
assertEquals(token, TOKEN_TYPE + " " + TOKEN);
Mockito.verify(provider, never())
.tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION));
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
}
}
Loading
Loading