From aa65e9ec3175fbd41ef1ed21b6f85fe9a2be8acb Mon Sep 17 00:00:00 2001 From: Peli de Halleux Date: Thu, 24 Oct 2024 03:06:21 +0000 Subject: [PATCH] refcatoring assembly of messages --- packages/core/src/chat.ts | 21 +++--- packages/core/src/expander.ts | 37 ++++------ packages/core/src/promptdom.ts | 74 +++++++++++--------- packages/core/src/runpromptcontext.ts | 9 +-- packages/sample/genaisrc/writetext.genai.mjs | 7 ++ 5 files changed, 77 insertions(+), 71 deletions(-) create mode 100644 packages/sample/genaisrc/writetext.genai.mjs diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index 74625b7b66..c645be5843 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -564,17 +564,22 @@ async function processChatMessage( const node = ctx.node checkCancelled(cancellationToken) // expand template - const { errors, userPrompt } = await renderPromptNode( - options.model, - node, - { + const { errors, messages: participantMessages } = + await renderPromptNode(options.model, node, { flexTokens: options.flexTokens, trace, - } - ) - if (userPrompt?.trim().length) { + }) + if (participantMessages?.length) { + if ( + participantMessages.some( + ({ role }) => role === "system" + ) + ) + throw new Error( + "system messages not supported for chat participants" + ) trace.detailsFenced(`💬 message`, userPrompt, "markdown") - messages.push({ role: "user", content: userPrompt }) + messages.push(...participantMessages) needsNewTurn = true } else trace.item("no message") if (errors?.length) { diff --git a/packages/core/src/expander.ts b/packages/core/src/expander.ts index 5f49724922..e156198deb 100644 --- a/packages/core/src/expander.ts +++ b/packages/core/src/expander.ts @@ -42,9 +42,7 @@ export async function callExpander( let status: GenerationStatus = undefined let statusText: string = undefined let logs = "" - let text = "" - let assistantText = "" - let systemText = "" + let messages: ChatCompletionMessageParam[] = [] let images: PromptImage[] = [] let schemas: Record = {} let functions: ToolCallback[] = [] @@ -74,9 +72,7 @@ export async function callExpander( const node = ctx.node if (provider !== MODEL_PROVIDER_AICI) { const { - userPrompt, - assistantPrompt, - systemPrompt, + messages: msgs, images: imgs, errors, schemas: schs, @@ -89,9 +85,7 @@ export async function callExpander( flexTokens: options.flexTokens, trace, }) - text = userPrompt - assistantText = assistantPrompt - systemText = systemPrompt + messages = msgs images = imgs schemas = schs functions = fns @@ -127,9 +121,7 @@ export async function callExpander( logs, status, statusText, - text, - assistantText, - systemText, + messages, images, schemas, functions: Object.freeze(functions), @@ -230,7 +222,7 @@ export async function expandTemplate( lineNumbers, }) - const { status, statusText, text } = prompt + const { status, statusText } = prompt const images = prompt.images.slice(0) const schemas = structuredClone(prompt.schemas) const functions = prompt.functions.slice(0) @@ -240,22 +232,21 @@ export async function expandTemplate( const fileOutputs = prompt.fileOutputs.slice(0) if (prompt.logs?.length) trace.details("📝 console.log", prompt.logs) - if (text) trace.detailsFenced(`📝 prompt`, text, "markdown") if (prompt.aici) trace.fence(prompt.aici, "yaml") trace.endDetails() - if (status !== "success" || text === "") - // cancelled + if (cancellationToken?.isCancellationRequested || status === "cancelled") return { - status, - statusText, + status: "cancelled", + statusText: "user cancelled", messages, } - if (cancellationToken?.isCancellationRequested) + if (status !== "success") + // cancelled return { - status: "cancelled", - statusText: "user cancelled", + status, + statusText, messages, } @@ -263,8 +254,8 @@ export async function expandTemplate( role: "system", content: "", } - if (prompt.text) - messages.push(toChatCompletionUserMessage(prompt.text, prompt.images)) + if (prompt.images) + messages.push(toChatCompletionUserMessage("", prompt.images)) if (prompt.aici) messages.push(prompt.aici) if (systems.length) diff --git a/packages/core/src/promptdom.ts b/packages/core/src/promptdom.ts index 83b2f3dda7..0f55027196 100644 --- a/packages/core/src/promptdom.ts +++ b/packages/core/src/promptdom.ts @@ -24,6 +24,7 @@ import { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam, ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, } from "./chattypes" import { resolveTokenEncoder } from "./encoders" import { expandFiles } from "./fs" @@ -527,9 +528,6 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) { // Interface for representing a rendered prompt node. export interface PromptNodeRender { - userPrompt: string // User prompt content - assistantPrompt: string // Assistant prompt content - systemPrompt: string // System prompt content images: PromptImage[] // Images included in the prompt errors: unknown[] // Errors encountered during rendering schemas: Record // Schemas included in the prompt @@ -940,9 +938,37 @@ export async function renderPromptNode( const truncated = await truncatePromptNode(model, node, options) if (truncated) await tracePromptNode(trace, node, { label: "truncated" }) - let userPrompt = "" - let assistantPrompt = "" - let systemPrompt = "" + const messages: ChatCompletionMessageParam[] = [] + const appendSystem = (content: string) => { + const last = messages.find( + ({ role }) => role === "system" + ) as ChatCompletionSystemMessageParam + if (last) last.content += content + SYSTEM_FENCE + else + messages.push({ + role: "system", + content, + } as ChatCompletionSystemMessageParam) + } + const appendUser = (content: string) => { + const last = messages.at(-1) as ChatCompletionUserMessageParam + if (last?.role === "user") last.content += content + "\n" + else + messages.push({ + role: "user", + content, + } as ChatCompletionUserMessageParam) + } + const appendAssistant = (content: string) => { + const last = messages.at(-1) as ChatCompletionAssistantMessageParam + if (last?.role === "assistant") last.content += content + else + messages.push({ + role: "assistant", + content, + } satisfies ChatCompletionAssistantMessageParam) + } + const images: PromptImage[] = [] const errors: unknown[] = [] const schemas: Record = {} @@ -956,27 +982,27 @@ export async function renderPromptNode( text: async (n) => { if (n.error) errors.push(n.error) const value = n.resolved - if (value != undefined) userPrompt += value + "\n" + if (value != undefined) appendUser(value) }, def: async (n) => { if (n.error) errors.push(n.error) const value = n.resolved - if (value !== undefined) userPrompt += renderDefNode(n) + "\n" + if (value !== undefined) appendUser(renderDefNode(n)) }, assistant: async (n) => { if (n.error) errors.push(n.error) const value = await n.resolved - if (value != undefined) assistantPrompt += value + "\n" + if (value != undefined) appendAssistant(value) }, system: async (n) => { if (n.error) errors.push(n.error) const value = await n.resolved - if (value != undefined) systemPrompt += value + SYSTEM_FENCE + if (value != undefined) appendSystem(value) }, stringTemplate: async (n) => { if (n.error) errors.push(n.error) const value = n.resolved - if (value != undefined) userPrompt += value + "\n" + if (value != undefined) appendUser(value) }, image: async (n) => { if (n.error) errors.push(n.error) @@ -1015,9 +1041,8 @@ export async function renderPromptNode( const text = `${schemaName}: \`\`\`${format + "-schema"} ${trimNewlines(schemaText)} -\`\`\` -` - userPrompt += text +\`\`\`` + appendUser(text) n.tokens = estimateTokens(text, encoder) if (trace && format !== "json") trace.detailsFenced( @@ -1066,33 +1091,16 @@ ${trimNewlines(schemaText)} const fods = fileOutputs?.filter((f) => !!f.description) if (fods?.length > 0) { - systemPrompt += ` + appendSystem(` ## File generation rules When generating files, use the following rules which are formatted as "file glob: description": ${fods.map((fo) => ` ${fo.pattern}: ${fo.description}`)} - -` +`) } - const messages: ChatCompletionMessageParam[] = [ - toChatCompletionUserMessage(userPrompt, images), - ] - if (assistantPrompt) - messages.push({ - role: "assistant", - content: assistantPrompt, - } as ChatCompletionAssistantMessageParam) - if (systemPrompt) - messages.unshift({ - role: "system", - content: systemPrompt, - } as ChatCompletionSystemMessageParam) const res = Object.freeze({ - userPrompt, - assistantPrompt, - systemPrompt, images, schemas, functions: tools, diff --git a/packages/core/src/runpromptcontext.ts b/packages/core/src/runpromptcontext.ts index 7e88c06e7b..dbb4211f11 100644 --- a/packages/core/src/runpromptcontext.ts +++ b/packages/core/src/runpromptcontext.ts @@ -16,6 +16,7 @@ import { renderPromptNode, createOutputProcessor, createFileMerge, + createSystemNode, } from "./promptdom" import { MarkdownTrace } from "./trace" import { GenerationOptions } from "./generation" @@ -117,7 +118,7 @@ export function createChatTurnGenerationContext( role === "assistant" ? createAssistantNode(body, { priority, maxTokens }) : role === "system" - ? creatSystemNode(body, { priority, maxTokens }) + ? createSystemNode(body, { priority, maxTokens }) : createTextNode(body, { priority, maxTokens }) ) } @@ -773,9 +774,3 @@ export function createChatGenerationContext( return ctx } -function creatSystemNode( - body: Awaitable, - arg1: { priority: number; maxTokens: number } -): PromptNode { - throw new Error("Function not implemented.") -} diff --git a/packages/sample/genaisrc/writetext.genai.mjs b/packages/sample/genaisrc/writetext.genai.mjs new file mode 100644 index 0000000000..1fa2b12ef0 --- /dev/null +++ b/packages/sample/genaisrc/writetext.genai.mjs @@ -0,0 +1,7 @@ +script({ model: "small", tests: {} }) +writeText("This is a system prompt.", { role: "system" }) +writeText("This is another system prompt.", { role: "system" }) +writeText("This is a user prompt.", { role: "user" }) +writeText("This is a asssitant prompt.", { role: "assistant" }) +writeText("This is a another user prompt.", { role: "user" }) +writeText("This is a another assitant prompt.", { role: "assistant" })