Skip to content

Commit

Permalink
feat: Audio evals & data processor for llm_classify() (#5616)
Browse files Browse the repository at this point in the history
* change 'rails' to 'expected_eval_labels'

* wip

* wip

* revert 'classify.py'

* include MultimodalPrompt

* further changes to audio.py

* getting smth now

* moving audio_classify() to classify.py

* remove print

* move data fetching to where we build the messages

* scrapped and redone - using llm_classify within audio_classify

* rename to proper

* cleanup

* merge llm_classify and audio_classify

* change back to llm_classify

* redo and test file format inference

* make data_processor apply to each data element, add a TEXT_DATA prompt part type, since we're adding a data processor, and remove some test cases to implement

* ruff + clean

* added a test... i have a feeling it needs work. dependency stuff still needs to be sorted out

* ruff cleanup

* pyright cleanup

* allow for users to pass through lists of strings and tuples to llm_classify

* unncessary print

* add tests for data_processor, allow for multiple types to be passed thru

* revise typing - eliminate use of Sequence

* typing enhancements

* formatting is ruff

* more mypy

* dependencies

* enhance tests + address comments

* Add deprecation message when using `dataframe` arg as a kwarg

* stash

* accidentally modified smth i shouldnt have

* ruff

* revert a test name

* revise tests + delete comments

* Tweak series normalization behavior

* dustin's change + tweak temp var logic

* windows tests failing bc. uv-loop not windows compliant

* wrong package

* made a mistake making template vars a set - order is not preserved

* emotion template

* issue w/ import

* trying revert

* update template

* addressing dustin's comments

* formatting is ruff

* update emotion template

* ruff

* Update default audio template

* Improve error message

* Add errors when using unsupported audio format

---------

Co-authored-by: Dustin Ngo <dustin@arize.com>
Co-authored-by: sallyannarize <sdelucia@arize.com>
  • Loading branch information
3 people authored Jan 16, 2025
1 parent 036d2c9 commit 0eda8ce
Show file tree
Hide file tree
Showing 10 changed files with 876 additions and 27 deletions.
1 change: 1 addition & 0 deletions packages/phoenix-evals/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ test = [
"nest_asyncio",
"pandas-stubs<=2.0.2.230605",
"types-tqdm",
"lameenc"
]

[project.urls]
Expand Down
118 changes: 104 additions & 14 deletions packages/phoenix-evals/src/phoenix/evals/classify.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

import inspect
import logging
import warnings
from collections import defaultdict
from enum import Enum
from functools import wraps
from itertools import product
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Iterable,
Expand All @@ -14,6 +18,7 @@
NamedTuple,
Optional,
Tuple,
TypeVar,
Union,
)

Expand Down Expand Up @@ -63,11 +68,37 @@ class ClassificationStatus(Enum):
MISSING_INPUT = "MISSING INPUT"


PROCESSOR_TYPE = TypeVar("PROCESSOR_TYPE")


def deprecate_dataframe_arg(func: Callable[..., Any]) -> Callable[..., Any]:
# Remove this once the `dataframe` arg in `llm_classify` is no longer supported

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
signature = inspect.signature(func)

if "dataframe" in kwargs:
warnings.warn(
"`dataframe` argument is deprecated; use `data` instead",
DeprecationWarning,
stacklevel=2,
)
kwargs["data"] = kwargs.pop("dataframe")
bound_args = signature.bind_partial(*args, **kwargs)
bound_args.apply_defaults()
return func(*bound_args.args, **bound_args.kwargs)

return wrapper


