Skip to content

Commit

Permalink
Added Databricks CLI & Azure CLI authentication and a basic integrati…
Browse files Browse the repository at this point in the history
…on test
  • Loading branch information
nfx committed Mar 27, 2024
1 parent a50c376 commit bb96aac
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 6 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/acceptance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: acceptance

on:
pull_request:
types: [opened, synchronize]

permissions:
id-token: write
contents: read
pull-requests: write

jobs:
integration:
if: github.event_name == 'pull_request'
environment: admin
runs-on: larger
steps:
- name: Checkout Code
uses: actions/checkout@v2.5.0

- name: Unshallow
run: git fetch --prune --unshallow

- uses: actions/checkout@v3

- uses: r-lib/actions/setup-r@v2
with:
r-version: release
use-public-rspm: true

- uses: azure/login@v1
with:
client-id: ${{ secrets.ARM_CLIENT_ID }}
tenant-id: ${{ secrets.ARM_TENANT_ID }}
allow-no-subscriptions: true

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: devtools
env:
R_COMPILE_AND_INSTALL_PACKAGES: never

- name: Run tests
run: Rscript -e "devtools::test()"
env:
CLOUD_ENV: "${{ vars.CLOUD_ENV }}"
DATABRICKS_HOST: "${{ secrets.DATABRICKS_HOST }}"

169 changes: 164 additions & 5 deletions R/api_client.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f

# cfg is the current unified authentication config of direct parameters,
# environment variables, and values loaded from ~/.databrickscfg file
cfg <- list(host = coalesce(host, Sys.getenv("DATABRICKS_HOST"), from_cli$host),
token = coalesce(token, Sys.getenv("DATABRICKS_TOKEN"), from_cli$token),
client_id = coalesce(Sys.getenv("DATABRICKS_CLIENT_ID"), from_cli$client_id),
client_secret = coalesce(Sys.getenv("DATABRICKS_CLIENT_SECRET"), from_cli$client_secret))
cfg <- new.env()
cfg$host = coalesce(host, from_cli$host, Sys.getenv("DATABRICKS_HOST"))
cfg$token = coalesce(token, from_cli$token, Sys.getenv("DATABRICKS_TOKEN"))
cfg$client_id = coalesce(from_cli$client_id, Sys.getenv("DATABRICKS_CLIENT_ID"))
cfg$client_secret = coalesce(from_cli$client_secret, Sys.getenv("DATABRICKS_CLIENT_SECRET"))

