From 45973b7c0b9798d114c3e629861933ee0b287c07 Mon Sep 17 00:00:00 2001 From: Parker Stafford <52351508+Parker-Stafford@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:22:19 -0700 Subject: [PATCH] feat(playground): parse model name and infer provider form span (#5021) * feat(playground): parse model name and infer provider form span * update azure model selector to be a text field to account for user defined deployment names * add label * move defaults into generative constants, fallback to openai not azure * update defult in test --- app/src/constants/generativeConstants.ts | 7 + .../pages/playground/ModelConfigButton.tsx | 44 +++++-- .../__tests__/playgroundUtils.test.ts | 121 +++++++++++++++++- app/src/pages/playground/constants.tsx | 22 +++- app/src/pages/playground/playgroundUtils.ts | 82 ++++++++++-- app/src/pages/playground/schemas.ts | 11 +- app/src/store/playground/playgroundStore.tsx | 8 +- app/src/store/playground/types.ts | 2 +- 8 files changed, 263 insertions(+), 34 deletions(-) diff --git a/app/src/constants/generativeConstants.ts b/app/src/constants/generativeConstants.ts index 6f08c8b7fb..837d11ccd6 100644 --- a/app/src/constants/generativeConstants.ts +++ b/app/src/constants/generativeConstants.ts @@ -6,3 +6,10 @@ export const ModelProviders: Record = { AZURE_OPENAI: "Azure OpenAI", ANTHROPIC: "Anthropic", }; + +/** + * The default model provider + */ +export const DEFAULT_MODEL_PROVIDER: ModelProvider = "OPENAI"; + +export const DEFAULT_CHAT_ROLE: ChatMessageRole = "user"; diff --git a/app/src/pages/playground/ModelConfigButton.tsx b/app/src/pages/playground/ModelConfigButton.tsx index 033a929a41..99bbc113a8 100644 --- a/app/src/pages/playground/ModelConfigButton.tsx +++ b/app/src/pages/playground/ModelConfigButton.tsx @@ -3,6 +3,7 @@ import React, { ReactNode, startTransition, Suspense, + useCallback, useState, } from "react"; import { graphql, useLazyLoadQuery } from "react-relay"; @@ -14,6 +15,7 @@ import { Flex, Form, Text, + TextField, View, } from "@arizeai/components"; @@ -95,6 +97,20 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) { `, { providerKey: instance.model.provider } ); + + const onModelNameChange = useCallback( + (modelName: string) => { + updateModel({ + instanceId: playgroundInstanceId, + model: { + provider: instance.model.provider, + modelName, + }, + }); + }, + [instance.model.provider, playgroundInstanceId, updateModel] + ); + return (
@@ -111,20 +127,20 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) { }); }} /> - { - updateModel({ - instanceId: playgroundInstanceId, - model: { - provider: instance.model.provider, - modelName, - }, - }); - }} - /> + {instance.model.provider === "AZURE_OPENAI" ? ( + + ) : ( + + )}
); diff --git a/app/src/pages/playground/__tests__/playgroundUtils.test.ts b/app/src/pages/playground/__tests__/playgroundUtils.test.ts index ce5e61b3fc..1d71fa48d8 100644 --- a/app/src/pages/playground/__tests__/playgroundUtils.test.ts +++ b/app/src/pages/playground/__tests__/playgroundUtils.test.ts @@ -1,3 +1,4 @@ +import { DEFAULT_MODEL_PROVIDER } from "@phoenix/constants/generativeConstants"; import { _resetInstanceId, _resetMessageId, @@ -5,11 +6,15 @@ import { } from "@phoenix/store"; import { - getChatRole, INPUT_MESSAGES_PARSING_ERROR, + MODEL_NAME_PARSING_ERROR, OUTPUT_MESSAGES_PARSING_ERROR, OUTPUT_VALUE_PARSING_ERROR, SPAN_ATTRIBUTES_PARSING_ERROR, +} from "../constants"; +import { + getChatRole, + getModelProviderFromModelName, transformSpanAttributesToPlaygroundInstance, } from "../playgroundUtils"; @@ -24,7 +29,7 @@ const expectedPlaygroundInstanceWithIO: PlaygroundInstance = { isRunning: false, model: { provider: "OPENAI", - modelName: "gpt-4o", + modelName: "gpt-3.5-turbo", }, input: { variableKeys: [], variablesValueCache: {} }, tools: [], @@ -70,6 +75,10 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { expect(transformSpanAttributesToPlaygroundInstance(span)).toStrictEqual({ playgroundInstance: { ...expectedPlaygroundInstanceWithIO, + model: { + provider: "OPENAI", + modelName: "gpt-4o", + }, template: defaultTemplate, output: undefined, }, @@ -85,6 +94,10 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { expect(transformSpanAttributesToPlaygroundInstance(span)).toStrictEqual({ playgroundInstance: { ...expectedPlaygroundInstanceWithIO, + model: { + provider: "OPENAI", + modelName: "gpt-4o", + }, template: defaultTemplate, output: undefined, @@ -93,6 +106,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { INPUT_MESSAGES_PARSING_ERROR, OUTPUT_MESSAGES_PARSING_ERROR, OUTPUT_VALUE_PARSING_ERROR, + MODEL_NAME_PARSING_ERROR, ], }); }); @@ -138,6 +152,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({ playgroundInstance: { ...expectedPlaygroundInstanceWithIO, + output: "This is an AI Answer", }, parsingErrors: [OUTPUT_MESSAGES_PARSING_ERROR], @@ -160,6 +175,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { ...basePlaygroundSpan, attributes: JSON.stringify({ llm: { + model_name: "gpt-4o", input_messages: [ { message: { @@ -182,6 +198,10 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({ playgroundInstance: { ...expectedPlaygroundInstanceWithIO, + model: { + provider: "OPENAI", + modelName: "gpt-4o", + }, template: { __type: "chat", messages: [ @@ -197,6 +217,84 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { parsingErrors: [], }); }); + + it("should correctly parse the model name and infer the provider", () => { + const openAiAttributes = JSON.stringify({ + ...spanAttributesWithInputMessages, + llm: { + ...spanAttributesWithInputMessages.llm, + model_name: "gpt-3.5-turbo", + }, + }); + const anthropicAttributes = JSON.stringify({ + ...spanAttributesWithInputMessages, + llm: { + ...spanAttributesWithInputMessages.llm, + model_name: "claude-3-5-sonnet-20240620", + }, + }); + const unknownAttributes = JSON.stringify({ + ...spanAttributesWithInputMessages, + llm: { + ...spanAttributesWithInputMessages.llm, + model_name: "test-my-deployment", + }, + }); + + expect( + transformSpanAttributesToPlaygroundInstance({ + ...basePlaygroundSpan, + attributes: openAiAttributes, + }) + ).toEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + model: { + provider: "OPENAI", + modelName: "gpt-3.5-turbo", + }, + }, + parsingErrors: [], + }); + + _resetMessageId(); + _resetInstanceId(); + + expect( + transformSpanAttributesToPlaygroundInstance({ + ...basePlaygroundSpan, + attributes: anthropicAttributes, + }) + ).toEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + model: { + provider: "ANTHROPIC", + modelName: "claude-3-5-sonnet-20240620", + }, + }, + parsingErrors: [], + }); + + _resetMessageId(); + _resetInstanceId(); + + expect( + transformSpanAttributesToPlaygroundInstance({ + ...basePlaygroundSpan, + attributes: unknownAttributes, + }) + ).toEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + model: { + provider: DEFAULT_MODEL_PROVIDER, + modelName: "test-my-deployment", + }, + }, + parsingErrors: [], + }); + }); }); describe("getChatRole", () => { @@ -215,3 +313,22 @@ describe("getChatRole", () => { expect(getChatRole("invalid")).toEqual("user"); }); }); + +describe("getModelProviderFromModelName", () => { + it("should return OPENAI if the model name includes 'gpt' or 'o1'", () => { + expect(getModelProviderFromModelName("gpt-3.5-turbo")).toEqual("OPENAI"); + expect(getModelProviderFromModelName("o1")).toEqual("OPENAI"); + }); + + it("should return ANTHROPIC if the model name includes 'claude'", () => { + expect(getModelProviderFromModelName("claude-3-5-sonnet-20240620")).toEqual( + "ANTHROPIC" + ); + }); + + it(`should return ${DEFAULT_MODEL_PROVIDER} if the model name does not match any known models`, () => { + expect(getModelProviderFromModelName("test-my-model")).toEqual( + DEFAULT_MODEL_PROVIDER + ); + }); +}); diff --git a/app/src/pages/playground/constants.tsx b/app/src/pages/playground/constants.tsx index 624e0e1ca5..90be404948 100644 --- a/app/src/pages/playground/constants.tsx +++ b/app/src/pages/playground/constants.tsx @@ -1,7 +1,5 @@ export const NUM_MAX_PLAYGROUND_INSTANCES = 4; -export const DEFAULT_CHAT_ROLE = "user"; - /** * Map of {@link ChatMessageRole} to potential role values. * Used to map roles to a canonical role. @@ -12,3 +10,23 @@ export const ChatRoleMap: Record = { system: ["system"], tool: ["tool"], }; + +/** + * Parsing errors for parsing a span to a playground instance + */ +export const INPUT_MESSAGES_PARSING_ERROR = + "Unable to parse span input messages, expected messages which include a role and content."; +export const OUTPUT_MESSAGES_PARSING_ERROR = + "Unable to parse span output messages, expected messages which include a role and content."; +export const OUTPUT_VALUE_PARSING_ERROR = + "Unable to parse span output expected output.value to be present."; +export const SPAN_ATTRIBUTES_PARSING_ERROR = + "Unable to parse span attributes, attributes must be valid JSON."; +export const MODEL_NAME_PARSING_ERROR = + "Unable to parse model name, expected llm.model_name to be present."; + +export const modelProviderToModelPrefixMap: Record = { + AZURE_OPENAI: [], + ANTHROPIC: ["claude"], + OPENAI: ["gpt", "o1"], +}; diff --git a/app/src/pages/playground/playgroundUtils.ts b/app/src/pages/playground/playgroundUtils.ts index 66d8290f53..04f10cf16f 100644 --- a/app/src/pages/playground/playgroundUtils.ts +++ b/app/src/pages/playground/playgroundUtils.ts @@ -1,4 +1,8 @@ -import { PlaygroundInstance } from "@phoenix/store"; +import { + DEFAULT_CHAT_ROLE, + DEFAULT_MODEL_PROVIDER, +} from "@phoenix/constants/generativeConstants"; +import { ModelConfig, PlaygroundInstance } from "@phoenix/store"; import { ChatMessage, createPlaygroundInstance, @@ -6,13 +10,22 @@ import { } from "@phoenix/store"; import { safelyParseJSON } from "@phoenix/utils/jsonUtils"; -import { ChatRoleMap, DEFAULT_CHAT_ROLE } from "./constants"; +import { + ChatRoleMap, + INPUT_MESSAGES_PARSING_ERROR, + MODEL_NAME_PARSING_ERROR, + modelProviderToModelPrefixMap, + OUTPUT_MESSAGES_PARSING_ERROR, + OUTPUT_VALUE_PARSING_ERROR, + SPAN_ATTRIBUTES_PARSING_ERROR, +} from "./constants"; import { chatMessageRolesSchema, chatMessagesSchema, llmInputMessageSchema, llmOutputMessageSchema, MessageSchema, + modelNameSchema, outputSchema, } from "./schemas"; import { PlaygroundSpan } from "./spanPlaygroundPageLoader"; @@ -62,15 +75,6 @@ function processAttributeMessagesToChatMessage( }); } -export const INPUT_MESSAGES_PARSING_ERROR = - "Unable to parse span input messages, expected messages which include a role and content."; -export const OUTPUT_MESSAGES_PARSING_ERROR = - "Unable to parse span output messages, expected messages which include a role and content."; -export const OUTPUT_VALUE_PARSING_ERROR = - "Unable to parse span output expected output.value to be present."; -export const SPAN_ATTRIBUTES_PARSING_ERROR = - "Unable to parse span attributes, attributes must be valid JSON."; - /** * Attempts to parse the input messages from the span attributes. * @param parsedAttributes the JSON parsed span attributes @@ -93,6 +97,11 @@ function getTemplateMessagesFromAttributes(parsedAttributes: unknown) { }; } +/** + * Attempts to get llm.output_messages then output.value from the span attributes. + * @param parsedAttributes the JSON parsed span attributes + * @returns an object containing the parsed output and any parsing errors + */ function getOutputFromAttributes(parsedAttributes: unknown) { const outputParsingErrors: string[] = []; const outputMessages = llmOutputMessageSchema.safeParse(parsedAttributes); @@ -123,6 +132,48 @@ function getOutputFromAttributes(parsedAttributes: unknown) { }; } +/** + * Attempts to infer the provider of the model from the model name. + * @param modelName the model name to get the provider from + * @returns the provider of the model defaulting to {@link DEFAULT_MODEL_PROVIDER} if the provider cannot be inferred + * + * NB: Only exported for testing + */ +export function getModelProviderFromModelName( + modelName: string +): ModelProvider { + for (const provider of Object.keys(modelProviderToModelPrefixMap)) { + const prefixes = modelProviderToModelPrefixMap[provider as ModelProvider]; + if (prefixes.some((prefix) => modelName.includes(prefix))) { + return provider as ModelProvider; + } + } + return DEFAULT_MODEL_PROVIDER; +} + +/** + * Attempts to get the llm.model_name and inferred provider from the span attributes. + * @param parsedAttributes the JSON parsed span attributes + * @returns the model config if it exists or parsing errors if it does not + */ +function getModelConfigFromAttributes( + parsedAttributes: unknown +): + | { modelConfig: ModelConfig; parsingErrors: never[] } + | { modelConfig: null; parsingErrors: string[] } { + const { success, data } = modelNameSchema.safeParse(parsedAttributes); + if (success) { + return { + modelConfig: { + modelName: data.llm.model_name, + provider: getModelProviderFromModelName(data.llm.model_name), + }, + parsingErrors: [], + }; + } + return { modelConfig: null, parsingErrors: [MODEL_NAME_PARSING_ERROR] }; +} + /** * Takes a {@link PlaygroundSpan|Span} and attempts to transform it's attributes into various fields on a {@link PlaygroundInstance}. * @param span the {@link PlaygroundSpan|Span} to transform into a playground instance @@ -155,12 +206,15 @@ export function transformSpanAttributesToPlaygroundInstance( getTemplateMessagesFromAttributes(parsedAttributes); const { output, outputParsingErrors } = getOutputFromAttributes(parsedAttributes); + const { modelConfig, parsingErrors: modelConfigParsingErrors } = + getModelConfigFromAttributes(parsedAttributes); // TODO(parker): add support for tools, variables, and input / output variants // https://github.com/Arize-ai/phoenix/issues/4886 return { playgroundInstance: { ...basePlaygroundInstance, + model: modelConfig ?? basePlaygroundInstance.model, template: messages != null ? { @@ -170,7 +224,11 @@ export function transformSpanAttributesToPlaygroundInstance( : basePlaygroundInstance.template, output, }, - parsingErrors: [...messageParsingErrors, ...outputParsingErrors], + parsingErrors: [ + ...messageParsingErrors, + ...outputParsingErrors, + ...modelConfigParsingErrors, + ], }; } diff --git a/app/src/pages/playground/schemas.ts b/app/src/pages/playground/schemas.ts index 11bdde6c35..cb650ab9cf 100644 --- a/app/src/pages/playground/schemas.ts +++ b/app/src/pages/playground/schemas.ts @@ -65,7 +65,6 @@ export const llmOutputMessageSchema = z.object({ /** * The zod schema for output attributes * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} - */ export const outputSchema = z.object({ [SemanticAttributePrefixes.output]: z.object({ @@ -92,3 +91,13 @@ const chatMessageSchema = schemaForType()( * The zod schema for ChatMessages */ export const chatMessagesSchema = z.array(chatMessageSchema); + +/** + * The zod schema for llm model name + * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} + */ +export const modelNameSchema = z.object({ + [SemanticAttributePrefixes.llm]: z.object({ + [LLMAttributePostfixes.model_name]: z.string(), + }), +}); diff --git a/app/src/store/playground/playgroundStore.tsx b/app/src/store/playground/playgroundStore.tsx index ba7c1a5a1b..fb3e347b2d 100644 --- a/app/src/store/playground/playgroundStore.tsx +++ b/app/src/store/playground/playgroundStore.tsx @@ -4,6 +4,10 @@ import { devtools } from "zustand/middleware"; import { TemplateLanguages } from "@phoenix/components/templateEditor/constants"; import { getTemplateLanguageUtils } from "@phoenix/components/templateEditor/templateEditorUtils"; import { TemplateLanguage } from "@phoenix/components/templateEditor/types"; +import { + DEFAULT_CHAT_ROLE, + DEFAULT_MODEL_PROVIDER, +} from "@phoenix/constants/generativeConstants"; import { assertUnreachable } from "@phoenix/typeUtils"; import { @@ -91,7 +95,7 @@ export function createPlaygroundInstance(): PlaygroundInstance { return { id: generateInstanceId(), template: generateChatCompletionTemplate(), - model: { provider: "OPENAI", modelName: "gpt-4o" }, + model: { provider: DEFAULT_MODEL_PROVIDER, modelName: "gpt-4o" }, tools: [], toolChoice: "auto", // TODO(apowell) - use datasetId if in dataset mode @@ -226,7 +230,7 @@ export const createPlaygroundStore = ( ...instance, messages: [ ...instance.template.messages, - { role: "user", content: "{question}" }, + { role: DEFAULT_CHAT_ROLE, content: "{question}" }, ], }; } diff --git a/app/src/store/playground/types.ts b/app/src/store/playground/types.ts index 12cc4cfb3d..c7aa70caf1 100644 --- a/app/src/store/playground/types.ts +++ b/app/src/store/playground/types.ts @@ -55,7 +55,7 @@ type ManualInput = { type PlaygroundInput = DatasetInput | ManualInput; -type ModelConfig = { +export type ModelConfig = { provider: ModelProvider; modelName: string | null; };