diff --git a/packages/cli/src/azuretoken.ts b/packages/cli/src/azuretoken.ts index 888c0449f9..502c6baece 100644 --- a/packages/cli/src/azuretoken.ts +++ b/packages/cli/src/azuretoken.ts @@ -1,28 +1,22 @@ import { AZURE_OPENAI_TOKEN_EXPIRATION, AZURE_OPENAI_TOKEN_SCOPES, + MODEL_PROVIDER_AZURE, } from "../../core/src/constants" +import { LanguageModelAuthenticationToken } from "../../core/src/host" import { logVerbose } from "../../core/src/util" -export interface AuthenticationToken { - token: string - expiresOnTimestamp: number -} - -export function isAzureTokenExpired(token: AuthenticationToken) { - return !token || token.expiresOnTimestamp < Date.now() - 5_000 // avoid data races -} - export async function createAzureToken( signal: AbortSignal -): Promise { +): Promise { const { DefaultAzureCredential } = await import("@azure/identity") const azureToken = await new DefaultAzureCredential().getToken( AZURE_OPENAI_TOKEN_SCOPES.slice(), { abortSignal: signal } ) const res = { - token: azureToken.token, + provider: MODEL_PROVIDER_AZURE, + token: "Bearer " + azureToken.token, expiresOnTimestamp: azureToken.expiresOnTimestamp ? azureToken.expiresOnTimestamp : Date.now() + AZURE_OPENAI_TOKEN_EXPIRATION, diff --git a/packages/cli/src/nodehost.ts b/packages/cli/src/nodehost.ts index 15899a1095..1bfa4deb1e 100644 --- a/packages/cli/src/nodehost.ts +++ b/packages/cli/src/nodehost.ts @@ -37,15 +37,13 @@ import { RuntimeHost, setRuntimeHost, ResponseStatus, + LanguageModelAuthenticationToken, + isLanguageModelAuthenticationTokenExpired, } from "../../core/src/host" import { AbortSignalOptions, TraceOptions } from "../../core/src/trace" import { logVerbose, unique } from "../../core/src/util" import { parseModelIdentifier } from "../../core/src/models" -import { - AuthenticationToken, - createAzureToken, - isAzureTokenExpired, -} from "./azuretoken" +import { createAzureToken } from "./azuretoken" import { LanguageModel } from "../../core/src/chat" import { errorMessage } from "../../core/src/error" import { BrowserManager } from "./playwright" @@ -151,7 +149,14 @@ export class NodeHost implements RuntimeHost { } clientLanguageModel: LanguageModel - private _azureToken: AuthenticationToken + setLanguageModelConnectionToken(token: LanguageModelAuthenticationToken) { + this._languageModelAuthenticationTokens[token.provider] = token + } + + private readonly _languageModelAuthenticationTokens: Record< + string, + LanguageModelAuthenticationToken + > = {} async getLanguageModelConfiguration( modelId: string, options?: { token?: boolean } & AbortSignalOptions & TraceOptions @@ -160,20 +165,25 @@ export class NodeHost implements RuntimeHost { await this.parseDefaults() const tok = await parseTokenFromEnv(process.env, modelId) if (!askToken && tok?.token) tok.token = "***" + console.log({ askToken, tok }) if ( askToken && tok && !tok.token && tok.provider === MODEL_PROVIDER_AZURE ) { - if (isAzureTokenExpired(this._azureToken)) { + let aztok = + this._languageModelAuthenticationTokens[MODEL_PROVIDER_AZURE] + if (isLanguageModelAuthenticationTokenExpired(aztok)) { logVerbose( - `fetching azure token (${this._azureToken?.expiresOnTimestamp >= Date.now() ? `expired ${new Date(this._azureToken.expiresOnTimestamp).toLocaleString()}` : "not available"})` + `fetching azure token (${aztok?.expiresOnTimestamp >= Date.now() ? `expired ${new Date(aztok.expiresOnTimestamp).toLocaleString()}` : "not available"})` ) - this._azureToken = await createAzureToken(signal) + aztok = this._languageModelAuthenticationTokens[ + MODEL_PROVIDER_AZURE + ] = await createAzureToken(signal) } - if (!this._azureToken) throw new Error("Azure token not available") - tok.token = "Bearer " + this._azureToken.token + if (!aztok) throw new Error("Azure token not available") + tok.token = aztok.token } if (!tok && this.clientLanguageModel) { return { diff --git a/packages/cli/src/run.ts b/packages/cli/src/run.ts index 3f484e7e2b..a73e9d6288 100644 --- a/packages/cli/src/run.ts +++ b/packages/cli/src/run.ts @@ -35,7 +35,11 @@ import { isCancelError, errorMessage } from "../../core/src/error" import { Fragment, GenerationResult } from "../../core/src/generation" import { parseKeyValuePair } from "../../core/src/fence" import { filePathOrUrlToWorkspaceFile, writeText } from "../../core/src/fs" -import { host, runtimeHost } from "../../core/src/host" +import { + host, + LanguageModelAuthenticationToken, + runtimeHost, +} from "../../core/src/host" import { isJSONLFilename, appendJSONL } from "../../core/src/jsonl" import { resolveModelConnectionInfo } from "../../core/src/models" import { @@ -121,11 +125,17 @@ export async function runScript( options: Partial & TraceOptions & CancellationOptions & { + modelToken?: LanguageModelAuthenticationToken infoCb?: (partialResponse: { text: string }) => void partialCb?: (progress: ChatCompletionsProgressReport) => void } ): Promise<{ exitCode: number; result?: GenerationResult }> { - const { trace = new MarkdownTrace(), infoCb, partialCb } = options || {} + const { + trace = new MarkdownTrace(), + infoCb, + partialCb, + modelToken, + } = options || {} let result: GenerationResult const excludedFiles = options.excludedFiles const excludeGitIgnore = !!options.excludeGitIgnore @@ -163,6 +173,8 @@ export async function runScript( return { exitCode, result } } + if (modelToken) runtimeHost.setLanguageModelConnectionToken(modelToken) + if (out) { if (removeOut) await emptyDir(out) await ensureDir(out) diff --git a/packages/cli/src/server.ts b/packages/cli/src/server.ts index 07836fdbcc..04848f8a3f 100644 --- a/packages/cli/src/server.ts +++ b/packages/cli/src/server.ts @@ -226,7 +226,13 @@ export async function startServer(options: { port: string }) { case "script.start": { cancelAll() - const { script, files = [], options = {}, runId } = data + const { + script, + files = [], + options = {}, + runId, + modelToken, + } = data const canceller = new AbortSignalCancellationController() const trace = new MarkdownTrace() @@ -252,6 +258,7 @@ export async function startServer(options: { port: string }) { logVerbose(`run ${runId}: starting`) const runner = runScript(script, files, { ...options, + modelToken, trace, cancellationToken: canceller.token, infoCb: ({ text }) => { diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index 0251c83119..9ec14fcd15 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -227,3 +227,4 @@ export const CONSOLE_COLOR_WARNING = 95 export const CONSOLE_COLOR_ERROR = 91 export const PLAYWRIGHT_DEFAULT_BROWSER = "chromium" +export const LANGUAGE_MODEL_TOKEN_EXPIRATION_OFFSET = 5_000 \ No newline at end of file diff --git a/packages/core/src/host.ts b/packages/core/src/host.ts index 403b401653..125818cc24 100644 --- a/packages/core/src/host.ts +++ b/packages/core/src/host.ts @@ -3,6 +3,7 @@ import { CancellationToken } from "./cancellation" import { LanguageModel } from "./chat" import { Progress } from "./progress" import { AbortSignalOptions, MarkdownTrace, TraceOptions } from "./trace" +import { LANGUAGE_MODEL_TOKEN_EXPIRATION_OFFSET } from "./constants" // this is typically an instance of TextDecoder export interface UTF8Decoder { @@ -27,11 +28,16 @@ export enum LogLevel { export type APIType = "openai" | "azure" | "localai" -export interface LanguageModelConfiguration { +export interface LanguageModelAuthenticationToken { provider: string + token: string + expiresOnTimestamp?: number +} + +export interface LanguageModelConfiguration + extends LanguageModelAuthenticationToken { model: string base: string - token: string curlHeaders?: Record type?: APIType source?: string @@ -39,6 +45,16 @@ export interface LanguageModelConfiguration { version?: string } +export function isLanguageModelAuthenticationTokenExpired( + token: LanguageModelAuthenticationToken +) { + return ( + token?.expiresOnTimestamp > 0 && + token?.expiresOnTimestamp < + Date.now() - LANGUAGE_MODEL_TOKEN_EXPIRATION_OFFSET + ) // avoid data races +} + export interface RetrievalClientOptions { progress?: Progress token?: CancellationToken @@ -128,6 +144,9 @@ export interface RuntimeHost extends Host { models: ModelService workspace: Omit + setLanguageModelConnectionToken( + token: LanguageModelAuthenticationToken + ): void readSecret(name: string): Promise // executes a process exec( diff --git a/packages/core/src/models.ts b/packages/core/src/models.ts index 3faaae0b70..b3275bf496 100644 --- a/packages/core/src/models.ts +++ b/packages/core/src/models.ts @@ -4,7 +4,11 @@ import { MODEL_PROVIDER_OPENAI, } from "./constants" import { errorMessage } from "./error" -import { LanguageModelConfiguration, host } from "./host" +import { + LanguageModelAuthenticationToken, + LanguageModelConfiguration, + host, +} from "./host" import { AbortSignalOptions, MarkdownTrace, TraceOptions } from "./trace" import { assert } from "./util" diff --git a/packages/core/src/server/client.ts b/packages/core/src/server/client.ts index 4d7d05836b..4d730a1fa5 100644 --- a/packages/core/src/server/client.ts +++ b/packages/core/src/server/client.ts @@ -4,6 +4,7 @@ import { randomHex } from "../crypto" import { errorMessage } from "../error" import { GenerationResult } from "../generation" import { + LanguageModelAuthenticationToken, LanguageModelConfiguration, ResponseStatus, ServerResponse, @@ -267,12 +268,14 @@ export class WebSocketClient extends EventTarget { jsSource?: string signal: AbortSignal trace: MarkdownTrace + modelToken?: LanguageModelAuthenticationToken infoCb: (partialResponse: { text: string }) => void partialCb: (progress: ChatCompletionsProgressReport) => void } ) { const runId = randomHex(6) - const { signal, infoCb, partialCb, trace, ...optionsRest } = options + const { signal, infoCb, partialCb, trace, modelToken, ...optionsRest } = + options let resolve: (value: GenerationResult) => void let reject: (reason?: any) => void const promise = new Promise((res, rej) => { @@ -299,6 +302,7 @@ export class WebSocketClient extends EventTarget { runId, script, files, + modelToken, options: optionsRest, }) if (!res.response?.ok) { diff --git a/packages/core/src/server/messages.ts b/packages/core/src/server/messages.ts index 3bd55bea11..cdd5b55740 100644 --- a/packages/core/src/server/messages.ts +++ b/packages/core/src/server/messages.ts @@ -1,6 +1,10 @@ import { ChatCompletionAssistantMessageParam } from "../chattypes" import { GenerationResult } from "../generation" -import { LanguageModelConfiguration, ResponseStatus } from "../host" +import { + LanguageModelAuthenticationToken, + LanguageModelConfiguration, + ResponseStatus, +} from "../host" export interface RequestMessage { type: string @@ -85,6 +89,7 @@ export interface PromptScriptStart extends RequestMessage { script: string files?: string[] options: Partial + modelToken?: LanguageModelAuthenticationToken } export interface PromptScriptStartResponse extends ResponseStatus { diff --git a/packages/core/src/testhost.ts b/packages/core/src/testhost.ts index a11afae805..4f02384298 100644 --- a/packages/core/src/testhost.ts +++ b/packages/core/src/testhost.ts @@ -8,6 +8,7 @@ import { UTF8Encoder, setRuntimeHost, RuntimeHost, + LanguageModelAuthenticationToken, } from "./host" import { TraceOptions } from "./trace" import { @@ -80,6 +81,9 @@ export class TestHost implements RuntimeHost { browse(url: string, options?: BrowseSessionOptions): Promise { throw new Error("Method not implemented.") } + setLanguageModelConnectionToken( + token: LanguageModelAuthenticationToken + ): void {} getLanguageModelConfiguration( modelId: string ): Promise { diff --git a/packages/vscode/src/azuremanager.ts b/packages/vscode/src/azuremanager.ts index 2bd9155179..42428edbaa 100644 --- a/packages/vscode/src/azuremanager.ts +++ b/packages/vscode/src/azuremanager.ts @@ -17,7 +17,7 @@ export class AzureManager { } async getOpenAIToken() { - if (this._session) return this._session.accessToken + if (this._session) return "Bearer " + this._session.accessToken // select account const accounts = await vscode.authentication.getAccounts("microsoft") @@ -53,7 +53,7 @@ export class AzureManager { } ) this._session = session - return this._session.accessToken + return "Bearer " + this._session.accessToken } catch {} try { @@ -68,7 +68,7 @@ export class AzureManager { } ) this._session = session - return this._session.accessToken + return "Bearer " + this._session.accessToken } catch (e) { const msg = errorMessage(e) vscode.window.showErrorMessage(msg) diff --git a/packages/vscode/src/state.ts b/packages/vscode/src/state.ts index 03115a778e..5e4a62803a 100644 --- a/packages/vscode/src/state.ts +++ b/packages/vscode/src/state.ts @@ -20,9 +20,13 @@ import { AI_REQUESTS_CACHE, TOOL_ID, GENAI_ANYJS_GLOB, + MODEL_PROVIDER_CLIENT, } from "../../core/src/constants" import { isCancelError } from "../../core/src/error" -import { resolveModelConnectionInfo } from "../../core/src/models" +import { + parseModelIdentifier, + resolveModelConnectionInfo, +} from "../../core/src/models" import { parseProject } from "../../core/src/parser" import { MarkdownTrace } from "../../core/src/trace" import { @@ -321,7 +325,16 @@ tests/ } if (connectionToken?.type === "localai") await startLocalAI() - // todo: send js source + const modelToken = + connectionToken.provider === MODEL_PROVIDER_CLIENT && + connectionToken.token + ? { + provider: parseModelIdentifier(connectionToken.model) + .provider, + token: connectionToken.token, + expiresOnTimestamp: connectionToken.expiresOnTimestamp, + } + : undefined const { runId, request } = await this.host.server.client.startScript( template.id, files, @@ -334,6 +347,7 @@ tests/ label, cache: cache ? template.cache : undefined, vars: parametersToVars(options.parameters), + modelToken, } ) r.runId = runId diff --git a/packages/vscode/src/vshost.ts b/packages/vscode/src/vshost.ts index f51b9001cf..96840e58b6 100644 --- a/packages/vscode/src/vshost.ts +++ b/packages/vscode/src/vshost.ts @@ -210,7 +210,7 @@ export class VSCodeHost extends EventTarget implements Host { ) { const azureToken = await this.azure.getOpenAIToken() if (!azureToken) throw new Error("Azure token not available") - tok.token = "Bearer " + azureToken + tok.token = azureToken tok.curlHeaders = { Authorization: "Bearer ***", }