From bb96aac144422722ecf433d51aad71b861b4b5ad Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Wed, 27 Mar 2024 20:16:13 +0100 Subject: [PATCH] Added Databricks CLI & Azure CLI authentication and a basic integration test --- .github/workflows/acceptance.yaml | 48 ++++++++ R/api_client.R | 169 ++++++++++++++++++++++++++++- tests/testthat/test_api_client.R | 2 +- tests/testthat/test_current_user.R | 9 ++ 4 files changed, 222 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/acceptance.yaml create mode 100644 tests/testthat/test_current_user.R diff --git a/.github/workflows/acceptance.yaml b/.github/workflows/acceptance.yaml new file mode 100644 index 00000000..e66d6ae2 --- /dev/null +++ b/.github/workflows/acceptance.yaml @@ -0,0 +1,48 @@ +name: acceptance + +on: + pull_request: + types: [opened, synchronize] + +permissions: + id-token: write + contents: read + pull-requests: write + +jobs: + integration: + if: github.event_name == 'pull_request' + environment: admin + runs-on: larger + steps: + - name: Checkout Code + uses: actions/checkout@v2.5.0 + + - name: Unshallow + run: git fetch --prune --unshallow + + - uses: actions/checkout@v3 + + - uses: r-lib/actions/setup-r@v2 + with: + r-version: release + use-public-rspm: true + + - uses: azure/login@v1 + with: + client-id: ${{ secrets.ARM_CLIENT_ID }} + tenant-id: ${{ secrets.ARM_TENANT_ID }} + allow-no-subscriptions: true + + - uses: r-lib/actions/setup-r-dependencies@v2 + with: + extra-packages: devtools + env: + R_COMPILE_AND_INSTALL_PACKAGES: never + + - name: Run tests + run: Rscript -e "devtools::test()" + env: + CLOUD_ENV: "${{ vars.CLOUD_ENV }}" + DATABRICKS_HOST: "${{ secrets.DATABRICKS_HOST }}" + \ No newline at end of file diff --git a/R/api_client.R b/R/api_client.R index 6b2972f2..3a9dd762 100644 --- a/R/api_client.R +++ b/R/api_client.R @@ -86,10 +86,11 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f # cfg is the current unified authentication config of direct parameters, # environment variables, and values loaded from ~/.databrickscfg file - cfg <- list(host = coalesce(host, Sys.getenv("DATABRICKS_HOST"), from_cli$host), - token = coalesce(token, Sys.getenv("DATABRICKS_TOKEN"), from_cli$token), - client_id = coalesce(Sys.getenv("DATABRICKS_CLIENT_ID"), from_cli$client_id), - client_secret = coalesce(Sys.getenv("DATABRICKS_CLIENT_SECRET"), from_cli$client_secret)) + cfg <- new.env() + cfg$host = coalesce(host, from_cli$host, Sys.getenv("DATABRICKS_HOST")) + cfg$token = coalesce(token, from_cli$token, Sys.getenv("DATABRICKS_TOKEN")) + cfg$client_id = coalesce(from_cli$client_id, Sys.getenv("DATABRICKS_CLIENT_ID")) + cfg$client_secret = coalesce(from_cli$client_secret, Sys.getenv("DATABRICKS_CLIENT_SECRET")) # add the missing https:// prefix to bare, ODBC-style hosts if (!is.null(cfg$host) && !startsWith(cfg$host, "http")) { @@ -105,7 +106,7 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f used <- c() sensitive <- c("token", "password", "client_secret", "google_credentials", "azure_client_secret") - for (attr in names(cfg)) { + for (attr in sort(names(cfg))) { value <- cfg[[attr]] if (is.null(value)) { next @@ -140,6 +141,32 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f return(function() { c(Authentication = paste("Bearer", cfg$token)) }) + }, `databricks-cli` = function() { + token_source <- .databricks_cli_token_source(cfg) + if (is.null(token_source)) { + return(NULL) + } + result <- try(token_source$token(), silent = TRUE) + if (inherits(result, "try-error")) { + return(NULL) + } + return(function() { + token <- token_source$token() + return(token$headers()) + }) + }, `azure-cli` = function() { + if (!is_azure()) { + return(NULL) + } + token_source <- .azure_cli_token_source("2ff814a6-3304-4ab8-85cb-cd0e6f879c1d") + result <- try(token_source$token(), silent = TRUE) + if (inherits(result, "try-error")) { + return(NULL) + } + return(function() { + token <- token_source$token() + return(token$headers()) + }) }) # authenticate follows the semantics of Unified Client Authentication and @@ -212,3 +239,135 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f return(list(is_aws = is_aws, is_azure = is_azure, is_gcp = is_gcp, do = do, debug_string = debug_string)) } + +.create_token <- function(access_token, token_type = NULL, expiry = NULL) { + if (is.na(expiry)) { + expiry <- Sys.time() + as.difftime(300, units = "secs") + } + state <- new.env() + state$access_token <- access_token + state$token_type <- token_type + state$expiry <- expiry + headers <- function() { + return(c(Authorization = paste(state$token_type, state$access_token))) + } + expired <- function() { + if (is.null(state$expiry)) { + return(FALSE) + } + # Azure Databricks rejects tokens that expire in 30 seconds or less, so we + # refresh the token 40 seconds before it expires. + potentially_expired <- state$expiry - as.difftime(40, units = "secs") + now <- Sys.time() + is_expired <- potentially_expired < now + return(is_expired) + } + valid <- function() { + if (is.null(state$access_token)) { + return(FALSE) + } + if (expired()) { + return(FALSE) + } + return(TRUE) + } + return(list(headers = headers, valid = valid)) +} + +.refreshable_token_source <- function(refresh) { + state <- new.env() + state$token <- NULL + # this is not thread-safe, but R is single-threaded + return(list(token = function() { + tok <- state$token + if (!is.null(tok)) { + is_valid <- tok$valid() + if (is_valid) { + return(tok) + } + } + state$token <- refresh() + return(state$token) + })) +} + +.token_source <- function() { + return(list(token = function() { + stop("token method must be implemented", call. = FALSE) + })) +} + +.cli_token_source <- function(cmd, token_type_field, access_token_field, expiry_field) { + parse_expiry <- function(expiry) { + formats <- c("%Y-%m-%dT%H:%M:%OS", "%Y-%m-%dT%H:%M:%S") + for (fmt in formats) { + tryCatch({ + x <- as.POSIXct(expiry, format = fmt) + return(x) + }, error = function(e) { + # TODO: improve this + last_error <<- e + }) + } + if (exists("last_error")) { + stop(last_error) + } + } + return(.refreshable_token_source(function() { + tryCatch({ + # TODO: do better handling, so that we don't see a warning message, when + # Databricks OAuth is not configured for the given host on the given + # machine + out <- system2(command = cmd, stdout = TRUE, stderr = TRUE) + tryCatch({ + it <- jsonlite::fromJSON(out) + expiry <- it[[expiry_field]] + expires_on <- parse_expiry(expiry) + access_token <- it[[access_token_field]] + token_type <- it[[token_type_field]] + return(.create_token(access_token, token_type, expires_on)) + }, error = function(e) { + if (inherits(e, "error")) { + message <- if (length(e$message) > 0) e$message else "" + stop(paste("cannot unmarshal CLI result:", out, message)) + } else if (inherits(e, "try-error")) { + stdout <- if (length(e$message) > 0) e$message else "" + stderr <- if (length(e$stderr) > 0) e$stderr else "" + message <- if (nchar(stdout) > 0) stdout else stderr + stop(paste("cannot get access token:", message)) + } + }) + }, error = function(e) { + stop("cannot execute CLI command:", e) + }) + })) +} + +.databricks_cli_token_source <- function(cfg) { + if (is.null(cfg$host)) { + return(NULL) + } + args <- c("auth", "token", "--host", cfg$host) + cli_path <- tryCatch({ + # Try to find 'databricks' in PATH + (Sys.which("databricks")) + }, error = function(e) { + # If 'databricks' is not found, try to find 'databricks.exe' + if (Sys.info()["sysname"] == "Windows") { + (Sys.which("databricks.exe")) + } else { + (NULL) + } + }) + if (is.null(cli_path)) { + return(NULL) + } + cmd <- c(cli_path, args) + return(.cli_token_source(cmd, "token_type", "access_token", "expiry")) +} + +.azure_cli_token_source <- function(resource) { + cmd <- c("az", "account", "get-access-token", "--resource", resource, "--output", + "json") + return(.cli_token_source(cmd, "tokenType", "accessToken", "expiresOn")) +} diff --git a/tests/testthat/test_api_client.R b/tests/testthat/test_api_client.R index f75daf3c..fca97b16 100644 --- a/tests/testthat/test_api_client.R +++ b/tests/testthat/test_api_client.R @@ -14,7 +14,7 @@ test_that("loads configuration file", { test_that("parses configuration profile", { client <- databricks:::DatabricksClient(config_file = "./data/awscfg", profile="client-secret") - expected <- "host=https://another.cloud.databricks.com/, client_id=xxx, client_secret=***" + expected <- "client_id=xxx, client_secret=***, host=https://another.cloud.databricks.com/" expect_equal(expected, client$debug_string()) }) diff --git a/tests/testthat/test_current_user.R b/tests/testthat/test_current_user.R new file mode 100644 index 00000000..62009e58 --- /dev/null +++ b/tests/testthat/test_current_user.R @@ -0,0 +1,9 @@ +library(testthat) + +skip_if(Sys.getenv("DATABRICKS_HOST") == "", "Not integration test") + +test_that("detects current user", { + client <- databricks::DatabricksClient() + user <- databricks::me(client) + expect_false(is.null(user$userName)) +}) \ No newline at end of file