@deprecate_dataframe_arg
def llm_classify(
dataframe: pd.DataFrame,
data: Union[pd.DataFrame, List[Any]],
model: BaseModel,
template: Union[ClassificationTemplate, PromptTemplate, str],
rails: List[str],
data_processor: Optional[Callable[[PROCESSOR_TYPE], PROCESSOR_TYPE]] = None,
system_instruction: Optional[str] = None,
verbose: bool = False,
use_function_calling_if_available: bool = True,
Expand All @@ -88,20 +119,28 @@ def llm_classify(
`provide_explanation=True`.
Args:
dataframe (pandas.DataFrame): A pandas dataframe in which each row represents
a record to be classified. All template variable names must appear as column
names in the dataframe (extra columns unrelated to the template are permitted).
data (Union[pd.DataFrame, List[Any]): A collection of data which
can contain template variables and other information necessary to generate evaluations.
If a passed a DataFrame, there must be column names that match the template variables.
If passed a list, the elements of the list will be mapped to the template variables
in the order that the template variables are defined.
model (BaseEvalModel): An LLM model class.
template (Union[ClassificationTemplate, PromptTemplate, str]): The prompt template
as either an instance of PromptTemplate, ClassificationTemplate or a string.
If a string, the variable names should be surrounded by curly braces so that
a call to `.format` can be made to substitute variable values.
model (BaseEvalModel): An LLM model class.
rails (List[str]): A list of strings representing the possible output classes
of the model's predictions.
data_processor (Optional[Callable[[T], T]]): An optional callable that is used to process
the input data before it is mapped to the template variables. This callable is passed
a single element of the input data and can return either a pandas.Series with indices
corresponding to the template variables or an iterable of values that will be mapped
to the template variables in the order that the template variables are defined.
system_instruction (Optional[str], optional): An optional system message.
verbose (bool, optional): If True, prints detailed info to stdout such as
Expand Down Expand Up @@ -171,8 +210,8 @@ def llm_classify(

prompt_options = PromptOptions(provide_explanation=provide_explanation)

labels: Iterable[Optional[str]] = [None] * len(dataframe)
explanations: Iterable[Optional[str]] = [None] * len(dataframe)
labels: Iterable[Optional[str]] = [None] * len(data)
explanations: Iterable[Optional[str]] = [None] * len(data)

printif(verbose, f"Using prompt:\n\n{eval_template.prompt(prompt_options)}")
if generation_info := model.verbose_generation_info():
Expand Down Expand Up @@ -211,18 +250,59 @@ def _process_response(response: str) -> Tuple[str, Optional[str]]:
unrailed_label, explanation = parse_openai_function_call(response)
return snap_to_rail(unrailed_label, rails, verbose=verbose), explanation

async def _run_llm_classification_async(input_data: pd.Series[Any]) -> ParsedLLMResponse:
def _normalize_to_series(
data: PROCESSOR_TYPE,
) -> pd.Series[Any]:
if isinstance(data, pd.Series):
return data

variable_count = len(eval_template.variables)
if variable_count == 1:
return pd.Series({eval_template.variables[0]: data})
elif variable_count > 1:
if isinstance(data, str):
raise ValueError("The data cannot be mapped to the template variables")
elif isinstance(data, Iterable):
return pd.Series(
{
template_var: input_val
for template_var, input_val in zip(eval_template.variables, data)
}
)
return pd.Series()

async def _run_llm_classification_async(
input_data: PROCESSOR_TYPE,
) -> ParsedLLMResponse:
with set_verbosity(model, verbose) as verbose_model:
prompt = _map_template(input_data)
if data_processor:
maybe_awaitable_data = data_processor(input_data)
if inspect.isawaitable(maybe_awaitable_data):
processed_data = await maybe_awaitable_data
else:
processed_data = maybe_awaitable_data
else:
processed_data = input_data

prompt = _map_template(_normalize_to_series(processed_data))
response = await verbose_model._async_generate(
prompt, instruction=system_instruction, **model_kwargs
)
inference, explanation = _process_response(response)
return inference, explanation, response, str(prompt)

def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMResponse:
def _run_llm_classification_sync(
input_data: PROCESSOR_TYPE,
) -> ParsedLLMResponse:
with set_verbosity(model, verbose) as verbose_model:
prompt = _map_template(input_data)
if data_processor:
processed_data = data_processor(input_data)
if inspect.isawaitable(processed_data):
raise ValueError("Cannot run the data processor asynchronously.")
else:
processed_data = input_data

prompt = _map_template(_normalize_to_series(processed_data))
response = verbose_model._generate(
prompt, instruction=system_instruction, **model_kwargs
)
Expand All @@ -242,7 +322,17 @@ def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMRespons
fallback_return_value=fallback_return_value,
)

results, execution_details = executor.run([row_tuple[1] for row_tuple in dataframe.iterrows()])
list_of_inputs: Union[Tuple[Any], List[Any]]
if isinstance(data, pd.DataFrame):
list_of_inputs = [row_tuple[1] for row_tuple in data.iterrows()]
dataframe_index = data.index
elif isinstance(data, (list, tuple)):
list_of_inputs = data
dataframe_index = pd.Index(range(len(data)))
else:
raise ValueError("Invalid 'data' input type.")

results, execution_details = executor.run(list_of_inputs)
labels, explanations, responses, prompts = zip(*results)
all_exceptions = [details.exceptions for details in execution_details]
execution_statuses = [details.status for details in execution_details]
Expand All @@ -264,7 +354,7 @@ def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMRespons
**({"execution_status": [status.value for status in classification_statuses]}),
**({"execution_seconds": [runtime for runtime in execution_times]}),
},
index=dataframe.index,
index=dataframe_index,
)


Expand Down
133 changes: 133 additions & 0 deletions packages/phoenix-evals/src/phoenix/evals/default_audio_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from phoenix.evals.templates import (
ClassificationTemplate,
PromptPartContentType,
PromptPartTemplate,
)

EMOTION_AUDIO_BASE_TEMPLATE_PT_1 = """
You are an AI system designed to classify emotions in audio files.
### TASK:
Analyze the provided audio file and classify the primary emotion based on these characteristics:
- Tone: General tone of the speaker (e.g., cheerful, tense, calm).
- Pitch: Level and variability of the pitch (e.g., high, low, monotone).
- Pace: Speed of speech (e.g., fast, slow, steady).
- Volume: Loudness of the speech (e.g., loud, soft, moderate).
- Intensity: Emotional strength or expression (e.g., subdued, sharp, exaggerated).
The classified emotion must be one of the following:
['anger', 'happiness', 'excitement', 'sadness', 'neutral', 'frustration', 'fear', 'surprise',
'disgust', 'other']
IMPORTANT: Choose the most dominant emotion expressed in the audio. Neutral should only be used when
no other emotion is clearly present, do your best to avoid this label.
************
Here is the audio to classify:
"""

EMOTION_AUDIO_BASE_TEMPLATE_PT_2 = """{audio}"""

EMOTION_AUDIO_BASE_TEMPLATE_PT_3 = """
RESPONSE FORMAT:
Provide a single word from the list above representing the detected emotion.
************
EXAMPLE RESPONSE: excitement
************
Analyze the audio and respond in this format.
"""

EMOTION_AUDIO_EXPLANATION_TEMPLATE_PT_1 = """
You are an AI system designed to classify emotions in audio files.
### TASK:
First, explain in a step-by-step manner how the provided audio file based on these characteristics
and how they indicate the emotion of the speaker:
- Tone: General tone of the speaker (e.g., cheerful, tense, calm).
- Pitch: Level and variability of the pitch (e.g., high, low, monotone).
- Pace: Speed of speech (e.g., fast, slow, steady).
- Volume: Loudness of the speech (e.g., loud, soft, moderate).
- Intensity: Emotional strength or expression (e.g., subdued, sharp, exaggerated).
Then, classify the primary emotion. The classified emotion must be one of the following:
['anger', 'happiness', 'excitement', 'sadness', 'neutral', 'frustration', 'fear', 'surprise',
'disgust', 'other']
IMPORTANT: Choose the most dominant emotion expressed in the audio. Neutral should only be used when
no other emotion is clearly present, do your best to avoid this label.
************
Here is the audio to classify:
"""

EMOTION_AUDIO_EXPLANATION_TEMPLATE_PT_3 = """
EXAMPLE RESPONSE FORMAT:
************
EXPLANATION: An explanation of your reasoning based on the tone, pitch, pace, volume, and intensity
of the audio.
LABEL: "excitement"
************
Analyze the audio and respond in the format shown above.
"""

EMOTION_AUDIO_RAILS = [
"anger",
"happiness",
"excitement",
"sadness",
"neutral",
"frustration",
"fear",
"surprise",
"disgust",
"other",
]

EMOTION_PROMPT_TEMPLATE = ClassificationTemplate(
rails=EMOTION_AUDIO_RAILS,
template=[
PromptPartTemplate(
content_type=PromptPartContentType.TEXT,
template=EMOTION_AUDIO_BASE_TEMPLATE_PT_1,
),
PromptPartTemplate(
content_type=PromptPartContentType.AUDIO,
template=EMOTION_AUDIO_BASE_TEMPLATE_PT_2,
),
PromptPartTemplate(
content_type=PromptPartContentType.TEXT,
template=EMOTION_AUDIO_BASE_TEMPLATE_PT_3,
),
],
explanation_template=[
PromptPartTemplate(
content_type=PromptPartContentType.TEXT,
template=EMOTION_AUDIO_EXPLANATION_TEMPLATE_PT_1,
),
PromptPartTemplate(
content_type=PromptPartContentType.AUDIO,
template=EMOTION_AUDIO_BASE_TEMPLATE_PT_2,
),
PromptPartTemplate(
content_type=PromptPartContentType.TEXT,
template=EMOTION_AUDIO_EXPLANATION_TEMPLATE_PT_3,
),
],
)
"""
A template for evaluating the emotion of an audio sample. It return
an emotion and provides a detailed explanation template
to assist users in articulating their judgment on code readability.
"""
2 changes: 2 additions & 0 deletions packages/phoenix-evals/src/phoenix/evals/default_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@
EXPLANATION:"""


REFERENCE_LINK_CORRECTNESS_PROMPT_BASE_TEMPLATE = """
You are given a conversation that contains questions by a CUSTOMER and you are
trying to determine if the documentation page shared by the ASSISTANT correctly
Expand Down Expand Up @@ -783,6 +784,7 @@
to assist users in articulating their judgment on code readability.
"""


REFERENCE_LINK_CORRECTNESS_PROMPT_TEMPLATE = ClassificationTemplate(
rails=list(REFERENCE_LINK_CORRECTNESS_PROMPT_RAILS_MAP.values()),
template=REFERENCE_LINK_CORRECTNESS_PROMPT_BASE_TEMPLATE,
Expand Down
4 changes: 4 additions & 0 deletions packages/phoenix-evals/src/phoenix/evals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class PhoenixContextLimitExceeded(PhoenixException):

class PhoenixTemplateMappingError(PhoenixException):
pass


class PhoenixUnsupportedAudioFormat(PhoenixException):
pass
Loading

0 comments on commit 0eda8ce

Please sign in to comment.