From eaf1c541afe10cf2746bd8812f6c97c6d9b8be09 Mon Sep 17 00:00:00 2001 From: Peli de Halleux Date: Mon, 9 Dec 2024 16:32:44 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=9A=80=20add=20prompt=20caching?= =?UTF-8?q?=20support=20and=20improve=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../content/docs/reference/scripts/context.md | 10 ++- .../docs/reference/scripts/prompt-caching.mdx | 37 ++++++++++ packages/core/src/chat.ts | 74 ++++++++++++------- packages/core/src/chattypes.ts | 27 +++++-- packages/core/src/openai.ts | 1 + packages/core/src/promptdom.ts | 61 +++++---------- packages/core/src/runpromptcontext.ts | 10 ++- packages/core/src/transformers.ts | 4 +- packages/core/src/types/prompt_template.d.ts | 11 +++ 9 files changed, 155 insertions(+), 80 deletions(-) create mode 100644 docs/src/content/docs/reference/scripts/prompt-caching.mdx diff --git a/docs/src/content/docs/reference/scripts/context.md b/docs/src/content/docs/reference/scripts/context.md index 725f9c6f7f..dd6241cd4e 100644 --- a/docs/src/content/docs/reference/scripts/context.md +++ b/docs/src/content/docs/reference/scripts/context.md @@ -200,13 +200,19 @@ def("FILE", env.files, { sliceSample: 100 }) ### Prompt Caching -You can specify `ephemeral: true` to enable prompt caching optimization. In particular, a `def` with `ephemeral` will be rendered at the back of the prompt -to persist the [cache prefix](https://openai.com/index/api-prompt-caching/). +You can use `cacheControl: "ephemeral"` to specify that the prompt can be cached +for a short amount of time, and enable prompt caching optimization, which is supported (differently) by various LLM providers. + +```js "ephemeral: true" +$`...`.cacheControl("ephemeral") +``` ```js "ephemeral: true" def("FILE", env.files, { ephemeral: true }) ``` +Read more about [prompt caching](/genaiscript/reference/scripts/prompt-caching). + ### Safety: Prompt Injection detection You can schedule a check for prompt injection/jai break with your configured [content safety](/genaiscript/reference/scripts/content-safety) provider. diff --git a/docs/src/content/docs/reference/scripts/prompt-caching.mdx b/docs/src/content/docs/reference/scripts/prompt-caching.mdx new file mode 100644 index 0000000000..13e6914692 --- /dev/null +++ b/docs/src/content/docs/reference/scripts/prompt-caching.mdx @@ -0,0 +1,37 @@ +--- +title: Prompt Caching +sidebar: + order: 80 +--- + +Prompt caching is a feature that can reduce processing time and costs for repetitive prompts. +It is supported by various LLM providers, but the implementation may vary. + +- OpenAI implements an automatic [cache prefix](https://openai.com/index/api-prompt-caching/). +- Anthropic supports settting [cache breakpoints](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching) + +## `ephemeral` + +You can mark `def` section or `$` function with `cacheControl` set as `"ephemeral"` to enable prompt caching optimization. This essentially means that it +is acceptable for the LLM provider to cache the prompt for a short amount of time. + +```js +def("FILE", env.files, { cacheControl: "ephemeral" }) +``` + +```js +$`Some very cool prompt`.cacheControl("ephemeral") +``` + +## LLM provider supporet + +In most cases, the `ephemeral` hint is ignored by LLM providers. However, the following are supported + +### OpenAI, Azure OpenAI + +[Prompt caching](https://platform.openai.com/docs/guides/prompt-caching) of the prompt prefix +is automatically enabled by OpenAI. + +### Anthropic + +- Anthropic: it is translated into a `'cache-control': { ... }` field in the message object diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index fcec008f57..1ff738bada 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -343,9 +343,9 @@ ${fenceMD(content, " ")} appendUserMessage( messages, `- ${call.name}(${JSON.stringify(call.arguments || {})}) -\`\`\`\`\` + ${toolResult.join("\n\n")} -\`\`\`\`\` + ` ) else @@ -407,12 +407,12 @@ schema: ${f.args?.schema || ""}, error: ${f.validation.schemaError}` ) .join("\n\n") - const repairMsg = dedent`DATA_FORMAT_ISSUES: -\`\`\` + const repairMsg = +` ${repair} -\`\`\` + -Repair the DATA_FORMAT_ISSUES. THIS IS IMPORTANT.` +Repair the . THIS IS IMPORTANT.` trace.fence(repairMsg, "markdown") messages.push({ role: "user", @@ -962,43 +962,68 @@ export function tracePromptResult(trace: MarkdownTrace, resp: RunPromptResult) { export function appendUserMessage( messages: ChatCompletionMessageParam[], - content: string + content: string, + options?: { ephemeral?: boolean } ) { if (!content) return - const last = messages.at(-1) as ChatCompletionUserMessageParam - if (last?.role === "user") last.content += "\n" + content - else - messages.push({ + const { ephemeral } = options || {} + let last = messages.at(-1) as ChatCompletionUserMessageParam + if ( + last?.role !== "user" || + !!ephemeral !== (last?.cacheControl === "ephemeral") + ) { + last = { role: "user", - content, - } as ChatCompletionUserMessageParam) + content: "", + } satisfies ChatCompletionUserMessageParam + if (ephemeral) last.cacheControl = "ephemeral" + messages.push(last) + } + if (last.content) last.content += "\n" + content + else last.content = content } export function appendAssistantMessage( messages: ChatCompletionMessageParam[], - content: string + content: string, + options?: { ephemeral?: boolean } ) { if (!content) return - const last = messages.at(-1) as ChatCompletionAssistantMessageParam - if (last?.role === "assistant") last.content += "\n" + content - else - messages.push({ + const { ephemeral } = options || {} + let last = messages.at(-1) as ChatCompletionAssistantMessageParam + if ( + last?.role !== "assistant" || + !!ephemeral !== (last?.cacheControl === "ephemeral") + ) { + last = { role: "assistant", - content, - } satisfies ChatCompletionAssistantMessageParam) + content: "", + } satisfies ChatCompletionAssistantMessageParam + if (ephemeral) last.cacheControl = "ephemeral" + messages.push(last) + } + if (last.content) last.content += "\n" + content + else last.content = content } export function appendSystemMessage( messages: ChatCompletionMessageParam[], - content: string + content: string, + options?: { ephemeral?: boolean } ) { if (!content) return + const { ephemeral } = options || {} + let last = messages[0] as ChatCompletionSystemMessageParam - if (last?.role !== "system") { + if ( + last?.role !== "system" || + !!ephemeral !== (last?.cacheControl === "ephemeral") + ) { last = { role: "system", content: "", } as ChatCompletionSystemMessageParam + if (ephemeral) last.cacheControl = "ephemeral" messages.unshift(last) } if (last.content) last.content += SYSTEM_FENCE @@ -1012,10 +1037,9 @@ export function addToolDefinitionsMessage( appendSystemMessage( messages, ` -TOOLS: -\`\`\`yaml + ${YAMLStringify(tools.map((t) => t.spec))} -\`\`\` + ` ) } diff --git a/packages/core/src/chattypes.ts b/packages/core/src/chattypes.ts index f96cca3094..4c463180b6 100644 --- a/packages/core/src/chattypes.ts +++ b/packages/core/src/chattypes.ts @@ -10,7 +10,7 @@ import OpenAI from "openai" /** * Interface representing a custom AI Chat Interface request. */ -export interface AICIRequest { +export interface AICIRequest extends ChatCompletionMessageParamCacheControl { role: "aici" // The role for this type of request content?: string // Optional content of the request error?: unknown // Optional error information @@ -44,19 +44,32 @@ export type ChatCompletionTokenLogprob = OpenAI.ChatCompletionTokenLogprob export type ChatCompletion = OpenAI.Chat.Completions.ChatCompletion export type ChatCompletionChoice = OpenAI.Chat.Completions.ChatCompletion.Choice +export interface ChatCompletionMessageParamCacheControl { + cacheControl?: PromptCacheControlType +} + // Parameters for a system message in a chat completion export type ChatCompletionSystemMessageParam = - OpenAI.Chat.Completions.ChatCompletionSystemMessageParam + OpenAI.Chat.Completions.ChatCompletionSystemMessageParam & + ChatCompletionMessageParamCacheControl // Parameters for a tool message in a chat completion export type ChatCompletionToolMessageParam = - OpenAI.Chat.Completions.ChatCompletionToolMessageParam + OpenAI.Chat.Completions.ChatCompletionToolMessageParam & + ChatCompletionMessageParamCacheControl +export type ChatCompletionFunctionMessageParam = + OpenAI.Chat.Completions.ChatCompletionFunctionMessageParam & + ChatCompletionMessageParamCacheControl /** * Type representing parameters for chat completion messages, including custom AICIRequest. */ export type ChatCompletionMessageParam = - | OpenAI.Chat.Completions.ChatCompletionMessageParam + | ChatCompletionSystemMessageParam + | ChatCompletionUserMessageParam + | ChatCompletionAssistantMessageParam + | ChatCompletionToolMessageParam + | ChatCompletionFunctionMessageParam | AICIRequest /** @@ -75,11 +88,13 @@ export type CreateChatCompletionRequest = Omit< // Parameters for an assistant message in a chat completion export type ChatCompletionAssistantMessageParam = - OpenAI.Chat.Completions.ChatCompletionAssistantMessageParam + OpenAI.Chat.Completions.ChatCompletionAssistantMessageParam & + ChatCompletionMessageParamCacheControl // Parameters for a user message in a chat completion export type ChatCompletionUserMessageParam = - OpenAI.Chat.Completions.ChatCompletionUserMessageParam + OpenAI.Chat.Completions.ChatCompletionUserMessageParam & + ChatCompletionMessageParamCacheControl // Image content part of a chat completion export type ChatCompletionContentPartImage = diff --git a/packages/core/src/openai.ts b/packages/core/src/openai.ts index 63f5ff4d37..9084619ae2 100644 --- a/packages/core/src/openai.ts +++ b/packages/core/src/openai.ts @@ -124,6 +124,7 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async ( const postReq = structuredClone({ ...req, + messages: req.messages.map(({ cacheControl, ...rest }) => ({ ...rest })), stream: true, stream_options: { include_usage: true }, model, diff --git a/packages/core/src/promptdom.ts b/packages/core/src/promptdom.ts index e1f06464f5..fe612d6396 100644 --- a/packages/core/src/promptdom.ts +++ b/packages/core/src/promptdom.ts @@ -75,7 +75,8 @@ export interface PromptNode extends ContextExpansionOptions { error?: unknown // Error information if present tokens?: number // Token count for the node /** - * This text is likely to change within 5 to 10 minutes. + * Definte a prompt caching breakpoint. + * This prompt prefix (including this text) is cacheable for a short amount of time. */ ephemeral?: boolean @@ -478,7 +479,7 @@ export function createDefData( options?: DefDataOptions ) { if (data === undefined) return undefined - let { format, headers, priority } = options || {} + let { format, headers, priority, ephemeral } = options || {} if ( !format && Array.isArray(data) && @@ -512,7 +513,7 @@ ${trimNewlines(text)} ${trimNewlines(text)} ` // TODO maxTokens does not work well with data - return createTextNode(value, { priority }) + return createTextNode(value, { priority, ephemeral }) } // Function to append a child node to a parent node. @@ -616,28 +617,6 @@ export interface PromptNodeRender { disposables: AsyncDisposable[] // Disposables } -/** - * To optimize chat caching with openai, move defs to the back of the prompt - * @see https://platform.openai.com/docs/guides/prompt-caching - * @param mode - * @param root - */ -async function layoutPromptNode(root: PromptNode) { - let changed = false - await visitNode(root, { - node: (n) => { - // sort children - const before = n.children?.map((c) => c.preview)?.join("\n") - n.children?.sort( - (a, b) => (a.ephemeral ? 1 : -1) - (b.ephemeral ? 1 : -1) - ) - const after = n.children?.map((c) => c.preview)?.join("\n") - changed = changed || before !== after - }, - }) - return changed -} - export function resolveFenceFormat(modelid: string): FenceFormat { return DEFAULT_FENCE_FORMAT } @@ -1105,9 +1084,6 @@ export async function renderPromptNode( if (await deduplicatePromptNode(trace, node)) await tracePromptNode(trace, node, { label: "deduplicate" }) - if (await layoutPromptNode(node)) - await tracePromptNode(trace, node, { label: "layout" }) - if (flexTokens) await flexPromptNode(node, { ...options, @@ -1121,11 +1097,14 @@ export async function renderPromptNode( if (safety) await tracePromptNode(trace, node, { label: "safety" }) const messages: ChatCompletionMessageParam[] = [] - const appendSystem = (content: string) => - appendSystemMessage(messages, content) - const appendUser = (content: string) => appendUserMessage(messages, content) - const appendAssistant = (content: string) => - appendAssistantMessage(messages, content) + const appendSystem = (content: string, options: { ephemeral?: boolean }) => + appendSystemMessage(messages, content, options) + const appendUser = (content: string, options: { ephemeral?: boolean }) => + appendUserMessage(messages, content, options) + const appendAssistant = ( + content: string, + options: { ephemeral?: boolean } + ) => appendAssistantMessage(messages, content, options) const images: PromptImage[] = [] const errors: unknown[] = [] @@ -1144,8 +1123,8 @@ export async function renderPromptNode( errors.push(n.error) }, text: async (n) => { - if (n.resolved !== undefined) appendUser(n.resolved) - else if (typeof n.value === "string") appendUser(n.value) + if (n.resolved !== undefined) appendUser(n.resolved, n) + else if (typeof n.value === "string") appendUser(n.value, n) }, def: async (n) => { const value = n.resolved @@ -1162,19 +1141,19 @@ export async function renderPromptNode( }, assistant: async (n) => { const value = await n.resolved - if (value != undefined) appendAssistant(value) + if (value != undefined) appendAssistant(value, n) }, system: async (n) => { const value = await n.resolved - if (value != undefined) appendSystem(value) + if (value != undefined) appendSystem(value, n) }, stringTemplate: async (n) => { const value = n.resolved const role = n.role || "user" if (value != undefined) { - if (role === "system") appendSystem(value) - else if (role === "assistant") appendAssistant(value) - else appendUser(value) + if (role === "system") appendSystem(value, n) + else if (role === "assistant") appendAssistant(value, n) + else appendUser(value, n) } }, image: async (n) => { @@ -1214,7 +1193,7 @@ export async function renderPromptNode( \`\`\`${format + "-schema"} ${trimNewlines(schemaText)} \`\`\`` - appendUser(text) + appendUser(text, n) n.tokens = estimateTokens(text, encoder) if (trace && format !== "json") trace.detailsFenced( diff --git a/packages/core/src/runpromptcontext.ts b/packages/core/src/runpromptcontext.ts index 9cc558bd44..f0a76f220c 100644 --- a/packages/core/src/runpromptcontext.ts +++ b/packages/core/src/runpromptcontext.ts @@ -141,9 +141,7 @@ export function createChatTurnGenerationContext( $: (strings, ...args) => { const current = createStringTemplateNode(strings, args) appendChild(node, current) - const res: PromptTemplateString = Object.freeze(< - PromptTemplateString - >{ + const res: PromptTemplateString = Object.freeze({ priority: (priority) => { current.priority = priority return res @@ -168,7 +166,11 @@ export function createChatTurnGenerationContext( current.role = r return res }, - }) + cacheControl: (cc) => { + current.ephemeral = cc === "ephemeral" + return res + } + } satisfies PromptTemplateString) return res }, def: (name, body, defOptions) => { diff --git a/packages/core/src/transformers.ts b/packages/core/src/transformers.ts index d151c49365..f9b70845c9 100644 --- a/packages/core/src/transformers.ts +++ b/packages/core/src/transformers.ts @@ -24,7 +24,7 @@ import { PLimitPromiseQueue } from "./concurrency" function progressBar(): ProgressCallback { const progress: Record = {} return (cb: ProgressInfo) => { - switch (cb.status as string) { + switch (cb.status) { case "progress": const p = progress[cb.file] || 0 const cp = Math.floor(cb.progress) @@ -34,7 +34,7 @@ function progressBar(): ProgressCallback { } break case "ready": { - logVerbose(`model ${(cb as any).model} ready`) + logVerbose(`model ${cb.model} ready`) logVerbose(``) break } diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts index b4ddd9220d..9b73b2eb59 100644 --- a/packages/core/src/types/prompt_template.d.ts +++ b/packages/core/src/types/prompt_template.d.ts @@ -903,16 +903,19 @@ interface ContextExpansionOptions { * Specifies an maximum of estimated tokens for this entry; after which it will be truncated. */ maxTokens?: number + /* * Value that is conceptually similar to a zIndex (higher number == higher priority). * If a rendered prompt has more message tokens than can fit into the available context window, the prompt renderer prunes messages with the lowest priority from the ChatMessages result, preserving the order in which they were declared. This means your extension code can safely declare TSX components for potentially large pieces of context like conversation history and codebase context. */ priority?: number + /** * Controls the proportion of tokens allocated from the container's budget to this element. * It defaults to 1 on all elements. */ flex?: number + /** * This text is likely to change and will probably break the prefix cache. */ @@ -2431,6 +2434,8 @@ interface ImportTemplateOptions { allowExtraArguments?: boolean } +type PromptCacheControlType = "ephemeral" + interface PromptTemplateString { /** * Set a priority similar to CSS z-index @@ -2462,6 +2467,12 @@ interface PromptTemplateString { * Updates the role of the message */ role(role: ChatMessageRole): PromptTemplateString + + /** + * Configure the cacheability of the prompt. + * @param value cache control type + */ + cacheControl(value: PromptCacheControlType): PromptTemplateString } type ImportTemplateArgumentType =