Skip to content

Commit

Permalink
support rendering prompty in importTemplate (#786)
Browse files Browse the repository at this point in the history
* feat: ✨ add support for system prompt messages

* feat: ✨ add system message handling to promptdom
  • Loading branch information
pelikhan authored Oct 18, 2024
1 parent aaef184 commit 5d645b3
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 22 deletions.
17 changes: 17 additions & 0 deletions packages/core/src/expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export async function callExpander(
let logs = ""
let text = ""
let assistantText = ""
let systemText = ""
let images: PromptImage[] = []
let schemas: Record<string, JSONSchema> = {}
let functions: ToolCallback[] = []
Expand Down Expand Up @@ -75,6 +76,7 @@ export async function callExpander(
const {
userPrompt,
assistantPrompt,
systemPrompt,
images: imgs,
errors,
schemas: schs,
Expand All @@ -89,6 +91,7 @@ export async function callExpander(
})
text = userPrompt
assistantText = assistantPrompt
systemText = systemPrompt
images = imgs
schemas = schs
functions = fns
Expand Down Expand Up @@ -126,6 +129,7 @@ export async function callExpander(
statusText,
text,
assistantText,
systemText,
images,
schemas,
functions: Object.freeze(functions),
Expand Down Expand Up @@ -363,6 +367,19 @@ ${schemaTs}
}
messages.push(assistantMessage)
}
if (prompt.systemText) {
trace.detailsFenced("👾 system", prompt.systemText, "markdown")
const systemMessage: ChatCompletionSystemMessageParam = {
role: "system",
content: prompt.systemText,
}
// insert system messages after the last system role message in messages
// assume system messages are at the start
let li = -1
for (let li = 0; li < messages.length; li++)
if (messages[li].role === "system") break
messages.splice(li, 0, systemMessage)
}

trace.endDetails()

Expand Down
66 changes: 47 additions & 19 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ import {
MARKDOWN_PROMPT_FENCE,
PROMPT_FENCE,
PROMPTY_REGEX,
SYSTEM_FENCE,
TEMPLATE_ARG_DATA_SLICE_SAMPLE,
TEMPLATE_ARG_FILE_MAX_TOKENS,
} from "./constants"
import { parseModelIdentifier } from "./models"
import { toChatCompletionUserMessage } from "./chat"
import { errorMessage } from "./error"
import { tidyData } from "./tidy"
import { inspect } from "./logging"
import { dedent } from "./indent"
import {
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
} from "./chattypes"
import { resolveTokenEncoder } from "./encoders"
import { expandFiles } from "./fs"
Expand All @@ -43,6 +44,7 @@ export interface PromptNode extends ContextExpansionOptions {
| "outputProcessor"
| "stringTemplate"
| "assistant"
| "system"
| "def"
| "chatParticipant"
| "fileOutput"
Expand Down Expand Up @@ -85,6 +87,12 @@ export interface PromptAssistantNode extends PromptNode {
resolved?: string // Resolved assistant content
}

export interface PromptSystemNode extends PromptNode {
type: "system"
value: Awaitable<string> // Assistant-related content
resolved?: string // Resolved assistant content
}

// Interface for a string template node.
export interface PromptStringTemplateNode extends PromptNode {
type: "stringTemplate"
Expand Down Expand Up @@ -271,6 +279,14 @@ export function createAssistantNode(
return { type: "assistant", value, ...(options || {}) }
}

export function createSystemNode(
value: Awaitable<string>,
options?: ContextExpansionOptions
): PromptSystemNode {
assert(value !== undefined)
return { type: "system", value, ...(options || {}) }
}

// Function to create a string template node.
export function createStringTemplateNode(
strings: TemplateStringsArray,
Expand Down Expand Up @@ -450,6 +466,7 @@ export interface PromptNodeVisitor {
stringTemplate?: (node: PromptStringTemplateNode) => Awaitable<void> // String template node visitor
outputProcessor?: (node: PromptOutputProcessorNode) => Awaitable<void> // Output processor node visitor
assistant?: (node: PromptAssistantNode) => Awaitable<void> // Assistant node visitor
system?: (node: PromptSystemNode) => Awaitable<void> // System node visitor
chatParticipant?: (node: PromptChatParticipantNode) => Awaitable<void> // Chat participant node visitor
fileOutput?: (node: FileOutputNode) => Awaitable<void> // File output node visitor
importTemplate?: (node: PromptImportTemplate) => Awaitable<void> // Import template node visitor
Expand Down Expand Up @@ -486,6 +503,9 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) {
case "assistant":
await visitor.assistant?.(node as PromptAssistantNode)
break
case "system":
await visitor.system?.(node as PromptSystemNode)
break
case "chatParticipant":
await visitor.chatParticipant?.(node as PromptChatParticipantNode)
break
Expand All @@ -509,6 +529,7 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) {
export interface PromptNodeRender {
userPrompt: string // User prompt content
assistantPrompt: string // Assistant prompt content
systemPrompt: string // System prompt content
images: PromptImage[] // Images included in the prompt
errors: unknown[] // Errors encountered during rendering
schemas: Record<string, JSONSchema> // Schemas included in the prompt
Expand Down Expand Up @@ -585,6 +606,15 @@ async function resolvePromptNode(
n.error = e
}
},
system: async (n) => {
try {
const value = await n.value
n.resolved = n.preview = value
n.tokens = estimateTokens(value, encoder)
} catch (e) {
n.error = e
}
},
assistant: async (n) => {
try {
const value = await n.value
Expand Down Expand Up @@ -717,6 +747,8 @@ async function resolveImportPrompty(
const txt = jinjaRenderChatMessage(message, args)
if (message.role === "assistant")
n.children.push(createAssistantNode(txt))
else if (message.role === "system")
n.children.push(createSystemNode(txt))
else n.children.push(createTextNode(txt))
n.preview += txt + "\n"
}
Expand Down Expand Up @@ -909,6 +941,7 @@ export async function renderPromptNode(

let userPrompt = ""
let assistantPrompt = ""
let systemPrompt = ""
const images: PromptImage[] = []
const errors: unknown[] = []
const schemas: Record<string, JSONSchema> = {}
Expand All @@ -934,6 +967,11 @@ export async function renderPromptNode(
const value = await n.resolved
if (value != undefined) assistantPrompt += value + "\n"
},
system: async (n) => {
if (n.error) errors.push(n.error)
const value = await n.resolved
if (value != undefined) systemPrompt += value + SYSTEM_FENCE
},
stringTemplate: async (n) => {
if (n.error) errors.push(n.error)
const value = n.resolved
Expand All @@ -953,22 +991,6 @@ export async function renderPromptNode(
}
}
},
importTemplate: async (n) => {
if (n.error) errors.push(n.error)
const value = n.resolved
if (value) {
for (const [filename, content] of Object.entries(value)) {
userPrompt += content
userPrompt += "\n"
if (trace)
trace.detailsFenced(
`📦 import template ${filename}`,
content,
"markdown"
)
}
}
},
schema: (n) => {
const { name: schemaName, value: schema, options } = n
if (schemas[schemaName])
Expand Down Expand Up @@ -1057,13 +1079,19 @@ ${fods.map((fo) => ` ${fo.pattern}: ${fo.description}`)}
toChatCompletionUserMessage(userPrompt, images),
]
if (assistantPrompt)
messages.push(<ChatCompletionAssistantMessageParam>{
messages.push({
role: "assistant",
content: assistantPrompt,
})
} as ChatCompletionAssistantMessageParam)
if (systemPrompt)
messages.unshift({
role: "system",
content: systemPrompt,
} as ChatCompletionSystemMessageParam)
const res = Object.freeze<PromptNodeRender>({
userPrompt,
assistantPrompt,
systemPrompt,
images,
schemas,
functions: tools,
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/prompty.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ export function promptyToGenAIScript(doc: PromptyDocument) {
const { role, content } = msg
if (role === "assistant") {
return `assistant(parsers.jinja(${JSON.stringify(content as string)}, env.vars))`
} else if (role === "system") {
return `writeText(${JSON.stringify(content as string)}, { role: "system" })`
} else {
if (typeof content === "string") return renderJinja(content)
else if (Array.isArray(content))
Expand Down
20 changes: 17 additions & 3 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,24 @@ export function createChatTurnGenerationContext(
writeText: (body, options) => {
if (body !== undefined && body !== null) {
const { priority, maxTokens } = options || {}
const role = options?.assistant
? "assistant"
: options?.role || "user"
appendChild(
node,
options?.assistant
role === "assistant"
? createAssistantNode(body, { priority, maxTokens })
: createTextNode(body, { priority, maxTokens })
: role === "system"
? creatSystemNode(body, { priority, maxTokens })
: createTextNode(body, { priority, maxTokens })
)
}
},
assistant: (body, options) =>
ctx.writeText(body, { ...options, assistant: true }),
ctx.writeText(body, {
...options,
role: "assistant",
} as WriteTextOptions),
$: (strings, ...args) => {
const current = createStringTemplateNode(strings, args)
appendChild(node, current)
Expand Down Expand Up @@ -745,3 +753,9 @@ export function createChatGenerationContext(

return ctx
}
function creatSystemNode(
body: Awaitable<string>,
arg1: { priority: number; maxTokens: number }
): PromptNode {
throw new Error("Function not implemented.")
}
5 changes: 5 additions & 0 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2028,8 +2028,13 @@ type ChatFunctionHandler = (args: ChatFunctionArgs) => Awaitable<ToolCallOutput>
interface WriteTextOptions extends ContextExpansionOptions {
/**
* Append text to the assistant response. This feature is not supported by all models.
* @deprecated
*/
assistant?: boolean
/**
* Specifies the message role. Default is user
*/
role?: "user" | "assistant" | "system"
}

type PromptGenerator = (ctx: ChatGenerationContext) => Awaitable<unknown>
Expand Down
8 changes: 8 additions & 0 deletions packages/sample/genaisrc/import-prompty.genai.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ script({
},
})

const res = await runPrompt((ctx) => {
ctx.importTemplate("src/basic.prompty", {
question: "what is the capital of france?",
hint: "starts with p",
})
})
console.log(`inline: ${res.text}`)

importTemplate("src/basic.prompty", {
question: "what is the capital of france?",
hint: "starts with p",
Expand Down

0 comments on commit 5d645b3

Please sign in to comment.