diff --git a/packages/cli/src/parse.ts b/packages/cli/src/parse.ts index 4fccf038d2..5401d69a04 100644 --- a/packages/cli/src/parse.ts +++ b/packages/cli/src/parse.ts @@ -98,7 +98,7 @@ export async function parseJinja2( } ) { let src = await readFile(file, { encoding: "utf-8" }) - if (PROMPTY_REGEX.test(file)) src = promptyParse(src).content + if (PROMPTY_REGEX.test(file)) src = promptyParse(file, src).content else if (MD_REGEX.test(file)) src = splitMarkdown(src).content const vars: Record = parseOptionsVars( @@ -188,7 +188,7 @@ export async function parseTokens( options: { excludedFiles: string[]; model: string } ) { const { model = DEFAULT_MODEL } = options || {} - const encoder = await resolveTokenEncoder(model) + const { encode: encoder } = await resolveTokenEncoder(model) const files = await expandFiles(filesGlobs, options?.excludedFiles) console.log(`parsing ${files.length} files`) @@ -222,7 +222,7 @@ export async function prompty2genaiscript( : replaceExt(f, ".genai.mts") console.log(`${f} -> ${gf}`) const content = await readText(f) - const doc = promptyParse(content) + const doc = promptyParse(f, content) const script = promptyToGenAIScript(doc) await writeText(gf, script) } diff --git a/packages/cli/src/run.ts b/packages/cli/src/run.ts index 7b607d7206..a9216a766e 100644 --- a/packages/cli/src/run.ts +++ b/packages/cli/src/run.ts @@ -272,7 +272,7 @@ export async function runScript( DOCS_CONFIGURATION_URL ) } - trace.options.encoder = await resolveTokenEncoder(info.model) + trace.options.encoder = (await resolveTokenEncoder(info.model)).encode await runtimeHost.models.pullModel(info.model) let tokenColor = 0 diff --git a/packages/core/src/anthropic.ts b/packages/core/src/anthropic.ts index f2cead1a24..d54f0efcb8 100644 --- a/packages/core/src/anthropic.ts +++ b/packages/core/src/anthropic.ts @@ -197,7 +197,7 @@ export const AnthropicChatCompletion: ChatCompletionHandler = async ( const { requestOptions, partialCb, cancellationToken, inner } = options const { headers } = requestOptions || {} const { model } = parseModelIdentifier(req.model) - const encoder = await resolveTokenEncoder(model) + const { encode: encoder } = await resolveTokenEncoder(model) const anthropic = new Anthropic({ baseURL: cfg.base, diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index 5229edef24..74625b7b66 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -140,7 +140,7 @@ async function runToolCalls( ) { const projFolder = host.projectFolder() const { cancellationToken, trace, model } = options || {} - const encoder = await resolveTokenEncoder(model) + const { encode: encoder } = await resolveTokenEncoder(model) assert(!!trace) let edits: Edits[] = [] diff --git a/packages/core/src/encoders.test.ts b/packages/core/src/encoders.test.ts index 05274b8bc1..ad58c1f0a1 100644 --- a/packages/core/src/encoders.test.ts +++ b/packages/core/src/encoders.test.ts @@ -6,27 +6,27 @@ import { encode as defaultEncode } from "gpt-tokenizer" describe("resolveTokenEncoder", () => { test("gpt-3.5-turbo", async () => { const encoder = await resolveTokenEncoder("gpt-3.5-turbo") - const result = encoder("test line") + const result = encoder.encode("test line") assert.deepEqual(result, [1985, 1584]) }) test("gpt-4", async () => { const encoder = await resolveTokenEncoder("gpt-4") - const result = encoder("test line") + const result = encoder.encode("test line") assert.deepEqual(result, [1985, 1584]) }) test("gpt-4o", async () => { const encoder = await resolveTokenEncoder("gpt-4o") - const result = encoder("test line") + const result = encoder.encode("test line") assert.deepEqual(result, [3190, 2543]) }) test("gpt-4o-mini", async () => { const encoder = await resolveTokenEncoder("gpt-4o-mini") - const result = encoder("test line") + const result = encoder.encode("test line") assert.deepEqual(result, [3190, 2543]) }) test("gpt-4o forbidden", async () => { const encoder = await resolveTokenEncoder("gpt-4o") - const result = encoder("<|im_end|>") + const result = encoder.encode("<|im_end|>") assert.deepEqual(result, [27, 91, 321, 13707, 91, 29]) }) }) diff --git a/packages/core/src/encoders.ts b/packages/core/src/encoders.ts index 95b51388c7..a1a54a9bd5 100644 --- a/packages/core/src/encoders.ts +++ b/packages/core/src/encoders.ts @@ -1,6 +1,11 @@ // Import the function to parse model identifiers import { parseModelIdentifier } from "./models" +export interface TokenEncoders { + encode: TokenEncoder + decode: TokenDecoder +} + /** * Resolves the appropriate token encoder based on the given model ID. * @param modelId - The identifier for the model to resolve the encoder for. @@ -8,7 +13,7 @@ import { parseModelIdentifier } from "./models" */ export async function resolveTokenEncoder( modelId: string -): Promise { +): Promise { // Parse the model identifier to extract the model information const { model } = parseModelIdentifier(modelId) const module = model // Assign model to module for dynamic import path @@ -16,11 +21,17 @@ export async function resolveTokenEncoder( const options = { disallowedSpecial: new Set() } try { // Attempt to dynamically import the encoder module for the specified model - const mod = await import(`gpt-tokenizer/model/${module}`) - return (line) => mod.encode(line, options) // Return the encoder function + const { encode, decode } = await import(`gpt-tokenizer/model/${module}`) + return Object.freeze({ + encode: (line) => encode(line, options), // Return the default encoder function + decode, + }) } catch (e) { // If the specific model encoder is not found, default to gpt-4o encoder - const { encode } = await import("gpt-tokenizer") - return (line) => encode(line, options) // Return the default encoder function + const { encode, decode } = await import("gpt-tokenizer") + return Object.freeze({ + encode: (line) => encode(line, options), // Return the default encoder function + decode, + }) } } diff --git a/packages/core/src/git.ts b/packages/core/src/git.ts index f1fb381c77..ba5811470a 100644 --- a/packages/core/src/git.ts +++ b/packages/core/src/git.ts @@ -282,7 +282,7 @@ export class GitClient implements Git { } if (!nameOnly && llmify) { res = llmifyDiff(res) - const encoder = await resolveTokenEncoder( + const { encode: encoder } = await resolveTokenEncoder( runtimeHost.defaultModelOptions.model || DEFAULT_MODEL ) const tokens = estimateTokens(res, encoder) diff --git a/packages/core/src/globals.ts b/packages/core/src/globals.ts index f206f149d4..b981c6ccd3 100644 --- a/packages/core/src/globals.ts +++ b/packages/core/src/globals.ts @@ -123,14 +123,14 @@ export function installGlobals() { glb.tokenizers = Object.freeze({ count: async (text, options) => { - const encoder = await resolveTokenEncoder( + const { encode: encoder } = await resolveTokenEncoder( options?.model || runtimeHost.defaultModelOptions.model ) const c = await estimateTokens(text, encoder) return c }, truncate: async (text, maxTokens, options) => { - const encoder = await resolveTokenEncoder( + const { encode: encoder } = await resolveTokenEncoder( options?.model || runtimeHost.defaultModelOptions.model ) return await truncateTextToTokens(text, maxTokens, encoder, options) diff --git a/packages/core/src/openai.ts b/packages/core/src/openai.ts index 38c7dc59f6..8d98b6a859 100644 --- a/packages/core/src/openai.ts +++ b/packages/core/src/openai.ts @@ -73,7 +73,7 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async ( const { headers = {}, ...rest } = requestOptions || {} const { token, source, ...cfgNoToken } = cfg const { model } = parseModelIdentifier(req.model) - const encoder = await resolveTokenEncoder(model) + const { encode: encoder } = await resolveTokenEncoder(model) const cache = !!cacheOrName || !!cacheName const cacheStore = getChatCompletionCache( diff --git a/packages/core/src/parsers.ts b/packages/core/src/parsers.ts index 236b0ca630..286ab75004 100644 --- a/packages/core/src/parsers.ts +++ b/packages/core/src/parsers.ts @@ -34,7 +34,7 @@ export async function createParsers(options: { model: string }): Promise { const { trace, model } = options - const encoder = await resolveTokenEncoder(model) + const { encode: encoder } = await resolveTokenEncoder(model) return Object.freeze({ JSON5: (text, options) => JSON5TryParse(filenameOrFileToContent(text), options?.defaultValue), @@ -120,6 +120,6 @@ export async function createParsers(options: { }, diff: (f1, f2) => llmifyDiff(createDiff(f1, f2)), tidyData: (rows, options) => tidyData(rows, options), - hash: async (text, options) => await hash(text, options) + hash: async (text, options) => await hash(text, options), }) } diff --git a/packages/core/src/promptdom.ts b/packages/core/src/promptdom.ts index c1d956a634..83b2f3dda7 100644 --- a/packages/core/src/promptdom.ts +++ b/packages/core/src/promptdom.ts @@ -568,7 +568,7 @@ async function resolvePromptNode( model: string, root: PromptNode ): Promise<{ errors: number }> { - const encoder = await resolveTokenEncoder(model) + const { encode: encoder } = await resolveTokenEncoder(model) let err = 0 const names = new Set() const uniqueName = (n_: string) => { @@ -742,7 +742,7 @@ async function resolveImportPrompty( args: Record, options: ImportTemplateOptions ) { - const { messages } = promptyParse(f.content) + const { messages } = promptyParse(f.filename, f.content) for (const message of messages) { const txt = jinjaRenderChatMessage(message, args) if (message.role === "assistant") @@ -761,7 +761,7 @@ async function truncatePromptNode( options?: TraceOptions ): Promise { const { trace } = options || {} - const encoder = await resolveTokenEncoder(model) + const { encode: encoder } = await resolveTokenEncoder(model) let truncated = false const cap = (n: { @@ -923,7 +923,7 @@ export async function renderPromptNode( ): Promise { const { trace, flexTokens } = options || {} const { model } = parseModelIdentifier(modelId) - const encoder = await resolveTokenEncoder(model) + const { encode: encoder } = await resolveTokenEncoder(model) await resolvePromptNode(model, node) await tracePromptNode(trace, node) diff --git a/packages/core/src/prompty.test.ts b/packages/core/src/prompty.test.ts index e24a41d715..2f5e5b5115 100644 --- a/packages/core/src/prompty.test.ts +++ b/packages/core/src/prompty.test.ts @@ -4,7 +4,7 @@ import assert from "node:assert/strict" describe("promptyParse", () => { test("correctly parses an empty markdown string", () => { - const result = promptyParse("") + const result = promptyParse(undefined, "") assert.deepStrictEqual(result, { meta: {}, frontmatter: {}, @@ -15,7 +15,7 @@ describe("promptyParse", () => { test("correctly parses a markdown string without frontmatter", () => { const content = "This is a sample content without frontmatter." - const result = promptyParse(content) + const result = promptyParse(undefined, "") assert.deepStrictEqual(result, { meta: {}, frontmatter: {}, @@ -40,7 +40,7 @@ sample: --- # Heading Content below heading.` - const result = promptyParse(markdownString) + const result = promptyParse(undefined, markdownString) assert.deepStrictEqual(result.frontmatter, { name: "Test", description: "A test description", @@ -59,7 +59,7 @@ assistant: Assistant's reply user: Another message from the user` - const result = promptyParse(markdownContent) + const result = promptyParse(undefined, markdownContent) assert.deepStrictEqual(result.messages, [ { role: "user", content: "User's message" }, { role: "assistant", content: "Assistant's reply" }, @@ -69,7 +69,7 @@ Another message from the user` test("correctly handles a markdown string with content but without roles", () => { const markdownContent = `Just some content without specifying roles.` - const result = promptyParse(markdownContent) + const result = promptyParse(undefined, markdownContent) assert.deepStrictEqual(result.messages, [ { role: "system", content: markdownContent }, ]) diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts index 6673357d99..81ec68c892 100644 --- a/packages/core/src/types/prompt_template.d.ts +++ b/packages/core/src/types/prompt_template.d.ts @@ -1083,6 +1083,7 @@ interface ParseZipOptions { } type TokenEncoder = (text: string) => number[] +type TokenDecoder = (lines: Iterable) => string interface CSVParseOptions { delimiter?: string