Skip to content

Commit

Permalink
available model completion in ide (#480)
Browse files Browse the repository at this point in the history
* using copletion to mine local models

* handle no key

* support mining ollama models

* shuffling

* collect all known models

* get ollama list

* aici list models

* better completion

* more help on model selectin
  • Loading branch information
pelikhan authored May 28, 2024
1 parent 0475af1 commit e861438
Show file tree
Hide file tree
Showing 19 changed files with 697 additions and 366 deletions.
38 changes: 36 additions & 2 deletions docs/src/content/docs/getting-started/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ You will need to configure the LLM connection and authorizion secrets.
If you do not have access to an LLM, you can use a [local model](#local-models) for inferencing.
:::

## model selection

The model used by the script is configured throught the `model` field in the `script` function.
The model name is formatted as `provider:model-name`, where `provider` is the LLM provider
and the `model-name` is provider specific.

```js 'model: "openai:gpt-4"'
script({
model: "openai:gpt-4",
})
```

## `.env` file

GenAIScript uses a `.env` file to store the secrets.
Expand Down Expand Up @@ -77,6 +89,9 @@ the `.env` file will appear grayed out in Visual Studio Code.

## OpenAI

This provider, `openai`, is the default provider.
It uses the `OPENAI_API_...` environment variables.

<Steps>

<ol>
Expand All @@ -94,11 +109,28 @@ OPENAI_API_KEY=sk_...
```
</li>

<li>

Set the `model` field in `script` to the model you want to use.

```js 'model: "openai:gpt-4"'
script({
model: "openai:gpt-4",
...
})
```

</li>

</ol>

</Steps>

## Azure OpenAI ([reference](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions))
## Azure OpenAI ([reference](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions))

<a id="azure" href=""></a>

The Azure OpenAI provider, `azure` uses the `AZURE_OPENAI_...` environment variables.

<Steps>

Expand Down Expand Up @@ -131,7 +163,7 @@ AZURE_OPENAI_ENDPOINT=https://....openai.azure.com

Update the `model` field in the `script` function to match the model deployment name in your Azure resource.

```js "model"
```js 'model: "azure:deployment-id"'
script({
model: "azure:deployment-id",
...
Expand Down Expand Up @@ -189,6 +221,8 @@ OPENAI_API_TYPE=localai

Running tools locally may require additional GPU resources depending on the model you are using.

Use the `ollama` provider to access Ollama models.

<Steps>

<ol>
Expand Down
2 changes: 1 addition & 1 deletion packages/cli/src/llamaindexretrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ export class LlamaIndexRetrievalService

private async getModelToken(modelId: string) {
const { provider } = parseModelIdentifier(modelId)
const conn = await this.host.getSecretToken(modelId)
const conn = await this.host.getLanguageModelConfiguration(modelId)
if (provider === MODEL_PROVIDER_OLLAMA)
conn.base = conn.base.replace(/\/v1$/i, "")
return conn
Expand Down
6 changes: 4 additions & 2 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ import prompts from "prompts"
import {
AskUserOptions,
Host,
LanguageModelConfiguration,
LogLevel,
ModelService,
OAIToken,
ReadFileOptions,
RetrievalService,
SHELL_EXEC_TIMEOUT,
Expand Down Expand Up @@ -75,7 +75,9 @@ export class NodeHost implements Host {
return process.env[name]
}

async getSecretToken(modelId: string): Promise<OAIToken> {
async getLanguageModelConfiguration(
modelId: string
): Promise<LanguageModelConfiguration> {
return await parseTokenFromEnv(process.env, modelId)
}

Expand Down
2 changes: 1 addition & 1 deletion packages/cli/src/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function parseModelSpec(m: string): ModelOptions {
}

async function resolveTestProvider(script: PromptScript) {
const token = await host.getSecretToken(script.model)
const token = await host.getLanguageModelConfiguration(script.model)
if (token && token.type === "azure") return token.base
return undefined
}
Expand Down
36 changes: 35 additions & 1 deletion packages/core/src/aici.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ import {
ChatCompletionHandler,
ChatCompletionResponse,
LanguageModel,
LanguageModelInfo,
} from "./chat"
import { PromptNode, visitNode } from "./promptdom"
import { fromHex, logError, normalizeInt, utf8Decode } from "./util"
import { AICI_CONTROLLER, TOOL_ID } from "./constants"
import { host } from "./host"
import { LanguageModelConfiguration, host } from "./host"
import { NotSupportedError, RequestError } from "./error"
import { ChatCompletionContentPartText } from "openai/resources"
import { createFetch } from "./fetch"
Expand Down Expand Up @@ -356,7 +357,40 @@ const AICIChatCompletion: ChatCompletionHandler = async (
}
}

async function listModels(cfg: LanguageModelConfiguration) {
const { token, base, version } = cfg
const url = `${base}/${version || "v1"}/controllers/tags`
const fetch = await createFetch()
const res = await fetch(url, {
method: "GET",
headers: {
"api-key": token,
"user-agent": TOOL_ID,
accept: "application/json",
},
})
if (res.status !== 200) return []
const body = (await res.json()) as {
tags: {
tag: string
module_id: string
updated_at: number
updated_by: string
wasm_size: number
compiled_size: number
}[]
}
return body.tags.map(
(tag) =>
<LanguageModelInfo>{
id: tag.tag,
details: `${tag.module_id}`,
}
)
}

export const AICIModel = Object.freeze<LanguageModel>({
completer: AICIChatCompletion,
id: "aici",
listModels,
})
17 changes: 13 additions & 4 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Cache } from "./cache"
import { MarkdownTrace } from "./trace"
import { PromptImage } from "./promptdom"
import { AICIRequest } from "./aici"
import { OAIToken, host } from "./host"
import { LanguageModelConfiguration, host } from "./host"
import { GenerationOptions } from "./promptcontext"
import { JSON5TryParse, JSON5parse, isJSONObjectOrArray } from "./json5"
import { CancellationToken, checkCancelled } from "./cancellation"
Expand Down Expand Up @@ -80,7 +80,7 @@ export const ModelError = OpenAI.APIError

export type ChatCompletionRequestCacheKey = CreateChatCompletionRequest &
ModelOptions &
Omit<OAIToken, "token" | "source">
Omit<LanguageModelConfiguration, "token" | "source">

export type ChatCompletationRequestCacheValue = {
text: string
Expand Down Expand Up @@ -175,14 +175,23 @@ function encodeMessagesForLlama(req: CreateChatCompletionRequest) {
*/
export type ChatCompletionHandler = (
req: CreateChatCompletionRequest,
connection: OAIToken,
connection: LanguageModelConfiguration,
options: ChatCompletionsOptions,
trace: MarkdownTrace
) => Promise<ChatCompletionResponse>

export interface LanguageModelInfo {
id: string
details?: string
url?: string
}

export type ListModelsFunction = (cfg: LanguageModelConfiguration) => Promise<LanguageModelInfo[]>

export interface LanguageModel {
id: string
completer: ChatCompletionHandler
listModels?: ListModelsFunction
}

async function runToolCalls(
Expand Down Expand Up @@ -432,7 +441,7 @@ export function mergeGenerationOptions(
}

export async function executeChatSession(
connectionToken: OAIToken,
connectionToken: LanguageModelConfiguration,
cancellationToken: CancellationToken,
messages: ChatCompletionMessageParam[],
functions: ChatFunctionCallback[],
Expand Down
9 changes: 7 additions & 2 deletions packages/core/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ import {
OPENAI_API_BASE,
} from "./constants"
import { fileExists, readText, writeText } from "./fs"
import { APIType, OAIToken } from "./host"
import { APIType, LanguageModelConfiguration } from "./host"
import { parseModelIdentifier } from "./models"
import { trimTrailingSlash } from "./util"

export async function parseTokenFromEnv(
env: Record<string, string>,
modelId: string
): Promise<OAIToken> {
): Promise<LanguageModelConfiguration> {
const { provider, model, tag } = parseModelIdentifier(modelId)

if (provider === MODEL_PROVIDER_OPENAI) {
Expand Down Expand Up @@ -61,6 +61,7 @@ export async function parseTokenFromEnv(
if (base && !URL.canParse(base))
throw new Error("OPENAI_API_BASE must be a valid URL")
return {
provider,
base,
type,
token,
Expand Down Expand Up @@ -116,6 +117,7 @@ export async function parseTokenFromEnv(
if (!base.endsWith("/openai/deployments"))
base += "/openai/deployments"
return {
provider,
base,
token,
type: "azure",
Expand Down Expand Up @@ -148,6 +150,7 @@ export async function parseTokenFromEnv(
if (base && !URL.canParse(base))
throw new Error(`${modelBase} must be a valid URL`)
return {
provider,
token,
base,
type,
Expand All @@ -164,6 +167,7 @@ export async function parseTokenFromEnv(

if (provider === MODEL_PROVIDER_OLLAMA) {
return {
provider,
base: OLLAMA_API_BASE,
token: "ollama",
type: "openai",
Expand All @@ -173,6 +177,7 @@ export async function parseTokenFromEnv(

if (provider === MODEL_PROVIDER_LITELLM) {
return {
provider,
base: LITELLM_API_BASE,
token: "litellm",
type: "openai",
Expand Down
28 changes: 28 additions & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,34 @@ export const DOCS_CONFIGURATION_LOCALAI_URL =
export const DOCS_CONFIGURATION_AICI_URL =
"https://microsoft.github.io/genaiscript/reference/scripts/aici/"

export const MODEL_PROVIDERS = Object.freeze([
{
id: MODEL_PROVIDER_OPENAI,
detail: "OpenAI or compatible",
url: DOCS_CONFIGURATION_OPENAI_URL,
},
{
id: MODEL_PROVIDER_AZURE,
detail: "Azure OpenAI deployment",
url: DOCS_CONFIGURATION_AZURE_OPENAI_URL,
},
{
id: MODEL_PROVIDER_OLLAMA,
detail: "Ollama local model",
url: DOCS_CONFIGURATION_OLLAMA_URL,
},
{
id: MODEL_PROVIDER_LITELLM,
detail: "LiteLLM proxy",
url: DOCS_CONFIGURATION_LITELLM_URL,
},
{
id: MODEL_PROVIDER_AICI,
detail: "AICI controller",
url: DOCS_CONFIGURATION_AICI_URL,
},
])

export const NEW_SCRIPT_TEMPLATE = `// use def to emit LLM variables
// https://microsoft.github.io/genaiscript/reference/scripts/context/#definition-def
def("FILE", env.files)
Expand Down
5 changes: 3 additions & 2 deletions packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ export enum LogLevel {

export type APIType = "openai" | "azure" | "localai"

export interface OAIToken {
export interface LanguageModelConfiguration {
provider: string
base: string
token: string
curlHeaders?: Record<string, string>
Expand Down Expand Up @@ -137,7 +138,7 @@ export interface Host {

// read a secret from the environment or a .env file
readSecret(name: string): Promise<string | undefined>
getSecretToken(modelId: string): Promise<OAIToken | undefined>
getLanguageModelConfiguration(modelId: string): Promise<LanguageModelConfiguration | undefined>

log(level: LogLevel, msg: string): void

Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ export * from "./html"
export * from "./parameters"
export * from "./scripts"
export * from "./math"
export * from "./fence"
export * from "./fence"
export * from "./ollama"
16 changes: 8 additions & 8 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ import {
MODEL_PROVIDER_OPENAI,
} from "./constants"
import { errorMessage } from "./error"
import { OAIToken, host } from "./host"
import { LanguageModelConfiguration, host } from "./host"
import { OllamaModel } from "./ollama"
import { OpenAIModel } from "./openai"
import { GenerationOptions } from "./promptcontext"
import { TraceOptions } from "./trace"

export function resolveLanguageModel(
options: GenerationOptions
): LanguageModel {
export function resolveLanguageModel(options: {
model?: string
languageModel?: LanguageModel
}): LanguageModel {
if (options.languageModel) return options.languageModel
const { provider } = parseModelIdentifier(options.model)
if (provider === MODEL_PROVIDER_OLLAMA) return OllamaModel
Expand Down Expand Up @@ -48,21 +48,21 @@ export function parseModelIdentifier(id: string) {

export interface ModelConnectionInfo
extends ModelConnectionOptions,
Partial<OAIToken> {
Partial<LanguageModelConfiguration> {
error?: string
model: string
}

export async function resolveModelConnectionInfo(
conn: ModelConnectionOptions,
options?: { model?: string; token?: boolean } & TraceOptions
): Promise<{ info: ModelConnectionInfo; token?: OAIToken }> {
): Promise<{ info: ModelConnectionInfo; token?: LanguageModelConfiguration }> {
const { trace } = options || {}
const model = options.model ?? conn.model ?? DEFAULT_MODEL
try {
trace?.startDetails(`⚙️ configuration`)
trace?.itemValue(`model`, model)
const secret = await host.getSecretToken(model)
const secret = await host.getLanguageModelConfiguration(model)
if (!secret) {
return { info: { ...conn, model } }
} else {
Expand Down
Loading

0 comments on commit e861438

Please sign in to comment.