Skip to content

Commit

Permalink
refcatoring assembly of messages
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Oct 24, 2024
1 parent 91be4da commit aa65e9e
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 71 deletions.
21 changes: 13 additions & 8 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -564,17 +564,22 @@ async function processChatMessage(
const node = ctx.node
checkCancelled(cancellationToken)
// expand template
const { errors, userPrompt } = await renderPromptNode(
options.model,
node,
{
const { errors, messages: participantMessages } =
await renderPromptNode(options.model, node, {
flexTokens: options.flexTokens,
trace,
}
)
if (userPrompt?.trim().length) {
})
if (participantMessages?.length) {
if (
participantMessages.some(
({ role }) => role === "system"
)
)
throw new Error(
"system messages not supported for chat participants"
)
trace.detailsFenced(`💬 message`, userPrompt, "markdown")
messages.push({ role: "user", content: userPrompt })
messages.push(...participantMessages)
needsNewTurn = true
} else trace.item("no message")
if (errors?.length) {
Expand Down
37 changes: 14 additions & 23 deletions packages/core/src/expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ export async function callExpander(
let status: GenerationStatus = undefined
let statusText: string = undefined
let logs = ""
let text = ""
let assistantText = ""
let systemText = ""
let messages: ChatCompletionMessageParam[] = []
let images: PromptImage[] = []
let schemas: Record<string, JSONSchema> = {}
let functions: ToolCallback[] = []
Expand Down Expand Up @@ -74,9 +72,7 @@ export async function callExpander(
const node = ctx.node
if (provider !== MODEL_PROVIDER_AICI) {
const {
userPrompt,
assistantPrompt,
systemPrompt,
messages: msgs,
images: imgs,
errors,
schemas: schs,
Expand All @@ -89,9 +85,7 @@ export async function callExpander(
flexTokens: options.flexTokens,
trace,
})
text = userPrompt
assistantText = assistantPrompt
systemText = systemPrompt
messages = msgs
images = imgs
schemas = schs
functions = fns
Expand Down Expand Up @@ -127,9 +121,7 @@ export async function callExpander(
logs,
status,
statusText,
text,
assistantText,
systemText,
messages,
images,
schemas,
functions: Object.freeze(functions),
Expand Down Expand Up @@ -230,7 +222,7 @@ export async function expandTemplate(
lineNumbers,
})

const { status, statusText, text } = prompt
const { status, statusText } = prompt
const images = prompt.images.slice(0)
const schemas = structuredClone(prompt.schemas)
const functions = prompt.functions.slice(0)
Expand All @@ -240,31 +232,30 @@ export async function expandTemplate(
const fileOutputs = prompt.fileOutputs.slice(0)

if (prompt.logs?.length) trace.details("📝 console.log", prompt.logs)
if (text) trace.detailsFenced(`📝 prompt`, text, "markdown")
if (prompt.aici) trace.fence(prompt.aici, "yaml")
trace.endDetails()

if (status !== "success" || text === "")
// cancelled
if (cancellationToken?.isCancellationRequested || status === "cancelled")
return {
status,
statusText,
status: "cancelled",
statusText: "user cancelled",
messages,
}

if (cancellationToken?.isCancellationRequested)
if (status !== "success")
// cancelled
return {
status: "cancelled",
statusText: "user cancelled",
status,
statusText,
messages,
}

const systemMessage: ChatCompletionSystemMessageParam = {
role: "system",
content: "",
}
if (prompt.text)
messages.push(toChatCompletionUserMessage(prompt.text, prompt.images))
if (prompt.images)
messages.push(toChatCompletionUserMessage("", prompt.images))
if (prompt.aici) messages.push(prompt.aici)

if (systems.length)
Expand Down
74 changes: 41 additions & 33 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
} from "./chattypes"
import { resolveTokenEncoder } from "./encoders"
import { expandFiles } from "./fs"
Expand Down Expand Up @@ -527,9 +528,6 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) {

// Interface for representing a rendered prompt node.
export interface PromptNodeRender {
userPrompt: string // User prompt content
assistantPrompt: string // Assistant prompt content
systemPrompt: string // System prompt content
images: PromptImage[] // Images included in the prompt
errors: unknown[] // Errors encountered during rendering
schemas: Record<string, JSONSchema> // Schemas included in the prompt
Expand Down Expand Up @@ -940,9 +938,37 @@ export async function renderPromptNode(
const truncated = await truncatePromptNode(model, node, options)
if (truncated) await tracePromptNode(trace, node, { label: "truncated" })

let userPrompt = ""
let assistantPrompt = ""
let systemPrompt = ""
const messages: ChatCompletionMessageParam[] = []
const appendSystem = (content: string) => {
const last = messages.find(
({ role }) => role === "system"
) as ChatCompletionSystemMessageParam
if (last) last.content += content + SYSTEM_FENCE
else
messages.push({
role: "system",
content,
} as ChatCompletionSystemMessageParam)
}
const appendUser = (content: string) => {
const last = messages.at(-1) as ChatCompletionUserMessageParam
if (last?.role === "user") last.content += content + "\n"
else
messages.push({
role: "user",
content,
} as ChatCompletionUserMessageParam)
}
const appendAssistant = (content: string) => {
const last = messages.at(-1) as ChatCompletionAssistantMessageParam
if (last?.role === "assistant") last.content += content
else
messages.push({
role: "assistant",
content,
} satisfies ChatCompletionAssistantMessageParam)
}

const images: PromptImage[] = []
const errors: unknown[] = []
const schemas: Record<string, JSONSchema> = {}
Expand All @@ -956,27 +982,27 @@ export async function renderPromptNode(
text: async (n) => {
if (n.error) errors.push(n.error)
const value = n.resolved
if (value != undefined) userPrompt += value + "\n"
if (value != undefined) appendUser(value)
},
def: async (n) => {
if (n.error) errors.push(n.error)
const value = n.resolved
if (value !== undefined) userPrompt += renderDefNode(n) + "\n"
if (value !== undefined) appendUser(renderDefNode(n))
},
assistant: async (n) => {
if (n.error) errors.push(n.error)
const value = await n.resolved
if (value != undefined) assistantPrompt += value + "\n"
if (value != undefined) appendAssistant(value)
},
system: async (n) => {
if (n.error) errors.push(n.error)
const value = await n.resolved
if (value != undefined) systemPrompt += value + SYSTEM_FENCE
if (value != undefined) appendSystem(value)
},
stringTemplate: async (n) => {
if (n.error) errors.push(n.error)
const value = n.resolved
if (value != undefined) userPrompt += value + "\n"
if (value != undefined) appendUser(value)
},
image: async (n) => {
if (n.error) errors.push(n.error)
Expand Down Expand Up @@ -1015,9 +1041,8 @@ export async function renderPromptNode(
const text = `${schemaName}:
\`\`\`${format + "-schema"}
${trimNewlines(schemaText)}
\`\`\`
`
userPrompt += text
\`\`\``
appendUser(text)
n.tokens = estimateTokens(text, encoder)
if (trace && format !== "json")
trace.detailsFenced(
Expand Down Expand Up @@ -1066,33 +1091,16 @@ ${trimNewlines(schemaText)}

const fods = fileOutputs?.filter((f) => !!f.description)
if (fods?.length > 0) {
systemPrompt += `
appendSystem(`
## File generation rules
When generating files, use the following rules which are formatted as "file glob: description":
${fods.map((fo) => ` ${fo.pattern}: ${fo.description}`)}
`
`)
}

const messages: ChatCompletionMessageParam[] = [
toChatCompletionUserMessage(userPrompt, images),
]
if (assistantPrompt)
messages.push({
role: "assistant",
content: assistantPrompt,
} as ChatCompletionAssistantMessageParam)
if (systemPrompt)
messages.unshift({
role: "system",
content: systemPrompt,
} as ChatCompletionSystemMessageParam)
const res = Object.freeze<PromptNodeRender>({
userPrompt,
assistantPrompt,
systemPrompt,
images,
schemas,
functions: tools,
Expand Down
9 changes: 2 additions & 7 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
renderPromptNode,
createOutputProcessor,
createFileMerge,
createSystemNode,
} from "./promptdom"
import { MarkdownTrace } from "./trace"
import { GenerationOptions } from "./generation"
Expand Down Expand Up @@ -117,7 +118,7 @@ export function createChatTurnGenerationContext(
role === "assistant"
? createAssistantNode(body, { priority, maxTokens })
: role === "system"
? creatSystemNode(body, { priority, maxTokens })
? createSystemNode(body, { priority, maxTokens })
: createTextNode(body, { priority, maxTokens })
)
}
Expand Down Expand Up @@ -773,9 +774,3 @@ export function createChatGenerationContext(

return ctx
}
function creatSystemNode(
body: Awaitable<string>,
arg1: { priority: number; maxTokens: number }
): PromptNode {
throw new Error("Function not implemented.")
}
7 changes: 7 additions & 0 deletions packages/sample/genaisrc/writetext.genai.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
script({ model: "small", tests: {} })
writeText("This is a system prompt.", { role: "system" })
writeText("This is another system prompt.", { role: "system" })
writeText("This is a user prompt.", { role: "user" })
writeText("This is a asssitant prompt.", { role: "assistant" })
writeText("This is a another user prompt.", { role: "user" })
writeText("This is a another assitant prompt.", { role: "assistant" })

0 comments on commit aa65e9e

Please sign in to comment.