Skip to content

Commit

Permalink
better error message on expired tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 17, 2024
1 parent 85728ff commit e155a9b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 27 deletions.
32 changes: 24 additions & 8 deletions packages/cli/src/azuretoken.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
} from "../../core/src/host"
import { logVerbose } from "../../core/src/util"
import type { TokenCredential } from "@azure/identity"
import { serializeError } from "../../core/src/error"

/**
* This module provides functions to handle Azure authentication tokens,
Expand Down Expand Up @@ -63,6 +64,7 @@ export async function createAzureToken(
credential = new DefaultAzureCredential()
break
}

// Obtain the Azure token using the DefaultAzureCredential
const azureToken = await credential.getToken(scopes.slice(), {
abortSignal,
Expand All @@ -87,35 +89,49 @@ export async function createAzureToken(

class AzureTokenResolverImpl implements AzureTokenResolver {
_token: AuthenticationToken
_resolver: Promise<AuthenticationToken>
_error: any
_resolver: Promise<{ token?: AuthenticationToken; error?: SerializedError }>

constructor(
public readonly name: string,
public readonly envName: string,
public readonly scopes: readonly string[]
) {}

get error(): SerializedError {
return this._error
}

async token(
credentialsType: AzureCredentialsType,
optoins?: { signal?: AbortSignal }
): Promise<AuthenticationToken> {
): Promise<{ token?: AuthenticationToken; error?: SerializedError }> {
// cached
const { signal } = optoins || {}

if (isAzureTokenExpired(this._token)) this._token = undefined
if (this._token) return this._token
if (this._token || this._error)
return { token: this._token, error: this._error }
if (!this._resolver) {
const scope = await runtimeHost.readSecret(this.envName)
const scopes = scope ? scope.split(",") : this.scopes
this._resolver = createAzureToken(
scopes,
credentialsType,
signal || new AbortController().signal
).then((res) => {
this._token = res
this._resolver = undefined
return res
})
)
.then((res) => {
this._token = res
this._error = undefined
this._resolver = undefined
return { token: this._token, error: this._error }
})
.catch((err) => {
this._resolver = undefined
this._token = undefined
this._error = serializeError(err)
return { token: this._token, error: this._error }
})
}
return this._resolver
}
Expand Down
44 changes: 33 additions & 11 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -294,34 +294,56 @@ export class NodeHost implements RuntimeHost {
modelId: string,
options?: { token?: boolean } & AbortSignalOptions & TraceOptions
): Promise<LanguageModelConfiguration> {
const { signal, token: askToken } = options || {}
const { token: askToken, trace } = options || {}
const tok = await parseTokenFromEnv(process.env, modelId)
if (!askToken && tok?.token) tok.token = "***"
if (askToken && tok && !tok.token) {
if (
tok.provider === MODEL_PROVIDER_AZURE_OPENAI ||
tok.provider === MODEL_PROVIDER_AZURE_SERVERLESS_OPENAI
) {
const azureToken = await this.azureToken.token(
tok.azureCredentialsType,
options
)
if (!azureToken)
const { token: azureToken, error: azureTokenError } =
await this.azureToken.token(
tok.azureCredentialsType,
options
)
if (!azureToken) {
if (azureTokenError) {
logError(
`Azure OpenAI token not available for ${modelId}`
)
logVerbose(azureTokenError.message)
trace.error(
`Azure OpenAI token not available for ${modelId}`,
azureTokenError
)
}
throw new Error(
`Azure OpenAI token not available for ${modelId}`
)
}
tok.token = "Bearer " + azureToken.token
} else if (
tok.provider === MODEL_PROVIDER_AZURE_SERVERLESS_MODELS
) {
const azureToken = await this.azureServerlessToken.token(
tok.azureCredentialsType,
options
)
if (!azureToken)
const { token: azureToken, error: azureTokenError } =
await this.azureServerlessToken.token(
tok.azureCredentialsType,
options
)
if (!azureToken) {
if (azureTokenError) {
logError(`Azure AI token not available for ${modelId}`)
logVerbose(azureTokenError.message)
trace.error(
`Azure AI token not available for ${modelId}`,
azureTokenError
)
}
throw new Error(
`Azure AI token not available for ${modelId}`
)
}
tok.token = "Bearer " + azureToken.token
}
}
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ export interface AzureTokenResolver {
token(
credentialsType: AzureCredentialsType,
options?: AbortSignalOptions
): Promise<AuthenticationToken>
): Promise<{ token?: AuthenticationToken; error?: SerializedError }>
}

export type ModelConfiguration = Readonly<
Expand Down
16 changes: 9 additions & 7 deletions packages/core/src/usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,6 @@ export class GenerationStats {
const c = this.cost()
if (this.model && isNaN(c) && isCosteable(this.model))
unknowns.add(this.model)
const au = this.accumulatedUsage()
if (au?.total_tokens > 0 && (this.resolvedModel || c)) {
logVerbose(
`${indent}${this.label ? `${this.label} (${this.resolvedModel})` : this.resolvedModel}> ${au.total_tokens} tokens (${au.prompt_tokens} -> ${au.completion_tokens}) ${renderCost(c)}`
)
}
if (this.chatTurns.length > 1) {
const chatTurns = this.chatTurns.slice(0, 10)
for (const { messages, usage, model: turnModel } of chatTurns) {
Expand All @@ -277,7 +271,15 @@ export class GenerationStats {
if (this.chatTurns.length > chatTurns.length)
logVerbose(`${indent} ...`)
}
for (const child of this.children) child.logTokens(indent + " ")
const children = this.children.slice(0, 10)
for (const child of children) child.logTokens(indent + " ")
if (this.children.length > children.length) logVerbose(`${indent} ...`)
const au = this.accumulatedUsage()
if (au?.total_tokens > 0 && (this.resolvedModel || c)) {
logVerbose(
`${indent}${this.label ? `${this.label} (${this.resolvedModel})` : this.resolvedModel}> ${au.total_tokens} tokens (${au.prompt_tokens} -> ${au.completion_tokens}) ${renderCost(c)}`
)
}
if (unknowns.size)
logVerbose(`missing pricing for ${[...unknowns].join(", ")}`)
}
Expand Down

0 comments on commit e155a9b

Please sign in to comment.