From d0b1641259d52321e3957765f04aed210a6c3077 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Thu, 17 Oct 2024 10:07:09 -0700 Subject: [PATCH] feat(playground): plumb through and apply template variables (#5052) --- app/schema.graphql | 11 ++ .../components/templateEditor/constants.ts | 4 +- app/src/pages/playground/PlaygroundInput.tsx | 4 +- app/src/pages/playground/PlaygroundOutput.tsx | 16 ++- .../PlaygroundOutputSubscription.graphql.ts | 34 +++-- src/phoenix/server/api/subscriptions.py | 86 ++++++++++-- src/phoenix/utilities/template_formatters.py | 70 ++++++++++ tests/utilities/test_template_formatters.py | 132 ++++++++++++++++++ 8 files changed, 329 insertions(+), 28 deletions(-) create mode 100644 src/phoenix/utilities/template_formatters.py create mode 100644 tests/utilities/test_template_formatters.py diff --git a/app/schema.graphql b/app/schema.graphql index 4149731865..200de637cf 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -72,6 +72,7 @@ input ChatCompletionInput { model: GenerativeModelInput! invocationParameters: InvocationParameters! tools: [JSON!] + template: TemplateOptions apiKey: String = null } @@ -1459,6 +1460,16 @@ type SystemApiKey implements ApiKey & Node { id: GlobalID! } +enum TemplateLanguage { + MUSTACHE + F_STRING +} + +input TemplateOptions { + variables: JSON! + language: TemplateLanguage! +} + type TextChunk { content: String! } diff --git a/app/src/components/templateEditor/constants.ts b/app/src/components/templateEditor/constants.ts index 803266784d..0bd3212489 100644 --- a/app/src/components/templateEditor/constants.ts +++ b/app/src/components/templateEditor/constants.ts @@ -10,6 +10,6 @@ * ``` */ export const TemplateLanguages = { - FString: "f-string", // {variable} - Mustache: "mustache", // {{variable}} + FString: "F_STRING", // {variable} + Mustache: "MUSTACHE", // {{variable}} } as const; diff --git a/app/src/pages/playground/PlaygroundInput.tsx b/app/src/pages/playground/PlaygroundInput.tsx index 14ece296ab..d7bc4b6433 100644 --- a/app/src/pages/playground/PlaygroundInput.tsx +++ b/app/src/pages/playground/PlaygroundInput.tsx @@ -24,11 +24,11 @@ export function PlaygroundInput() { if (variableKeys.length === 0) { let templateSyntax = ""; switch (templateLanguage) { - case "f-string": { + case "F_STRING": { templateSyntax = "{input name}"; break; } - case "mustache": { + case "MUSTACHE": { templateSyntax = "{{input name}}"; break; } diff --git a/app/src/pages/playground/PlaygroundOutput.tsx b/app/src/pages/playground/PlaygroundOutput.tsx index f0a7f65b4d..f084fd8bb0 100644 --- a/app/src/pages/playground/PlaygroundOutput.tsx +++ b/app/src/pages/playground/PlaygroundOutput.tsx @@ -9,7 +9,11 @@ import { useCredentialsContext } from "@phoenix/contexts/CredentialsContext"; import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext"; import { useChatMessageStyles } from "@phoenix/hooks/useChatMessageStyles"; import type { ToolCall } from "@phoenix/store"; -import { ChatMessage, generateMessageId } from "@phoenix/store"; +import { + ChatMessage, + generateMessageId, + selectDerivedInputVariables, +} from "@phoenix/store"; import { assertUnreachable } from "@phoenix/typeUtils"; import { @@ -135,6 +139,7 @@ function useChatCompletionSubscription({ $model: GenerativeModelInput! $invocationParameters: InvocationParameters! $tools: [JSON!] + $templateOptions: TemplateOptions $apiKey: String ) { chatCompletion( @@ -143,6 +148,7 @@ function useChatCompletionSubscription({ model: $model invocationParameters: $invocationParameters tools: $tools + template: $templateOptions apiKey: $apiKey } ) { @@ -212,6 +218,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { const instance = instances.find( (instance) => instance.id === props.playgroundInstanceId ); + const templateLanguage = usePlaygroundContext( + (state) => state.templateLanguage + ); + const templateVariables = usePlaygroundContext(selectDerivedInputVariables); const markPlaygroundInstanceComplete = usePlaygroundContext( (state) => state.markPlaygroundInstanceComplete ); @@ -239,6 +249,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { invocationParameters: { toolChoice: instance.toolChoice, }, + templateOptions: { + variables: templateVariables, + language: templateLanguage, + }, tools: instance.tools.map((tool) => tool.definition), apiKey: credentials[instance.model.provider], }, diff --git a/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts b/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts index d560c6a348..294eb02bcd 100644 --- a/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts +++ b/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<> + * @generated SignedSource<<65d35a875aab582b4ecd25a6b530ee33>> * @lightSyntaxTransform * @nogrep */ @@ -11,6 +11,7 @@ import { ConcreteRequest, GraphQLSubscription } from 'relay-runtime'; export type ChatCompletionMessageRole = "AI" | "SYSTEM" | "TOOL" | "USER"; export type GenerativeProviderKey = "ANTHROPIC" | "AZURE_OPENAI" | "OPENAI"; +export type TemplateLanguage = "F_STRING" | "MUSTACHE"; export type ChatCompletionMessageInput = { content: any; role: ChatCompletionMessageRole; @@ -28,11 +29,16 @@ export type InvocationParameters = { toolChoice?: any | null; topP?: number | null; }; +export type TemplateOptions = { + language: TemplateLanguage; + variables: any; +}; export type PlaygroundOutputSubscription$variables = { apiKey?: string | null; invocationParameters: InvocationParameters; messages: ReadonlyArray; model: GenerativeModelInput; + templateOptions?: TemplateOptions | null; tools?: ReadonlyArray | null; }; export type PlaygroundOutputSubscription$data = { @@ -79,11 +85,16 @@ v3 = { "name": "model" }, v4 = { + "defaultValue": null, + "kind": "LocalArgument", + "name": "templateOptions" +}, +v5 = { "defaultValue": null, "kind": "LocalArgument", "name": "tools" }, -v5 = [ +v6 = [ { "alias": null, "args": [ @@ -109,6 +120,11 @@ v5 = [ "name": "model", "variableName": "model" }, + { + "kind": "Variable", + "name": "template", + "variableName": "templateOptions" + }, { "kind": "Variable", "name": "tools", @@ -195,12 +211,13 @@ return { (v1/*: any*/), (v2/*: any*/), (v3/*: any*/), - (v4/*: any*/) + (v4/*: any*/), + (v5/*: any*/) ], "kind": "Fragment", "metadata": null, "name": "PlaygroundOutputSubscription", - "selections": (v5/*: any*/), + "selections": (v6/*: any*/), "type": "Subscription", "abstractKey": null }, @@ -210,24 +227,25 @@ return { (v2/*: any*/), (v3/*: any*/), (v1/*: any*/), + (v5/*: any*/), (v4/*: any*/), (v0/*: any*/) ], "kind": "Operation", "name": "PlaygroundOutputSubscription", - "selections": (v5/*: any*/) + "selections": (v6/*: any*/) }, "params": { - "cacheID": "924d84f911c5156af0abb2a371d098f2", + "cacheID": "187fbcd6de7ddfef0a20d669441d6d6f", "id": null, "metadata": {}, "name": "PlaygroundOutputSubscription", "operationKind": "subscription", - "text": "subscription PlaygroundOutputSubscription(\n $messages: [ChatCompletionMessageInput!]!\n $model: GenerativeModelInput!\n $invocationParameters: InvocationParameters!\n $tools: [JSON!]\n $apiKey: String\n) {\n chatCompletion(input: {messages: $messages, model: $model, invocationParameters: $invocationParameters, tools: $tools, apiKey: $apiKey}) {\n __typename\n ... on TextChunk {\n content\n }\n ... on ToolCallChunk {\n id\n function {\n name\n arguments\n }\n }\n }\n}\n" + "text": "subscription PlaygroundOutputSubscription(\n $messages: [ChatCompletionMessageInput!]!\n $model: GenerativeModelInput!\n $invocationParameters: InvocationParameters!\n $tools: [JSON!]\n $templateOptions: TemplateOptions\n $apiKey: String\n) {\n chatCompletion(input: {messages: $messages, model: $model, invocationParameters: $invocationParameters, tools: $tools, template: $templateOptions, apiKey: $apiKey}) {\n __typename\n ... on TextChunk {\n content\n }\n ... on ToolCallChunk {\n id\n function {\n name\n arguments\n }\n }\n }\n}\n" } }; })(); -(node as any).hash = "30b9973d6ea69054a907549af97c0e5f"; +(node as any).hash = "ab948845156ab09904236c95cf379b04"; export default node; diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index 638eb1943a..477f7bb4ef 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -2,6 +2,7 @@ from collections import defaultdict from dataclasses import fields from datetime import datetime +from enum import Enum from itertools import chain from typing import ( TYPE_CHECKING, @@ -10,6 +11,7 @@ AsyncIterator, DefaultDict, Dict, + Iterable, Iterator, List, Optional, @@ -46,6 +48,11 @@ from phoenix.server.dml_event import SpanInsertEvent from phoenix.trace.attributes import unflatten from phoenix.utilities.json import jsonify +from phoenix.utilities.template_formatters import ( + FStringTemplateFormatter, + MustacheTemplateFormatter, + TemplateFormatter, +) if TYPE_CHECKING: from openai.types.chat import ( @@ -57,6 +64,18 @@ ToolCallIndex: TypeAlias = int +@strawberry.enum +class TemplateLanguage(Enum): + MUSTACHE = "MUSTACHE" + F_STRING = "F_STRING" + + +@strawberry.input +class TemplateOptions: + variables: JSONScalarType + language: TemplateLanguage + + @strawberry.type class TextChunk: content: str @@ -91,11 +110,12 @@ class ChatCompletionInput: model: GenerativeModelInput invocation_parameters: InvocationParameters tools: Optional[List[JSONScalarType]] = UNSET + template: Optional[TemplateOptions] = UNSET api_key: Optional[str] = strawberry.field(default=None) def to_openai_chat_completion_param( - message: ChatCompletionMessageInput, + role: ChatCompletionMessageRole, content: JSONScalarType ) -> "ChatCompletionMessageParam": from openai.types.chat import ( ChatCompletionAssistantMessageParam, @@ -103,30 +123,30 @@ def to_openai_chat_completion_param( ChatCompletionUserMessageParam, ) - if message.role is ChatCompletionMessageRole.USER: + if role is ChatCompletionMessageRole.USER: return ChatCompletionUserMessageParam( { - "content": message.content, + "content": content, "role": "user", } ) - if message.role is ChatCompletionMessageRole.SYSTEM: + if role is ChatCompletionMessageRole.SYSTEM: return ChatCompletionSystemMessageParam( { - "content": message.content, + "content": content, "role": "system", } ) - if message.role is ChatCompletionMessageRole.AI: + if role is ChatCompletionMessageRole.AI: return ChatCompletionAssistantMessageParam( { - "content": message.content, + "content": content, "role": "assistant", } ) - if message.role is ChatCompletionMessageRole.TOOL: + if role is ChatCompletionMessageRole.TOOL: raise NotImplementedError - assert_never(message.role) + assert_never(role) @strawberry.type @@ -140,6 +160,13 @@ async def chat_completion( client = AsyncOpenAI(api_key=input.api_key) invocation_parameters = jsonify(input.invocation_parameters) + messages: List[Tuple[ChatCompletionMessageRole, str]] = [ + (message.role, message.content) for message in input.messages + ] + if template_options := input.template: + messages = list(_formatted_messages(messages, template_options)) + openai_messages = [to_openai_chat_completion_param(*message) for message in messages] + in_memory_span_exporter = InMemorySpanExporter() tracer_provider = TracerProvider() tracer_provider.add_span_processor( @@ -154,7 +181,7 @@ async def chat_completion( _llm_span_kind(), _llm_model_name(input.model.name), _llm_tools(input.tools or []), - _llm_input_messages(input.messages), + _llm_input_messages(messages), _llm_invocation_parameters(invocation_parameters), _input_value_and_mime_type(input), ) @@ -165,7 +192,7 @@ async def chat_completion( tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]] = defaultdict(list) role: Optional[str] = None async for chunk in await client.chat.completions.create( - messages=(to_openai_chat_completion_param(message) for message in input.messages), + messages=openai_messages, model=input.model.name, stream=True, tools=input.tools or NOT_GIVEN, @@ -291,10 +318,12 @@ def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]: yield OUTPUT_VALUE, safe_json_dumps(jsonify(output)) -def _llm_input_messages(messages: List[ChatCompletionMessageInput]) -> Iterator[Tuple[str, Any]]: - for i, message in enumerate(messages): - yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", message.role.value.lower() - yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", message.content +def _llm_input_messages( + messages: Iterable[Tuple[ChatCompletionMessageRole, str]], +) -> Iterator[Tuple[str, Any]]: + for i, (role, content) in enumerate(messages): + yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower() + yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content def _llm_output_messages( @@ -332,6 +361,33 @@ def _datetime(*, epoch_nanoseconds: float) -> datetime: return datetime.fromtimestamp(epoch_seconds) +def _formatted_messages( + messages: Iterable[Tuple[ChatCompletionMessageRole, str]], template_options: TemplateOptions +) -> Iterator[Tuple[ChatCompletionMessageRole, str]]: + """ + Formats the messages using the given template options. + """ + template_formatter = _template_formatter(template_language=template_options.language) + roles, templates = zip(*messages) + formatted_templates = map( + lambda template: template_formatter.format(template, **template_options.variables), + templates, + ) + formatted_messages = zip(roles, formatted_templates) + return formatted_messages + + +def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter: + """ + Instantiates the appropriate template formatter for the template language. + """ + if template_language is TemplateLanguage.MUSTACHE: + return MustacheTemplateFormatter() + if template_language is TemplateLanguage.F_STRING: + return FStringTemplateFormatter() + assert_never(template_language) + + JSON = OpenInferenceMimeTypeValues.JSON.value LLM = OpenInferenceSpanKindValues.LLM.value diff --git a/src/phoenix/utilities/template_formatters.py b/src/phoenix/utilities/template_formatters.py new file mode 100644 index 0000000000..610be52e44 --- /dev/null +++ b/src/phoenix/utilities/template_formatters.py @@ -0,0 +1,70 @@ +import re +from abc import ABC, abstractmethod +from string import Formatter +from typing import Any, Iterable, Set + + +class TemplateFormatter(ABC): + @abstractmethod + def parse(self, template: str) -> Set[str]: + """ + Parse the template and return a set of variable names. + """ + raise NotImplementedError + + def format(self, template: str, **variables: Any) -> str: + """ + Formats the template with the given variables. + """ + template_variable_names = self.parse(template) + if missing_template_variables := template_variable_names - set(variables.keys()): + raise ValueError(f"Missing template variables: {', '.join(missing_template_variables)}") + return self._format(template, template_variable_names, **variables) + + @abstractmethod + def _format(self, template: str, variable_names: Iterable[str], **variables: Any) -> str: + raise NotImplementedError + + +class FStringTemplateFormatter(TemplateFormatter): + """ + Regular f-string template formatter. + + Examples: + + >>> formatter = FStringTemplateFormatter() + >>> formatter.format("{hello}", hello="world") + 'world' + """ + + def parse(self, template: str) -> Set[str]: + return set(field_name for _, field_name, _, _ in Formatter().parse(template) if field_name) + + def _format(self, template: str, variable_names: Iterable[str], **variables: Any) -> str: + return template.format(**variables) + + +class MustacheTemplateFormatter(TemplateFormatter): + """ + Mustache template formatter. + + Examples: + + >>> formatter = MustacheTemplateFormatter() + >>> formatter.format("{{ hello }}", hello="world") + 'world' + """ + + PATTERN = re.compile(r"(? Set[str]: + return set(match for match in re.findall(self.PATTERN, template)) + + def _format(self, template: str, variable_names: Iterable[str], **variables: Any) -> str: + for variable_name in variable_names: + template = re.sub( + pattern=rf"(? None: + formatter = formatter_cls() + prompt = formatter.format(template, **variables) + assert prompt == expected_prompt + + +@pytest.mark.parametrize( + "formatter_cls, template", + ( + pytest.param( + MustacheTemplateFormatter, + "{{ hello }}", + id="mustache-missing-template-variables", + ), + pytest.param( + FStringTemplateFormatter, + "{hello}", + id="f-string-missing-template-variables", + ), + ), +) +def test_template_formatters_raise_expected_error_on_missing_variables( + formatter_cls: Type[TemplateFormatter], template: str +) -> None: + formatter = formatter_cls() + with pytest.raises(ValueError, match="Missing template variables"): + formatter.format(template)