diff --git a/packages/phoenix-evals/pyproject.toml b/packages/phoenix-evals/pyproject.toml index f1c5513577..a6a7920c51 100644 --- a/packages/phoenix-evals/pyproject.toml +++ b/packages/phoenix-evals/pyproject.toml @@ -53,6 +53,7 @@ test = [ "nest_asyncio", "pandas-stubs<=2.0.2.230605", "types-tqdm", + "lameenc" ] [project.urls] diff --git a/packages/phoenix-evals/src/phoenix/evals/classify.py b/packages/phoenix-evals/src/phoenix/evals/classify.py index 8d47145da4..22582e3366 100644 --- a/packages/phoenix-evals/src/phoenix/evals/classify.py +++ b/packages/phoenix-evals/src/phoenix/evals/classify.py @@ -1,11 +1,15 @@ from __future__ import annotations +import inspect import logging +import warnings from collections import defaultdict from enum import Enum +from functools import wraps from itertools import product from typing import ( Any, + Callable, DefaultDict, Dict, Iterable, @@ -14,6 +18,7 @@ NamedTuple, Optional, Tuple, + TypeVar, Union, ) @@ -63,11 +68,37 @@ class ClassificationStatus(Enum): MISSING_INPUT = "MISSING INPUT" +PROCESSOR_TYPE = TypeVar("PROCESSOR_TYPE") + + +def deprecate_dataframe_arg(func: Callable[..., Any]) -> Callable[..., Any]: + # Remove this once the `dataframe` arg in `llm_classify` is no longer supported + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + signature = inspect.signature(func) + + if "dataframe" in kwargs: + warnings.warn( + "`dataframe` argument is deprecated; use `data` instead", + DeprecationWarning, + stacklevel=2, + ) + kwargs["data"] = kwargs.pop("dataframe") + bound_args = signature.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + return func(*bound_args.args, **bound_args.kwargs) + + return wrapper + + +@deprecate_dataframe_arg def llm_classify( - dataframe: pd.DataFrame, + data: Union[pd.DataFrame, List[Any]], model: BaseModel, template: Union[ClassificationTemplate, PromptTemplate, str], rails: List[str], + data_processor: Optional[Callable[[PROCESSOR_TYPE], PROCESSOR_TYPE]] = None, system_instruction: Optional[str] = None, verbose: bool = False, use_function_calling_if_available: bool = True, @@ -88,20 +119,28 @@ def llm_classify( `provide_explanation=True`. Args: - dataframe (pandas.DataFrame): A pandas dataframe in which each row represents - a record to be classified. All template variable names must appear as column - names in the dataframe (extra columns unrelated to the template are permitted). + data (Union[pd.DataFrame, List[Any]): A collection of data which + can contain template variables and other information necessary to generate evaluations. + If a passed a DataFrame, there must be column names that match the template variables. + If passed a list, the elements of the list will be mapped to the template variables + in the order that the template variables are defined. + + model (BaseEvalModel): An LLM model class. template (Union[ClassificationTemplate, PromptTemplate, str]): The prompt template as either an instance of PromptTemplate, ClassificationTemplate or a string. If a string, the variable names should be surrounded by curly braces so that a call to `.format` can be made to substitute variable values. - model (BaseEvalModel): An LLM model class. - rails (List[str]): A list of strings representing the possible output classes of the model's predictions. + data_processor (Optional[Callable[[T], T]]): An optional callable that is used to process + the input data before it is mapped to the template variables. This callable is passed + a single element of the input data and can return either a pandas.Series with indices + corresponding to the template variables or an iterable of values that will be mapped + to the template variables in the order that the template variables are defined. + system_instruction (Optional[str], optional): An optional system message. verbose (bool, optional): If True, prints detailed info to stdout such as @@ -171,8 +210,8 @@ def llm_classify( prompt_options = PromptOptions(provide_explanation=provide_explanation) - labels: Iterable[Optional[str]] = [None] * len(dataframe) - explanations: Iterable[Optional[str]] = [None] * len(dataframe) + labels: Iterable[Optional[str]] = [None] * len(data) + explanations: Iterable[Optional[str]] = [None] * len(data) printif(verbose, f"Using prompt:\n\n{eval_template.prompt(prompt_options)}") if generation_info := model.verbose_generation_info(): @@ -211,18 +250,59 @@ def _process_response(response: str) -> Tuple[str, Optional[str]]: unrailed_label, explanation = parse_openai_function_call(response) return snap_to_rail(unrailed_label, rails, verbose=verbose), explanation - async def _run_llm_classification_async(input_data: pd.Series[Any]) -> ParsedLLMResponse: + def _normalize_to_series( + data: PROCESSOR_TYPE, + ) -> pd.Series[Any]: + if isinstance(data, pd.Series): + return data + + variable_count = len(eval_template.variables) + if variable_count == 1: + return pd.Series({eval_template.variables[0]: data}) + elif variable_count > 1: + if isinstance(data, str): + raise ValueError("The data cannot be mapped to the template variables") + elif isinstance(data, Iterable): + return pd.Series( + { + template_var: input_val + for template_var, input_val in zip(eval_template.variables, data) + } + ) + return pd.Series() + + async def _run_llm_classification_async( + input_data: PROCESSOR_TYPE, + ) -> ParsedLLMResponse: with set_verbosity(model, verbose) as verbose_model: - prompt = _map_template(input_data) + if data_processor: + maybe_awaitable_data = data_processor(input_data) + if inspect.isawaitable(maybe_awaitable_data): + processed_data = await maybe_awaitable_data + else: + processed_data = maybe_awaitable_data + else: + processed_data = input_data + + prompt = _map_template(_normalize_to_series(processed_data)) response = await verbose_model._async_generate( prompt, instruction=system_instruction, **model_kwargs ) inference, explanation = _process_response(response) return inference, explanation, response, str(prompt) - def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMResponse: + def _run_llm_classification_sync( + input_data: PROCESSOR_TYPE, + ) -> ParsedLLMResponse: with set_verbosity(model, verbose) as verbose_model: - prompt = _map_template(input_data) + if data_processor: + processed_data = data_processor(input_data) + if inspect.isawaitable(processed_data): + raise ValueError("Cannot run the data processor asynchronously.") + else: + processed_data = input_data + + prompt = _map_template(_normalize_to_series(processed_data)) response = verbose_model._generate( prompt, instruction=system_instruction, **model_kwargs ) @@ -242,7 +322,17 @@ def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMRespons fallback_return_value=fallback_return_value, ) - results, execution_details = executor.run([row_tuple[1] for row_tuple in dataframe.iterrows()]) + list_of_inputs: Union[Tuple[Any], List[Any]] + if isinstance(data, pd.DataFrame): + list_of_inputs = [row_tuple[1] for row_tuple in data.iterrows()] + dataframe_index = data.index + elif isinstance(data, (list, tuple)): + list_of_inputs = data + dataframe_index = pd.Index(range(len(data))) + else: + raise ValueError("Invalid 'data' input type.") + + results, execution_details = executor.run(list_of_inputs) labels, explanations, responses, prompts = zip(*results) all_exceptions = [details.exceptions for details in execution_details] execution_statuses = [details.status for details in execution_details] @@ -264,7 +354,7 @@ def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMRespons **({"execution_status": [status.value for status in classification_statuses]}), **({"execution_seconds": [runtime for runtime in execution_times]}), }, - index=dataframe.index, + index=dataframe_index, ) diff --git a/packages/phoenix-evals/src/phoenix/evals/default_audio_templates.py b/packages/phoenix-evals/src/phoenix/evals/default_audio_templates.py new file mode 100644 index 0000000000..c51b114677 --- /dev/null +++ b/packages/phoenix-evals/src/phoenix/evals/default_audio_templates.py @@ -0,0 +1,133 @@ +from phoenix.evals.templates import ( + ClassificationTemplate, + PromptPartContentType, + PromptPartTemplate, +) + +EMOTION_AUDIO_BASE_TEMPLATE_PT_1 = """ +You are an AI system designed to classify emotions in audio files. + +### TASK: +Analyze the provided audio file and classify the primary emotion based on these characteristics: +- Tone: General tone of the speaker (e.g., cheerful, tense, calm). +- Pitch: Level and variability of the pitch (e.g., high, low, monotone). +- Pace: Speed of speech (e.g., fast, slow, steady). +- Volume: Loudness of the speech (e.g., loud, soft, moderate). +- Intensity: Emotional strength or expression (e.g., subdued, sharp, exaggerated). + +The classified emotion must be one of the following: +['anger', 'happiness', 'excitement', 'sadness', 'neutral', 'frustration', 'fear', 'surprise', +'disgust', 'other'] + +IMPORTANT: Choose the most dominant emotion expressed in the audio. Neutral should only be used when +no other emotion is clearly present, do your best to avoid this label. + +************ + +Here is the audio to classify: + +""" + +EMOTION_AUDIO_BASE_TEMPLATE_PT_2 = """{audio}""" + +EMOTION_AUDIO_BASE_TEMPLATE_PT_3 = """ +RESPONSE FORMAT: + +Provide a single word from the list above representing the detected emotion. + +************ + +EXAMPLE RESPONSE: excitement + +************ + +Analyze the audio and respond in this format. +""" + +EMOTION_AUDIO_EXPLANATION_TEMPLATE_PT_1 = """ +You are an AI system designed to classify emotions in audio files. + +### TASK: +First, explain in a step-by-step manner how the provided audio file based on these characteristics +and how they indicate the emotion of the speaker: +- Tone: General tone of the speaker (e.g., cheerful, tense, calm). +- Pitch: Level and variability of the pitch (e.g., high, low, monotone). +- Pace: Speed of speech (e.g., fast, slow, steady). +- Volume: Loudness of the speech (e.g., loud, soft, moderate). +- Intensity: Emotional strength or expression (e.g., subdued, sharp, exaggerated). + +Then, classify the primary emotion. The classified emotion must be one of the following: +['anger', 'happiness', 'excitement', 'sadness', 'neutral', 'frustration', 'fear', 'surprise', +'disgust', 'other'] + +IMPORTANT: Choose the most dominant emotion expressed in the audio. Neutral should only be used when +no other emotion is clearly present, do your best to avoid this label. + +************ + +Here is the audio to classify: +""" + +EMOTION_AUDIO_EXPLANATION_TEMPLATE_PT_3 = """ +EXAMPLE RESPONSE FORMAT: + +************ + +EXPLANATION: An explanation of your reasoning based on the tone, pitch, pace, volume, and intensity + of the audio. +LABEL: "excitement" + +************ + +Analyze the audio and respond in the format shown above. +""" + +EMOTION_AUDIO_RAILS = [ + "anger", + "happiness", + "excitement", + "sadness", + "neutral", + "frustration", + "fear", + "surprise", + "disgust", + "other", +] + +EMOTION_PROMPT_TEMPLATE = ClassificationTemplate( + rails=EMOTION_AUDIO_RAILS, + template=[ + PromptPartTemplate( + content_type=PromptPartContentType.TEXT, + template=EMOTION_AUDIO_BASE_TEMPLATE_PT_1, + ), + PromptPartTemplate( + content_type=PromptPartContentType.AUDIO, + template=EMOTION_AUDIO_BASE_TEMPLATE_PT_2, + ), + PromptPartTemplate( + content_type=PromptPartContentType.TEXT, + template=EMOTION_AUDIO_BASE_TEMPLATE_PT_3, + ), + ], + explanation_template=[ + PromptPartTemplate( + content_type=PromptPartContentType.TEXT, + template=EMOTION_AUDIO_EXPLANATION_TEMPLATE_PT_1, + ), + PromptPartTemplate( + content_type=PromptPartContentType.AUDIO, + template=EMOTION_AUDIO_BASE_TEMPLATE_PT_2, + ), + PromptPartTemplate( + content_type=PromptPartContentType.TEXT, + template=EMOTION_AUDIO_EXPLANATION_TEMPLATE_PT_3, + ), + ], +) +""" +A template for evaluating the emotion of an audio sample. It return +an emotion and provides a detailed explanation template +to assist users in articulating their judgment on code readability. +""" diff --git a/packages/phoenix-evals/src/phoenix/evals/default_templates.py b/packages/phoenix-evals/src/phoenix/evals/default_templates.py index 98f37e2d40..9284ad27f4 100644 --- a/packages/phoenix-evals/src/phoenix/evals/default_templates.py +++ b/packages/phoenix-evals/src/phoenix/evals/default_templates.py @@ -378,6 +378,7 @@ EXPLANATION:""" + REFERENCE_LINK_CORRECTNESS_PROMPT_BASE_TEMPLATE = """ You are given a conversation that contains questions by a CUSTOMER and you are trying to determine if the documentation page shared by the ASSISTANT correctly @@ -783,6 +784,7 @@ to assist users in articulating their judgment on code readability. """ + REFERENCE_LINK_CORRECTNESS_PROMPT_TEMPLATE = ClassificationTemplate( rails=list(REFERENCE_LINK_CORRECTNESS_PROMPT_RAILS_MAP.values()), template=REFERENCE_LINK_CORRECTNESS_PROMPT_BASE_TEMPLATE, diff --git a/packages/phoenix-evals/src/phoenix/evals/exceptions.py b/packages/phoenix-evals/src/phoenix/evals/exceptions.py index dda7fa27c2..841eea1de5 100644 --- a/packages/phoenix-evals/src/phoenix/evals/exceptions.py +++ b/packages/phoenix-evals/src/phoenix/evals/exceptions.py @@ -8,3 +8,7 @@ class PhoenixContextLimitExceeded(PhoenixException): class PhoenixTemplateMappingError(PhoenixException): pass + + +class PhoenixUnsupportedAudioFormat(PhoenixException): + pass diff --git a/packages/phoenix-evals/src/phoenix/evals/models/openai.py b/packages/phoenix-evals/src/phoenix/evals/models/openai.py index c7efe16229..cb5f58e8b8 100644 --- a/packages/phoenix-evals/src/phoenix/evals/models/openai.py +++ b/packages/phoenix-evals/src/phoenix/evals/models/openai.py @@ -14,10 +14,11 @@ get_origin, ) -from phoenix.evals.exceptions import PhoenixContextLimitExceeded +from phoenix.evals.exceptions import PhoenixContextLimitExceeded, PhoenixUnsupportedAudioFormat from phoenix.evals.models.base import BaseModel from phoenix.evals.models.rate_limiters import RateLimiter from phoenix.evals.templates import MultimodalPrompt, PromptPartContentType +from phoenix.evals.utils import get_audio_format_from_base64 MINIMUM_OPENAI_VERSION = "1.0.0" MODEL_TOKEN_LIMIT_MAPPING = { @@ -35,6 +36,7 @@ "gpt-4-vision-preview": 128000, } LEGACY_COMPLETION_API_MODELS = ("gpt-3.5-turbo-instruct",) +SUPPORTED_AUDIO_FORMATS = {"mp3", "wav"} logger = logging.getLogger(__name__) @@ -282,11 +284,29 @@ def _build_messages( self, prompt: MultimodalPrompt, system_instruction: Optional[str] = None ) -> List[Dict[str, str]]: messages = [] - for parts in prompt.parts: - if parts.content_type == PromptPartContentType.TEXT: - messages.append({"role": "system", "content": parts.content}) + for part in prompt.parts: + if part.content_type == PromptPartContentType.TEXT: + messages.append({"role": "system", "content": part.content}) + elif part.content_type == PromptPartContentType.AUDIO: + format = str(get_audio_format_from_base64(part.content)) + if format not in SUPPORTED_AUDIO_FORMATS: + raise PhoenixUnsupportedAudioFormat(f"Unsupported audio format: {format}") + messages.append( + { + "role": "user", + "content": [ # type: ignore + { + "type": "input_audio", + "input_audio": { + "data": part.content, + "format": str(get_audio_format_from_base64(part.content)), + }, + } + ], + } + ) else: - raise ValueError(f"Unsupported content type: {parts.content_type}") + raise ValueError(f"Unsupported content type: {part.content_type}") if system_instruction: messages.insert(0, {"role": "system", "content": str(system_instruction)}) return messages @@ -321,7 +341,7 @@ def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str: prompt = MultimodalPrompt.from_string(prompt) invoke_params = self.invocation_params - messages = self._build_messages(prompt, kwargs.get("instruction")) + messages = self._build_messages(prompt=prompt, system_instruction=kwargs.get("instruction")) if functions := kwargs.get("functions"): invoke_params["functions"] = functions if function_call := kwargs.get("function_call"): diff --git a/packages/phoenix-evals/src/phoenix/evals/templates.py b/packages/phoenix-evals/src/phoenix/evals/templates.py index 962f21d635..dc1ca8db57 100644 --- a/packages/phoenix-evals/src/phoenix/evals/templates.py +++ b/packages/phoenix-evals/src/phoenix/evals/templates.py @@ -1,4 +1,5 @@ import re +from collections import OrderedDict from dataclasses import dataclass from enum import Enum from string import Formatter @@ -31,7 +32,7 @@ def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, A class PromptPartContentType(str, Enum): TEXT = "text" - AUDIO_URL = "audio_url" + AUDIO = "audio" @dataclass @@ -40,6 +41,7 @@ class PromptPart: content: str +# TODO: ask about rename to PromptTemplatePart @dataclass class PromptPartTemplate: content_type: PromptPartContentType @@ -148,6 +150,9 @@ def __init__( for _template in [self.template, self.explanation_template]: if _template: self.variables.extend(self._parse_variables(template=_template)) + # remove duplicates while preserving order + self.variables = list(OrderedDict.fromkeys(self.variables)) + self._scores = scores def __repr__(self) -> str: diff --git a/packages/phoenix-evals/src/phoenix/evals/utils.py b/packages/phoenix-evals/src/phoenix/evals/utils.py index 6721f51f7d..68201bd41a 100644 --- a/packages/phoenix-evals/src/phoenix/evals/utils.py +++ b/packages/phoenix-evals/src/phoenix/evals/utils.py @@ -1,6 +1,7 @@ +import base64 import json from io import BytesIO -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple from urllib.error import HTTPError from urllib.request import urlopen from zipfile import ZipFile @@ -174,3 +175,56 @@ def _default_openai_function( def printif(condition: bool, *args: Any, **kwargs: Any) -> None: if condition: tqdm.write(*args, **kwargs) + + +def get_audio_format_from_base64( + enc_str: str, +) -> Literal["mp3", "wav", "ogg", "flac", "m4a", "aac"]: + """ + Determines the audio format from a base64 encoded string by checking file signatures. + + Args: + enc_str: Base64 encoded audio data + + Returns: + Audio format as string + + Raises: + ValueError: If the audio format is not supported or cannot be determined + """ + audio_bytes = base64.b64decode(enc_str) + + if len(audio_bytes) < 12: + raise ValueError("Audio data too short to determine format") + + # WAV check + if audio_bytes[0:4] == b"RIFF" and audio_bytes[8:12] == b"WAVE": + return "wav" + + # OGG check + if audio_bytes[0:4] == b"OggS": + return "ogg" + + # FLAC check + if audio_bytes[0:4] == b"fLaC": + return "flac" + + # M4A check (ISO Base Media File Format) + if len(audio_bytes) > 10 and (audio_bytes[4:11] == b"ftypM4A" or audio_bytes[0:4] == b"M4A "): + return "m4a" + + # AAC check + if audio_bytes[:2] in (bytearray([0xFF, 0xF1]), bytearray([0xFF, 0xF9])): + return "aac" + + # MP3 checks + if len(audio_bytes) >= 3: + if audio_bytes[0:3] == b"ID3": + return "mp3" + elif audio_bytes[0] == 0xFF and (audio_bytes[1] & 0xE0 == 0xE0): # MPEG sync + return "mp3" + + # If no match, raise an error + raise ValueError( + "Unsupported audio format. Supported formats are: mp3, wav, ogg, flac, m4a, aac" + ) diff --git a/packages/phoenix-evals/tests/phoenix/evals/functions/test_classify.py b/packages/phoenix-evals/tests/phoenix/evals/functions/test_classify.py index 84cb6803d7..f493c55b8c 100644 --- a/packages/phoenix-evals/tests/phoenix/evals/functions/test_classify.py +++ b/packages/phoenix-evals/tests/phoenix/evals/functions/test_classify.py @@ -22,8 +22,10 @@ run_evals, ) from phoenix.evals.default_templates import ( + RAG_RELEVANCY_PROMPT_BASE_TEMPLATE, RAG_RELEVANCY_PROMPT_TEMPLATE, TOXICITY_PROMPT_TEMPLATE, + TOXICITY_PROMPT_TEMPLATE_BASE_TEMPLATE, ) from phoenix.evals.evaluators import LLMEvaluator from phoenix.evals.executors import ExecutionStatus @@ -99,19 +101,240 @@ def classification_template(): return RAG_RELEVANCY_PROMPT_TEMPLATE +@pytest.fixture +def mock_respx_responses(respx_mock: respx.mock): + def _mock_responses(response_mapping): + for (query, reference), response in response_mapping.items(): + matcher = M(content__contains=query) & M(content__contains=reference) + payload = { + "choices": [ + { + "message": { + "content": response, + }, + } + ], + } + respx_mock.route(matcher).mock(return_value=httpx.Response(200, json=payload)) + + return _mock_responses + + @pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") def test_llm_classify( openai_api_key: str, classification_dataframe: DataFrame, - respx_mock: respx.mock, + mock_respx_responses: mock_respx_responses, ): dataframe = classification_dataframe keys = list(zip(dataframe["input"], dataframe["reference"])) responses = ["relevant", "unrelated", "\nrelevant ", "unparsable"] response_mapping = {key: response for key, response in zip(keys, responses)} - for (query, reference), response in response_mapping.items(): - matcher = M(content__contains=query) & M(content__contains=reference) + mock_respx_responses(response_mapping) + + model = OpenAIModel() + + result = llm_classify( + dataframe=dataframe, + template=RAG_RELEVANCY_PROMPT_TEMPLATE, + model=model, + rails=["relevant", "unrelated"], + verbose=True, + ) + + expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE] + assert result.iloc[:, 0].tolist() == expected_labels + assert_frame_equal( + result[["label"]], + pd.DataFrame( + data={"label": expected_labels}, + ), + ) + + +@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") +def test_llm_classify_data_processor_dataframe( + openai_api_key: str, + classification_dataframe: DataFrame, + mock_respx_responses: mock_respx_responses, +): + def row_flag_processor(row_series: pd.Series) -> pd.Series: + if "C++" in row_series["reference"] or "Python" in row_series["reference"]: + return pd.Series( + { + "input": row_series["input"], + "reference": row_series["reference"] + " - FLAGGED", + } + ) + else: + return row_series + + dataframe = classification_dataframe + expected_dataframe = pd.DataFrame( + [ + { + "input": "What is Python?", + "reference": "Python is a programming language." + " - FLAGGED", + }, + { + "input": "What is Python?", + "reference": "Ruby is a programming language.", + }, + {"input": "What is C++?", "reference": "C++ is a programming language." + " - FLAGGED"}, + {"input": "What is C++?", "reference": "unrelated"}, + ], + ) + responses = ["relevant", "unrelated", "\nrelevant", "unparsable"] + processed_keys = list(zip(expected_dataframe["input"], expected_dataframe["reference"])) + processed_mapping = {key: response for key, response in zip(processed_keys, responses)} + + mock_respx_responses(processed_mapping) + + model = OpenAIModel() + + result = llm_classify( + dataframe=dataframe, + template=RAG_RELEVANCY_PROMPT_TEMPLATE, + model=model, + rails=["relevant", "unrelated"], + data_processor=row_flag_processor, + verbose=True, + include_prompt=True, + ) + + for original_row, processed_prompt in zip(dataframe.itertuples(), result["prompt"].to_list()): + inp = original_row.input + ref = original_row.reference + if "C++" in ref or "Python" in ref: + assert ( + RAG_RELEVANCY_PROMPT_BASE_TEMPLATE.format( + input=inp, + reference=ref + " - FLAGGED", + ) + == processed_prompt + ) + else: + assert ( + RAG_RELEVANCY_PROMPT_BASE_TEMPLATE.format( + input=inp, + reference=ref, + ) + == processed_prompt + ) + + expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE] + assert result.iloc[:, 0].tolist() == expected_labels + assert_frame_equal( + result[["label"]], + pd.DataFrame( + data={"label": expected_labels}, + ), + ) + + +@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") +def test_llm_classify_data_processor_list_of_tuples( + openai_api_key: str, + classification_dataframe: DataFrame, + mock_respx_responses: mock_respx_responses, +): + def tuple_flag_processor(row_tuple: tuple) -> tuple: + if "C++" in row_tuple[1] or "Python" in row_tuple[1]: + return row_tuple[0], row_tuple[1] + " - FLAGGED" + else: + return row_tuple + + list_of_tuples = [ + ("What is Python?", "Python is a programming language."), + ("What is Python?", "Ruby is a programming language."), + ("What is C++?", "C++ is a programming language."), + ("What is C++?", "unrelated"), + ] + processed_list_of_tuples = [ + ("What is Python?", "Python is a programming language." + " - FLAGGED"), + ("What is Python?", "Ruby is a programming language."), + ("What is C++?", "C++ is a programming language." + " - FLAGGED"), + ("What is C++?", "unrelated"), + ] + + responses = ["unparsable", "unparsable", "unparsable", "unparsable"] + processed_keys = processed_list_of_tuples + processed_mapping = {key: response for key, response in zip(processed_keys, responses)} + + mock_respx_responses(processed_mapping) + + model = OpenAIModel() + + result = llm_classify( + dataframe=list_of_tuples, + template=RAG_RELEVANCY_PROMPT_TEMPLATE, + model=model, + rails=["relevant", "unrelated"], + data_processor=tuple_flag_processor, + verbose=True, + include_prompt=True, + ) + + for original_tuple, processed_prompt in zip(list_of_tuples, result["prompt"].to_list()): + inp = original_tuple[0] + ref = original_tuple[1] + if "C++" in ref or "Python" in ref: + assert ( + RAG_RELEVANCY_PROMPT_BASE_TEMPLATE.format( + input=inp, + reference=ref + " - FLAGGED", + ) + == processed_prompt + ) + else: + assert ( + RAG_RELEVANCY_PROMPT_BASE_TEMPLATE.format( + input=inp, + reference=ref, + ) + == processed_prompt + ) + + expected_labels = [NOT_PARSABLE, NOT_PARSABLE, NOT_PARSABLE, NOT_PARSABLE] + assert result.iloc[:, 0].tolist() == expected_labels + assert_frame_equal( + result[["label"]], + pd.DataFrame( + data={"label": expected_labels}, + ), + ) + + +@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") +def test_llm_classify_data_processor_list_of_strings( + openai_api_key: str, + classification_dataframe: DataFrame, + respx_mock: respx.mock, +): + list_of_str = [ + "Python is a programming language.", + "Your opinion is irrelevant and you should leave", + "C++ is a programming language", + "", + ] + processed_list_of_str = [ + "Python is a programming language.", + "Your opinion is irrelevant and you should leave" + " - FLAGGED", + "C++ is a programming language", + "", + ] + responses = ["non-toxic", "toxic", "\nnon-toxic", "unparsable"] + processed_keys = processed_list_of_str + processed_mapping = {key: response for key, response in zip(processed_keys, responses)} + + def string_flag_processor(value: str) -> str: + if "irrelevant" in value: + return value + " - FLAGGED" + return value + + for query, response in processed_mapping.items(): + matcher = M(content__contains=query) payload = { "choices": [ { @@ -126,7 +349,124 @@ def test_llm_classify( model = OpenAIModel() result = llm_classify( - dataframe=dataframe, + dataframe=list_of_str, + template=TOXICITY_PROMPT_TEMPLATE, + model=model, + rails=["toxic", "non-toxic"], + data_processor=string_flag_processor, + include_prompt=True, + verbose=True, + ) + + for original_str, processed_prompt in zip(list_of_str, result["prompt"].to_list()): + if "irrelevant" in original_str: + assert ( + TOXICITY_PROMPT_TEMPLATE_BASE_TEMPLATE.format( + input=original_str + " - FLAGGED", + ) + == processed_prompt + ) + else: + assert ( + TOXICITY_PROMPT_TEMPLATE_BASE_TEMPLATE.format( + input=original_str, + ) + == processed_prompt + ) + + expected_labels = ["non-toxic", "toxic", "non-toxic", NOT_PARSABLE] + assert result.iloc[:, 0].tolist() == expected_labels + assert_frame_equal( + result[["label"]], + pd.DataFrame( + data={"label": expected_labels}, + ), + ) + + +@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") +def test_llm_classify_data_and_no_dataframe_args( + classification_dataframe: DataFrame, + openai_api_key: str, + mock_respx_responses: mock_respx_responses, +): + dataframe = classification_dataframe + keys = list(zip(dataframe["input"], dataframe["reference"])) + responses = ["relevant", "unrelated", "\nrelevant ", "unparsable"] + response_mapping = {key: response for key, response in zip(keys, responses)} + + mock_respx_responses(response_mapping) + + model = OpenAIModel() + + result = llm_classify( + data=dataframe, + template=RAG_RELEVANCY_PROMPT_TEMPLATE, + model=model, + rails=["relevant", "unrelated"], + verbose=True, + ) + + expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE] + assert result.iloc[:, 0].tolist() == expected_labels + assert_frame_equal( + result[["label"]], + pd.DataFrame( + data={"label": expected_labels}, + ), + ) + + +@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") +def test_llm_classify_positional_args( + classification_dataframe: DataFrame, + openai_api_key: str, + mock_respx_responses: mock_respx_responses, +): + dataframe = classification_dataframe + keys = list(zip(dataframe["input"], dataframe["reference"])) + responses = ["relevant", "unrelated", "\nrelevant ", "unparsable"] + response_mapping = {key: response for key, response in zip(keys, responses)} + + mock_respx_responses(response_mapping) + + model = OpenAIModel() + + result = llm_classify( + dataframe, + model, + RAG_RELEVANCY_PROMPT_TEMPLATE, + ["relevant", "unrelated"], + verbose=True, + ) + + expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE] + assert result.iloc[:, 0].tolist() == expected_labels + assert_frame_equal( + result[["label"]], + pd.DataFrame( + data={"label": expected_labels}, + ), + ) + + +@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") +def test_llm_classify_data_positional_rest_keyword_args( + classification_dataframe: DataFrame, + openai_api_key: str, + mock_respx_responses: mock_respx_responses, +): + dataframe = classification_dataframe + keys = list(zip(dataframe["input"], dataframe["reference"])) + responses = ["relevant", "unrelated", "\nrelevant ", "unparsable"] + response_mapping = {key: response for key, response in zip(keys, responses)} + + mock_respx_responses(response_mapping) + + model = OpenAIModel() + + result = llm_classify( + classification_dataframe, template=RAG_RELEVANCY_PROMPT_TEMPLATE, model=model, rails=["relevant", "unrelated"], @@ -143,6 +483,101 @@ def test_llm_classify( ) +@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") +def test_llm_classify_positional_args_no_data( + classification_dataframe: DataFrame, + openai_api_key: str, +): + model = OpenAIModel() + + with pytest.raises( + TypeError, match=r"llm_classify\(\) missing 1 required positional argument: 'rails'" + ): + llm_classify( + model, + RAG_RELEVANCY_PROMPT_TEMPLATE, + ["relevant", "unrelated"], + verbose=True, + ) + + +@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") +def test_llm_classify_both_data_and_dataframe_args( + openai_api_key: str, + classification_dataframe: DataFrame, + mock_respx_responses: mock_respx_responses, +): + dataframe = classification_dataframe + alt_dataframe = pd.DataFrame( + [ + { + "input": "What is Go?", + "reference": "Go is a programming language.", + }, + { + "input": "What is Go?", + "reference": "C# is a programming language.", + }, + {"input": "What is Julia?", "reference": "Julia is a programming language."}, + {"input": "What is Julia?", "reference": "unrelated"}, + ], + ) + keys = list(zip(dataframe["input"], dataframe["reference"])) + responses = ["relevant", "unrelated", "\nrelevant ", "unparsable"] + response_mapping = {key: response for key, response in zip(keys, responses)} + + mock_respx_responses(response_mapping) + + model = OpenAIModel() + + result = llm_classify( + dataframe=dataframe, + data=alt_dataframe, + template=RAG_RELEVANCY_PROMPT_TEMPLATE, + model=model, + rails=["relevant", "unrelated"], + include_prompt=True, + ) + + for original_row, processed_prompt in zip(dataframe.itertuples(), result["prompt"].to_list()): + inp = original_row.input + ref = original_row.reference + assert ( + RAG_RELEVANCY_PROMPT_BASE_TEMPLATE.format( + input=inp, + reference=ref, + ) + == processed_prompt + ) + + expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE] + assert result.iloc[:, 0].tolist() == expected_labels + assert_frame_equal( + result[["label"]], + pd.DataFrame( + data={"label": expected_labels}, + ), + ) + + +@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") +def test_llm_classify_no_data_and_no_dataframe_args( + classification_dataframe: DataFrame, + openai_api_key: str, +): + model = OpenAIModel() + + with pytest.raises( + TypeError, match=r"llm_classify\(\) missing 1 required positional argument: 'data'" + ): + llm_classify( + template=RAG_RELEVANCY_PROMPT_TEMPLATE, + model=model, + rails=["relevant", "unrelated"], + verbose=True, + ) + + @pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions") def test_llm_classify_with_included_prompt_and_response( openai_api_key: str, diff --git a/packages/phoenix-evals/tests/phoenix/evals/test_utils.py b/packages/phoenix-evals/tests/phoenix/evals/test_utils.py index c0bbdd093a..d80ec1ebec 100644 --- a/packages/phoenix-evals/tests/phoenix/evals/test_utils.py +++ b/packages/phoenix-evals/tests/phoenix/evals/test_utils.py @@ -1,4 +1,18 @@ -from phoenix.evals.utils import NOT_PARSABLE, snap_to_rail +import base64 +import io +import wave + +import lameenc + +from phoenix.evals.utils import NOT_PARSABLE, get_audio_format_from_base64, snap_to_rail + +PCM_ENC_STR = "cvhy+AT7BPsP/w//VgJWAhoBGgHyAPIAOgE6AdAA0ACA/4D/Nvs2+735vfkC+AL4EfYR9izy" +"LPLZ9dn1AvoC+vn3+fdy+XL5K/cr" +PCM_BYTES = base64.b64decode(PCM_ENC_STR) +SAMPLE_RATE = 44100 +CHANNELS = 1 +SAMPLE_WIDTH = 2 +BITRATE = 128 def test_snap_to_rail(): @@ -16,3 +30,94 @@ def test_snap_to_rail(): assert snap_to_rail("a", ["a", "b", "c"]) == "a" assert snap_to_rail(" abc", ["a", "ab", "abc"]) == "abc" assert snap_to_rail("abc", ["abc", "a", "ab"]) == "abc" + + +def test_get_audio_format_from_base64_wav(): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(CHANNELS) + wav_file.setsampwidth(SAMPLE_WIDTH) + wav_file.setframerate(SAMPLE_RATE) + wav_file.writeframes(PCM_BYTES) + + wav_bytes = buffer.getvalue() + wav_base64 = base64.b64encode(wav_bytes).decode("utf-8") + + assert get_audio_format_from_base64(wav_base64) == "wav" + + +def test_get_audio_format_from_base64_mp3(): + encoder = lameenc.Encoder() + encoder.set_bit_rate(BITRATE) + encoder.set_in_sample_rate(SAMPLE_RATE) + encoder.set_channels(CHANNELS) + encoder.set_quality(2) + + mp3_bytes = encoder.encode(PCM_BYTES) + mp3_bytes += encoder.flush() + + mp3_base64 = base64.b64encode(mp3_bytes).decode("utf-8") + + assert get_audio_format_from_base64(mp3_base64) == "mp3" + + +def test_get_audio_format_from_base64_ogg(): + # ogg sample file source: https://commons.wikimedia.org/wiki/File:Example.ogg + audio_base64 = ( + "T2dnUwACAAAAAAAAAABdwLNHAAAAANYQEycBHgF2b3JiaXMAAAAAAkSsAAAAAAAAA3E" + "CAAAAAAC4AU9nZ1MAAAAAAAAAAAAAXcCzRwEAAAAHafUAEjb//////////////////" + "///kQN2b3JiaXMNAAAATGF2ZjU5LjI3LjEwMAEAAAAV" + ) + + assert get_audio_format_from_base64(audio_base64) == "ogg" + + +def test_get_audio_format_from_base64_flac(): + # flac sample file source: https://helpguide.sony.net/high-res/sample1/v1/en/index.html + audio_base64 = ( + "ZkxhQwAAACIEgASAAAZyABfyF3ADcAA6aYDl0QDGP1GIkAxmtqagjOLrBAAAlSYAAABy" + "ZWZlcmVuY2UgbGliRkxBQyAxLjIuMSB3aW42NCAyMDA4MDcwOQUAAAAPAAAAQUxCVU09" + "QmVlIE1vdmVkDwAAAFRJVExFPUJlZSBNb3ZlZBoAAABBTEJVTUFSVElTVD1CbHVlIE1v" + ) + + assert get_audio_format_from_base64(audio_base64) == "flac" + + +def test_get_audio_format_from_base64_aac(): + # aac sample file source: https://filesamples.com/formats/aac#google_vignette + audio_base64 = ( + "//FQgAP//N4EAExhdmM1Ny4xMDcuMTAwAEIgCMEYOP/xUIBxH/whKwwBNFoeYDeQgCQQEQQCg" + "QCw7jBLFAWFQ2EiTzU874zjjPP3367qpk5b6us556pzg27ZLh4Lmq55jZ41HbEMSU42E7JU4z" + "trvCOW27IXMf6jZv3vuzszY90MIQk5pqA5B6vbfI/uH1u7xaUtkmEGivxvt+Yvvk+1SduFkxN" + "10CeFwhOgWT0E62PJ4bBE7dcnZvk8FcJzFEwmJwIMrGqj2/5d33WCxwyU/ZGzoTNFgZUDgAe6P" + "uHJEjzHYbGiWDWRxjg8Z4nmPFy0E2cadkbE9Y5sVW4+N/8li2+Czwv+3xEzD1H8+TMXiWOc7i7" + "K9g8b3t3R6nRYN+a5+P+6cA0f5J3J9LzlknK46nCZ434fojbIfoi8LnmNj4fd49Xt44JdWXd/" + ) + + assert get_audio_format_from_base64(audio_base64) == "aac" + + +def test_get_audio_format_from_base64_m4a(): + # m4a sample file source: https://filesamples.com/formats/m4a + audio_base64 = ( + "AAAAGGZ0eXBNNEEgAAACAGlzb21pc28yAAAACGZyZWUAGlhHbWRhdN4EAExhdmM1Ny4xMDcuMT" + "AwAEIgCMEYOCErDAE0Wh5gN5CAJBARBAKBALDuMEsUBYVDYSJPNTzvjOOM8/ffruqmTlvq6znnq" + "nODbtkuHguarnmNnjUdsQxJTjYTslTjO2u8I5bbshcx/qNm/e+7OzNj3QwhCTmmoDkHq9t8j+4f" + "W7vFpS2SYQaK/G+35i++T7VJ24WTE3XQJ4XCE6BZPQTrY8nhsETt1ydm+TwVwnMUTCYnAgysaqP" + "b/l3fdYLHDJT9kbOhM0WBlQOAB7o+4ckSPMdhsaJYNZHGODxnieY8XLQTZxp2RsT1jmxVbj43/y" + "WLb4LPC/7fETMPUfz5MxeJY5zuLsr2Dxve3dHqdFg35rn4/7pwDR/kncn0vOWScrjqcJnjfh+iN" + "Za6traBwazT+25On7HByDLP3vgn6H3Pc3Wn0vW3pHT/FX9DCaS/uaZ5ruC5f+em4T9fvLjR+Xx/" + ) + + assert get_audio_format_from_base64(audio_base64) == "m4a" + + +def test_get_audio_format_from_base64_unsupported(): + try: + assert get_audio_format_from_base64(PCM_ENC_STR) is None + except ValueError as e: + assert ( + str(e) + == "Unsupported audio format. Supported formats are: mp3, wav, ogg, flac, m4a, aac" + )