Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support rendering prompty in importTemplate #786

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions packages/core/src/expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export async function callExpander(
let logs = ""
let text = ""
let assistantText = ""
let systemText = ""
let images: PromptImage[] = []
let schemas: Record<string, JSONSchema> = {}
let functions: ToolCallback[] = []
Expand Down Expand Up @@ -75,6 +76,7 @@ export async function callExpander(
const {
userPrompt,
assistantPrompt,
systemPrompt,
images: imgs,
errors,
schemas: schs,
Expand All @@ -89,6 +91,7 @@ export async function callExpander(
})
text = userPrompt
assistantText = assistantPrompt
systemText = systemPrompt
images = imgs
schemas = schs
functions = fns
Expand Down Expand Up @@ -126,6 +129,7 @@ export async function callExpander(
statusText,
text,
assistantText,
systemText,
images,
schemas,
functions: Object.freeze(functions),
Expand Down Expand Up @@ -363,6 +367,19 @@ ${schemaTs}
}
messages.push(assistantMessage)
}
if (prompt.systemText) {
trace.detailsFenced("👾 system", prompt.systemText, "markdown")
const systemMessage: ChatCompletionSystemMessageParam = {
role: "system",
content: prompt.systemText,
}
// insert system messages after the last system role message in messages
// assume system messages are at the start
let li = -1
for (let li = 0; li < messages.length; li++)
if (messages[li].role === "system") break
messages.splice(li, 0, systemMessage)
}

trace.endDetails()

Expand Down
66 changes: 47 additions & 19 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ import {
MARKDOWN_PROMPT_FENCE,
PROMPT_FENCE,
PROMPTY_REGEX,
SYSTEM_FENCE,
TEMPLATE_ARG_DATA_SLICE_SAMPLE,
TEMPLATE_ARG_FILE_MAX_TOKENS,
} from "./constants"
import { parseModelIdentifier } from "./models"
import { toChatCompletionUserMessage } from "./chat"
import { errorMessage } from "./error"
import { tidyData } from "./tidy"
import { inspect } from "./logging"
import { dedent } from "./indent"
import {
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
} from "./chattypes"
import { resolveTokenEncoder } from "./encoders"
import { expandFiles } from "./fs"
Expand All @@ -43,6 +44,7 @@ export interface PromptNode extends ContextExpansionOptions {
| "outputProcessor"
| "stringTemplate"
| "assistant"
| "system"
| "def"
| "chatParticipant"
| "fileOutput"
Expand Down Expand Up @@ -85,6 +87,12 @@ export interface PromptAssistantNode extends PromptNode {
resolved?: string // Resolved assistant content
}

export interface PromptSystemNode extends PromptNode {
type: "system"
value: Awaitable<string> // Assistant-related content
resolved?: string // Resolved assistant content
}

// Interface for a string template node.
export interface PromptStringTemplateNode extends PromptNode {
type: "stringTemplate"
Expand Down Expand Up @@ -271,6 +279,14 @@ export function createAssistantNode(
return { type: "assistant", value, ...(options || {}) }
}

export function createSystemNode(
value: Awaitable<string>,
options?: ContextExpansionOptions
): PromptSystemNode {
assert(value !== undefined)
return { type: "system", value, ...(options || {}) }
}

// Function to create a string template node.
export function createStringTemplateNode(
strings: TemplateStringsArray,
Expand Down Expand Up @@ -450,6 +466,7 @@ export interface PromptNodeVisitor {
stringTemplate?: (node: PromptStringTemplateNode) => Awaitable<void> // String template node visitor
outputProcessor?: (node: PromptOutputProcessorNode) => Awaitable<void> // Output processor node visitor
assistant?: (node: PromptAssistantNode) => Awaitable<void> // Assistant node visitor
system?: (node: PromptSystemNode) => Awaitable<void> // System node visitor
chatParticipant?: (node: PromptChatParticipantNode) => Awaitable<void> // Chat participant node visitor
fileOutput?: (node: FileOutputNode) => Awaitable<void> // File output node visitor
importTemplate?: (node: PromptImportTemplate) => Awaitable<void> // Import template node visitor
Expand Down Expand Up @@ -486,6 +503,9 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) {
case "assistant":
await visitor.assistant?.(node as PromptAssistantNode)
break
case "system":
await visitor.system?.(node as PromptSystemNode)
break
case "chatParticipant":
await visitor.chatParticipant?.(node as PromptChatParticipantNode)
break
Expand All @@ -509,6 +529,7 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) {
export interface PromptNodeRender {
userPrompt: string // User prompt content
assistantPrompt: string // Assistant prompt content
systemPrompt: string // System prompt content
images: PromptImage[] // Images included in the prompt
errors: unknown[] // Errors encountered during rendering
schemas: Record<string, JSONSchema> // Schemas included in the prompt
Expand Down Expand Up @@ -585,6 +606,15 @@ async function resolvePromptNode(
n.error = e
}
},
system: async (n) => {
try {
const value = await n.value
n.resolved = n.preview = value
n.tokens = estimateTokens(value, encoder)
} catch (e) {
n.error = e
}
},
assistant: async (n) => {
try {
const value = await n.value
Expand Down Expand Up @@ -717,6 +747,8 @@ async function resolveImportPrompty(
const txt = jinjaRenderChatMessage(message, args)
if (message.role === "assistant")
n.children.push(createAssistantNode(txt))
else if (message.role === "system")
n.children.push(createSystemNode(txt))
else n.children.push(createTextNode(txt))
n.preview += txt + "\n"
}
Expand Down Expand Up @@ -909,6 +941,7 @@ export async function renderPromptNode(

let userPrompt = ""
let assistantPrompt = ""
let systemPrompt = ""
const images: PromptImage[] = []
const errors: unknown[] = []
const schemas: Record<string, JSONSchema> = {}
Expand All @@ -934,6 +967,11 @@ export async function renderPromptNode(
const value = await n.resolved
if (value != undefined) assistantPrompt += value + "\n"
},
system: async (n) => {
if (n.error) errors.push(n.error)
const value = await n.resolved
if (value != undefined) systemPrompt += value + SYSTEM_FENCE
},
stringTemplate: async (n) => {
if (n.error) errors.push(n.error)
const value = n.resolved
Expand All @@ -953,22 +991,6 @@ export async function renderPromptNode(
}
}
},
importTemplate: async (n) => {
if (n.error) errors.push(n.error)
const value = n.resolved
if (value) {
for (const [filename, content] of Object.entries(value)) {
userPrompt += content
userPrompt += "\n"
if (trace)
trace.detailsFenced(
`📦 import template ${filename}`,
content,
"markdown"
)
}
}
},
schema: (n) => {
const { name: schemaName, value: schema, options } = n
if (schemas[schemaName])
Expand Down Expand Up @@ -1057,13 +1079,19 @@ ${fods.map((fo) => ` ${fo.pattern}: ${fo.description}`)}
toChatCompletionUserMessage(userPrompt, images),
]
if (assistantPrompt)
messages.push(<ChatCompletionAssistantMessageParam>{
messages.push({
role: "assistant",
content: assistantPrompt,
})
} as ChatCompletionAssistantMessageParam)
if (systemPrompt)
messages.unshift({
role: "system",
content: systemPrompt,
} as ChatCompletionSystemMessageParam)
const res = Object.freeze<PromptNodeRender>({
userPrompt,
assistantPrompt,
systemPrompt,
images,
schemas,
functions: tools,
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/prompty.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ export function promptyToGenAIScript(doc: PromptyDocument) {
const { role, content } = msg
if (role === "assistant") {
return `assistant(parsers.jinja(${JSON.stringify(content as string)}, env.vars))`
} else if (role === "system") {
return `writeText(${JSON.stringify(content as string)}, { role: "system" })`
} else {
if (typeof content === "string") return renderJinja(content)
else if (Array.isArray(content))
Expand Down
20 changes: 17 additions & 3 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,24 @@ export function createChatTurnGenerationContext(
writeText: (body, options) => {
if (body !== undefined && body !== null) {
const { priority, maxTokens } = options || {}
const role = options?.assistant
? "assistant"
: options?.role || "user"
appendChild(
node,
options?.assistant
role === "assistant"
? createAssistantNode(body, { priority, maxTokens })
: createTextNode(body, { priority, maxTokens })
: role === "system"
? creatSystemNode(body, { priority, maxTokens })
: createTextNode(body, { priority, maxTokens })
)
}
},
assistant: (body, options) =>
ctx.writeText(body, { ...options, assistant: true }),
ctx.writeText(body, {
...options,
role: "assistant",
} as WriteTextOptions),
$: (strings, ...args) => {
const current = createStringTemplateNode(strings, args)
appendChild(node, current)
Expand Down Expand Up @@ -745,3 +753,9 @@ export function createChatGenerationContext(

return ctx
}
function creatSystemNode(
body: Awaitable<string>,
arg1: { priority: number; maxTokens: number }
): PromptNode {
throw new Error("Function not implemented.")
}
5 changes: 5 additions & 0 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2028,8 +2028,13 @@ type ChatFunctionHandler = (args: ChatFunctionArgs) => Awaitable<ToolCallOutput>
interface WriteTextOptions extends ContextExpansionOptions {
/**
* Append text to the assistant response. This feature is not supported by all models.
* @deprecated
*/
assistant?: boolean
/**
* Specifies the message role. Default is user
*/
role?: "user" | "assistant" | "system"
}

type PromptGenerator = (ctx: ChatGenerationContext) => Awaitable<unknown>
Expand Down
8 changes: 8 additions & 0 deletions packages/sample/genaisrc/import-prompty.genai.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ script({
},
})

const res = await runPrompt((ctx) => {
ctx.importTemplate("src/basic.prompty", {
question: "what is the capital of france?",
hint: "starts with p",
})
})
console.log(`inline: ${res.text}`)

importTemplate("src/basic.prompty", {
question: "what is the capital of france?",
hint: "starts with p",
Expand Down
Loading