Skip to content

Commit

Permalink
feat: 🚀 add prompt caching support and improve handling
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 9, 2024
1 parent 3c324a5 commit eaf1c54
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 80 deletions.
10 changes: 8 additions & 2 deletions docs/src/content/docs/reference/scripts/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,19 @@ def("FILE", env.files, { sliceSample: 100 })
### Prompt Caching
You can specify `ephemeral: true` to enable prompt caching optimization. In particular, a `def` with `ephemeral` will be rendered at the back of the prompt
to persist the [cache prefix](https://openai.com/index/api-prompt-caching/).
You can use `cacheControl: "ephemeral"` to specify that the prompt can be cached
for a short amount of time, and enable prompt caching optimization, which is supported (differently) by various LLM providers.
```js "ephemeral: true"
$`...`.cacheControl("ephemeral")
```
```js "ephemeral: true"
def("FILE", env.files, { ephemeral: true })
```
Read more about [prompt caching](/genaiscript/reference/scripts/prompt-caching).

Check warning on line 215 in docs/src/content/docs/reference/scripts/context.md

View workflow job for this annotation

GitHub Actions / build

The documentation for prompt caching is duplicated between `context.md` and a new file `prompt-caching.mdx`. This redundancy can be avoided by linking to the existing content or consolidating it in one place.
### Safety: Prompt Injection detection
You can schedule a check for prompt injection/jai break with your configured [content safety](/genaiscript/reference/scripts/content-safety) provider.
Expand Down
37 changes: 37 additions & 0 deletions docs/src/content/docs/reference/scripts/prompt-caching.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
---
title: Prompt Caching
sidebar:
order: 80
---

Prompt caching is a feature that can reduce processing time and costs for repetitive prompts.
It is supported by various LLM providers, but the implementation may vary.

- OpenAI implements an automatic [cache prefix](https://openai.com/index/api-prompt-caching/).
- Anthropic supports settting [cache breakpoints](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching)

## `ephemeral`

You can mark `def` section or `$` function with `cacheControl` set as `"ephemeral"` to enable prompt caching optimization. This essentially means that it
is acceptable for the LLM provider to cache the prompt for a short amount of time.

```js
def("FILE", env.files, { cacheControl: "ephemeral" })
```

```js
$`Some very cool prompt`.cacheControl("ephemeral")
```

## LLM provider supporet

Check warning on line 26 in docs/src/content/docs/reference/scripts/prompt-caching.mdx

View workflow job for this annotation

GitHub Actions / build

The section on `ephemeral` is duplicated between `context.md` and `prompt-caching.mdx`. This redundancy can be avoided by linking to the existing content or consolidating it in one place.

In most cases, the `ephemeral` hint is ignored by LLM providers. However, the following are supported

### OpenAI, Azure OpenAI

[Prompt caching](https://platform.openai.com/docs/guides/prompt-caching) of the prompt prefix
is automatically enabled by OpenAI.

### Anthropic

- Anthropic: it is translated into a `'cache-control': { ... }` field in the message object

Check warning on line 37 in docs/src/content/docs/reference/scripts/prompt-caching.mdx

View workflow job for this annotation

GitHub Actions / build

The frontmatter in `prompt-caching.mdx` is not necessary since it's a standalone document and does not need to be included in the sidebar or have a title. This can be removed to simplify the file.

Check warning on line 37 in docs/src/content/docs/reference/scripts/prompt-caching.mdx

View workflow job for this annotation

GitHub Actions / build

The section on LLM provider support is duplicated between `context.md` and `prompt-caching.mdx`. This redundancy can be avoided by linking to the existing content or consolidating it in one place.
74 changes: 49 additions & 25 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,9 @@ ${fenceMD(content, " ")}
appendUserMessage(
messages,
`- ${call.name}(${JSON.stringify(call.arguments || {})})
\`\`\`\`\`
<tool_result>
${toolResult.join("\n\n")}
\`\`\`\`\`
</tool_result>
`
)
else
Expand Down Expand Up @@ -407,12 +407,12 @@ schema: ${f.args?.schema || ""},
error: ${f.validation.schemaError}`
)
.join("\n\n")
const repairMsg = dedent`DATA_FORMAT_ISSUES:
\`\`\`
const repairMsg =
`<data_format_issues>
${repair}
\`\`\`
</data_format_issues>
Repair the DATA_FORMAT_ISSUES. THIS IS IMPORTANT.`
Repair the <data_format_issues>. THIS IS IMPORTANT.`
trace.fence(repairMsg, "markdown")
messages.push({
role: "user",
Expand Down Expand Up @@ -962,43 +962,68 @@ export function tracePromptResult(trace: MarkdownTrace, resp: RunPromptResult) {

export function appendUserMessage(
messages: ChatCompletionMessageParam[],
content: string
content: string,
options?: { ephemeral?: boolean }
) {
if (!content) return
const last = messages.at(-1) as ChatCompletionUserMessageParam
if (last?.role === "user") last.content += "\n" + content
else
messages.push({
const { ephemeral } = options || {}
let last = messages.at(-1) as ChatCompletionUserMessageParam
if (
last?.role !== "user" ||
!!ephemeral !== (last?.cacheControl === "ephemeral")
) {
last = {
role: "user",
content,
} as ChatCompletionUserMessageParam)
content: "",
} satisfies ChatCompletionUserMessageParam
if (ephemeral) last.cacheControl = "ephemeral"
messages.push(last)
}
if (last.content) last.content += "\n" + content
else last.content = content
}

export function appendAssistantMessage(
messages: ChatCompletionMessageParam[],
content: string
content: string,
options?: { ephemeral?: boolean }
) {
if (!content) return
const last = messages.at(-1) as ChatCompletionAssistantMessageParam
if (last?.role === "assistant") last.content += "\n" + content
else
messages.push({
const { ephemeral } = options || {}
let last = messages.at(-1) as ChatCompletionAssistantMessageParam
if (
last?.role !== "assistant" ||
!!ephemeral !== (last?.cacheControl === "ephemeral")
) {
last = {
role: "assistant",
content,
} satisfies ChatCompletionAssistantMessageParam)
content: "",
} satisfies ChatCompletionAssistantMessageParam
if (ephemeral) last.cacheControl = "ephemeral"
messages.push(last)
}
if (last.content) last.content += "\n" + content
else last.content = content
}

export function appendSystemMessage(
messages: ChatCompletionMessageParam[],
content: string
content: string,
options?: { ephemeral?: boolean }
) {
if (!content) return
const { ephemeral } = options || {}

let last = messages[0] as ChatCompletionSystemMessageParam
if (last?.role !== "system") {
if (
last?.role !== "system" ||
!!ephemeral !== (last?.cacheControl === "ephemeral")
) {
last = {
role: "system",
content: "",
} as ChatCompletionSystemMessageParam
if (ephemeral) last.cacheControl = "ephemeral"
messages.unshift(last)
}
if (last.content) last.content += SYSTEM_FENCE
Expand All @@ -1012,10 +1037,9 @@ export function addToolDefinitionsMessage(
appendSystemMessage(
messages,
`
TOOLS:
\`\`\`yaml
<tools>
${YAMLStringify(tools.map((t) => t.spec))}
\`\`\`
</tools>
`
)
}
27 changes: 21 additions & 6 deletions packages/core/src/chattypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import OpenAI from "openai"
/**
* Interface representing a custom AI Chat Interface request.
*/
export interface AICIRequest {
export interface AICIRequest extends ChatCompletionMessageParamCacheControl {
role: "aici" // The role for this type of request
content?: string // Optional content of the request
error?: unknown // Optional error information
Expand Down Expand Up @@ -44,19 +44,32 @@ export type ChatCompletionTokenLogprob = OpenAI.ChatCompletionTokenLogprob
export type ChatCompletion = OpenAI.Chat.Completions.ChatCompletion
export type ChatCompletionChoice = OpenAI.Chat.Completions.ChatCompletion.Choice

export interface ChatCompletionMessageParamCacheControl {
cacheControl?: PromptCacheControlType
}

// Parameters for a system message in a chat completion
export type ChatCompletionSystemMessageParam =
OpenAI.Chat.Completions.ChatCompletionSystemMessageParam
OpenAI.Chat.Completions.ChatCompletionSystemMessageParam &
ChatCompletionMessageParamCacheControl

// Parameters for a tool message in a chat completion
export type ChatCompletionToolMessageParam =
OpenAI.Chat.Completions.ChatCompletionToolMessageParam
OpenAI.Chat.Completions.ChatCompletionToolMessageParam &
ChatCompletionMessageParamCacheControl
export type ChatCompletionFunctionMessageParam =
OpenAI.Chat.Completions.ChatCompletionFunctionMessageParam &
ChatCompletionMessageParamCacheControl

/**
* Type representing parameters for chat completion messages, including custom AICIRequest.
*/
export type ChatCompletionMessageParam =
| OpenAI.Chat.Completions.ChatCompletionMessageParam
| ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
| AICIRequest

/**
Expand All @@ -75,11 +88,13 @@ export type CreateChatCompletionRequest = Omit<

// Parameters for an assistant message in a chat completion
export type ChatCompletionAssistantMessageParam =
OpenAI.Chat.Completions.ChatCompletionAssistantMessageParam
OpenAI.Chat.Completions.ChatCompletionAssistantMessageParam &
ChatCompletionMessageParamCacheControl

// Parameters for a user message in a chat completion
export type ChatCompletionUserMessageParam =
OpenAI.Chat.Completions.ChatCompletionUserMessageParam
OpenAI.Chat.Completions.ChatCompletionUserMessageParam &
ChatCompletionMessageParamCacheControl

// Image content part of a chat completion
export type ChatCompletionContentPartImage =
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async (

const postReq = structuredClone({
...req,
messages: req.messages.map(({ cacheControl, ...rest }) => ({ ...rest })),
stream: true,
stream_options: { include_usage: true },
model,
Expand Down
61 changes: 20 additions & 41 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ export interface PromptNode extends ContextExpansionOptions {
error?: unknown // Error information if present
tokens?: number // Token count for the node
/**
* This text is likely to change within 5 to 10 minutes.
* Definte a prompt caching breakpoint.
* This prompt prefix (including this text) is cacheable for a short amount of time.
*/
ephemeral?: boolean

Expand Down Expand Up @@ -478,7 +479,7 @@ export function createDefData(
options?: DefDataOptions
) {
if (data === undefined) return undefined
let { format, headers, priority } = options || {}
let { format, headers, priority, ephemeral } = options || {}
if (
!format &&
Array.isArray(data) &&
Expand Down Expand Up @@ -512,7 +513,7 @@ ${trimNewlines(text)}
${trimNewlines(text)}
`
// TODO maxTokens does not work well with data
return createTextNode(value, { priority })
return createTextNode(value, { priority, ephemeral })
}

// Function to append a child node to a parent node.
Expand Down Expand Up @@ -616,28 +617,6 @@ export interface PromptNodeRender {
disposables: AsyncDisposable[] // Disposables
}

/**
* To optimize chat caching with openai, move defs to the back of the prompt
* @see https://platform.openai.com/docs/guides/prompt-caching
* @param mode
* @param root
*/
async function layoutPromptNode(root: PromptNode) {
let changed = false
await visitNode(root, {
node: (n) => {
// sort children
const before = n.children?.map((c) => c.preview)?.join("\n")
n.children?.sort(
(a, b) => (a.ephemeral ? 1 : -1) - (b.ephemeral ? 1 : -1)
)
const after = n.children?.map((c) => c.preview)?.join("\n")
changed = changed || before !== after
},
})
return changed
}

export function resolveFenceFormat(modelid: string): FenceFormat {
return DEFAULT_FENCE_FORMAT
}
Expand Down Expand Up @@ -1105,9 +1084,6 @@ export async function renderPromptNode(
if (await deduplicatePromptNode(trace, node))
await tracePromptNode(trace, node, { label: "deduplicate" })

if (await layoutPromptNode(node))
await tracePromptNode(trace, node, { label: "layout" })

if (flexTokens)
await flexPromptNode(node, {
...options,
Expand All @@ -1121,11 +1097,14 @@ export async function renderPromptNode(
if (safety) await tracePromptNode(trace, node, { label: "safety" })

const messages: ChatCompletionMessageParam[] = []
const appendSystem = (content: string) =>
appendSystemMessage(messages, content)
const appendUser = (content: string) => appendUserMessage(messages, content)
const appendAssistant = (content: string) =>
appendAssistantMessage(messages, content)
const appendSystem = (content: string, options: { ephemeral?: boolean }) =>
appendSystemMessage(messages, content, options)
const appendUser = (content: string, options: { ephemeral?: boolean }) =>
appendUserMessage(messages, content, options)
const appendAssistant = (
content: string,
options: { ephemeral?: boolean }
) => appendAssistantMessage(messages, content, options)

const images: PromptImage[] = []
const errors: unknown[] = []
Expand All @@ -1144,8 +1123,8 @@ export async function renderPromptNode(
errors.push(n.error)
},
text: async (n) => {
if (n.resolved !== undefined) appendUser(n.resolved)
else if (typeof n.value === "string") appendUser(n.value)
if (n.resolved !== undefined) appendUser(n.resolved, n)
else if (typeof n.value === "string") appendUser(n.value, n)
},
def: async (n) => {
const value = n.resolved
Expand All @@ -1162,19 +1141,19 @@ export async function renderPromptNode(
},
assistant: async (n) => {
const value = await n.resolved
if (value != undefined) appendAssistant(value)
if (value != undefined) appendAssistant(value, n)
},
system: async (n) => {
const value = await n.resolved
if (value != undefined) appendSystem(value)
if (value != undefined) appendSystem(value, n)
},
stringTemplate: async (n) => {
const value = n.resolved
const role = n.role || "user"
if (value != undefined) {
if (role === "system") appendSystem(value)
else if (role === "assistant") appendAssistant(value)
else appendUser(value)
if (role === "system") appendSystem(value, n)
else if (role === "assistant") appendAssistant(value, n)
else appendUser(value, n)
}
},
image: async (n) => {
Expand Down Expand Up @@ -1214,7 +1193,7 @@ export async function renderPromptNode(
\`\`\`${format + "-schema"}
${trimNewlines(schemaText)}
\`\`\``
appendUser(text)
appendUser(text, n)
n.tokens = estimateTokens(text, encoder)
if (trace && format !== "json")
trace.detailsFenced(
Expand Down
Loading

0 comments on commit eaf1c54

Please sign in to comment.