# add the missing https:// prefix to bare, ODBC-style hosts
if (!is.null(cfg$host) && !startsWith(cfg$host, "http")) {
Expand All @@ -105,7 +106,7 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f
used <- c()
sensitive <- c("token", "password", "client_secret", "google_credentials",
"azure_client_secret")
for (attr in names(cfg)) {
for (attr in sort(names(cfg))) {
value <- cfg[[attr]]
if (is.null(value)) {
next
Expand Down Expand Up @@ -140,6 +141,32 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f
return(function() {
c(Authentication = paste("Bearer", cfg$token))
})
}, `databricks-cli` = function() {
token_source <- .databricks_cli_token_source(cfg)
if (is.null(token_source)) {
return(NULL)
}
result <- try(token_source$token(), silent = TRUE)
if (inherits(result, "try-error")) {
return(NULL)
}
return(function() {
token <- token_source$token()
return(token$headers())
})
}, `azure-cli` = function() {
if (!is_azure()) {
return(NULL)
}
token_source <- .azure_cli_token_source("2ff814a6-3304-4ab8-85cb-cd0e6f879c1d")
result <- try(token_source$token(), silent = TRUE)
if (inherits(result, "try-error")) {
return(NULL)
}
return(function() {
token <- token_source$token()
return(token$headers())
})
})

# authenticate follows the semantics of Unified Client Authentication and
Expand Down Expand Up @@ -212,3 +239,135 @@ DatabricksClient <- function(profile = NULL, host = NULL, token = NULL, config_f

return(list(is_aws = is_aws, is_azure = is_azure, is_gcp = is_gcp, do = do, debug_string = debug_string))
}

.create_token <- function(access_token, token_type = NULL, expiry = NULL) {
if (is.na(expiry)) {
expiry <- Sys.time() + as.difftime(300, units = "secs")
}
state <- new.env()
state$access_token <- access_token
state$token_type <- token_type
state$expiry <- expiry
headers <- function() {
return(c(Authorization = paste(state$token_type, state$access_token)))
}
expired <- function() {
if (is.null(state$expiry)) {
return(FALSE)
}
# Azure Databricks rejects tokens that expire in 30 seconds or less, so we
# refresh the token 40 seconds before it expires.
potentially_expired <- state$expiry - as.difftime(40, units = "secs")
now <- Sys.time()
is_expired <- potentially_expired < now
return(is_expired)
}
valid <- function() {
if (is.null(state$access_token)) {
return(FALSE)
}
if (expired()) {
return(FALSE)
}
return(TRUE)
}
return(list(headers = headers, valid = valid))
}

.refreshable_token_source <- function(refresh) {
state <- new.env()
state$token <- NULL
# this is not thread-safe, but R is single-threaded
return(list(token = function() {
tok <- state$token
if (!is.null(tok)) {
is_valid <- tok$valid()
if (is_valid) {
return(tok)
}
}
state$token <- refresh()
return(state$token)
}))
}

.token_source <- function() {
return(list(token = function() {
stop("token method must be implemented", call. = FALSE)
}))
}

.cli_token_source <- function(cmd, token_type_field, access_token_field, expiry_field) {
parse_expiry <- function(expiry) {
formats <- c("%Y-%m-%dT%H:%M:%OS", "%Y-%m-%dT%H:%M:%S")
for (fmt in formats) {
tryCatch({
x <- as.POSIXct(expiry, format = fmt)
return(x)
}, error = function(e) {
# TODO: improve this
last_error <<- e
})
}
if (exists("last_error")) {
stop(last_error)
}
}
return(.refreshable_token_source(function() {
tryCatch({
# TODO: do better handling, so that we don't see a warning message, when
# Databricks OAuth is not configured for the given host on the given
# machine
out <- system2(command = cmd, stdout = TRUE, stderr = TRUE)
tryCatch({
it <- jsonlite::fromJSON(out)
expiry <- it[[expiry_field]]
expires_on <- parse_expiry(expiry)
access_token <- it[[access_token_field]]
token_type <- it[[token_type_field]]
return(.create_token(access_token, token_type, expires_on))
}, error = function(e) {
if (inherits(e, "error")) {
message <- if (length(e$message) > 0) e$message else ""
stop(paste("cannot unmarshal CLI result:", out, message))
} else if (inherits(e, "try-error")) {
stdout <- if (length(e$message) > 0) e$message else ""
stderr <- if (length(e$stderr) > 0) e$stderr else ""
message <- if (nchar(stdout) > 0) stdout else stderr
stop(paste("cannot get access token:", message))
}
})
}, error = function(e) {
stop("cannot execute CLI command:", e)
})
}))
}

.databricks_cli_token_source <- function(cfg) {
if (is.null(cfg$host)) {
return(NULL)
}
args <- c("auth", "token", "--host", cfg$host)
cli_path <- tryCatch({
# Try to find 'databricks' in PATH
(Sys.which("databricks"))
}, error = function(e) {
# If 'databricks' is not found, try to find 'databricks.exe'
if (Sys.info()["sysname"] == "Windows") {
(Sys.which("databricks.exe"))
} else {
(NULL)
}
})
if (is.null(cli_path)) {
return(NULL)
}
cmd <- c(cli_path, args)
return(.cli_token_source(cmd, "token_type", "access_token", "expiry"))
}

.azure_cli_token_source <- function(resource) {
cmd <- c("az", "account", "get-access-token", "--resource", resource, "--output",
"json")
return(.cli_token_source(cmd, "tokenType", "accessToken", "expiresOn"))
}
2 changes: 1 addition & 1 deletion tests/testthat/test_api_client.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ test_that("loads configuration file", {

test_that("parses configuration profile", {
client <- databricks:::DatabricksClient(config_file = "./data/awscfg", profile="client-secret")
expected <- "host=https://another.cloud.databricks.com/, client_id=xxx, client_secret=***"
expected <- "client_id=xxx, client_secret=***, host=https://another.cloud.databricks.com/"
expect_equal(expected, client$debug_string())
})

Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test_current_user.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
library(testthat)

skip_if(Sys.getenv("DATABRICKS_HOST") == "", "Not integration test")

test_that("detects current user", {
client <- databricks::DatabricksClient()
user <- databricks::me(client)
expect_false(is.null(user$userName))
})

0 comments on commit bb96aac

Please sign in to comment.