Skip to content

Commit

Permalink
Refactor authentication token handling for Azure and adjust interface…
Browse files Browse the repository at this point in the history
…s across packages
  • Loading branch information
pelikhan committed Sep 5, 2024
1 parent c3da0fc commit 2612099
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 36 deletions.
16 changes: 5 additions & 11 deletions packages/cli/src/azuretoken.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
import {
AZURE_OPENAI_TOKEN_EXPIRATION,
AZURE_OPENAI_TOKEN_SCOPES,
MODEL_PROVIDER_AZURE,
} from "../../core/src/constants"
import { LanguageModelAuthenticationToken } from "../../core/src/host"
import { logVerbose } from "../../core/src/util"

export interface AuthenticationToken {
token: string
expiresOnTimestamp: number
}

export function isAzureTokenExpired(token: AuthenticationToken) {
return !token || token.expiresOnTimestamp < Date.now() - 5_000 // avoid data races
}

export async function createAzureToken(
signal: AbortSignal
): Promise<AuthenticationToken> {
): Promise<LanguageModelAuthenticationToken> {
const { DefaultAzureCredential } = await import("@azure/identity")
const azureToken = await new DefaultAzureCredential().getToken(
AZURE_OPENAI_TOKEN_SCOPES.slice(),
{ abortSignal: signal }
)
const res = {
token: azureToken.token,
provider: MODEL_PROVIDER_AZURE,
token: "Bearer " + azureToken.token,
expiresOnTimestamp: azureToken.expiresOnTimestamp
? azureToken.expiresOnTimestamp
: Date.now() + AZURE_OPENAI_TOKEN_EXPIRATION,

Check failure on line 22 in packages/cli/src/azuretoken.ts

View workflow job for this annotation

GitHub Actions / build

The `MODEL_PROVIDER_AZURE` constant is used but not imported in this file.
Expand Down
32 changes: 21 additions & 11 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@ import {
RuntimeHost,
setRuntimeHost,
ResponseStatus,
LanguageModelAuthenticationToken,
isLanguageModelAuthenticationTokenExpired,
} from "../../core/src/host"
import { AbortSignalOptions, TraceOptions } from "../../core/src/trace"
import { logVerbose, unique } from "../../core/src/util"
import { parseModelIdentifier } from "../../core/src/models"
import {
AuthenticationToken,
createAzureToken,
isAzureTokenExpired,
} from "./azuretoken"
import { createAzureToken } from "./azuretoken"
import { LanguageModel } from "../../core/src/chat"
import { errorMessage } from "../../core/src/error"
import { BrowserManager } from "./playwright"
Expand Down Expand Up @@ -151,7 +149,14 @@ export class NodeHost implements RuntimeHost {
}
clientLanguageModel: LanguageModel

private _azureToken: AuthenticationToken
setLanguageModelConnectionToken(token: LanguageModelAuthenticationToken) {
this._languageModelAuthenticationTokens[token.provider] = token
}

private readonly _languageModelAuthenticationTokens: Record<
string,
LanguageModelAuthenticationToken
> = {}

Check failure on line 159 in packages/cli/src/nodehost.ts

View workflow job for this annotation

GitHub Actions / build

The type of `_languageModelAuthenticationTokens` is not defined. It should be `Record<string, LanguageModelAuthenticationToken>`.
async getLanguageModelConfiguration(
modelId: string,
options?: { token?: boolean } & AbortSignalOptions & TraceOptions
Expand All @@ -160,20 +165,25 @@ export class NodeHost implements RuntimeHost {
await this.parseDefaults()
const tok = await parseTokenFromEnv(process.env, modelId)
if (!askToken && tok?.token) tok.token = "***"
console.log({ askToken, tok })
if (
askToken &&
tok &&
!tok.token &&
tok.provider === MODEL_PROVIDER_AZURE
) {
if (isAzureTokenExpired(this._azureToken)) {
let aztok =
this._languageModelAuthenticationTokens[MODEL_PROVIDER_AZURE]
if (isLanguageModelAuthenticationTokenExpired(aztok)) {
logVerbose(
`fetching azure token (${this._azureToken?.expiresOnTimestamp >= Date.now() ? `expired ${new Date(this._azureToken.expiresOnTimestamp).toLocaleString()}` : "not available"})`
`fetching azure token (${aztok?.expiresOnTimestamp >= Date.now() ? `expired ${new Date(aztok.expiresOnTimestamp).toLocaleString()}` : "not available"})`
)
this._azureToken = await createAzureToken(signal)
aztok = this._languageModelAuthenticationTokens[
MODEL_PROVIDER_AZURE
] = await createAzureToken(signal)
}
if (!this._azureToken) throw new Error("Azure token not available")
tok.token = "Bearer " + this._azureToken.token
if (!aztok) throw new Error("Azure token not available")
tok.token = aztok.token
}
if (!tok && this.clientLanguageModel) {
return <LanguageModelConfiguration>{
Expand Down
16 changes: 14 additions & 2 deletions packages/cli/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ import { isCancelError, errorMessage } from "../../core/src/error"
import { Fragment, GenerationResult } from "../../core/src/generation"
import { parseKeyValuePair } from "../../core/src/fence"
import { filePathOrUrlToWorkspaceFile, writeText } from "../../core/src/fs"
import { host, runtimeHost } from "../../core/src/host"
import {
host,
LanguageModelAuthenticationToken,
runtimeHost,
} from "../../core/src/host"
import { isJSONLFilename, appendJSONL } from "../../core/src/jsonl"
import { resolveModelConnectionInfo } from "../../core/src/models"
import {
Expand Down Expand Up @@ -121,11 +125,17 @@ export async function runScript(
options: Partial<PromptScriptRunOptions> &
TraceOptions &
CancellationOptions & {
modelToken?: LanguageModelAuthenticationToken
infoCb?: (partialResponse: { text: string }) => void
partialCb?: (progress: ChatCompletionsProgressReport) => void
}
): Promise<{ exitCode: number; result?: GenerationResult }> {
const { trace = new MarkdownTrace(), infoCb, partialCb } = options || {}
const {
trace = new MarkdownTrace(),
infoCb,
partialCb,
modelToken,
} = options || {}
let result: GenerationResult
const excludedFiles = options.excludedFiles
const excludeGitIgnore = !!options.excludeGitIgnore
Expand Down Expand Up @@ -163,6 +173,8 @@ export async function runScript(
return { exitCode, result }
}

if (modelToken) runtimeHost.setLanguageModelConnectionToken(modelToken)

if (out) {
if (removeOut) await emptyDir(out)
await ensureDir(out)
Expand Down
9 changes: 8 additions & 1 deletion packages/cli/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,13 @@ export async function startServer(options: { port: string }) {
case "script.start": {
cancelAll()

const { script, files = [], options = {}, runId } = data
const {
script,
files = [],
options = {},
runId,
modelToken,
} = data
const canceller =
new AbortSignalCancellationController()
const trace = new MarkdownTrace()
Expand All @@ -252,6 +258,7 @@ export async function startServer(options: { port: string }) {
logVerbose(`run ${runId}: starting`)
const runner = runScript(script, files, {
...options,
modelToken,
trace,
cancellationToken: canceller.token,
infoCb: ({ text }) => {
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,4 @@ export const CONSOLE_COLOR_WARNING = 95
export const CONSOLE_COLOR_ERROR = 91

export const PLAYWRIGHT_DEFAULT_BROWSER = "chromium"
export const LANGUAGE_MODEL_TOKEN_EXPIRATION_OFFSET = 5_000
23 changes: 21 additions & 2 deletions packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { CancellationToken } from "./cancellation"
import { LanguageModel } from "./chat"
import { Progress } from "./progress"
import { AbortSignalOptions, MarkdownTrace, TraceOptions } from "./trace"
import { LANGUAGE_MODEL_TOKEN_EXPIRATION_OFFSET } from "./constants"

// this is typically an instance of TextDecoder
export interface UTF8Decoder {
Expand All @@ -27,18 +28,33 @@ export enum LogLevel {

export type APIType = "openai" | "azure" | "localai"

export interface LanguageModelConfiguration {
export interface LanguageModelAuthenticationToken {
provider: string
token: string
expiresOnTimestamp?: number
}

export interface LanguageModelConfiguration
extends LanguageModelAuthenticationToken {
model: string
base: string
token: string
curlHeaders?: Record<string, string>
type?: APIType
source?: string
aici?: boolean
version?: string
}

export function isLanguageModelAuthenticationTokenExpired(
token: LanguageModelAuthenticationToken
) {
return (
token?.expiresOnTimestamp > 0 &&
token?.expiresOnTimestamp <
Date.now() - LANGUAGE_MODEL_TOKEN_EXPIRATION_OFFSET
) // avoid data races
}

Check failure on line 56 in packages/core/src/host.ts

View workflow job for this annotation

GitHub Actions / build

The `LANGUAGE_MODEL_TOKEN_EXPIRATION_OFFSET` constant is used but not imported in this file.

export interface RetrievalClientOptions {
progress?: Progress
token?: CancellationToken
Expand Down Expand Up @@ -128,6 +144,9 @@ export interface RuntimeHost extends Host {
models: ModelService
workspace: Omit<WorkspaceFileSystem, "grep">

setLanguageModelConnectionToken(
token: LanguageModelAuthenticationToken
): void
readSecret(name: string): Promise<string | undefined>
// executes a process
exec(
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ import {
MODEL_PROVIDER_OPENAI,
} from "./constants"
import { errorMessage } from "./error"
import { LanguageModelConfiguration, host } from "./host"
import {
LanguageModelAuthenticationToken,
LanguageModelConfiguration,
host,
} from "./host"
import { AbortSignalOptions, MarkdownTrace, TraceOptions } from "./trace"
import { assert } from "./util"

Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/server/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { randomHex } from "../crypto"
import { errorMessage } from "../error"
import { GenerationResult } from "../generation"
import {
LanguageModelAuthenticationToken,
LanguageModelConfiguration,
ResponseStatus,
ServerResponse,
Expand Down Expand Up @@ -267,12 +268,14 @@ export class WebSocketClient extends EventTarget {
jsSource?: string
signal: AbortSignal
trace: MarkdownTrace
modelToken?: LanguageModelAuthenticationToken
infoCb: (partialResponse: { text: string }) => void
partialCb: (progress: ChatCompletionsProgressReport) => void
}
) {
const runId = randomHex(6)
const { signal, infoCb, partialCb, trace, ...optionsRest } = options
const { signal, infoCb, partialCb, trace, modelToken, ...optionsRest } =
options
let resolve: (value: GenerationResult) => void
let reject: (reason?: any) => void
const promise = new Promise<GenerationResult>((res, rej) => {
Expand All @@ -299,6 +302,7 @@ export class WebSocketClient extends EventTarget {
runId,
script,
files,
modelToken,
options: optionsRest,
})
if (!res.response?.ok) {
Expand Down
7 changes: 6 additions & 1 deletion packages/core/src/server/messages.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import { ChatCompletionAssistantMessageParam } from "../chattypes"
import { GenerationResult } from "../generation"
import { LanguageModelConfiguration, ResponseStatus } from "../host"
import {
LanguageModelAuthenticationToken,
LanguageModelConfiguration,
ResponseStatus,
} from "../host"

export interface RequestMessage {
type: string
Expand Down Expand Up @@ -85,6 +89,7 @@ export interface PromptScriptStart extends RequestMessage {
script: string
files?: string[]
options: Partial<PromptScriptRunOptions>
modelToken?: LanguageModelAuthenticationToken
}

export interface PromptScriptStartResponse extends ResponseStatus {
Expand Down
4 changes: 4 additions & 0 deletions packages/core/src/testhost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
UTF8Encoder,
setRuntimeHost,
RuntimeHost,
LanguageModelAuthenticationToken,
} from "./host"
import { TraceOptions } from "./trace"
import {
Expand Down Expand Up @@ -80,6 +81,9 @@ export class TestHost implements RuntimeHost {
browse(url: string, options?: BrowseSessionOptions): Promise<BrowserPage> {
throw new Error("Method not implemented.")
}
setLanguageModelConnectionToken(
token: LanguageModelAuthenticationToken
): void {}
getLanguageModelConfiguration(
modelId: string
): Promise<LanguageModelConfiguration> {
Expand Down
6 changes: 3 additions & 3 deletions packages/vscode/src/azuremanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export class AzureManager {
}

async getOpenAIToken() {
if (this._session) return this._session.accessToken
if (this._session) return "Bearer " + this._session.accessToken

// select account
const accounts = await vscode.authentication.getAccounts("microsoft")
Expand Down Expand Up @@ -53,7 +53,7 @@ export class AzureManager {
}
)
this._session = session
return this._session.accessToken
return "Bearer " + this._session.accessToken
} catch {}

try {
Expand All @@ -68,7 +68,7 @@ export class AzureManager {
}
)
this._session = session
return this._session.accessToken
return "Bearer " + this._session.accessToken
} catch (e) {
const msg = errorMessage(e)
vscode.window.showErrorMessage(msg)
Expand Down
18 changes: 16 additions & 2 deletions packages/vscode/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@ import {
AI_REQUESTS_CACHE,
TOOL_ID,
GENAI_ANYJS_GLOB,
MODEL_PROVIDER_CLIENT,
} from "../../core/src/constants"
import { isCancelError } from "../../core/src/error"
import { resolveModelConnectionInfo } from "../../core/src/models"
import {
parseModelIdentifier,
resolveModelConnectionInfo,
} from "../../core/src/models"
import { parseProject } from "../../core/src/parser"
import { MarkdownTrace } from "../../core/src/trace"
import {
Expand Down Expand Up @@ -321,7 +325,16 @@ tests/
}
if (connectionToken?.type === "localai") await startLocalAI()

// todo: send js source
const modelToken =
connectionToken.provider === MODEL_PROVIDER_CLIENT &&
connectionToken.token
? {
provider: parseModelIdentifier(connectionToken.model)
.provider,
token: connectionToken.token,
expiresOnTimestamp: connectionToken.expiresOnTimestamp,
}
: undefined
const { runId, request } = await this.host.server.client.startScript(
template.id,
files,
Expand All @@ -334,6 +347,7 @@ tests/
label,
cache: cache ? template.cache : undefined,
vars: parametersToVars(options.parameters),
modelToken,
}
)
r.runId = runId
Expand Down
2 changes: 1 addition & 1 deletion packages/vscode/src/vshost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ export class VSCodeHost extends EventTarget implements Host {
) {
const azureToken = await this.azure.getOpenAIToken()
if (!azureToken) throw new Error("Azure token not available")
tok.token = "Bearer " + azureToken
tok.token = azureToken
tok.curlHeaders = {
Authorization: "Bearer ***",
}
Expand Down

0 comments on commit 2612099

Please sign in to comment.